Performances tips#
By design, most functions in this library work on arrays, or even batch of arrays,
see the *batch
axes section.
However, the runtime type checking of those functions, coupled with the use of Python logic, introduces some overhead that can degrade performances, especially with nested function calls.
To this end, we encourage using JAX’s just-in-time compilation (JIT). Please read the linked content if you are not familiar with this concept.
Almost all functions we provide are wrapped with jax.jit
, in order
to compile them to efficient code. The type checkers we use are aware of that
and will only check functions are compilation time.
Once compiled, no more type checking will be performed, reducing the overhead to the bare minimal.
JIT-ing an existing function#
Here, we will look at the
rays_intersect_triangles
function and how much it can benefit from JIT compilation.
from inspect import unwrap
import jax
from differt.rt.utils import rays_intersect_triangles
# Because we already applied @jit, we need to remove it first
rays_intersect_triangles = unwrap(rays_intersect_triangles)
key = jax.random.PRNGKey(1234)
key1, key2, key3 = jax.random.split(key, 3)
batch = (10, 100)
ray_origins = jax.random.uniform(key1, (*batch, 3))
ray_directions = jax.random.uniform(key2, (*batch, 3))
triangle_vertices = jax.random.uniform(key2, (*batch, 3, 3))
Let’s look at the execution time without compilation.
The [0].block_until_ready()
is needed because:
the function returns a tuple and we need to select one (e.g., the first with
[0]
) of the output arrays to;call
.block_until_ready()
on it, so JAX knows it must actually perform the computation.
If the call to .block_until_ready()
is omitted, the execution time may not be relevant.
%timeit rays_intersect_triangles(ray_origins, ray_directions, triangle_vertices)[0].block_until_ready()
1.73 ms ± 83.5 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
Then, let’s compare it with its JIT compiled version.
Note that we call the function before timing it, so we do not take the compilation overhead into account.
rays_intersect_triangles_jit = jax.jit(rays_intersect_triangles)
# Warmup to compile code
rays_intersect_triangles_jit(ray_origins, ray_directions, triangle_vertices)[0]
%timeit rays_intersect_triangles_jit(ray_origins, ray_directions, triangle_vertices)[0].block_until_ready()
40.8 µs ± 136 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
See! Here, we reduced the execution time by more that one order of magnitude, which is quite
nice given the fact that we only had to wrap it with jax.jit
, nothing more.
In general, the amount of performances gained will hightly depend on the function that is compiled.
We advice to first try without any JIT compilation, and gradually add some @jax.jit
decorators to the functions you feel could benefit from it.
JIT-ing a new function#
Like for already existing functions, JIT compilation can be applied with a simple addition
of @jax.jit
before the function definition, like so:
from jaxtyping import Array, Float, jaxtyped
@jax.jit
def matmul_t_sum(x: Float[Array, "m k"], y: Float[Array, "k n"]) -> Float[Array, " "]:
return (x @ y).sum()
For advanced usage, see jax.jit
’s documentation.
Warning
If you have multiple function decorators, pay attention that the order
in which they are placed plays an important role. Logically, you want to place
your @jax.jit
decorator at the very top, so it applies to the whole function.
One common case is combining @jax.jit
with @jaxtyped
from jaxtyping
for
runtime type checking. If you apply the @jax.jit
decorator before the type checker,
you will pay the cost of type checker on every call of your function, see the example below.
(Bad) Type checker placed after JIT decorator#
from beartype import beartype as typechecker
key = jax.random.PRNGKey(1234)
key1, key2, key3 = jax.random.split(key, 3)
batch = (100, 10, 2)
x = jax.random.uniform(key1, batch)
y = jax.random.uniform(key2, batch)
z = jax.random.uniform(key3, batch)
# Don't do this!
@jaxtyped(typechecker=typechecker)
@jax.jit
def jit_then_typecheck(
a: Float[Array, " *batch"], b: Float[Array, " *batch"], c: Float[Array, " *batch"]
) -> Float[Array, " "]:
return (a * b + c).sum()
# Warmup to compile code
jit_then_typecheck(x, y, z)
%timeit jit_then_typecheck(x, y, z).block_until_ready()
75.5 µs ± 344 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
(Good) Type checker placed before JIT decorator#
# Do this!
@jax.jit
@jaxtyped(typechecker=typechecker)
def typecheck_then_jit(
a: Float[Array, " *batch"], b: Float[Array, " *batch"], c: Float[Array, " *batch"]
) -> Float[Array, " "]:
return (a * b + c).sum()
# Warmup to compile code
typecheck_then_jit(x, y, z)
%timeit typecheck_then_jit(x, y, z).block_until_ready()
9.76 µs ± 24 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
As you can see, a small permutation of the two decorators changed the performances! Usually, the cost of type checking is fixed and small, but it can add-up quite rapidly with many function calls.
Why not JIT all functions?#
JIT compilation comes at a cost of compiling the function during its first execution, which can become slow during debugging stages. Also, if some arguments are static, if it will need to re-compile the function every time the static arguments change.
Moreover, JIT compilation removes print statements, does not allow for inpure functions (e.g., using globals), and might not always produce a faster code.
For all those reasons, this is the responsibility of the end-user to determine when to use JIT compilation in their code.