Multipath Lifetime Map#

This notebook aims at being a tutorial to reproduce the results presented in the paper Comparing Differentiable and Dynamic Ray Tracing: Introducing the Multipath Lifetime Map [6], and assumes you are familiar with its content.

You can run it locally or with Google Colab by clicking on the rocket at the top of this page!

Tip

On Google Colab, make sure to select a GPU or TPU runtime for a faster experience.

If you find this tutorial useful and plan on using this tool for your publications, please cite our work, see Citing.

Summary#

In our work, we present the Multipath Lifetime Map (MLM), a visual tool, along with two metrics, to help determine the scope of application of the Dynamic Ray Tracing (RT) method. For further details, please refer to the paper.

Important

The below implementation is far from being the most efficient, as it first aims at providing a nice visual output, and it is tailored for users moving on a 2D grid.

While our method extends to any number and kind of dynamic objects, it may become extremely hard to provide a visual representation of the MLM, especially for higher dimensions.

Imports#

As our code includes some non-trivial plots, we need to import quite a few Python modules. All of them should be installed with differt[all].

import hashlib
import random

import equinox as eqx
import jax.numpy as jnp
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from differt.em import (
    Dipole,
    materials,
    poynting_vector,
    reflection_coefficients,
    sp_directions,
)
from differt.geometry import (
    TriangleMesh,
    merge_cell_ids,
    min_distance_between_cells,
    normalize,
)
from differt.plotting import draw_image, draw_markers, reuse, set_backend
from differt.scene import (
    TriangleScene,
    download_sionna_scenes,
    get_sionna_scene,
)
from differt.utils import safe_divide
from jaxtyping import Array, Bool, Int
from plotly.colors import convert_to_RGB_255
from plotly.subplots import make_subplots

Simple Urban Street Canyon Scenario#

Street canyons are probably one of the most common types of scenarios studied in RT, due to their simplicity to model, but also important presence in big cities.

As we provide a compatibility layer with Sionna scenes [9], we will simply load the 'simple_street_canyon' scene.

You can download all the scenes from the Sionna repository with download_sionna_scenes. By default, they will be placed in a subfolder of the differt Python module.

download_sionna_scenes()  # Let's download Sionna scenes (from the main branch)
set_backend(
    "plotly"
)  # Our scene is simple, and Plotly is the best backend for online interactive plots :-)

file = get_sionna_scene("simple_street_canyon")
scene = TriangleScene.load_xml(file)
scene.plot()

In the cell below (hidden by default), we define quite a few utility functions to get nice plots, but this code is not needed if you are only interested in the metrics, and not the visual output.

Hide code cell source

def hashfun(*objects: bytes) -> bytes:
    m = hashlib.sha256()

    for obj in objects:
        m.update(obj)

    return m.digest()


def get_cell_hashes(
    cell_ids: Int[Array, " *batch"],
    mask: Bool[Array, "*batch num_path_candidates"],
) -> dict[int, bytes]:
    mask = mask.reshape(-1, mask.shape[-1])

    return {
        int(i): hashfun(mask[i, :].tobytes())
        for i in jnp.unique(cell_ids, return_index=True)[1]
    }


def merge_cell_ids_and_hashes(
    cell_ids: Int[Array, " *batch"],
    new_cell_ids: Int[Array, " *batch"],
    cell_hashes: dict[int, bytes],
    new_cell_hashes: dict[int, bytes],
) -> tuple[Int[Array, " *batch"], dict[int, bytes]]:
    ret_cell_ids = merge_cell_ids(cell_ids, new_cell_ids)

    ret_cell_hashes = {}

    for index in jnp.unique(ret_cell_ids, return_index=True)[1]:
        i = cell_ids.ravel()[index]
        j = new_cell_ids.ravel()[index]

        ret_cell_hashes[int(index)] = hashfun(
            cell_hashes[int(i)],
            new_cell_hashes[int(j)],
        )

    return ret_cell_ids, ret_cell_hashes


def draw_mesh_2d(mesh: TriangleMesh, figure: go.Figure) -> None:
    assert mesh.object_bounds is not None

    for i, j in mesh.object_bounds:
        sub_mesh = mesh[i:j]

        (xs, ys, (_, z_max)) = sub_mesh.bounding_box.T

        layer = "below" if z_max < 1e-6 else None

        assert sub_mesh.face_colors is not None
        color = convert_to_RGB_255(sub_mesh.face_colors[0, :])

        figure.add_shape(
            type="rect",
            x0=xs[0],
            y0=ys[0],
            x1=xs[1],
            y1=ys[1],
            fillcolor=f"rgb{color!s}",
            layer=layer,
        )


