Runtime type checking

Runtime type checking#

To avoid common pitfalls with function arguments, such as using the wrong data type or array shape, functions in this library are annotated with type hints, and can be type checked at runtime.

For that, we rely on the jaxtyping and beartype modules.

By default, no type-checking is performed, to avoid an additional overhead when calling a function. To enable runtime type checking, you can use jaxtyping.install_import_hook.

Input arguments checking#

Let’s take the example of the perpendicular_vectors function:

import inspect

import jax
import jax.numpy as jnp
from jaxtyping import install_import_hook

with install_import_hook("differt", "beartype.beartype"):
    from differt.geometry import perpendicular_vectors

inspect.signature(perpendicular_vectors)
<Signature (u: Union[jaxtyping.Float[Array, '*batch 3'], jaxtyping.Float[ndarray, '*batch 3'], jaxtyping.Float[TypedNdArray, '*batch 3']]) -> jaxtyping.Float[Array, '*batch 3']>

As we can see, its signature expects an array of 3D vectors as input, and an array of 3D vectors as output, which matching shapes.

key = jax.random.key(1234)

arr = jax.random.normal(key, (10, 3))
arr
Array([[ 1.1031016 ,  0.86306226, -0.33868238],
       [ 1.0272458 , -1.0735804 ,  0.32937315],
       [ 0.75820905, -0.30057552, -0.37079313],
       [ 0.5862171 ,  0.88390875, -1.2834665 ],
       [ 0.2709998 ,  1.7730063 , -1.235121  ],
       [-1.8687268 , -0.11319233,  0.12312093],
       [ 0.14007285,  0.7311839 ,  1.9988668 ],
       [ 0.03140245,  0.4228713 ,  0.4774148 ],
       [ 0.04098483, -1.1967784 ,  0.7829566 ],
       [ 1.3783231 , -0.2089781 ,  0.25486845]], dtype=float32)

Hence, if we provide an array of 3D vectors as input, everything works just fine:

perpendicular_vectors(arr)
Array([[ 0.1851116 ,  0.14483057,  0.9719865 ],
       [ 0.73785526,  0.64527345, -0.19796911],
       [ 0.38472942, -0.15251763,  0.9103415 ],
       [ 0.9359692 , -0.19969855,  0.2899693 ],
       [ 0.99222696, -0.1021079 ,  0.07113095],
       [ 0.06550259,  0.00396762,  0.9978445 ],
       [ 0.9978415 , -0.02255976, -0.06167253],
       [ 0.99879   , -0.03260797, -0.03681386],
       [ 0.9995896 ,  0.02397186, -0.01568287],
       [-0.17780961,  0.02695907,  0.98369557]], dtype=float32)

However, if anything else than a 3D vectors is provided, an error will be raised:

arr = jax.random.normal(key, (2, 10, 4))  # 4D vectors
perpendicular_vectors(arr)
---------------------------------------------------------------------------
BeartypeCallHintParamViolation            Traceback (most recent call last)
    [... skipping hidden 1 frame]

File <@beartype(differt.geometry._utils.perpendicular_vectors) at 0x70c0245cc2c0>:32, in perpendicular_vectors(__beartype_object_123970544057536, __beartype_get_violation, __beartype_conf, __beartype_check_meta, __beartype_func, *args, **kwargs)

BeartypeCallHintParamViolation: Function differt.geometry._utils.perpendicular_vectors() parameter u="JitTracer<float32[2,10,4]>" violates type hint typing.Union[jaxtyping.Float[Array, '*batch 3'], jaxtyping.Float[ndarray, '*batch 3'], jaxtyping.Float[TypedNdArray, '*batch 3']], as <class "jax._src.interpreters.partial_eval.DynamicJaxprTracer"> "JitTracer<float32[2,10,4]>" not <class "jaxtyping.Float[TypedNdArray, '*batch 3']">, <class "jaxtyping.Float[ndarray, '*batch 3']">, or <class "jaxtyping.Float[Array, '*batch 3']">.

During handling of the above exception, another exception occurred:

