Welcome to eincheck’s documentation!
Getting Started
To install eincheck, run
pip install eincheck
eincheck is compatible with numpy, pytorch, tensorflow, jax, or any tensor object that has a shape
field which returns a Sequence[int | None]
.
While none of these libraries are required to use eincheck, you will most likely want to install at least one of them.
How to use eincheck
There are three key functions in eincheck, described in API:
check_shapes
: compares tensors against shape specificationscheck_func
: decorates functions to add shape checks on the inputs and outputscheck_data
: decorates classes to add shape checks to class fields
eincheck is inspired by Einstein notation, so basic functionality should be intuitive to anyone familiar with einsum or einops.
import numpy as np
from eincheck import check_func
@check_func("... i, j i, j -> ... j")
def linear(x, m, b):
return (x[..., None, :] * m).sum(-1) + b
@check_func("*x -> *x")
def softmax(x):
y = np.exp(x - np.max(x))
return y / y.sum()
@check_func("n q c, n k c, n k d -> n q d, n q k")
def attention(query, key, value):
coeffs = np.einsum("n q c, n k c -> n q k", query, key)
weights = softmax(coeffs / np.sqrt(query.shape[-1]))
outputs = (
np.expand_dims(weights, -1) *
np.expand_dims(value, 1)
).sum(2)
return outputs, weights
Resources
Specifying Shapes contains information on how to format shape specifications (e.g.
"... i j"
)API contains information on the
check_*
functionsPerformance contains information on making code with shape checks run faster