def random_rgb(cell_hash: bytes) -> str:
    rng = random.Random(cell_hash)  # noqa: S311
    r = rng.randint(0, 255)
    g = rng.randint(0, 255)
    b = rng.randint(0, 255)
    return f"rgb({r},{g},{b})"


def create_discrete_colorscale(
    cell_ids: Int[Array, " *batch"],
    cell_hashes: dict[int, bytes],
    first_is_multipath_cell: bool,
) -> list[list[float | str]]:
    unique_ids = jnp.unique(cell_ids).tolist()
    min_id = min(unique_ids)
    max_id = max(unique_ids)
    scale_factor = 1 + max_id - min_id

    def scale(id_: int) -> float:
        return (id_ - min_id) / scale_factor

    colorscale = [
        [scale(id_ + offset), random_rgb(cell_hashes[id_])]
        for id_ in unique_ids
        for offset in (0, 1)
    ]

    if first_is_multipath_cell:  # Let's hide the cell with no multipath
        colorscale[0][1] = colorscale[1][1] = "rgba(0,0,0,0)"

    return colorscale

Ray Tracing with a Grid of Receivers#

Because we want to detect multipath cells for a moving receiver, we will need to perform RT simulations for many receiving antenna locations. The easiest way to do so is to use the TriangleScene.with_receivers_grid method and rely on the global support of batched dimensions that this library offers to perform many RT simulations at the same time.

Warning

Beware that large batches or higher order reflection paths could rapidly cause out-of-memory issues!

If you have multiple GPUs or TPUs at your disposal, you can also use the parallel feature of TriangleScene.compute_paths.

In the cell below, we also estimate the coverage map of the received power.

# Let's put one transmitter and many receivers in our scene
scene = eqx.tree_at(
    lambda s: s.transmitters, scene, jnp.array([-33.0, 0.0, 32.0])
)
# Our scene can be simplified to quadrilaterals,
# so informing the code of that matter will make it run faster
scene = scene.set_assume_quads()
batch = (
    100,
    100,
)  # Warning: a too large batch could easily cause OOM issues,
#    or you may want to reduce the 'chunk_size' value below.
z0 = 1.5  # The z coordinate of the receivers
scene_grid = scene.with_receivers_grid(*batch, height=z0)

# And also keep track of multipath cells
cell_ids = jnp.zeros(batch, dtype=jnp.int32)  # Multipath cell indices
unique_cell_ids = jnp.empty(
    (), dtype=jnp.int32
)  # Contains unique values of 'cell_ids'
cell_hashes = {
    0: b""
}  # This is only used to generate constant random color per cell
has_multipath = jnp.zeros(
    batch, dtype=bool
)  # Will be true if a receiver has at least one valid ray path

# Only needed for plotting purposes
x, y, _ = jnp.unstack(scene_grid.receivers, axis=-1)

with reuse() as fig:
    scene.plot()

    ant = Dipole(
        2.4e9, center=scene.transmitters, look_at=[0.0, 0.0, z0]
    )  # 2.4 GHz
    ant.plot_radiation_pattern(distance=0.25, opacity=0.5, showscale=False)
    A_e = ant.aperture
    E = jnp.zeros((*batch, 3), dtype=jnp.complex64)
    B = jnp.zeros_like(E)

    eta_r = jnp.array([
        materials[mat_name].relative_permittivity(ant.frequency)
        for mat_name in scene.mesh.material_names
    ])
    n_r = jnp.sqrt(eta_r)

    for order in range(2):
        for paths in scene_grid.compute_paths(order=order, chunk_size=1_000):
            # 1 - Identify multipath cells
            # [*batch]
            new_cell_ids = paths.multipath_cells()
            # [num_unique_new_cell_ids]
            new_unique_cell_ids = jnp.unique(new_cell_ids)
            new_cell_hashes = get_cell_hashes(new_cell_ids, paths.mask)
            # [*batch]
            has_multipath |= paths.mask.any(axis=-1)
            cell_ids, cell_hashes = merge_cell_ids_and_hashes(
                cell_ids, new_cell_ids, cell_hashes, new_cell_hashes
            )
            # 2 - Compute EM fields (optional)

            E_i, B_i = ant.fields(paths.vertices[..., 1, :])

            if order > 0:
                # [*batch num_path_candidates order]
                obj_indices = paths.objects[..., 1:-1]
                # [*batch num_path_candidates order]
                mat_indices = jnp.take(
                    scene.mesh.face_materials, obj_indices, axis=0
                )
                # [*batch num_path_candidates order 3]
                obj_normals = jnp.take(scene.mesh.normals, obj_indices, axis=0)
                # [*batch num_path_candidates order]
                obj_n_r = jnp.take(n_r, mat_indices, axis=0)
                # [*batch num_path_candidates order+1 3]
                path_segments = jnp.diff(paths.vertices, axis=-2)
                # [*batch num_path_candidates order+1 3],
                # [*batch num_path_candidates order+1 1]
                k, s = normalize(path_segments, keepdims=True)
                # [*batch num_path_candidates order 3]
                k_i = k[..., :-1, :]
                k_r = k[..., +1:, :]
                # [*batch num_path_candidates order 3]
                (e_i_s, e_i_p), (e_r_s, e_r_p) = sp_directions(
                    k_i, k_r, obj_normals
                )
                # [*batch num_path_candidates order 1]
                cos_theta = jnp.sum(obj_normals * -k_i, axis=-1, keepdims=True)
                # [*batch num_path_candidates order 1]
                r_s, r_p = reflection_coefficients(
                    obj_n_r[..., None], cos_theta
                )
                # [*batch num_path_candidates 1]
                r_s = jnp.prod(r_s, axis=-2)
                r_p = jnp.prod(r_p, axis=-2)
                # [*batch num_path_candidates order 3]
                (e_i_s, e_i_p), (e_r_s, e_r_p) = sp_directions(
                    k_i, k_r, obj_normals
                )
                # [*batch num_path_candidates 1]
                E_i_s = jnp.sum(E_i * e_i_s[..., 0, :], axis=-1, keepdims=True)
                E_i_p = jnp.sum(E_i * e_i_p[..., 0, :], axis=-1, keepdims=True)
                B_i_s = jnp.sum(B_i * e_i_s[..., 0, :], axis=-1, keepdims=True)
                B_i_p = jnp.sum(B_i * e_i_p[..., 0, :], axis=-1, keepdims=True)
                # [*batch num_path_candidates 1]
                E_r_s = r_s * E_i_s
                E_r_p = r_p * E_i_p
                B_r_s = r_s * B_i_s
                B_r_p = r_p * B_i_p
                # [*batch num_path_candidates 3]
                E_r = E_r_s * e_r_s[..., -1, :] + E_r_p * e_r_p[..., -1, :]
                B_r = B_r_s * e_r_s[..., -1, :] + B_r_p * e_r_p[..., -1, :]
                # [*batch num_path_candidates 1]
                s_tot = s.sum(axis=-2)
                spreading_factor = safe_divide(
                    s[..., 0, :], s_tot
                )  # Far-field approximation
                phase_shift = jnp.exp(1j * s_tot * ant.wavenumber)
                # [*batch num_path_candidates 3]
                E_r *= spreading_factor * phase_shift
                B_r *= spreading_factor * phase_shift
            else:
                # [*batch num_path_candidates 3]
                E_r = E_i
                B_r = B_i

            # [*batch 3]
            E += jnp.sum(E_r, axis=-2, where=paths.mask[..., None])
            B += jnp.sum(B_r, axis=-2, where=paths.mask[..., None])

    S = poynting_vector(E, B)
    P = A_e * jnp.linalg.norm(S, axis=-1)
    G_dB = 10 * jnp.log10(P / ant.reference_power)

    draw_image(
        G_dB,
        x=x[0, :],
        y=y[:, 0],
        z0=z0,
        colorbar={"title": "Gain (dB)"},
    )

# We set cell ids with no multipath to -1 for easier identification

if not has_multipath.all():
    # Simple way to retrieve the cell index that has no multipath
    cell_id = jnp.max(cell_ids, initial=0, where=~has_multipath)
    cell_hashes[-1] = cell_hashes.pop(int(cell_id))

cell_ids = jnp.where(has_multipath, cell_ids, -1)

fig