Source code for design_by_contract

import logging
from dataclasses import dataclass
from functools import partial, wraps
from inspect import get_annotations, getfullargspec, signature
from typing import (
    Annotated,
    Any,
    Callable,
    Optional,
    ParamSpec,
    TypeVar,
    Union,
    overload,
)

logger = logging.getLogger(__name__)


[docs]class ContractViolationError(Exception): """Raised when a contract is violated"""
[docs]class ContractLogicError(Exception): """Raised when there is a syntactical error"""
[docs]@dataclass class UnresolvedSymbol: """ Placeholder for unknown symbols in contracts. Overrides the equality operator to behave like an assignment. """ name: str value: Optional[Any] = None def __eq__(self, other: Any) -> Union["UnresolvedSymbol", bool]: # type: ignore[override] match other: case UnresolvedSymbol(name=name, value=None): if self.value is None: raise ContractViolationError(f"Symbols `{self.name}` and `{name}` undefined") other.value = self.value case UnresolvedSymbol(name=_, value=value) if self.value is None: self.value = value case UnresolvedSymbol(name=name, value=value) if value != self.value: raise ContractViolationError( f"Symbols `{self.name}` and `{name}` do not match: `{self.value}` != `{value}`" ) case self.value: return True case value if self.value is not None: raise ContractViolationError( f"Symbols `{self.name}` and `{other}` do not match: `{self.value}` != `{other}`" ) case value: self.value = value return self def __bool__(self) -> bool: return self.value is not None
P = ParamSpec("P") R = TypeVar("R") @overload def contract(func: Callable[P, R]) -> Callable[P, R]: ... @overload def contract(*, reserved: str = "x", evaluate: bool = True) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
[docs]def contract( func: Optional[Callable[P, R]] = None, *, reserved: str = "x", evaluate: bool = True ) -> Union[Callable[[Callable[P, R]], Callable[P, R]], Callable[P, R]]: """ A decorator for enabling design by contract using :class:`typing.Annotated`. Define contract conditions as lambdas together with their type annotation. The decorator is overloaded so you can call it eigher with `@contract` or `@contract(...)` with our without arguments. Note that positional keywords are not allowed (i.e., you need to use keyword arguments) Parameters ---------- reserved : str, optional This symbol gets always replaced by the current argument name, by default "x". This is a keyword only argument. evaluate : bool, optional If False, the contracts are not evaluated, by default True. This is a keyword only argument. """ def wrapper(func: Callable[P, R], *args: Any, **kw: Any) -> R: """The actual logic""" if not evaluate: return func(*args, **kw) annotations = get_annotations(func) return_annotation = annotations.pop("return", None) if reserved in annotations.keys(): raise ValueError(f"Argument cannot be the reserved identifier `{reserved}`") # Resolved function arguments passed to func injectables = dict(zip(annotations.keys(), args)) logger.debug("injectables: %s", injectables) def evaluate_annotations(annotations: dict[str, Any]) -> None: nonlocal injectables for arg_name, annotation in annotations.items(): # Filter for typing.Annotation objects with extra annotations if hasattr(annotation, "__metadata__"): for meta in annotation.__metadata__: # Only consider lambdas/callables if callable(meta): meta_args = getfullargspec(meta).args # Only if the original argument's name is among its argument names # TODO we shold remove that if arg_name in meta_args or reserved in meta_args: # the reserved identifier is a shortcut injectables[reserved] = injectables[arg_name] dependencies = set(injectables.keys()).intersection(meta_args) logger.debug( "contract for `%s`, resolved: `%s`", arg_name, {i: injectables[i] for i in dependencies}, ) # Look for arguments that cannot be injected if unresolved := set(meta_args) - set(injectables.keys()): symbols = {i: UnresolvedSymbol(i) for i in unresolved} logger.debug( "contract for `%s`, unresolved: `%s`, %s", arg_name, unresolved, symbols ) if not meta(*[(symbols | injectables)[i] for i in meta_args]): raise ContractViolationError(f"Contract violated for argument: `{arg_name}`") if any([i.value is None for i in symbols.values()]): raise ContractLogicError( f"Not all symbols were resolved `{symbols}`", ) injectables |= {k: v.value for k, v in symbols.items()} else: # Evaluate contract by injecting values into the lambda if not meta(*(_args := [injectables[i] for i in meta_args])): raise ContractViolationError(f"Contract violated for argument: `{arg_name}`") logger.debug("Contract fulfilled for argument `%s`", arg_name) evaluate_annotations(annotations) result = func(*args, **kw) if return_annotation is not None: injectables["return"] = result logger.debug(injectables) evaluate_annotations({"return": return_annotation}) return result def decorator(func: Callable[P, R]) -> Callable[P, R]: return wraps(func)(partial(wrapper, func)) if func is not None: if not callable(func): raise TypeError("Not a callable. Did you use a non-keyword argument?") return wraps(func)(partial(wrapper, func)) return decorator