Source code for eincheck.checks.data

import functools
import itertools
import sys
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Tuple, TypeVar

from eincheck.checks.shapes import check_shapes
from eincheck.parser.grammar import ShapeArg, create_shape_spec
from eincheck.parser.shape_spec import ShapeSpec
from eincheck.utils import get_object, parse_dot_name

_T = TypeVar("_T")


class DataWrapper(ABC):
    module_name: Optional[str] = None

    @classmethod
    def can_load(cls) -> bool:
        return cls.module_name is None or cls.module_name in sys.modules

    @abstractmethod
    def is_match(self, x: Any) -> bool:
        """Whether x is a data object of the right type."""

    @abstractmethod
    def wrap(self, cls: _T, shapes: Mapping[str, ShapeSpec]) -> _T:
        pass

    @staticmethod
    def check_fields(shapes: Mapping[str, ShapeSpec], got: Set[str]) -> None:
        extra_names = {parse_dot_name(n)[0] for n in shapes} - got
        if extra_names:
            raise ValueError("No field found: [" + " ".join(sorted(extra_names)) + "]")

    @staticmethod
    def get_shapes_func(
        shapes: Mapping[str, ShapeSpec]
    ) -> Callable[[Any], Dict[str, Tuple[Any, ShapeSpec]]]:
        def get_shapes(self: Any) -> Dict[str, Tuple[Any, ShapeSpec]]:
            return {k: (get_object(k, self), s) for k, s in shapes.items()}

        return get_shapes


class NamedTupleWrapper(DataWrapper):
    def is_match(self, x: Any) -> bool:
        return issubclass(x, tuple) and hasattr(x, "_fields")

    def wrap(self, cls: _T, shapes: Mapping[str, ShapeSpec]) -> _T:
        self.check_fields(shapes, set(cls._fields))  # type: ignore[attr-defined]

        _new = cls.__new__

        cls._get_shapes = DataWrapper.get_shapes_func(  # type: ignore[attr-defined]
            shapes
        )

        @functools.wraps(_new)
        def new_new(*a: Any, **k: Any) -> Any:
            out = _new(*a, **k)
            check_shapes(**out._get_shapes())
            return out

        cls.__new__ = new_new  # type: ignore[method-assign]

        return cls


def _func_with_check(
    cls: Any, func: str, shapes: Mapping[str, ShapeSpec], append: bool
) -> None:
    old_f = getattr(cls, func)

    cls._get_shapes = DataWrapper.get_shapes_func(shapes)

    if append:

        def new_f(self: Any, *a: Any, **k: Any) -> Any:
            old_f(self, *a, **k)
            check_shapes(**self._get_shapes())

    else:

        def new_f(self: Any, *a: Any, **k: Any) -> Any:
            check_shapes(**self._get_shapes())
            old_f(self, *a, **k)

    new_f = functools.wraps(old_f)(new_f)
    setattr(cls, func, new_f)


class DataclassWrapper(DataWrapper):
    if sys.version_info[:2] < (3, 8):
        module_name = "dataclasses"

    def __init__(self) -> None:
        super().__init__()
        import dataclasses

        self.dataclasses = dataclasses

    def is_match(self, x: Any) -> bool:
        out = self.dataclasses.is_dataclass(x)
        assert isinstance(out, bool)
        return out

    def wrap(self, cls: _T, shapes: Mapping[str, ShapeSpec]) -> _T:
        self.check_fields(
            shapes,
            {f.name for f in self.dataclasses.fields(cls)},  # type: ignore[arg-type]
        )

        if hasattr(cls, "__post_init__"):
            _func_with_check(cls, "__post_init__", shapes, False)
        else:
            _func_with_check(cls, "__init__", shapes, True)
        return cls


class AttrsWrapper(DataWrapper):
    module_name = "attrs"

    def __init__(self) -> None:
        super().__init__()
        import attrs

        self.attrs = attrs

    def is_match(self, x: Any) -> bool:
        return self.attrs.has(x)

    def wrap(self, cls: _T, shapes: Mapping[str, ShapeSpec]) -> _T:
        self.check_fields(
            shapes, {a.name for a in self.attrs.fields(cls)}  # type: ignore[arg-type]
        )

        if hasattr(cls, "__attrs_post_init__"):
            _func_with_check(cls, "__attrs_post_init__", shapes, False)

        else:
            _func_with_check(cls, "__init__", shapes, True)

        return cls


_wrappers: List[DataWrapper] = []

_T_Data = TypeVar("_T_Data")


[docs]def check_data( shape_dict: Optional[Mapping[str, ShapeArg]] = None, /, **kwargs: ShapeArg ) -> Callable[[_T_Data], _T_Data]: """Check the shapes of fields of a data object. The currently supported data objects are NamedTuple, dataclasses, and attrs. :param shape_dict: shape specs for fields of the data object in a dictionary :param kwargs: shape specs for fields of the data object as keywords :return: a decorator for the data object class """ if shape_dict is None: shape_dict = {} shapes = { k: create_shape_spec(v) for k, v in itertools.chain(shape_dict.items(), kwargs.items()) } def wrapper(cls: _T_Data) -> _T_Data: for w in _wrappers: if w.is_match(cls): return w.wrap(cls, shapes) for w_cls in DataWrapper.__subclasses__(): if w_cls.can_load(): _wrappers.append(w_cls()) # type: ignore[abstract] if _wrappers[-1].is_match(cls): return _wrappers[-1].wrap(cls, shapes) raise TypeError(f"Unexpected data type {cls}") return wrapper