Welcome to eincheck’s documentation!

Getting Started

To install eincheck, run

pip install eincheck

eincheck is compatible with numpy, pytorch, tensorflow, jax, or any tensor object that has a shape field which returns a Sequence[int | None]. While none of these libraries are required to use eincheck, you will most likely want to install at least one of them.

How to use eincheck

There are three key functions in eincheck, described in API:

  • check_shapes: compares tensors against shape specifications

  • check_func: decorates functions to add shape checks on the inputs and outputs

  • check_data: decorates classes to add shape checks to class fields

eincheck is inspired by Einstein notation, so basic functionality should be intuitive to anyone familiar with einsum or einops.

from typing import Any, NamedTuple

import numpy as np
import numpy.typing as npt
from eincheck import check_func, check_data

@check_func("*x -> *x")
def softmax(x):
    y = np.exp(x - np.max(x))
    return y / y.sum()

@check_data(tokens="n q d", scores="n q k")
class AttentionOutputs(NamedTuple):
    tokens: npt.NDArray[Any]
    scores: npt.NDArray[Any]

@check_func("n q c, n k c, n k d -> $")
def attention(query, key, value):
    coeffs = np.einsum("n q c, n k c -> n q k", query, key)
    weights = softmax(coeffs / np.sqrt(query.shape[-1]))
    outputs = (
        np.expand_dims(weights, -1) *
        np.expand_dims(value, 1)
    ).sum(2)
    return AttentionOutputs(outputs, weights)

Resources

  • Specifying Shapes contains information on how to format shape specifications (e.g. "... i j")

  • API contains information on the check_* functions

  • Performance contains information on making code with shape checks run faster

Indices and tables