Advanced Path Tracing#

Differt provides both high-level and low-level interface to Path Tracing.

This tutorial provides a quick tour of what you can do with the lower-level API, and the logic used to perform Ray Tracing (RT).

Example on a simple scene#

Before diving into a complex scene, this is worth using a very simple scene first.

Note

All the logic presented in this section is contained in the TriangleScene.compute_paths method.

It also contains more post-processing steps to avoid degenerate solutions, and optimized routines, but we omitted them here.

Necessary imports#

Because we are going for the lower-level way, we will need quite a few imports.

import differt.plotting as dplt
import jax.numpy as jnp
import numpy as np
from differt.geometry import (
    TriangleMesh,
    assemble_paths,
    triangles_contain_vertices_assuming_inside_same_plane,
)
from differt.rt import (
    consecutive_vertices_are_on_same_side_of_mirrors,
    generate_all_path_candidates,
    image_method,
    rays_intersect_triangles,
)
from jaxtyping import Array, Bool, Float

Loading a mesh#

For each type of supported meshes, we provide some utilities to load a mesh from a file.

mesh_file = "two_buildings.obj"  # Very simple scene with two buildings
mesh = TriangleMesh.load_obj(mesh_file)

Plotting your setup#

Here, we will use Plotly as the plotting backend, because it renders very nicely, especially on the web. On larger scenes, you will likely need something more performant, like Vispy, see Choosing your plotting backend.

dplt.set_backend("plotly")  # Let's use the Plotly backend

fig = mesh.plot(opacity=0.5)
fig

Ray Tracing without start and end points is not very interesting. Let’s add one transmitter (TX) and one receiver (RX) in the scene, represented by their 3D coordinates.

tx = jnp.array([0.0, 4.9352, 22.0])
rx = jnp.array([0.0, 10.034, 1.50])

dplt.draw_markers(
    np.array([tx, rx]), labels=["tx", "rx"], figure=fig, name="nodes"
)

How we trace rays#

Ray Tracing can be implemented in many ways, depending on the desired performances, the level of accuracy needed, or the representation of the geometry.

Here, we will implement exhaustive (also referred to as deterministic or exact) RT. That is, we want to generate all possible paths from TX to RX, that undergo up to a maximum number of interactions with the environment. Interactions can be reflections, diffractions, etc.

One way to generate all possible paths is to represent the problem as a graph. Then, the goal is to find all the paths from the node corresponding to TX, to the node corresponding to RX, while possibly visiting intermediate nodes in the graph, where each corresponds to a specific primitive—or object—in the scene (here, a triangle).

A graph algorithm will therefore generate a list of path candidates. We use the word candidate to emphasize that this is not a real path (i.e., not 3D coordinates), but only an ordered list of nodes to visit, for a given path.

Then, this is the role of the path tracing method (e.g., image_method or fermat_path_on_planar_mirrors) to determine the exact coordinates of that path.

Let’s select a subset of our primitives to understand what we have just talked about.

select = [
    8,  # Red
    9,  # Red
    22,  # Green
    23,  # Green
]  # In practice, you will never hard-code the primitive indices yourself

vertices = mesh.vertices
triangles = mesh.triangles[select, :]

dplt.draw_mesh(vertices, triangles[:2, :], figure=fig, color="red")
dplt.draw_mesh(vertices, triangles[2:, :], figure=fig, color="green")

Looking at the above, we can clearly see that a line-of-sight (LOS) path between TX and RX exists.

With a bit of thinking, we could also imagine that a path with one or more reflections might join TX and RX.

For example, TX -> Red surface -> RX might probably produce a valid path. The same logic can be applied to TX -> Red surface -> Green surface -> RX.

# A list of color to easily differentiate paths
color = ["black", "green", "orange", "yellow", "blue"]

select = jnp.array(
    select[::2],
    dtype=int,
)  # We actually only need one triangle per plane, so [8, 22]

# Iterate through path candidates
#
#                         ┌> order 0
#                         |           ┌> order 1
#                         |           |           ┌> order 2
for path_candidate in [select[:0], select[:1], select[:2]]:
    # 1 - Prepare input arrays
    mirror_vertices = mesh.vertices[mesh.triangles[path_candidate, 0], :]
    mirror_normals = mesh.normals[path_candidate, :]

    # 2 - Trace paths

    path = image_method(tx, rx, mirror_vertices, mirror_normals)

    # 3 - ??

    # 4 - Obtain final valid paths and plot

    # The full path is [tx, paths, rx]
    full_path = jnp.concatenate(
        (
            tx[None, :],
            path,
            rx[None, :],
        ),
    )

    # The we plot it
    dplt.draw_paths(
        full_path,
        figure=fig,
        line={"color": color[len(path_candidate)], "width": 3},
        name=f"Order {len(path_candidate)}",
    )

