Performances tips#
By design, most functions in this library work on arrays, or even batch of arrays,
see the *batch axes section.
However, the runtime type checking of those functions, coupled with the use of Python logic, introduces some overhead that can degrade performances, especially with nested function calls.
To this end, we encourage using JAX’s just-in-time compilation (JIT). Please read the linked content if you are not familiar with this concept.
Almost all functions we provide are wrapped with jax.jit, in order
to compile them to efficient code. Type checking with jaxtyping is aware of that
and will only check functions are compilation time.
Once compiled, no more type checking will be performed, reducing the overhead to the bare minimal.
JIT-ing an existing function#
Here, we will look at the
rays_intersect_triangles
function and how much it can benefit from JIT compilation.
from inspect import unwrap
import jax
from differt.rt import rays_intersect_triangles
# Because we already applied @jit, we need to remove it first
rays_intersect_triangles = unwrap(rays_intersect_triangles)
key = jax.random.key(1234)
key1, key2, key3 = jax.random.split(key, 3)
batch = (10, 100)
ray_origins = jax.random.uniform(key1, (*batch, 3))
ray_directions = jax.random.uniform(key2, (*batch, 3))
triangle_vertices = jax.random.uniform(key2, (*batch, 3, 3))
Let’s look at the execution time without compilation.
The [0].block_until_ready() is needed because:
the function returns a tuple and we need to select one (e.g., the first with
[0]) of the output arrays to;call
.block_until_ready()on it, so JAX knows it must actually perform the computation.
If the call to .block_until_ready() is omitted, the execution time may not be relevant.
%timeit rays_intersect_triangles(ray_origins, ray_directions, triangle_vertices)[0].block_until_ready()
802 μs ± 131 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
Then, let’s compare it with its JIT compiled version.
Note that we call the function before timing it, so we do not take the compilation overhead into account.
rays_intersect_triangles_jit = jax.jit(rays_intersect_triangles)
# Warmup to compile code
rays_intersect_triangles_jit(ray_origins, ray_directions, triangle_vertices)[0]
%timeit rays_intersect_triangles_jit(ray_origins, ray_directions, triangle_vertices)[0].block_until_ready()
20.6 μs ± 89.3 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
See! Here, we reduced the execution time by more that one order of magnitude, which is quite
nice given the fact that we only had to wrap it with jax.jit, nothing more.
In general, the amount of performances gained will hightly depend on the function that is compiled.
We advice to first try without any JIT compilation, and gradually add some @jax.jit
decorators to the functions you feel could benefit from it.
JIT-ing a new function#
Like for already existing functions, JIT compilation can be applied with a simple addition
of @jax.jit before the function definition, like so:
from jaxtyping import Array, Float, jaxtyped
@jax.jit
def matmul_t_sum(
x: Float[Array, "m k"], y: Float[Array, "k n"]
) -> Float[Array, " "]:
return (x @ y).sum()
For advanced usage, see jax.jit’s documentation.
Warning
If you have multiple function decorators, pay attention that the order
in which they are placed plays an important role. Logically, you want to place
your @jax.jit decorator at the very top, so it applies to the whole function.
One common case is combining @jax.jit with @jaxtyped from jaxtyping for
runtime type checking. If you apply the @jax.jit decorator before the type checker,
you will pay the cost of type checker on every call of your function, see the example below.
(Bad) Type checker placed after JIT decorator#
from beartype import beartype as typechecker
key = jax.random.key(1234)
key1, key2, key3 = jax.random.split(key, 3)
batch = (100, 10, 2)
x = jax.random.uniform(key1, batch)
y = jax.random.uniform(key2, batch)
z = jax.random.uniform(key3, batch)
# Don't do this!
# N.B.: as of jaxtyping v0.3.3, an error is now raised when the @jax.jit decorator is
# placed first. Prior to this version, placing the @jax.jit decorator under the
# @jaxtyped would lead to a performance decrease (see the next cell)
@jaxtyped(typechecker=typechecker)
@jax.jit
def jit_then_typecheck(
a: Float[Array, " *batch"],
b: Float[Array, " *batch"],
c: Float[Array, " *batch"],
) -> Float[Array, " "]:
return (a * b + c).sum()
# Warmup to compile code
jit_then_typecheck(x, y, z)
%timeit jit_then_typecheck(x, y, z).block_until_ready()
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[6], line 17
9 z = jax.random.uniform(key3, batch)
11 # Don't do this!
12 # N.B.: as of jaxtyping v0.3.3, an error is now raised when the @jax.jit decorator is
13 # placed first. Prior to this version, placing the @jax.jit decorator under the
14 # @jaxtyped would lead to a performance decrease (see the next cell)
---> 17 @jaxtyped(typechecker=typechecker)
18 @jax.jit
19 def jit_then_typecheck(
20 a: Float[Array, " *batch"],
21 b: Float[Array, " *batch"],
22 c: Float[Array, " *batch"],
23 ) -> Float[Array, " "]:
24 return (a * b + c).sum()
27 # Warmup to compile code
File ~/checkouts/readthedocs.org/user_builds/differt/envs/latest/lib/python3.12/site-packages/jaxtyping/_decorator.py:399, in jaxtyped(fn, typechecker)
397 new_params = []
398 for p_value in full_signature.parameters.values():
--> 399 p_annotation = _destring_annotation(p_value.annotation, fn.__globals__)
400 p_value = p_value.replace(annotation=p_annotation)
401 new_params.append(p_value)
AttributeError: 'jaxlib._jax.PjitFunction' object has no attribute '__globals__'
(Good) Type checker placed before JIT decorator#
# Do this!
@jax.jit
@jaxtyped(typechecker=typechecker)
def typecheck_then_jit(
a: Float[Array, " *batch"],
b: Float[Array, " *batch"],
c: Float[Array, " *batch"],
) -> Float[Array, " "]:
return (a * b + c).sum()
# Warmup to compile code
typecheck_then_jit(x, y, z)
%timeit typecheck_then_jit(x, y, z).block_until_ready()
16.9 μs ± 119 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
As you can see, a small permutation of the two decorators changed the performances! Usually, the cost of type checking is fixed and small, but it can add-up quite rapidly with many function calls.
Why not JIT all functions?#
JIT compilation comes at a cost of compiling the function during its first execution, which can become slow during debugging stages. Also, if some arguments are static, if it will need to re-compile the function every time the static arguments change.
Moreover, JIT compilation removes print statements, does not allow for inpure functions (e.g., using globals), and might not always produce a faster code.
For all those reasons, this is the responsibility of the end-user to determine when to use JIT compilation in their code.
Further reading: profiling your code#
The best way to understand how your code performs is to profile it. JAX provides built-in utilities for this purpose, and the necessary dependencies can be installed with the command uv sync --group=profiling (see install from source) or with Pip. You can then profile any code using the following method:
import jax.profiler
with jax.profiler.trace("/tmp/jax", create_perfetto_link=True):
f(...).block_until_ready() # Assuming f returns a JAX array
After a few moments, a link will be prompted and the program will wait for you to click on it to load the trace.
As profiling JAX code is a relatively complex subject, we provide links to some important documentation on the topic: