Runtime type checking

Runtime type checking#

To avoid common pitfalls with function arguments, such as using the wrong data type or array shape, most functions in this library are wrapped with a runtime type checker that utilizes the type annotations to determine what inputs and outputs we should expect.

For that, we rely on the jaxtyping and beartype modules.

Input arguments checking#

Let’s take the example of the sorted_array2 function:

import inspect

import jax
import jax.numpy as jnp

from differt.utils import sorted_array2

inspect.signature(sorted_array2)
<Signature (array: jaxtyping.Shaped[Array, 'm n']) -> jaxtyping.Shaped[Array, 'm n']>

As we can see, its signature expects a 2D array as input, and a 2D array as output, which matching shapes.

key = jax.random.PRNGKey(1234)

arr = jax.random.randint(key, (10, 4), 0, 2)
arr
Array([[1, 1, 1, 0],
       [1, 0, 1, 1],
       [1, 0, 0, 0],
       [1, 0, 0, 1],
       [0, 1, 0, 1],
       [0, 0, 0, 0],
       [1, 1, 0, 0],
       [0, 0, 1, 0],
       [1, 1, 1, 1],
       [1, 0, 1, 0]], dtype=int32)

Hence, if we provide a 2D array as input, everything works just fine:

sorted_array2(arr)
Array([[0, 0, 0, 0],
       [0, 0, 1, 0],
       [0, 1, 0, 1],
       [1, 0, 0, 0],
       [1, 0, 0, 1],
       [1, 0, 1, 0],
       [1, 0, 1, 1],
       [1, 1, 0, 0],
       [1, 1, 1, 0],
       [1, 1, 1, 1]], dtype=int32)

However, if anything else than a 2D array is provided, an error will be raised:

arr = jax.random.randint(key, (2, 10, 4), 0, 2)  # 3D array
sorted_array2(arr)
---------------------------------------------------------------------------
BeartypeCallHintParamViolation            Traceback (most recent call last)
    [... skipping hidden 1 frame]

File <@beartype(differt.utils.check_params) at 0x7f6a99c5cc20>:29, in check_params(__beartype_object_94412212175472, __beartype_get_violation, __beartype_conf, __beartype_func, *args, **kwargs)

BeartypeCallHintParamViolation: Function differt.utils.check_params() parameter array="Traced<ShapedArray(int32[2,10,4])>with<DynamicJaxprTrace(level=1/0)>" violates type hint <class 'jaxtyping.Shaped[Array, 'm n']'>, as this array has 3 dimensions, not the 2 expected by the type hint.

The above exception was the direct cause of the following exception:

TypeCheckError                            Traceback (most recent call last)
Cell In[4], line 2
      1 arr = jax.random.randint(key, (2, 10, 4), 0, 2)  # 3D array
----> 2 sorted_array2(arr)

    [... skipping hidden 11 frame]

File ~/checkouts/readthedocs.org/user_builds/differt/envs/latest/lib/python3.11/site-packages/jaxtyping/_decorator.py:447, in jaxtyped.<locals>.wrapped_fn(*args, **kwargs)
    445         raise TypeCheckError(msg) from None
    446     else:
--> 447         raise TypeCheckError(msg) from e
    449 # Actually call the function.
    450 out = fn(*args, **kwargs)

TypeCheckError: Type-check error whilst checking the parameters of sorted_array2.
The problem arose whilst typechecking parameter 'array'.
Actual value: i32[2,10,4]
Expected type: <class 'Shaped[Array, 'm n']'>.
----------------------
Called with parameters: {'array': i32[2,10,4]}
Parameter annotations: (array: Shaped[Array, 'm n']).

The error message is a bit verbose, but we can see at the end that we expected Shaped[Array, 'm n'] and we received i32[2,10,4] (i.e., Int32[Array, "2 10 4"]). Int32 is a subclass of Shaped, but m n cannot be matched to 2 10 4, as there is one extra dimension. Thus, an error was raised for that reason.

Output checking#

The output values are also checked by the type checker. If you use one of the functions from our library, you are guaranteed to have correct output types if you provided valid inputs.

In other words, type checking the outputs should never fail. If you encounter a case where your input is valid, but the returned output is not, please report it via the GitHub issues.