fig

Nice! Thanks to the image_method, we successfully generated the paths we just mentioned.

Scaling on more paths and more surfaces#

Manually identifying the surfaces of interest and generating all possible path candidates can rapidly become tedious as the number of surfaces or the path order increase.

For this purpose, we created the generate_all_path_candidates function. Written in Rust for performance purposes, this function can generate millions of path candidates per second!

This is all nice, but there is one important side-effect of this: if you generate all possible path candidates, how to remove invalid paths that may, e.g., cross a building?

This is where our third step comes into play: we need to validate our path against a series of checks. We can usually identify three types of checks:

  1. Are path coordinates within the boundary of their respective objects? Many times, the objects are assumed to be infinitely long. Then, a check is performed to verify if the solution was found within the object’s boundaries;

  2. Are all interactions valid? E.g., do all reflections occur with an angle of reflection equal to the angle of incidence? Most path tracing methods have some fallible cases where it can return degenerate solutions;

  3. Does any object in the scene obstruct the path? Usually, the path is first computed without taking the surrounding objects into account, which produce paths that cross buildings.

A possible implementation of the above rules, applied to the image_method, is provided below. A lot of the code is just broadcasting arrays into the right shapes, to benefit from the vectorized computations on arrays, i.e., instead of using slow Python for-loops.

fig.data = fig.data[:2]  # Keep only first 2 traces: geometry and TX/RX

# [num_triangles 3 3]
all_triangle_vertices = mesh.triangle_vertices

num_triangles = mesh.num_triangles

for order in range(5):
    # 1 - Prepare input arrays

    # [num_path_candidates order]
    path_candidates = generate_all_path_candidates(num_triangles, order)
    num_path_candidates = path_candidates.shape[0]

    # [num_path_candidates order 3]
    triangles = jnp.take(mesh.triangles, path_candidates, axis=0)

    # [num_path_candidates order 3 3]
    triangle_vertices = jnp.take(mesh.vertices, triangles, axis=0)

    # [num_path_candidates order 3]
    mirror_vertices = triangle_vertices[
        ...,
        0,
        :,
    ]  # Only one vertex per triangle is needed
    # [num_path_candidates order 3]
    mirror_normals = jnp.take(mesh.normals, path_candidates, axis=0)

    # 2 - Trace paths

    # [num_path_candidates order 3]
    paths = image_method(tx, rx, mirror_vertices, mirror_normals)

    # 3 - Remove invalid paths

    # 3.1 - Remove paths with vertices outside triangles
    # [num_path_candidates order]
    mask = triangles_contain_vertices_assuming_inside_same_plane(
        triangle_vertices,
        paths,
    )
    # [num_path_candidates]
    mask = jnp.all(mask, axis=-1)

    # [num_paths_inter order+2 3]
    full_paths = assemble_paths(
        tx,
        paths[mask, ...],
        rx,
    )
    # 3.2 - Remove paths with vertices not on the same side of mirrors
    # [num_paths_inter order]
    mask = consecutive_vertices_are_on_same_side_of_mirrors(
        full_paths,
        mirror_vertices[mask, ...],
        mirror_normals[mask, ...],
    )

    # [num_paths_inter]
    mask = jnp.all(mask, axis=-1)  # We will actually remove them later

    # 3.3 - Remove paths that are obstructed by other objects
    # [num_paths_inter order+1 3]
    ray_origins = full_paths[..., :-1, :]
    # [num_paths_inter order+1 3]
    ray_directions = jnp.diff(full_paths, axis=-2)

    # [num_paths_inter order+1 num_triangles], [num_paths_inter order+1 num_triangles]
    t, hit = rays_intersect_triangles(
        ray_origins[..., None, :],
        ray_directions[..., None, :],
        all_triangle_vertices[None, None, ...],
    )
    # In theory, we could do t < 1.0 (because t == 1.0 means we are perfectly on a surface,
    # which is probably desirable, e.g., from a reflection) but in practice numerical
    # errors accumulate and will make this check impossible.
    # [num_paths_inter order+1 num_triangles]
    intersect = (t < 0.999) & hit
    #  [num_paths_inter]
    intersect = jnp.any(intersect, axis=(-1, -2))
    #  [num_paths_inter]
    mask = mask & ~intersect

    # 4 - Obtain final valid paths and plot

    #  [num_paths_final]
    full_paths = full_paths[mask, ...]

    dplt.draw_paths(
        full_paths,
        figure=fig,
        line={"color": color[order], "width": 3},
        name=f"Order {order}",
    )

