Welcome to eincheck’s documentation!
Getting Started
To install eincheck, run
pip install eincheck
eincheck is compatible with numpy, pytorch, tensorflow, jax, or any tensor object that has a shape field which returns a Sequence[int | None].
While none of these libraries are required to use eincheck, you will most likely want to install at least one of them.
How to use eincheck
There are three key functions in eincheck, described in API:
check_shapes: compares tensors against shape specificationscheck_func: decorates functions to add shape checks on the inputs and outputscheck_data: decorates classes to add shape checks to class fields
eincheck is inspired by Einstein notation, so basic functionality should be intuitive to anyone familiar with einsum or einops.
from typing import Any, NamedTuple
import numpy as np
import numpy.typing as npt
from eincheck import check_func, check_data
@check_func("*x -> *x")
def softmax(x):
y = np.exp(x - np.max(x))
return y / y.sum()
@check_data(tokens="n q d", scores="n q k")
class AttentionOutputs(NamedTuple):
tokens: npt.NDArray[Any]
scores: npt.NDArray[Any]
@check_func("n q c, n k c, n k d -> $")
def attention(query, key, value):
coeffs = np.einsum("n q c, n k c -> n q k", query, key)
weights = softmax(coeffs / np.sqrt(query.shape[-1]))
outputs = (
np.expand_dims(weights, -1) *
np.expand_dims(value, 1)
).sum(2)
return AttentionOutputs(outputs, weights)
Resources
Specifying Shapes contains information on how to format shape specifications (e.g.
"... i j")API contains information on the
check_*functionsPerformance contains information on making code with shape checks run faster