The *batch
axes#
As you will probably notice, many functions we provided accept
arrays with set of leading axes, referred to as *batch
.
This notation indicates that you may provide arbitrary many dimensions for the batch axes, and the output will preserve those batch dimensions in the output.
For example, the array annotation Float[Array, "*batch n 3"]
indicates that the array must have at least two dimensions,
with the last one equal to 3. Additional dimensions will be
considered as batch dimensions.
For more details on array annotations, see NumPy vs JAX arrays.
Why we provide batch axes#
By design, the *batch
axes are optional and functions
will work just fine if you do not provide any additional dimensions.
However, in Ray Tracing applications, many functions are called
repeatedly on a number of samples, e.g.,
the image_method
will be
called on thousands, if not millions, of path candidates. For
every path candidate, you may also want to repeat for every pair of
transmitter and receiver locations.
Thus, allowing for arbitrary batch dimensions will help you write code in a way that is mostly transparent to the number of repetitions.
E.g., the following function computes the dot product between batch of arrays:
>>> import jax
>>> import jax.numpy as jnp
>>> from jaxtyping import Array, Num
>>>
>>> def dot(
... x: Num[Array, "*batch n"], y: Num[Array, "*batch n"]
... ) -> Num[Array, " *batch"]:
... return jnp.sum(x * y, axis=-1)
>>>
>>> *batch, n = 40, 10, 30, 3 # batch = (40, 10, 30), n = 3
>>>
>>> x = jnp.ones((*batch, n)) * 1.0
>>> y = jnp.ones((*batch, n)) * 2.0
>>> z = dot(x, y)
>>>
>>> z.shape
(40, 10, 30)
>>> jnp.allclose(z, 1.0 * 2.0 * n)
Array(True, dtype=bool)
>>> # Of course, you can always use such functions without any *batch axes:
>>> x = jnp.array([1., 2., 3.])
>>> dot(x, x) # 1.0 * 1.0 + 2.0 * 2.0 + 3.0 * 3.0
Array(14., dtype=float32)
That is, the resulting output will have a shape of *batch
,
where is entry is the result of the dot produt between n
pairs of values
from the corresponding entries in x
and y
input arguments.
When batch axes are not available#
If a function does not offer batch axes, there are two possibilities:
you can use vectorization functions, like
jax.vmap
, to call a repeat a given function over another array axis;or you think code would really benefit from having batch axes. In that case, we recommend opening an issue on GitHub.
For the latter, you can also directly suggest a patch if you know how to implement the batch axes.
When too large batches cause out-of-memory errors#
The biggest cost of using many batch dimensions is mainly in the memory footprint.
E.g., in Ray Tracing applications, when dealing with larges scene
(i.e., a large number of objects) or high-order paths,
the size of some dimensions can rapidly become so large than they
cannot fit inside your memory.
This is also why we propose chunked iterators
(e.g.,
AllPathsFromCompleteGraphChunksIter
)
as an alternative.
Likewise, when a dimension is getting too big, it is recommend to just iterator of batches, rather than computing everything all at once.
However, iterations are slow and this can become also a bottleneck in your pipeline. In some cases, if you are only interested in a reduced result, JAX may be able to optimize the computation such that some batch dimensions are never allocated. See jax#1929 for reference.
In general, finding the optimum is a trial-and-error process, where the solution will highly depend on your problem parameters.