BeartypeCallHintParamViolation            Traceback (most recent call last)
File ~/checkouts/readthedocs.org/user_builds/differt/envs/latest/lib/python3.12/site-packages/jaxtyping/_decorator.py:757, in _get_problem_arg(param_signature, args, kwargs, arguments, module, typechecker)
    756 try:
--> 757     fn(*args, **kwargs)
    758 except Exception as e:

File <@beartype(differt.geometry._utils.check_single_arg) at 0x70c02423f060>:32, in check_single_arg(__beartype_object_123970544057536, __beartype_get_violation, __beartype_conf, __beartype_check_meta, __beartype_func, *args, **kwargs)

BeartypeCallHintParamViolation: Function differt.geometry._utils.check_single_arg() parameter u="JitTracer<float32[2,10,4]>" violates type hint typing.Union[jaxtyping.Float[Array, '*batch 3'], jaxtyping.Float[ndarray, '*batch 3'], jaxtyping.Float[TypedNdArray, '*batch 3']], as <class "jax._src.interpreters.partial_eval.DynamicJaxprTracer"> "JitTracer<float32[2,10,4]>" not <class "jaxtyping.Float[TypedNdArray, '*batch 3']">, <class "jaxtyping.Float[ndarray, '*batch 3']">, or <class "jaxtyping.Float[Array, '*batch 3']">.

The above exception was the direct cause of the following exception:

TypeCheckError                            Traceback (most recent call last)
    [... skipping hidden 1 frame]

File ~/checkouts/readthedocs.org/user_builds/differt/envs/latest/lib/python3.12/site-packages/jaxtyping/_decorator.py:760, in _get_problem_arg(param_signature, args, kwargs, arguments, module, typechecker)
    759         keep_value = _pformat(arguments[keep_name], short_self=False)
--> 760         raise TypeCheckError(
    761             f"\nThe problem arose whilst typechecking parameter '{keep_name}'.\n"
    762             f"Actual value: {keep_value}\n"
    763             f"Expected type: {keep_annotation}."
    764         ) from e
    765 else:
    766     # Could not localise the problem to a single argument -- probably due to
    767     # e.g. a mismatched typevar, which each individual argument is okay with.

TypeCheckError: 
The problem arose whilst typechecking parameter 'u'.
Actual value: f32[2,10,4]
Expected type: typing.Union[Float[Array, '*batch 3'], Float[ndarray, '*batch 3'], Float[TypedNdArray, '*batch 3']].

The above exception was the direct cause of the following exception:

TypeCheckError                            Traceback (most recent call last)
Cell In[4], line 2
      1 arr = jax.random.normal(key, (2, 10, 4))  # 4D vectors
----> 2 perpendicular_vectors(arr)

    [... skipping hidden 14 frame]

File ~/checkouts/readthedocs.org/user_builds/differt/envs/latest/lib/python3.12/site-packages/jaxtyping/_decorator.py:462, in jaxtyped.<locals>.wrapped_fn_impl(args, kwargs, bound, memos)
    460             raise TypeCheckError(msg) from None
    461         else:
--> 462             raise TypeCheckError(msg) from e
    464 # Actually call the function.
    465 out = fn(*args, **kwargs)

TypeCheckError: Type-check error whilst checking the parameters of differt.geometry._utils.perpendicular_vectors.
The problem arose whilst typechecking parameter 'u'.
Actual value: f32[2,10,4]
Expected type: typing.Union[Float[Array, '*batch 3'], Float[ndarray, '*batch 3'], Float[TypedNdArray, '*batch 3']].
----------------------
Called with parameters: {'u': f32[2,10,4]}
Parameter annotations: (u: Union[Float[Array, '*batch 3'], Float[ndarray, '*batch 3'], Float[TypedNdArray, '*batch 3']]) -> Any.

The error message is a bit verbose, but we can see at the end that we expected Shaped[Array, '*batch 3'], and we received f32[2,10,4] (i.e., Float32[Array, "2 10 4"]). Float32 is a subclass of Shaped, but *batch 3 cannot be matched to 2 10 4, as 4 != 3. Thus, an error was raised for that reason.

Output checking#

The output values are also checked by the type checker. If you use one of the functions from our library, you are guaranteed to have correct output types if you provided valid inputs.

