The *batch axes#

As you will probably notice, many functions we provided accept arrays with set of leading axes, referred to as *batch.

This notation indicates that you may provide arbitrary many dimensions for the batch axes, and the output will preserve those batch dimensions in the output.

For example, the array annotation Float[Array, "*batch n 3"] indicates that the array must have at least two dimensions, with the last one equal to 3. Additional dimensions will be considered as batch dimensions.

For more details on array annotations, see NumPy vs JAX arrays.

Why we provide batch axes#

By design, the *batch axes are optional and functions will work just fine if you do not provide any additional dimensions.

However, in Ray Tracing applications, many functions are called repeatedly on a number of samples, e.g., the image_method will be called on thousands, if not millions, of path candidates. For every path candidate, you may also want to repeat for every pair of transmitter and receiver locations.

Thus, allowing for arbitrary batch dimensions will help you write code in a way that is mostly transparent to the number of repetitions.

E.g., the following function computes the dot product between batch of arrays:

>>> import jax
>>> import jax.numpy as jnp
>>> from jaxtyping import Array, Num
>>>
>>> def dot(
...     x: Num[Array, "*batch n"], y: Num[Array, "*batch n"]
... ) -> Num[Array, " *batch"]:
...     return jnp.sum(x * y, axis=-1)
>>>
>>> *batch, n = 40, 10, 30, 3  # batch = (40, 10, 30), n = 3
>>>
>>> x = jnp.ones((*batch, n)) * 1.0
>>> y = jnp.ones((*batch, n)) * 2.0
>>> z = dot(x, y)
>>>
>>> z.shape
(40, 10, 30)
>>> jnp.allclose(z, 1.0 * 2.0 * n)
Array(True, dtype=bool)
>>> # Of course, you can always use such functions without any *batch axes:
>>> x = jnp.array([1., 2., 3.])
>>> dot(x, x)  # 1.0 * 1.0 + 2.0 * 2.0 + 3.0 * 3.0
Array(14., dtype=float32)

That is, the resulting output will have a shape of *batch, where is entry is the result of the dot produt between n pairs of values from the corresponding entries in x and y input arguments.

When batch axes are not available#

If a function does not offer batch axes, there are two possibilities:

  1. you can use vectorization functions, like jax.vmap, to call a repeat a given function over another array axis;

  2. or you think code would really benefit from having batch axes. In that case, we recommend opening an issue on GitHub.

For the latter, you can also directly suggest a patch if you know how to implement the batch axes.

When too large batches cause out-of-memory errors#

The biggest cost of using many batch dimensions is mainly in the memory footprint.

E.g., in Ray Tracing applications, when dealing with larges scene (i.e., a large number of objects) or high-order paths, the size of some dimensions can rapidly become so large than they cannot fit inside your memory. This is also why we propose chunked iterators (e.g., AllPathsFromCompleteGraphChunksIter) as an alternative.

Likewise, when a dimension is getting too big, it is recommend to just iterator of batches, rather than computing everything all at once.

However, iterations are slow and this can become also a bottleneck in your pipeline. In some cases, if you are only interested in a reduced result, JAX may be able to optimize the computation such that some batch dimensions are never allocated. See jax#1929 for reference.

In general, finding the optimum is a trial-and-error process, where the solution will highly depend on your problem parameters.