.. eincheck documentation master file, created by sphinx-quickstart on Mon Feb 13 22:20:22 2023. You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. Welcome to eincheck's documentation! ==================================== .. toctree:: :maxdepth: 2 :caption: Contents: :hidden: specifying_shapes api performance Getting Started --------------- To install eincheck, run .. code-block:: shell pip install eincheck eincheck is compatible with `numpy `_, `pytorch `_, `tensorflow `_, `jax `_, or any tensor object that has a ``shape`` field which returns a ``Sequence[int | None]``. While none of these libraries are required to use eincheck, you will most likely want to install at least one of them. How to use eincheck ------------------- There are three key functions in `eincheck`, described in :ref:`API`: * ``check_shapes``: compares tensors against shape specifications * ``check_func``: decorates functions to add shape checks on the inputs and outputs * ``check_data``: decorates classes to add shape checks to class fields `eincheck` is inspired by Einstein notation, so basic functionality should be intuitive to anyone familiar with `einsum `_ or `einops `_. .. testcode:: from typing import Any, NamedTuple import numpy as np import numpy.typing as npt from eincheck import check_func, check_data @check_func("*x -> *x") def softmax(x): y = np.exp(x - np.max(x)) return y / y.sum() @check_data(tokens="n q d", scores="n q k") class AttentionOutputs(NamedTuple): tokens: npt.NDArray[Any] scores: npt.NDArray[Any] @check_func("n q c, n k c, n k d -> $") def attention(query, key, value): coeffs = np.einsum("n q c, n k c -> n q k", query, key) weights = softmax(coeffs / np.sqrt(query.shape[-1])) outputs = ( np.expand_dims(weights, -1) * np.expand_dims(value, 1) ).sum(2) return AttentionOutputs(outputs, weights) .. testcode:: :hide: out, weights = attention( np.random.randn(7, 4, 10), np.random.randn(7, 5, 10), np.random.randn(7, 5, 8), ) print(out.shape) print(weights.shape) .. testoutput:: :hide: (7, 4, 8) (7, 4, 5) Resources --------- * :ref:`Specifying Shapes` contains information on how to format shape specifications (e.g. ``"... i j"``) * :ref:`API` contains information on the ``check_*`` functions * :ref:`Performance` contains information on making code with shape checks run faster Indices and tables ================== * :ref:`genindex` * :ref:`search`