fig

Another path tracing method that is fully compatible with the above cell is fermat_path_on_planar_mirrors. You can safely use the latter instead of the former, and that should produce the same result. Note that the Fermat path tracing is much slower than the Image method, but can be applied to other types of interactions than just pure specular reflection. This is left as an exercise to the reader.

Example on more complex scenes#

Most of the code we presented so far scales pretty well on larger scenes. However, there are is notable exception: generate_all_path_candidates.

With a bit of math[1], we can determine that a call to generate_all_path_candidates(num_triangles, order) generates an array of size \(\texttt{num_triangles}(\texttt{num_triangles}-1)^{\texttt{order}-1} \times \texttt{order}\).

On scenes with many triangles, this rapidly becomes too big to fit in any computer memory. To circumvent this issue, we also provide an iterator variant, generate_all_path_candidates_chunks_iter, that produces arrays of a smaller size, defined by the chunk_size argument.

While this offers a solution to the memory allocation issue, this does not reduce the number of path candidates. To reduce this number, you somehow need to prune a subset of the path candidates before you actually generate them.

Recalling the graph analogy we mentioned above, we can implement this behavior by disconnecting some primitives (i.e., triangles) in the graph. There is no unique solution to this challenge, but we provide a small utility to estimate the visibility matrix between objects in a given scene: triangles_visible_from_vertices.

Then, from this visibility matrix, which is actually just an adjacency matrix of the nodes in the graph, we can instantiate a DiGraph from the differt_core.rt module.

Numbers getting crazy#

To illustrate what we said above, we will load a much larger scene that contains quite a few objects, i.e., triangles.

A transmitter and a receiver are placed in the scene as example positions.

from differt.scene import TriangleScene

mesh_file = "bruxelles.obj"
mesh = TriangleMesh.load_obj(mesh_file)

tx = jnp.array([-40.0, 75, 30.0])
rx = jnp.array([+20.0, 108.034, 1.50])

scene = TriangleScene(transmitters=tx, receivers=rx, mesh=mesh)
scene.plot()
# This is the number of triangles
mesh.num_primitives
14206

This number isn’t actually that big, and can easily reach above a million on large cities. However, it is large enough to present serious challenges when it comes to performing exhaustive RT.

Using the core library, we can compute the exact number of path candidates one would have to try for a given number of interactions.

from differt_core.rt import CompleteGraph

graph = CompleteGraph(mesh.num_primitives)

from_ = graph.num_nodes  # Index of TX in the graph
to = from_ + 1  # Index of RX in the graph
order = 2  # Number of interactions
depth = order + 2  # + 2 because we add TX and RX notes

num_path_candidates = len(graph.all_paths(from_, to, depth))
print(f"Number of path candidates: {num_path_candidates:_}")
Number of path candidates: 201_796_230

That means that there are over 200 million second order reflection paths to test… We need to reduce that number!

Assuming quadrilaterals#

In many cases, a scene is simply a collection of quadrilaterals, that are each split into two triangles. This is not always true, and probably not the case for our scene, but we will assume it is.

Using set_assume_quads, the mesh will now tell all other function that it should use, when available, optimized routines for quadrilateral facets.

mesh = mesh.set_assume_quads(True)
# This is now the number of quadrilaterals, exactly half the number of triangles
mesh.num_primitives
7103

Again, we can compute the number of path candidates, and see that it is reduced almost by a factor 4.

In general, the reduction factor is nearly \(2^\texttt{order}\).

graph = CompleteGraph(mesh.num_primitives)

from_ = graph.num_nodes
to = from_ + 1
order = 2

num_path_candidates = len(graph.all_paths(from_, to, depth))
# Roughly a quarter of the previous number
print(f"Number of path candidates: {num_path_candidates:_}")
Number of path candidates: 50_445_506

Determining TX’s visibility#

Another way to reduce the number of path candidates is to indicate to the graph that TX cannot reach all objects in the scene, but only a subset of the objects.

