Performances tips#

By design, most functions in this library work on arrays, or even batch of arrays, see the *batch axes section.

However, the runtime type checking of those functions, coupled with the use of Python logic, introduces some overhead that can degrade performances, especially with nested function calls.

To this end, we encourage using JAX’s just-in-time compilation (JIT). Please read the linked content if you are not familiar with this concept.

Almost all functions we provide are wrapped with jax.jit, in order to compile them to efficient code. The type checkers we use are aware of that and will only check functions are compilation time.

Once compiled, no more type checking will be performed, reducing the overhead to the bare minimal.

JIT-ing an existing function#

Here, we will look at the rays_intersect_triangles function and how much it can benefit from JIT compilation.

from inspect import unwrap

import jax

from differt.rt.utils import rays_intersect_triangles

# Because we already applied @jit, we need to remove it first
rays_intersect_triangles = unwrap(rays_intersect_triangles)
key = jax.random.PRNGKey(1234)
key1, key2, key3 = jax.random.split(key, 3)

batch = (10, 100)

ray_origins = jax.random.uniform(key1, (*batch, 3))
ray_directions = jax.random.uniform(key2, (*batch, 3))
triangle_vertices = jax.random.uniform(key2, (*batch, 3, 3))

Let’s look at the execution time without compilation. The [0].block_until_ready() is needed because:

  1. the function returns a tuple and we need to select one (e.g., the first with [0]) of the output arrays to;

  2. call .block_until_ready() on it, so JAX knows it must actually perform the computation.

If the call to .block_until_ready() is omitted, the execution time may not be relevant.

%timeit rays_intersect_triangles(ray_origins, ray_directions, triangle_vertices)[0].block_until_ready()
1.73 ms ± 83.5 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

Then, let’s compare it with its JIT compiled version.

Note that we call the function before timing it, so we do not take the compilation overhead into account.

rays_intersect_triangles_jit = jax.jit(rays_intersect_triangles)

# Warmup to compile code
rays_intersect_triangles_jit(ray_origins, ray_directions, triangle_vertices)[0]

%timeit rays_intersect_triangles_jit(ray_origins, ray_directions, triangle_vertices)[0].block_until_ready()
40.8 µs ± 136 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

See! Here, we reduced the execution time by more that one order of magnitude, which is quite nice given the fact that we only had to wrap it with jax.jit, nothing more.

In general, the amount of performances gained will hightly depend on the function that is compiled. We advice to first try without any JIT compilation, and gradually add some @jax.jit decorators to the functions you feel could benefit from it.

JIT-ing a new function#

Like for already existing functions, JIT compilation can be applied with a simple addition of @jax.jit before the function definition, like so:

from jaxtyping import Array, Float, jaxtyped


@jax.jit
def matmul_t_sum(x: Float[Array, "m k"], y: Float[Array, "k n"]) -> Float[Array, " "]:
    return (x @ y).sum()

For advanced usage, see jax.jit’s documentation.

Warning

If you have multiple function decorators, pay attention that the order in which they are placed plays an important role. Logically, you want to place your @jax.jit decorator at the very top, so it applies to the whole function.

One common case is combining @jax.jit with @jaxtyped from jaxtyping for runtime type checking. If you apply the @jax.jit decorator before the type checker, you will pay the cost of type checker on every call of your function, see the example below.

(Bad) Type checker placed after JIT decorator#

from beartype import beartype as typechecker

key = jax.random.PRNGKey(1234)
key1, key2, key3 = jax.random.split(key, 3)

batch = (100, 10, 2)
x = jax.random.uniform(key1, batch)
y = jax.random.uniform(key2, batch)
z = jax.random.uniform(key3, batch)

# Don't do this!


@jaxtyped(typechecker=typechecker)
@jax.jit
def jit_then_typecheck(
    a: Float[Array, " *batch"], b: Float[Array, " *batch"], c: Float[Array, " *batch"]
) -> Float[Array, " "]:
    return (a * b + c).sum()


# Warmup to compile code
jit_then_typecheck(x, y, z)

%timeit jit_then_typecheck(x, y, z).block_until_ready()
75.5 µs ± 344 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

(Good) Type checker placed before JIT decorator#

# Do this!


@jax.jit
@jaxtyped(typechecker=typechecker)
def typecheck_then_jit(
    a: Float[Array, " *batch"], b: Float[Array, " *batch"], c: Float[Array, " *batch"]
) -> Float[Array, " "]:
    return (a * b + c).sum()


# Warmup to compile code
typecheck_then_jit(x, y, z)

%timeit typecheck_then_jit(x, y, z).block_until_ready()
9.76 µs ± 24 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

As you can see, a small permutation of the two decorators changed the performances! Usually, the cost of type checking is fixed and small, but it can add-up quite rapidly with many function calls.

Why not JIT all functions?#

JIT compilation comes at a cost of compiling the function during its first execution, which can become slow during debugging stages. Also, if some arguments are static, if it will need to re-compile the function every time the static arguments change.

Moreover, JIT compilation removes print statements, does not allow for inpure functions (e.g., using globals), and might not always produce a faster code.

For all those reasons, this is the responsibility of the end-user to determine when to use JIT compilation in their code.