Source code for eincheck.checks.shapes

import itertools
from typing import Any, Dict, Optional, Set, Tuple, Union, cast

from eincheck.contexts import _should_do_checks
from eincheck.parser.dim_spec import DimSpec, DimType
from eincheck.parser.expressions import DataExpr, Variable
from eincheck.parser.grammar import ShapeArg, create_shape_spec
from eincheck.parser.shape_spec import ShapeSpec
from eincheck.types import ShapeVariable, Tensor
from eincheck.utils import get_shape


def _is_broadcast_compatible(got: Tuple[int, ...], expected: Tuple[int, ...]) -> bool:
    """Check if got is broadcast-compatible with expected (numpy-style).

    A shorter got is allowed: missing leading dims are implicitly 1.
    """
    if len(got) > len(expected):
        return False
    return all(g == e or g == 1 for g, e in zip(reversed(got), reversed(expected)))


def _check_dim_spec(
    got_shape: Tuple[int, ...],
    d: DimSpec,
    bindings: Dict[str, ShapeVariable],
    name: str,
    msg: str,
    start_idx: int,
) -> None:
    if d.value is None:
        return

    expected_value = d.value.eval(bindings)

    if d.can_broadcast:
        broadcast_values: Set[ShapeVariable]
        if isinstance(expected_value, int):
            broadcast_values = {expected_value, 1}
        else:
            broadcast_values = {
                tuple(p) for p in itertools.product(*([x, 1] for x in expected_value))
            }

    if d.type is DimType.VARIADIC and isinstance(expected_value, int):
        raise ValueError(
            f"{name}: expected variadic DimSpec {d} to evaluate to a tuple, "
            f"got {expected_value}{msg}"
        )
    elif d.type is not DimType.VARIADIC and isinstance(expected_value, tuple):
        raise ValueError(
            f"{name}: expected non-variadic DimSpec {d} to evaluate to an integer, "
            f"got {expected_value}{msg}"
        )

    def do_check(g: ShapeVariable, indices: ShapeVariable) -> None:
        if d.can_broadcast:
            if isinstance(g, tuple) and isinstance(expected_value, tuple):
                is_ok = _is_broadcast_compatible(g, expected_value)
            else:
                is_ok = g in broadcast_values
            if not is_ok:
                first_line = (
                    f"expected can broadcast to {d.value}={expected_value} got {g}"
                )
            else:
                first_line = None
        elif g != expected_value:
            first_line = f"expected {d.value}={expected_value} got {g}"
        else:
            first_line = None

        if first_line:
            dim_str = (
                f"dim {indices}" if isinstance(indices, int) else f"dims {indices}"
            )
            raise ValueError(
                f"{name} {dim_str}: {first_line}"
                + "".join(f"\n    {k}={v}" for k, v in bindings.items())
                + msg
            )

    if d.type is DimType.REPEATED:
        for g_idx, g in enumerate(got_shape):
            do_check(g, start_idx + g_idx)
    elif d.type is DimType.VARIADIC:
        do_check(got_shape, tuple(range(start_idx, start_idx + len(got_shape))))
    elif len(got_shape) != 1:
        raise RuntimeError(
            f"{name}: expected a single dimension for {d}, got {got_shape}{msg}"
        )
    else:
        do_check(got_shape[0], start_idx)


def _bind_shape(
    got_shape: Tuple[Optional[int], ...],
    s: ShapeArg,
    bindings: Dict[str, ShapeVariable],
    name: str,
    msg: str,
) -> None:
    expected_shape = create_shape_spec(s)

    for d, start_idx, end_idx in expected_shape.matched_indices(
        bindings, len(got_shape)
    ):
        if not isinstance(d.value, Variable) or d.value.x in bindings:
            continue

        g_slice = got_shape[start_idx:end_idx]
        if None in g_slice:
            raise ValueError(
                f"{name}: tried to match {g_slice} to {d}, found None{msg}"
            )

        g_slice = cast(Tuple[int, ...], g_slice)

        if d.can_broadcast:
            pass
        elif d.type is DimType.VARIADIC:
            bindings.setdefault(d.value.x, g_slice)
        elif d.type is DimType.REPEATED and g_slice:
            bindings.setdefault(d.value.x, g_slice[0])
        elif d.type is DimType.SINGLE and len(g_slice) != 1:
            raise RuntimeError(
                f"{name}: expected a single dimension for {d}, got {got_shape}{msg}"
            )
        elif d.type is DimType.SINGLE:
            bindings.setdefault(d.value.x, g_slice[0])
        else:
            # Only reach here if d.type is REPEATED and not got_shape.
            # Nothing to check in this case.
            pass


