Welcome to eincheck’s documentation!

Getting Started

To install eincheck, run

pip install eincheck

eincheck is compatible with numpy, pytorch, tensorflow, jax, or any tensor object that has a shape field which returns a Sequence[int | None]. While none of these libraries are required to use eincheck, you will most likely want to install at least one of them.

How to use eincheck

There are three key functions in eincheck, described in API:

  • check_shapes: compares tensors against shape specifications

  • check_func: decorates functions to add shape checks on the inputs and outputs

  • check_data: decorates classes to add shape checks to class fields

eincheck is inspired by Einstein notation, so basic functionality should be intuitive to anyone familiar with einsum or einops.

import numpy as np
from eincheck import check_func

@check_func("... i, j i, j -> ... j")
def linear(x, m, b):
    return (x[..., None, :] * m).sum(-1) + b

@check_func("*x -> *x")
def softmax(x):
    y = np.exp(x - np.max(x))
    return y / y.sum()

@check_func("n q c, n k c, n k d -> n q d, n q k")
def attention(query, key, value):
    coeffs = np.einsum("n q c, n k c -> n q k", query, key)
    weights = softmax(coeffs / np.sqrt(query.shape[-1]))
    outputs = (
        np.expand_dims(weights, -1) *
        np.expand_dims(value, 1)
    ).sum(2)
    return outputs, weights

Resources

  • Specifying Shapes contains information on how to format shape specifications (e.g. "... i j")

  • API contains information on the check_* functions

  • Performance contains information on making code with shape checks run faster

Indices and tables