Runtime type checking#
To avoid common pitfalls with function arguments, such as using the wrong data type or array shape, functions in this library are annotated with type hints, and can be type checked at runtime.
For that, we rely on the jaxtyping
and beartype modules.
By default, no type-checking is performed, to avoid an additional overhead when
calling a function.
To enable runtime type checking, you can use jaxtyping.install_import_hook.
Input arguments checking#
Let’s take the example of the perpendicular_vectors function:
import inspect
import jax
import jax.numpy as jnp
from jaxtyping import install_import_hook
with install_import_hook("differt", "beartype.beartype"):
from differt.geometry import perpendicular_vectors
inspect.signature(perpendicular_vectors)
<Signature (u: Union[jaxtyping.Float[Array, '*batch 3'], jaxtyping.Float[ndarray, '*batch 3'], jaxtyping.Float[TypedNdArray, '*batch 3']]) -> jaxtyping.Float[Array, '*batch 3']>
As we can see, its signature expects an array of 3D vectors as input, and an array of 3D vectors as output, which matching shapes.
key = jax.random.key(1234)
arr = jax.random.normal(key, (10, 3))
arr
Array([[ 1.1031016 , 0.86306226, -0.33868238],
[ 1.0272458 , -1.0735804 , 0.32937315],
[ 0.75820905, -0.30057552, -0.37079313],
[ 0.5862171 , 0.88390875, -1.2834665 ],
[ 0.2709998 , 1.7730063 , -1.235121 ],
[-1.8687268 , -0.11319233, 0.12312093],
[ 0.14007285, 0.7311839 , 1.9988668 ],
[ 0.03140245, 0.4228713 , 0.4774148 ],
[ 0.04098483, -1.1967784 , 0.7829566 ],
[ 1.3783231 , -0.2089781 , 0.25486845]], dtype=float32)
Hence, if we provide an array of 3D vectors as input, everything works just fine:
perpendicular_vectors(arr)
Array([[ 0.1851116 , 0.14483057, 0.9719865 ],
[ 0.73785526, 0.64527345, -0.19796911],
[ 0.38472942, -0.15251763, 0.9103415 ],
[ 0.9359692 , -0.19969855, 0.2899693 ],
[ 0.99222696, -0.1021079 , 0.07113095],
[ 0.06550259, 0.00396762, 0.9978445 ],
[ 0.9978415 , -0.02255976, -0.06167253],
[ 0.99879 , -0.03260797, -0.03681386],
[ 0.9995896 , 0.02397186, -0.01568287],
[-0.17780961, 0.02695907, 0.98369557]], dtype=float32)
However, if anything else than a 3D vectors is provided, an error will be raised:
arr = jax.random.normal(key, (2, 10, 4)) # 4D vectors
perpendicular_vectors(arr)
---------------------------------------------------------------------------
BeartypeCallHintParamViolation Traceback (most recent call last)
[... skipping hidden 1 frame]
File <@beartype(differt.geometry._utils.perpendicular_vectors) at 0x70c0245cc2c0>:32, in perpendicular_vectors(__beartype_object_123970544057536, __beartype_get_violation, __beartype_conf, __beartype_check_meta, __beartype_func, *args, **kwargs)
BeartypeCallHintParamViolation: Function differt.geometry._utils.perpendicular_vectors() parameter u="JitTracer<float32[2,10,4]>" violates type hint typing.Union[jaxtyping.Float[Array, '*batch 3'], jaxtyping.Float[ndarray, '*batch 3'], jaxtyping.Float[TypedNdArray, '*batch 3']], as <class "jax._src.interpreters.partial_eval.DynamicJaxprTracer"> "JitTracer<float32[2,10,4]>" not <class "jaxtyping.Float[TypedNdArray, '*batch 3']">, <class "jaxtyping.Float[ndarray, '*batch 3']">, or <class "jaxtyping.Float[Array, '*batch 3']">.
During handling of the above exception, another exception occurred:
BeartypeCallHintParamViolation Traceback (most recent call last)
File ~/checkouts/readthedocs.org/user_builds/differt/envs/latest/lib/python3.12/site-packages/jaxtyping/_decorator.py:757, in _get_problem_arg(param_signature, args, kwargs, arguments, module, typechecker)
756 try:
--> 757 fn(*args, **kwargs)
758 except Exception as e:
File <@beartype(differt.geometry._utils.check_single_arg) at 0x70c02423f060>:32, in check_single_arg(__beartype_object_123970544057536, __beartype_get_violation, __beartype_conf, __beartype_check_meta, __beartype_func, *args, **kwargs)
BeartypeCallHintParamViolation: Function differt.geometry._utils.check_single_arg() parameter u="JitTracer<float32[2,10,4]>" violates type hint typing.Union[jaxtyping.Float[Array, '*batch 3'], jaxtyping.Float[ndarray, '*batch 3'], jaxtyping.Float[TypedNdArray, '*batch 3']], as <class "jax._src.interpreters.partial_eval.DynamicJaxprTracer"> "JitTracer<float32[2,10,4]>" not <class "jaxtyping.Float[TypedNdArray, '*batch 3']">, <class "jaxtyping.Float[ndarray, '*batch 3']">, or <class "jaxtyping.Float[Array, '*batch 3']">.
The above exception was the direct cause of the following exception:
TypeCheckError Traceback (most recent call last)
[... skipping hidden 1 frame]
File ~/checkouts/readthedocs.org/user_builds/differt/envs/latest/lib/python3.12/site-packages/jaxtyping/_decorator.py:760, in _get_problem_arg(param_signature, args, kwargs, arguments, module, typechecker)
759 keep_value = _pformat(arguments[keep_name], short_self=False)
--> 760 raise TypeCheckError(
761 f"\nThe problem arose whilst typechecking parameter '{keep_name}'.\n"
762 f"Actual value: {keep_value}\n"
763 f"Expected type: {keep_annotation}."
764 ) from e
765 else:
766 # Could not localise the problem to a single argument -- probably due to
767 # e.g. a mismatched typevar, which each individual argument is okay with.
TypeCheckError:
The problem arose whilst typechecking parameter 'u'.
Actual value: f32[2,10,4]
Expected type: typing.Union[Float[Array, '*batch 3'], Float[ndarray, '*batch 3'], Float[TypedNdArray, '*batch 3']].
The above exception was the direct cause of the following exception:
TypeCheckError Traceback (most recent call last)
Cell In[4], line 2
1 arr = jax.random.normal(key, (2, 10, 4)) # 4D vectors
----> 2 perpendicular_vectors(arr)
[... skipping hidden 14 frame]
File ~/checkouts/readthedocs.org/user_builds/differt/envs/latest/lib/python3.12/site-packages/jaxtyping/_decorator.py:462, in jaxtyped.<locals>.wrapped_fn_impl(args, kwargs, bound, memos)
460 raise TypeCheckError(msg) from None
461 else:
--> 462 raise TypeCheckError(msg) from e
464 # Actually call the function.
465 out = fn(*args, **kwargs)
TypeCheckError: Type-check error whilst checking the parameters of differt.geometry._utils.perpendicular_vectors.
The problem arose whilst typechecking parameter 'u'.
Actual value: f32[2,10,4]
Expected type: typing.Union[Float[Array, '*batch 3'], Float[ndarray, '*batch 3'], Float[TypedNdArray, '*batch 3']].
----------------------
Called with parameters: {'u': f32[2,10,4]}
Parameter annotations: (u: Union[Float[Array, '*batch 3'], Float[ndarray, '*batch 3'], Float[TypedNdArray, '*batch 3']]) -> Any.
The error message is a bit verbose,
but we can see at the end that we expected Shaped[Array, '*batch 3'],
and we received f32[2,10,4] (i.e., Float32[Array, "2 10 4"]).
Float32 is a subclass of Shaped,
but *batch 3 cannot be matched to 2 10 4, as 4 != 3. Thus, an error was raised for that reason.
Output checking#
The output values are also checked by the type checker. If you use one of the functions from our library, you are guaranteed to have correct output types if you provided valid inputs.
In other words, type checking the outputs should never fail. If you encounter a case where your input is valid, but the returned output is not, please report it via the GitHub issues.
If you define custom function yourself, this is always good to use type annotations and runtime checking:
from beartype import beartype as typechecker
from jaxtyping import Array, Num, jaxtyped
@jaxtyped(typechecker=typechecker)
def my_custom_transpose(x: Num[Array, "m n"]) -> Num[Array, "n m"]:
return x.transpose().transpose() # Oops, transposed one too many times
x = jnp.arange(70).reshape(10, 7)
x
Array([[ 0, 1, 2, 3, 4, 5, 6],
[ 7, 8, 9, 10, 11, 12, 13],
[14, 15, 16, 17, 18, 19, 20],
[21, 22, 23, 24, 25, 26, 27],
[28, 29, 30, 31, 32, 33, 34],
[35, 36, 37, 38, 39, 40, 41],
[42, 43, 44, 45, 46, 47, 48],
[49, 50, 51, 52, 53, 54, 55],
[56, 57, 58, 59, 60, 61, 62],
[63, 64, 65, 66, 67, 68, 69]], dtype=int32)
my_custom_transpose(x)
---------------------------------------------------------------------------
BeartypeCallHintReturnViolation Traceback (most recent call last)
[... skipping hidden 1 frame]
File <@beartype(__main__.my_custom_transpose) at 0x70c019a62200>:46, in my_custom_transpose(__beartype_object_105428396956112, __beartype_get_violation, __beartype_conf, __beartype_object_105428394941440, __beartype_check_meta, __beartype_func, *args, **kwargs)
BeartypeCallHintReturnViolation: Function __main__.my_custom_transpose() return "Array([[ 0, 1, 2, 3, 4, 5, 6],
[ 7, 8, 9, 10, 11, 12, 13],
[14, 15, ...)" violates type hint <class 'jaxtyping.Num[Array, 'n m']'>, as the size of dimension n is 10 which does not equal the existing value of 7.
The above exception was the direct cause of the following exception:
TypeCheckError Traceback (most recent call last)
Cell In[6], line 1
----> 1 my_custom_transpose(x)
[... skipping hidden 1 frame]
File ~/checkouts/readthedocs.org/user_builds/differt/envs/latest/lib/python3.12/site-packages/jaxtyping/_decorator.py:515, in jaxtyped.<locals>.wrapped_fn_impl(args, kwargs, bound, memos)
513 raise TypeCheckError(msg) from None
514 else:
--> 515 raise TypeCheckError(msg) from e
517 return out
TypeCheckError: Type-check error whilst checking the return value of __main__.my_custom_transpose.
Actual value: i32[10,7]
Expected type: Num[Array, 'n m'].
----------------------
Called with parameters: {'x': i32[10,7]}
Parameter annotations: (x: Num[Array, 'm n']) -> Any.
The current values for each jaxtyping axis annotation are as follows.
m=10
n=7
Here, the error message tells us that it inferred m=10 and n=7 from the input arguments,
but that is does not match the expected output shape, i.e., (n, m) = (7, 10) != (10, 7).
Thanks to the type checker, we rapidly caught the error, and we can fix the function:
@jaxtyped(typechecker=typechecker)
def my_custom_transpose_fixed(x: Num[Array, "m n"]) -> Num[Array, "n m"]:
return x.transpose() # Now this is all good
my_custom_transpose_fixed(x)
Array([[ 0, 7, 14, 21, 28, 35, 42, 49, 56, 63],
[ 1, 8, 15, 22, 29, 36, 43, 50, 57, 64],
[ 2, 9, 16, 23, 30, 37, 44, 51, 58, 65],
[ 3, 10, 17, 24, 31, 38, 45, 52, 59, 66],
[ 4, 11, 18, 25, 32, 39, 46, 53, 60, 67],
[ 5, 12, 19, 26, 33, 40, 47, 54, 61, 68],
[ 6, 13, 20, 27, 34, 41, 48, 55, 62, 69]], dtype=int32)