Source code for pythonwrench.typing.checks

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import inspect
import logging
import sys
from numbers import Integral
from types import FunctionType, MethodType
from typing import (
    Any,
    Callable,
    Dict,
    Generator,
    Iterable,
    Literal,
    Mapping,
    Sequence,
    Tuple,
    Type,
    TypedDict,
    Union,
)

import typing_extensions
from typing_extensions import (
    NotRequired,
    ParamSpec,
    Required,
    TypeGuard,
    TypeIs,
    TypeVar,
    get_args,
    get_origin,
)

from pythonwrench.typing.classes import (
    BuiltinCollection,
    BuiltinNumber,
    BuiltinScalar,
    DataclassInstance,
    NamedTupleInstance,
    NoneType,
)

T = TypeVar("T")
P = ParamSpec("P")

logger = logging.getLogger(__name__)


[docs] def check_args_types(fn: Callable[P, T]) -> Callable[P, T]: """Decorator to check argument types before call to a function. Example ------- >>> import pythonwrench as pw >>> @pw.check_args_types >>> def f(a: int, b: str) -> str: >>> return a * b >>> f(1, "a") # pass check >>> f(1, 2) # raises TypeError from decorator """ if not isinstance(fn, (FunctionType, MethodType)): msg = f"Invalid argument type {type(fn)}. (expected function or method)" raise TypeError(msg) parameters = inspect.signature(fn).parameters annotations = {k: v.annotation for k, v in parameters.items()} argnames = list(annotations.keys()) def _wrapper(*args: P.args, **kwargs: P.kwargs) -> T: num_positional = len(args) given_kwargs = dict(zip(argnames[:num_positional], args)) given_kwargs.update(kwargs) msgs = [] for i, (k, v) in enumerate(given_kwargs.items()): if isinstance_generic(v, annotations[k]): continue if i < num_positional: msg = f" - invalid argument n°{i + 1} with value {v!r}; expected an instance of {annotations[k]}." else: msg = f" - invalid argument '{k}' with value {v!r}; expected an instance of {annotations[k]}." msgs.append(msg) if len(msgs) > 0: msgs = [ f"{fn.__name__}() has {len(msgs)}/{len(given_kwargs)} invalid argument(s):", ] + msgs msg = "\n".join(msgs) raise TypeError(msg) result = fn(*args, **kwargs) return result return _wrapper
[docs] def isinstance_generic( obj: Any, class_or_tuple: Union[Type[T], None, Tuple[Type[T], ...], Any], *, check_only_first: bool = False, ) -> TypeIs[T]: """Improved isinstance(...) function that supports parametrized Union, TypedDict, Literal, Mapping or Iterable. Args: obj: Object to check. class_or_tuple: Type to check. Can be a parametrized type from `typing`. check_only_first: If True, check only if first element when checking for Iterable[type]. defaults to False. Example 1 --------- >>> isinstance_generic({"a": 1, "b": 2}, dict) ... True >>> isinstance_generic({"a": 1, "b": 2}, dict[str, int]) ... True >>> isinstance_generic({"a": 1, "b": 2}, dict[str, str]) ... False >>> from typing import Literal >>> isinstance_generic({"a": 1, "b": 2}, dict[str, Literal[1, 2]]) ... True """ if class_or_tuple is Any or class_or_tuple is typing_extensions.Any: return True if class_or_tuple is None: return obj is None if isinstance(class_or_tuple, tuple): return any( isinstance_generic(obj, target_type_i) for target_type_i in class_or_tuple ) if is_typed_dict(class_or_tuple): return _isinstance_generic_typed_dict(obj, class_or_tuple) origin = get_origin(class_or_tuple) if origin is None: return isinstance(obj, class_or_tuple) # type: ignore # Special case for empty tuple because get_args(Tuple[()]) returns () and not ((),) in python >= 3.11 # More info at https://github.com/python/cpython/issues/91137 if class_or_tuple == Tuple[()]: return obj == () args = get_args(class_or_tuple) if origin is Callable: if len(args) == 0: return callable(obj) else: # TODO: impl msg = "Function `isinstance_generic` currently does not support parametrized Callable." raise NotImplementedError(msg) if len(args) == 0: return isinstance_generic(obj, origin) if origin is Union: return any(isinstance_generic(obj, arg) for arg in args) if origin is Literal: return obj in args if isinstance(obj, Generator): msg = f"Invalid argument type {type(obj)}. (cannot check elements in generator)" raise TypeError(msg) if issubclass(origin, Generator): msg = f"Invalid argument type {origin}. (cannot check generator type)" raise TypeError(msg) if issubclass(origin, Mapping): assert len(args) == 2, f"{args=}" if not isinstance_generic(obj, origin): return False return all(isinstance_generic(k, args[0]) for k in obj.keys()) and all( isinstance_generic(v, args[1]) for v in obj.values() ) if issubclass(origin, Tuple): if not isinstance_generic(obj, origin): return False elif len(args) == 1 and args[0] == (): return len(obj) == 0 elif len(args) == 2 and args[1] is ...: if check_only_first: args = (args[0],) else: args = tuple([args[0]] * len(obj)) elif len(obj) != len(args): return False return all(isinstance_generic(xi, ti) for xi, ti in zip(obj, args)) if issubclass(origin, Iterable): if not isinstance_generic(obj, origin): return False if check_only_first: return isinstance_generic(next(iter(obj)), args[0]) else: return all(isinstance_generic(xi, args[0]) for xi in obj) msg = f"Unsupported type {class_or_tuple}. (expected unparametrized type or parametrized Union, TypedDict, Literal, Mapping or Iterable)" raise NotImplementedError(msg)
def _isinstance_generic_typed_dict(x: Any, target_type: type) -> bool: if not isinstance_generic(x, Dict[str, Any]): return False total: bool = target_type.__total__ annotations = target_type.__annotations__ required_annotations = {} optional_annotations = {} for k, v in annotations.items(): origin = get_origin(v) if origin is Required: required_annotations[k] = v elif origin is NotRequired: optional_annotations[k] = v elif total: required_annotations[k] = v else: optional_annotations[k] = v if not set(required_annotations.keys()).issubset(x.keys()): return False annotations_set = set(required_annotations.keys()) | set( optional_annotations.keys() ) if not annotations_set.issuperset(x.keys()): return False for k, v in required_annotations.items(): origin = get_origin(v) if origin is Required: v = get_args(v)[0] if not isinstance_generic(x[k], v): return False for k, v in optional_annotations.items(): if k not in x: continue origin = get_origin(v) if origin is NotRequired: v = get_args(v)[0] if not isinstance_generic(x[k], v): return False return True
[docs] def is_builtin_collection(x: Any, *, strict: bool = False) -> TypeIs[BuiltinCollection]: """Returns True if x is an instance of a builtin collection type (list, tuple, dict, set, frozenset). Args: x: Object to check. strict: If True, it will not consider custom subtypes of builtins as builtin collections. defaults to False. """ if strict and not is_builtin_obj(x): return False return isinstance(x, (list, tuple, dict, set, frozenset))
[docs] def is_builtin_number(x: Any, *, strict: bool = False) -> TypeIs[BuiltinNumber]: """Returns True if x is an instance of a builtin number type (int, float, bool, complex). Args: x: Object to check. strict: If True, it will not consider custom subtypes of builtins as builtin numbers. defaults to False. """ if strict and not is_builtin_obj(x): return False return isinstance(x, (int, float, bool, complex))
[docs] def is_builtin_obj(x: Any) -> bool: """Returns True if object is an instance of a builtin object. Note: If the object is an instance of a custom subtype of a builtin object, this function returns False. """ return x.__class__.__module__ == "builtins" and not isinstance(x, type)
[docs] def is_builtin_scalar(x: Any, *, strict: bool = False) -> TypeIs[BuiltinScalar]: """Returns True if x is an instance of a builtin scalar type (int, float, bool, complex, NoneType, str, bytes). Args: x: Object to check. strict: If True, it will not consider subtypes of builtins as builtin scalars. defaults to False. """ if strict and not is_builtin_obj(x): return False return isinstance(x, (int, float, bool, complex, NoneType, str, bytes))
[docs] def is_dataclass_instance(x: Any) -> TypeIs[DataclassInstance]: """Returns True if argument is a dataclass. Unlike function `dataclasses.is_dataclass`, this function returns False for a dataclass type. """ return not isinstance(x, type) and isinstance_generic(x, DataclassInstance)
[docs] def is_iterable_bool( x: Any, *, accept_generator: bool = True, ) -> TypeIs[Iterable[bool]]: if not accept_generator and isinstance(x, Generator): return False return isinstance_generic(x, Iterable[bool])
[docs] def is_iterable_bytes_or_list( x: Any, *, accept_generator: bool = True, ) -> TypeIs[Iterable[Union[bytes, list]]]: if not accept_generator and isinstance(x, Generator): return False return isinstance_generic(x, Iterable[Union[bytes, list]])
[docs] def is_iterable_float( x: Any, *, accept_generator: bool = True, ) -> TypeIs[Iterable[float]]: if not accept_generator and isinstance(x, Generator): return False return isinstance_generic(x, Iterable[float])
[docs] def is_iterable_int( x: Any, *, accept_bool: bool = True, accept_generator: bool = True, ) -> TypeIs[Iterable[int]]: if not accept_generator and isinstance(x, Generator): return False return isinstance_generic(x, Iterable[int]) and ( accept_bool or not isinstance_generic(x, Iterable[bool]) )
[docs] def is_iterable_integral( x: Any, *, accept_generator: bool = True, ) -> TypeIs[Iterable[Integral]]: if not accept_generator and isinstance(x, Generator): return False return isinstance_generic(x, Iterable[Integral])
[docs] def is_iterable_str( x: Any, *, accept_str: bool = True, accept_generator: bool = True, ) -> TypeGuard[Iterable[str]]: if isinstance(x, str): return accept_str if isinstance(x, Generator): return accept_generator and all(isinstance(xi, str) for xi in x) return isinstance_generic(x, Iterable[str])
[docs] def is_namedtuple_instance(x: Any) -> TypeIs[NamedTupleInstance]: """Returns True if argument is a NamedTuple.""" return not isinstance(x, type) and isinstance_generic(x, NamedTupleInstance)
[docs] def is_sequence_str( x: Any, *, accept_str: bool = True, ) -> TypeGuard[Sequence[str]]: return (accept_str and isinstance(x, str)) or ( not isinstance(x, str) and isinstance(x, Sequence) and all(isinstance(xi, str) for xi in x) )
[docs] def is_typed_dict(x: Any) -> TypeGuard[type]: if sys.version_info.major == 3 and sys.version_info.minor < 9: return x.__class__.__name__ == "_TypedDictMeta" else: return hasattr(x, "__orig_bases__") and TypedDict in x.__orig_bases__