def _check_shape(
    got_shape: Tuple[Optional[int], ...],
    s: ShapeArg,
    bindings: Dict[str, ShapeVariable],
    name: str,
    msg: str,
) -> Dict[str, ShapeVariable]:
    expected_shape = create_shape_spec(s)

    if msg:
        msg = "\n" + msg

    _check_rank(name, got_shape, expected_shape, bindings, msg)

    unknown_size_inds = expected_shape.unknown_n_dims_indices(bindings)
    if len(unknown_size_inds) > 1:
        raise RuntimeError(
            f"{name} has multiple DimSpec of unknown size: "
            f"{[expected_shape.dims[i] for i in unknown_size_inds]}" + msg
        )

    for d, start_idx, end_idx in expected_shape.matched_indices(
        bindings, len(got_shape)
    ):
        if d.value is None:
            continue

        if d.type is DimType.REPEATED and start_idx == end_idx:
            continue

        g_slice = got_shape[start_idx:end_idx]
        if None in g_slice:
            inds = (
                f"dim {start_idx}"
                if end_idx == start_idx + 1
                else f"dims {tuple(range(start_idx, end_idx))}"
            )
            raise ValueError(
                f"{name} {inds}: tried to check {d} against {g_slice}, found None{msg}"
            )

        _check_dim_spec(
            cast(Tuple[int, ...], g_slice), d, bindings, name, msg, start_idx
        )

    return bindings