Such information can be obtained by estimating the visibility vector of some TX, and use it when creating the path candidates iterator.

If one knows the location of the receiving antenna, a similar logic can be used to compute the to_adjacency vector, which is also a visibility vector, but from RX.

On the other hand, if the mesh is fixed but the TX / RX are not, it is also possible to compute the visibility vector of each triangle in the scene, thereby constructing the visibility matrix of the scene, and use it to construct the graph with DiGraph.from_adjacency_matrix. As computing such matrix can be extremely expensive, it is recommended to perform that as a pre-processing step and save the resulting matrix in a file.

The code below shows how to estimate[2] the objects (i.e., triangles) seen by TX. For this example, visible triangles are colored in red, and hidden ones in black.

%%time

from differt.rt import triangles_visible_from_vertices

tx = jnp.array([-40.0, 75, 30.0])

default_color = jnp.array([[0.2, 0.2, 0.2]])  # Hidden, black
visible_color = jnp.array([[1.0, 0.2, 0.2]])  # Visible, red
visible_triangles = triangles_visible_from_vertices(
    tx,
    mesh.triangle_vertices,
)

mesh = mesh.set_face_colors(default_color)
mesh = mesh.set_face_colors(
    mesh.face_colors.at[visible_triangles].set(visible_color)
)

with dplt.reuse() as fig:
    dplt.draw_markers(np.array([tx]), ["tx"])
    mesh.plot()

fig
CPU times: user 2min 49s, sys: 232 ms, total: 2min 49s
Wall time: 1min 25s

A visibility vector is simply an array of boolean, each entry indicating if a corresponding object (here, a triangle) can be seen from TX.

The number of visible triangles is then the sum of all true entries in the array.

ratio = visible_triangles.sum() / mesh.num_triangles
print(f"Percentage of visible triangles: {100 * ratio:.2f}%")
Percentage of visible triangles: 17.72%

It is also possible to get the number of visible quadrilaterals by counting visible triangles by pairs. If any of the two triangles forming a quadrilateral is visible, then this quadrilateral is considered visible.

visible_quads = visible_triangles.reshape(mesh.num_quads, 2).any(axis=-1)
ratio = visible_quads.sum() / mesh.num_quads
print(f"Percentage of visible quadrilaterals: {100 * ratio:.2f}%")
Percentage of visible quadrilaterals: 26.79%

We can then use this result to inform the graph about the limited number of faces visible from TX.

As expected, the number of path candidates get reduced to about 43% of the previous value.

However, 43% visibility is probably too high to switch from a CompleteGraph to a DiGraph, as iterating through the latter is quite slower (because the former is optimized for complete graphs).

%%time

from differt_core.rt import DiGraph

graph = DiGraph.from_complete_graph(CompleteGraph(mesh.num_quads))
from_, to = graph.insert_from_and_to_nodes(
    from_adjacency=np.asarray(visible_quads)
)

# DiGraph iterators are not sized, so we consume them to determine their size
num_path_candidates = graph.all_paths(from_, to, depth).count()
# Roughly 43% of the previous number
print(f"Number of path candidates: {num_path_candidates:_}")
Number of path candidates: 13_515_106
CPU times: user 41.8 s, sys: 495 ms, total: 42.3 s
Wall time: 42 s

What about Ray Launching#

Eventually, all the above solutions reach a glass ceiling at one point or another, where the number of path candidates takes over any possible optimization.

In those cases, Ray Launching (RL) can be used as an alternative to exhaustive RT, as the number of path candidates is usually fixed, a bit like when estimating the visibility from TX. This is fact what tools like Sionna use for coverage map.

Currently, DiffeRT provides a basic shooting and bouncing reflection (SBR) method if you specify method = 'sbr' when calling TriangleScene.compute_paths. Below, we provide a simplified implementation of this method to find reflection paths between TX and RX in (part of) the city of Bruxelles.

Contrarily to the exhaustive path tracing method, SBR launches a fixed number of rays from TX, and allow them a number of reflections with the environment. Before each reflection, we check if any path passes in the vicinity of RX. If it does, then this path is considered valid and will be later corrected[3] to include RX.

%%time

import jax
from differt.geometry import fibonacci_lattice, viewing_frustum
from differt.rt import first_triangles_hit_by_rays

mesh = TriangleMesh.load_obj(mesh_file)  # Reload mesh to reset colors

# [num_triangles 3 3]
triangle_vertices = mesh.triangle_vertices

num_triangles = mesh.num_triangles

