API
- eincheck.check_shapes(*args, **kwargs)[source]
Check the shapes of Tensors against ShapeArg specifications.
Examples:
>>> from numpy.random import randn >>> from eincheck import check_shapes >>> >>> check_shapes((randn(3, 4, 5), "... i j"), (randn(5, 6), "... j k")) {'i': 4, 'j': 5, 'k': 6} >>> check_shapes( ... x=(randn(8, 2, 7, 3), "*batch t 3"), ... y=(randn(8, 2, 1, 1, 3), "*batch ... 3"), ... batch=(8, 2), ... ) {'batch': (8, 2), 't': 7}
Pass pairs of (tensor, shape spec) as either args or kwargs. The only difference when using a kwarg is the display name in error messages (args are called
arg0,arg1, etc).Kwargs can also specify the variable values (e.g.
batch=(8, 2)).If all shape checks pass, returns a dictionary with all variable values.
- Parameters:
args (
Tuple[TypeVar(Tensor),Union[str,ShapeSpec,Sequence[Union[DimSpec,str,int,None]]]]) – Pairs of (tensor, shape spec)kwargs (
Union[int,Tuple[int,...],Tuple[TypeVar(Tensor),Union[str,ShapeSpec,Sequence[Union[DimSpec,str,int,None]]]]]) – Pairs of (tensor, shape spec)
- Raises:
ValueError – If the shapes are incorrect or cannot be verified
- Return type:
Dict[str,Union[int,Tuple[int,...]]]- Returns:
Values for all bound variables from the shape specs
- eincheck.check_func(shapes='', **kwargs)[source]
Check the input and output shapes of a function.
- Parameters:
shapes (
str) – string of input and output shape specskwargs (
Union[str,ShapeSpec,Sequence[Union[DimSpec,str,int,None]]]) – additional shape specs for function inputs
- Return type:
Callable[[TypeVar(_T_Callable, bound=Callable[...,Any])],TypeVar(_T_Callable, bound=Callable[...,Any])]- Returns:
a function decorator
The check_func decorator adds shape checks to the inputs and outputs of the decorated function.
The simplest way to specify shapes is as a comma separated list of input shape specs, an arrow ->, and a comma separated list of output shape specs.
>>> @check_func("i, i -> i")
... def foo(x, y):
... return x + y
...
>>> foo(randn(4), randn(4)).shape
(4,)
>>> foo(randn(4), randn(3))
Traceback (most recent call last):
...
ValueError: y dim 0: expected i=4 got 3
i=4
x: got (4,) expected [i]
y: got (3,) expected [i]
Input specs match function parameters in the order they’re declared. There need to be at least as many function parameters as input shape specs.
>>> @check_func("i -> i")
... def foo(x, y):
... return x + y
>>> foo(randn(4), randn(4)).shape
(4,)
>>> foo(randn(3, 4), randn(1))
Traceback (most recent call last):
...
ValueError: x: expected rank 1, got shape (3, 4)
x: got (3, 4) expected [i]
>>> foo(randn(3), randn(3, 3))
Traceback (most recent call last):
...
ValueError: output0: expected rank 1, got shape (3, 3)
i = 3
output0: got (3, 3) expected [i]
>>> @check_func("i, i, i -> i")
... def foo(x):
... return x
Traceback (most recent call last):
...
ValueError: Expected at least 3 input parameters, got 1
The shape spec for a variadic positional argument (e.g. *args) or variadic keyword argument (e.g. **kwargs) is compared against each Tensor matching that argument.
>>> @check_func("*x -> _ *x")
... def stack(*x):
... return np.stack(x, 0)
...
>>> stack(randn(3, 4), randn(3, 4)).shape
(2, 3, 4)
>>> stack(randn(3, 4), randn(3, 5), randn(3, 4))
Traceback (most recent call last):
...
ValueError: x_1 dims (0, 1): expected x=(3, 4) got (3, 5)
x=(3, 4)
x_0: got (3, 4) expected [*x]
x_1: got (3, 5) expected [*x]
x_2: got (3, 4) expected [*x]
The shapes of function inputs can also be specified with keyword arguments using the parameter name.
>>> @check_func("i -> i", y="i 2", z="i 3")
... def foo(x, y, *, z):
... return x + y[:, 0] + z[:, 2]
...
>>> foo(randn(3), randn(3, 2), z=randn(3, 3)).shape
(3,)
>>> foo(randn(3), randn(3, 3), z=randn(3, 3)).shape
Traceback (most recent call last):
...
ValueError: y dim 1: expected 2=2 got 3
i=3
x: got (3,) expected [i]
y: got (3, 3) expected [i 2]
z: got (3, 3) expected [i 3]
If you want to only use keywords for inputs you can omit the "->" in the function spec string.
For example, @check_func("x") is equivalent to @check_func("-> x").
If you specify an input shape with a keyword it should not also be included in the positional shape specs.
>>> @check_func("i", x="i")
... def foo(x):
... return x
>>> @check_func("i -> i", x="i")
... def bad(x):
... return x
Traceback (most recent call last):
...
ValueError: Spec for x specified in both args and kwargs.
Multiple output specs can be used if the function returns a tuple.
>>> @check_func("i j -> i, j")
... def split_sum(x):
... return x.sum(1), x.sum(0)
...
>>> [i.shape for i in split_sum(randn(3, 4))]
[(3,), (4,)]
Similar to the positional input arguments, there can be more outputs than are captured in the spec.
>>> @check_func("i j -> i")
... def split_sum(x):
... return x.sum(1), x.sum(0)
...
>>> @check_func("i j -> i, j")
... def bad(x):
... return x
...
>>> bad(randn(2, 3))
Traceback (most recent call last):
...
ValueError: Expected at least 2 outputs, got 1
check_func can also be used to decorate class methods.
The first argument to the function is self.
>>> class Foo:
... def __init__(self, x):
... self.x = x
...
... @check_func("_, i -> i")
... def foo(self, y):
... return self.x + y
...
>>> f = Foo(randn(4))
>>> f.foo(randn(4)).shape
(4,)
>>> f.foo(randn(1))
Traceback (most recent call last):
...
ValueError: output0 dim 0: expected i=1 got 4
i=1
output0: got (4,) expected [i]
- eincheck.check_func2(input_shapes, output_shapes='')[source]
Check the input and output shapes of a function.
This function is an alternative to
check_functhat works better with dictionaries. It takes an input spec and an output spec, where each spec can be either a dictionary or a comma separated string.If both are strings,
check_func2(input_str, output_str)is equivalent tocheck_func(f"{input_str} -> {output_str}").If the input spec is a dictionary and the output spec is a string,
check_func2(input_dict, output_str)is equivalent tocheck_func(output_str, **input_dict).This decorator also supports a dictionary for output shapes, which
check_funcdoes not. This enables dotpath names on the returned object.Examples:
>>> from eincheck import check_func2 >>> from numpy.random import randn >>> from numpy.typing import NDArray >>> from typing import NamedTuple, Tuple >>> >>> Array = NDArray[float] >>> >>> # Three equivalent ways of using check_func2. >>> @check_func2("i, j", "i j, i j") ... def foo1(x: Array, y: Array) -> Tuple[Array, Array]: ... return x[:, None] + y, x[:, None] * y ... >>> _ = foo1(randn(4), randn(5)) >>> >>> @check_func2("i, j -> i j, i j") ... def foo2(x: Array, y: Array) -> Tuple[Array, Array]: ... return x[:, None] + y, x[:, None] * y ... >>> _ = foo2(randn(4), randn(5)) >>> >>> @check_func2({"x": "i", "y": "j"}, {"0": "i j", "1": "i j"}) ... def foo3(x: Array, y: Array) -> Tuple[Array, Array]: ... return x[:, None] + y, x[:, None] * y ... >>> _ = foo3(randn(4), randn(5)) >>> >>> class Pair(NamedTuple): ... first: Array ... second: Array ... >>> @check_func2( ... {"x.first": "*a", "x.second": "*b", "y.first": "*a", "y.second": "*b"}, ... {"first": "*a", "second": "*b"}, ... ) ... def add_pairs(x: Pair, y: Pair) -> Pair: ... return Pair(x.first + y.first, x.second + y.second) ... >>> _ = add_pairs(Pair(randn(4), randn(5, 6)), Pair(randn(4), randn(5, 6)))
- Parameters:
input_shapes (
Union[str,Mapping[str,Union[str,ShapeSpec,Sequence[Union[DimSpec,str,int,None]]]]]) – comma separated string or dictionary of shapesoutput_shapes (
Union[str,Mapping[str,Union[str,ShapeSpec,Sequence[Union[DimSpec,str,int,None]]]]]) – comma separated string or dictionary of shapes
- Return type:
Callable[[TypeVar(_T_Callable, bound=Callable[...,Any])],TypeVar(_T_Callable, bound=Callable[...,Any])]- Returns:
a function decorator
- eincheck.check_data(shape_dict=None, /, **kwargs)[source]
Check the shapes of fields of a data object.
The currently supported data objects are NamedTuple, dataclasses, and attrs.
- Parameters:
shape_dict (
Optional[Mapping[str,Union[str,ShapeSpec,Sequence[Union[DimSpec,str,int,None]]]]]) – shape specs for fields of the data object in a dictionarykwargs (
Union[str,ShapeSpec,Sequence[Union[DimSpec,str,int,None]]]) – shape specs for fields of the data object as keywords
- Return type:
Callable[[TypeVar(_T_Data)],TypeVar(_T_Data)]- Returns:
a decorator for the data object class
The @check_data decorator can add shape assertions to NamedTuple, dataclass, and attrs classes.
Keyword arguments are matched against fields of the class.
>>> @check_data(x="*i", y="*i")
... class Foo(NamedTuple):
... x: npt.NDArray[float]
... y: npt.NDArray[float]
...
>>> _ = Foo(randn(3, 4), randn(3, 4))
>>> _ = Foo(randn(4, 5), randn(4))
Traceback (most recent call last):
...
ValueError: y: expected rank 2, got shape (4,)
i = (4, 5)
x: got (4, 5) expected [*i]
y: got (4,) expected [*i]
Not all fields of the object need shape specs.
>>> @check_data(x="i", y="i")
... class Foo(NamedTuple):
... x: npt.NDArray[float]
... y: npt.NDArray[float]
... z: npt.NDArray[float]
...
>>> _ = Foo(randn(4), randn(4), randn(4))
>>> _ = Foo(randn(5), randn(5), randn(42))
A dictionary can also be used to specify shapes instead of keyword arguments.
>>> @check_data({"x": "*i", "y": "*i"})
... class Foo(NamedTuple):
... x: npt.NDArray[float]
... y: npt.NDArray[float]
...
>>> _ = Foo(randn(3, 4), randn(3, 4))
>>> _ = Foo(randn(4, 5), randn(4))
Traceback (most recent call last):
...
ValueError: y: expected rank 2, got shape (4,)
i = (4, 5)
x: got (4, 5) expected [*i]
y: got (4,) expected [*i]
What if you want to compare the shapes in a @check_data decorated object with other tensors?
The shape spec $ will match @check_data decorated objects and include all the shapes from the object.
For example, the following two check_shapes are equivalent.
>>> @check_data(x="i", y="i")
... class Foo(NamedTuple):
... x: npt.NDArray[float]
... y: npt.NDArray[float]
...
>>> f = Foo(randn(3), randn(3))
>>> z = randn(3, 3)
>>> check_shapes(
... **{
... "f.x": (f.x, "i"),
... "f.y": (f.y, "i"),
... "z": (z, "i i"),
... }
... )
{'i': 3}
>>> check_shapes(f=(f, "$"), z=(z, "i i"))
{'i': 3}
This can also be used to pass @check_data decorated objects to functions and include them in other classes.
>>> @check_data(x="i", y="i")
... class Foo(NamedTuple):
... x: npt.NDArray[float]
... y: npt.NDArray[float]
...
... @check_func("$, i -> i")
... def method(self, z: npt.NDArray[float]) -> npt.NDArray[float]:
... return self.y + z
...
>>> f = Foo(randn(4), randn(4))
>>> f.method(f.x).shape
(4,)
>>> f.method(randn(7))
Traceback (most recent call last):
...
ValueError: z dim 0: expected i=4 got 7
i=4
self.x: got (4,) expected [i]
self.y: got (4,) expected [i]
z: got (7,) expected [i]
>>> @check_func("$, i -> $")
... def add_x(f: Foo, x: npt.NDArray[float]) -> Foo:
... return Foo(f.x + x, f.y)
...
>>> _ = add_x(f, randn(4))
>>> add_x(f, randn(5))
Traceback (most recent call last):
...
ValueError: x dim 0: expected i=4 got 5
i=4
f.x: got (4,) expected [i]
f.y: got (4,) expected [i]
x: got (5,) expected [i]
>>> @check_data(f="$", g="i i")
... @attrs.frozen
... class Bar:
... f: Foo
... g: npt.NDArray[float]
...
>>> _ = Bar(f, randn(4, 4))
>>> Bar(f, randn(5, 5))
Traceback (most recent call last):
...
ValueError: g dim 0: expected i=4 got 5
i=4
f.x: got (4,) expected [i]
f.y: got (4,) expected [i]
g: got (5, 5) expected [i i]
The @check_data decorator only checks the shapes on construction.
If you modify class members after creation it will not be checked.
You can use check_shapes to re-check a data object.
>>> @check_data(p="i 2")
... @attrs.define
... class Foo:
... p: npt.NDArray[float]
...
>>> f = Foo(randn(3, 2))
>>> f.p = randn(4) # unchecked
>>> check_shapes((f, "$"))
Traceback (most recent call last):
...
ValueError: arg0.p: expected rank 2, got shape (4,)
arg0.p: got (4,) expected [i 2]
- eincheck.disable_checks()[source]
Disable eincheck from doing shape checks.
- Return type:
ContextManager[None]
Context manager to disable eincheck. This can be used to make code run faster once you’re confident the shapes are correct. check_shapes will return an empty dictionary.
>>> with disable_checks():
... # Eincheck is a no-op inside this context.
... print(check_shapes((randn(2, 3), "i")))
...
{}
- eincheck.enable_checks()[source]
Enable eincheck to do shape checks.
- Return type:
ContextManager[None]
Context manager to enable eincheck (e.g. if inside a disable_checks context).
>>> with disable_checks():
... with enable_checks():
... check_shapes((randn(2, 3), "i"))
...
Traceback (most recent call last):
...
ValueError: arg0: expected rank 1, got shape (2, 3)
arg0: got (2, 3) expected [i]
- eincheck.parser_cache_clear()
- Return type:
None
Clear the lru_cache for parsing shape strings.
- eincheck.parser_cache_info()
- Return type:
CacheInfo
Get the lru_cache cache info for the parser cache.
>>> parser_cache_clear()
>>> check_shapes((randn(2, 3), "a b"), (randn(3, 4), "b c"), (randn(2, 3), "a b"))
{'a': 2, 'b': 3, 'c': 4}
>>> parser_cache_info()
CacheInfo(hits=1, misses=2, maxsize=128, currsize=2)
- eincheck.parser_resize_cache(maxsize)
- Return type:
None
Reset the parser cache to a lru_cache with the given size.
This will clear the cache and change the maxsize field in CacheInfo.