[docs]def check_shapes( *args: Tuple[Tensor, ShapeArg], **kwargs: Union[ShapeVariable, Tuple[Tensor, ShapeArg]], ) -> Dict[str, ShapeVariable]: """Check the shapes of Tensors against ShapeArg specifications. Examples: .. doctest:: >>> from numpy.random import randn >>> from eincheck import check_shapes >>> >>> check_shapes((randn(3, 4, 5), "... i j"), (randn(5, 6), "... j k")) {'i': 4, 'j': 5, 'k': 6} >>> check_shapes( ... x=(randn(8, 2, 7, 3), "*batch t 3"), ... y=(randn(8, 2, 1, 1, 3), "*batch ... 3"), ... batch=(8, 2), ... ) {'batch': (8, 2), 't': 7} Pass pairs of (tensor, shape spec) as either args or kwargs. The only difference when using a kwarg is the display name in error messages (args are called ``arg0``, ``arg1``, etc). Kwargs can also specify the variable values (e.g. ``batch=(8, 2)``). If all shape checks pass, returns a dictionary with all variable values. :param args: Pairs of (tensor, shape spec) :param kwargs: Pairs of (tensor, shape spec) :raise ValueError: If the shapes are incorrect or cannot be verified :return: Values for all bound variables from the shape specs """ if not _should_do_checks(): return {} tensors, bindings = _get_tensors_and_bindings(*args, **kwargs) _check_variable_types(tensors, bindings) if not tensors: return bindings got_msgs = [str(x) for x, _ in tensors.values()] got_len = max(map(len, got_msgs)) msg = "\n".join( f" {name}: got {g:<{got_len}} expected {s}" for (name, (_, s)), g in zip(tensors.items(), got_msgs) ) checked_names = set() binded_names = set() for _ in range(len(tensors)): bindings_len = len(bindings) for t_name, (t_got, t_expected) in tensors.items(): if ( t_name in checked_names or len(t_expected.unknown_n_dims_indices(bindings)) > 1 ): continue if t_name not in binded_names: _check_rank(t_name, t_got, t_expected, bindings, msg) _bind_shape(t_got, t_expected, bindings, t_name, msg) binded_names.add(t_name) if not t_expected.is_checkable(bindings): continue _check_shape(t_got, t_expected, bindings, t_name, msg) checked_names.add(t_name) if len(bindings) == bindings_len: break unbound = set(tensors) - binded_names if unbound: raise ValueError( f"Unable to determine bindings for: {' '.join(unbound)}\n{msg}" ) unchecked = set(tensors) - checked_names if unchecked: missing_vars = set.union(*(tensors[k][1].variables for k in unchecked)) raise ValueError( f"Unable to check: [{' '.join(sorted(unchecked))}] " f"missing variables: [{' '.join(sorted(missing_vars))}]\n{msg}" ) return bindings
def _check_variable_types( tensors: Dict[str, Tuple[Tuple[Optional[int], ...], ShapeSpec]], bindings: Dict[str, ShapeVariable], ) -> None: """Check that each variable is either an int or a tuple. Categorizes each variable present in the ShapeSpec and bindings as being either an int or a tuple. If any variables are in both sets, raises an error. """ int_vars = set() tuple_vars = set() for k, v in bindings.items(): if isinstance(v, int): int_vars.add(k) else: tuple_vars.add(k) for _, spec in tensors.values(): for dim in spec.dims: if not dim.value: continue if dim.type is DimType.VARIADIC: tuple_vars.update(dim.value.variables) else: int_vars.update(dim.value.variables) both = int_vars & tuple_vars if both: raise ValueError( "Found variables in both variadic and non-variadic expressions: " + " ".join(sorted(both)) ) def _check_rank( name: str, got_shape: Tuple[Optional[int], ...], shape_spec: ShapeSpec, bindings: Dict[str, ShapeVariable], msg: str, ) -> None: expected_rank = shape_spec.min_rank(bindings) if shape_spec.unknown_n_dims_indices(bindings): bound_text = "at least " check = len(got_shape) < expected_rank else: bound_text = "" check = len(got_shape) != expected_rank if check: rows = [ f"{name}: expected rank {bound_text}{expected_rank}, " f"got shape {got_shape}" ] + [f" {k} = {v}" for k, v in bindings.items()] raise ValueError("\n".join(rows) + "\n" + msg) def _get_tensors_and_bindings( *args: Tuple[Tensor, ShapeArg], **kwargs: Union[ShapeVariable, Tuple[Tensor, ShapeArg]], ) -> Tuple[ Dict[str, Tuple[Tuple[Optional[int], ...], ShapeSpec]], Dict[str, ShapeVariable] ]: tensors: Dict[str, Tuple[Tuple[Optional[int], ...], ShapeSpec]] = {} for idx, (a_tensor, a_shape) in enumerate(args): tensors.update(_get_shapes(a_tensor, create_shape_spec(a_shape), f"arg{idx}")) bindings: Dict[str, ShapeVariable] = {} for k, v in kwargs.items(): if isinstance(v, int) or ( isinstance(v, tuple) and all(isinstance(vv, int) for vv in v) ): bindings[k] = cast(ShapeVariable, v) elif isinstance(v, tuple) and len(v) == 2: v = cast(Tuple[Tensor, ShapeArg], v) assert k not in tensors tensors.update(_get_shapes(v[0], create_shape_spec(v[1]), k)) else: raise ValueError(f"Unexpected kwarg {v}") return tensors, bindings def _get_shapes( x: Any, s: ShapeSpec, name: str ) -> Dict[str, Tuple[Tuple[Optional[int], ...], ShapeSpec]]: if s.is_data_expr: if not hasattr(x, "_get_shapes"): raise ValueError( f"{name}: spec $ specified, but no _get_shapes method was found. " "This should have been added by the @check_data decorator." ) return { k2: v2 for k, (vt, vs) in x._get_shapes().items() for k2, v2 in _get_shapes(vt, vs, f"{name}.{k}").items() } shape = get_shape(x) if shape is not None: if any(isinstance(d.value, DataExpr) for d in s.dims): raise ValueError( f"{name}: $ should not be present in the shape spec for a Tensor, " f"got {s}" ) return {name: (shape, s)} return {}