with dplt.reuse() as fig:
    dplt.draw_markers(
        np.array([tx, rx]),
        ["tx", "rx"],
        showlegend=False,
    )
    mesh.plot()

    num_rays = int(1e5)
    max_dist = 1.0**2  # Squared distance (to avoid sqrt)
    max_order = 2

    # [num_path_candidates order 3]
    frustum = viewing_frustum(
        tx, triangle_vertices.reshape(-1, 3)
    )  # This avoids launching rays where there are no object

    # [num_rays 3]
    ray_origins = jnp.broadcast_to(tx, (num_rays, 3))
    ray_directions = fibonacci_lattice(num_rays, frustum=frustum)

    ScanC = tuple[
        Float[Array, f"{num_rays} 3"],  # Ray origins
        Float[Array, f"{num_rays} 3"],  # Ray directions (unit length)
        Bool[Array, f" {num_rays} "],  # Whether ray is still valid
    ]
    ScanR = tuple[
        Float[Array, f"{num_rays} 3"],  # Path vertices
        Bool[
            Array, f"{num_rays} 3"
        ],  # Whether paths pass close (i.e., < max_dist) to RX
    ]

    def scan_fun(
        ray_origins_directions_and_valids: ScanC, _: None
    ) -> tuple[ScanC, ScanR]:
        ray_origins, ray_directions, valid_rays = (
            ray_origins_directions_and_valids
        )

        # 1 - Compute next intersection with triangles

        # [num_rays]
        triangles, t_hit = first_triangles_hit_by_rays(
            ray_origins,
            ray_directions,
            triangle_vertices,
        )  # This may generate jnp.inf values, so we will need to be careful with those

        # 2 - Check if the rays pass near RX

        # [num_rays 3]
        ray_origins_to_rx = rx - ray_origins

        # [num_rays]
        # note: the fact that ray directions have unit length allows for
        #       some simplifications.
        ray_distances_to_rx = jnp.square(
            jnp.cross(ray_directions, ray_origins_to_rx)
        ).sum(axis=-1)  # Squared distance from rays to RXs
        t_rx = jnp.sum(
            ray_directions * ray_origins_to_rx, axis=-1
        )  # Distance (scaled by ray directions) from RX projected onto rays to ray origins
        masks = jnp.where(
            (t_rx < t_hit)
            & (t_rx > 0)
            & valid_rays,  # Check if RX is between origin and first triangle hit
            ray_distances_to_rx < max_dist,  # Check if RX is close enough
            False,
        )  # Whether rays pass near RX

        # 3 - Update rays

        # [num_rays 3]
        mirror_normals = jnp.take(mesh.normals, triangles, axis=0)

        ray_origins += t_hit[..., None] * ray_directions
        ray_directions = (
            ray_directions
            - 2.0
            * jnp.sum(ray_directions * mirror_normals, axis=-1, keepdims=True)
            * mirror_normals
        )
        # We mark rays that left the scene
        # i.e., when they no longer hit any object (t_hit is +inf.)
        valid_rays = valid_rays & jnp.isfinite(t_hit)

        return (ray_origins, ray_directions, valid_rays), (
            ray_origins,
            masks,
        )

    # We mark rays that left the scene as invalid
    valid_rays = jnp.ones(num_rays, dtype=bool)

    # [max_order+1 num_rays 3], [max_order+1 num_rays]
    _, (paths, masks) = jax.lax.scan(
        scan_fun,
        (ray_origins, ray_directions, valid_rays),
        length=max_order + 1,
    )

    # We swap 'max_order' and 'num_rays' axes
    # [num_rays max_order+1 3], [num_rays max_order+1]
    paths = jnp.moveaxis(paths, 0, 1)
    masks = jnp.moveaxis(masks, 0, 1)

    for order in range(max_order + 1):
        full_paths = assemble_paths(
            tx,
            # [num_valid_rays order 3]
            paths[masks[..., order], :order, :],
            rx,
        )

        dplt.draw_paths(
            full_paths,
            showlegend=False,
        )

fig
CPU times: user 2min, sys: 1.46 s, total: 2min 2s
Wall time: 1min 2s

We hope this tutorial gave you a good overview about how paths can be traced using this library. Of course, some types of paths, like diffraction paths, are not yet documented.

Moreover, the above methods may not be optimized yet, and you can expect changes in future releases.

If you want to contribute to extending or improving DiffeRT, please feel free to reach out on GitHub!