In other words, type checking the outputs should never fail. If you encounter a case where your input is valid, but the returned output is not, please report it via the GitHub issues.

If you define custom function yourself, this is always good to use type annotations and runtime checking:

from beartype import beartype as typechecker
from jaxtyping import Array, Num, jaxtyped


@jaxtyped(typechecker=typechecker)
def my_custom_transpose(x: Num[Array, "m n"]) -> Num[Array, "n m"]:
    return x.transpose().transpose()  # Oops, transposed one too many times


x = jnp.arange(70).reshape(10, 7)
x
Array([[ 0,  1,  2,  3,  4,  5,  6],
       [ 7,  8,  9, 10, 11, 12, 13],
       [14, 15, 16, 17, 18, 19, 20],
       [21, 22, 23, 24, 25, 26, 27],
       [28, 29, 30, 31, 32, 33, 34],
       [35, 36, 37, 38, 39, 40, 41],
       [42, 43, 44, 45, 46, 47, 48],
       [49, 50, 51, 52, 53, 54, 55],
       [56, 57, 58, 59, 60, 61, 62],
       [63, 64, 65, 66, 67, 68, 69]], dtype=int32)
my_custom_transpose(x)
---------------------------------------------------------------------------
BeartypeCallHintReturnViolation           Traceback (most recent call last)
    [... skipping hidden 1 frame]

File <@beartype(__main__.my_custom_transpose) at 0x70c019a62200>:46, in my_custom_transpose(__beartype_object_105428396956112, __beartype_get_violation, __beartype_conf, __beartype_object_105428394941440, __beartype_check_meta, __beartype_func, *args, **kwargs)

BeartypeCallHintReturnViolation: Function __main__.my_custom_transpose() return "Array([[ 0,  1,  2,  3,  4,  5,  6],
       [ 7,  8,  9, 10, 11, 12, 13],
       [14, 15, ...)" violates type hint <class 'jaxtyping.Num[Array, 'n m']'>, as the size of dimension n is 10 which does not equal the existing value of 7.

The above exception was the direct cause of the following exception:

TypeCheckError                            Traceback (most recent call last)
Cell In[6], line 1
----> 1 my_custom_transpose(x)

    [... skipping hidden 1 frame]

File ~/checkouts/readthedocs.org/user_builds/differt/envs/latest/lib/python3.12/site-packages/jaxtyping/_decorator.py:515, in jaxtyped.<locals>.wrapped_fn_impl(args, kwargs, bound, memos)
    513             raise TypeCheckError(msg) from None
    514         else:
--> 515             raise TypeCheckError(msg) from e
    517 return out

TypeCheckError: Type-check error whilst checking the return value of __main__.my_custom_transpose.
Actual value: i32[10,7]
Expected type: Num[Array, 'n m'].
----------------------
Called with parameters: {'x': i32[10,7]}
Parameter annotations: (x: Num[Array, 'm n']) -> Any.
The current values for each jaxtyping axis annotation are as follows.
m=10
n=7

Here, the error message tells us that it inferred m=10 and n=7 from the input arguments, but that is does not match the expected output shape, i.e., (n, m) = (7, 10) != (10, 7).

Thanks to the type checker, we rapidly caught the error, and we can fix the function:

@jaxtyped(typechecker=typechecker)
def my_custom_transpose_fixed(x: Num[Array, "m n"]) -> Num[Array, "n m"]:
    return x.transpose()  # Now this is all good


my_custom_transpose_fixed(x)
Array([[ 0,  7, 14, 21, 28, 35, 42, 49, 56, 63],
       [ 1,  8, 15, 22, 29, 36, 43, 50, 57, 64],
       [ 2,  9, 16, 23, 30, 37, 44, 51, 58, 65],
       [ 3, 10, 17, 24, 31, 38, 45, 52, 59, 66],
       [ 4, 11, 18, 25, 32, 39, 46, 53, 60, 67],
       [ 5, 12, 19, 26, 33, 40, 47, 54, 61, 68],
       [ 6, 13, 20, 27, 34, 41, 48, 55, 62, 69]], dtype=int32)