If you define custom function yourself, this is always good to use type annotations and runtime checking:

from beartype import beartype as typechecker
from jaxtyping import Array, Num, jaxtyped


@jaxtyped(typechecker=typechecker)
def my_custom_transpose(x: Num[Array, "m n"]) -> Num[Array, "n m"]:
    return x.transpose().transpose()  # Oops, transposed one too many times


x = jnp.arange(70).reshape(10, 7)
x
Array([[ 0,  1,  2,  3,  4,  5,  6],
       [ 7,  8,  9, 10, 11, 12, 13],
       [14, 15, 16, 17, 18, 19, 20],
       [21, 22, 23, 24, 25, 26, 27],
       [28, 29, 30, 31, 32, 33, 34],
       [35, 36, 37, 38, 39, 40, 41],
       [42, 43, 44, 45, 46, 47, 48],
       [49, 50, 51, 52, 53, 54, 55],
       [56, 57, 58, 59, 60, 61, 62],
       [63, 64, 65, 66, 67, 68, 69]], dtype=int32)
my_custom_transpose(x)
---------------------------------------------------------------------------
BeartypeCallHintReturnViolation           Traceback (most recent call last)
File ~/checkouts/readthedocs.org/user_builds/differt/envs/latest/lib/python3.11/site-packages/jaxtyping/_decorator.py:469, in jaxtyped.<locals>.wrapped_fn(*args, **kwargs)
    468 try:
--> 469     full_fn(*args, **kwargs)
    470 except AnnotationError:

File <@beartype(__main__.check_return) at 0x7f6a81bafba0>:47, in check_return(__beartype_object_94412225234032, __beartype_get_violation, __beartype_conf, __beartype_object_94412225211104, __beartype_func, *args, **kwargs)

BeartypeCallHintReturnViolation: Function __main__.check_return() return "Array([[ 0,  1,  2,  3,  4,  5,  6],
       [ 7,  8,  9, 10, 11, 12, 13],
       [14, 15, ...)" violates type hint <class 'jaxtyping.Num[Array, 'n m']'>, as the size of dimension n is 10 which does not equal the existing value of 7.

The above exception was the direct cause of the following exception:

TypeCheckError                            Traceback (most recent call last)
Cell In[6], line 1
----> 1 my_custom_transpose(x)

File ~/checkouts/readthedocs.org/user_builds/differt/envs/latest/lib/python3.11/site-packages/jaxtyping/_decorator.py:500, in jaxtyped.<locals>.wrapped_fn(*args, **kwargs)
    498                 raise TypeCheckError(msg) from None
    499             else:
--> 500                 raise TypeCheckError(msg) from e
    502     return out
    503 finally:

TypeCheckError: Type-check error whilst checking the return value of my_custom_transpose.
Actual value: i32[10,7]
Expected type: Num[Array, 'n m'].
----------------------
Called with parameters: {'x': i32[10,7]}
Parameter annotations: (x: Num[Array, 'm n']).
The current values for each jaxtyping axis annotation are as follows.
m=10
n=7

Here, the error message tells us that it inferred m=10 and n=7 from the input arguments, but that is does not match the expected output shape, i.e., (n, m) = (7, 10) != (10, 7).

Thanks to the type checker, we rapidly caught the error, and we can fix the function:

@jaxtyped(typechecker=typechecker)
def my_custom_transpose_fixed(x: Num[Array, "m n"]) -> Num[Array, "n m"]:
    return x.transpose()  # Now this is all good


my_custom_transpose_fixed(x)
Array([[ 0,  7, 14, 21, 28, 35, 42, 49, 56, 63],
       [ 1,  8, 15, 22, 29, 36, 43, 50, 57, 64],
       [ 2,  9, 16, 23, 30, 37, 44, 51, 58, 65],
       [ 3, 10, 17, 24, 31, 38, 45, 52, 59, 66],
       [ 4, 11, 18, 25, 32, 39, 46, 53, 60, 67],
       [ 5, 12, 19, 26, 33, 40, 47, 54, 61, 68],
       [ 6, 13, 20, 27, 34, 41, 48, 55, 62, 69]], dtype=int32)