Multipath Lifetime Map#

This notebook aims at being a tutorial to reproduce the results presented in the paper Comparing Differentiable and Dynamic Ray Tracing: Introducing the Multipath Lifetime Map [6], and assumes you are familiar with its content.

You can run it locally or with Google Colab by clicking on the rocket at the top of this page!

Tip

On Google Colab, make sure to select a GPU or TPU runtime for a faster experience.

If you find this tutorial useful and plan on using this tool for your publications, please cite our work, see Citing.

Summary#

In our work, we present the Multipath Lifetime Map (MLM), a visual tool, along with two metrics, to help determine the scope of application of the Dynamic Ray Tracing (RT) method. For further details, please refer to the paper.

Important

The below implementation is far from being the most efficient, as it first aims at providing a nice visual output, and it is tailored for users moving on a 2D grid.

While our method extends to any number and kind of dynamic objects, it may become extremely hard to provide a visual representation of the MLM, especially for higher dimensions.

Imports#

As our code includes some non-trivial plots, we need to import quite a few Python modules. All of them should be installed with differt[all].

import hashlib
import random

import equinox as eqx
import jax.numpy as jnp
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from differt.em import (
    Dipole,
    materials,
    poynting_vector,
    reflection_coefficients,
    sp_directions,
)
from differt.geometry import (
    TriangleMesh,
    merge_cell_ids,
    min_distance_between_cells,
    normalize,
)
from differt.plotting import draw_image, draw_markers, reuse, set_backend
from differt.scene import (
    TriangleScene,
    download_sionna_scenes,
    get_sionna_scene,
)
from differt.utils import safe_divide
from jaxtyping import Array, Bool, Int
from plotly.colors import convert_to_RGB_255
from plotly.subplots import make_subplots

Simple Urban Street Canyon Scenario#

Street canyons are probably one of the most common types of scenarios studied in RT, due to their simplicity to model, but also important presence in big cities.

As we provide a compatibility layer with Sionna scenes [9], we will simply load the 'simple_street_canyon' scene.

You can download all the scenes from the Sionna repository with download_sionna_scenes. By default, they will be placed in a subfolder of the differt Python module.

download_sionna_scenes()  # Let's download Sionna scenes (from the main branch)
set_backend(
    "plotly"
)  # Our scene is simple, and Plotly is the best backend for online interactive plots :-)

file = get_sionna_scene("simple_street_canyon")
scene = TriangleScene.load_xml(file)
scene.plot()

In the cell below (hidden by default), we define quite a few utility functions to get nice plots, but this code is not needed if you are only interested in the metrics, and not the visual output.

Hide code cell source

def hashfun(*objects: bytes) -> bytes:
    m = hashlib.sha256()

    for obj in objects:
        m.update(obj)

    return m.digest()


def get_cell_hashes(
    cell_ids: Int[Array, " *batch"],
    mask: Bool[Array, "*batch num_path_candidates"],
) -> dict[int, bytes]:
    mask = mask.reshape(-1, mask.shape[-1])

    return {
        int(i): hashfun(mask[i, :].tobytes())
        for i in jnp.unique(cell_ids, return_index=True)[1]
    }


def merge_cell_ids_and_hashes(
    cell_ids: Int[Array, " *batch"],
    new_cell_ids: Int[Array, " *batch"],
    cell_hashes: dict[int, bytes],
    new_cell_hashes: dict[int, bytes],
) -> tuple[Int[Array, " *batch"], dict[int, bytes]]:
    ret_cell_ids = merge_cell_ids(cell_ids, new_cell_ids)

    ret_cell_hashes = {}

    for index in jnp.unique(ret_cell_ids, return_index=True)[1]:
        i = cell_ids.ravel()[index]
        j = new_cell_ids.ravel()[index]

        ret_cell_hashes[int(index)] = hashfun(
            cell_hashes[int(i)],
            new_cell_hashes[int(j)],
        )

    return ret_cell_ids, ret_cell_hashes


def draw_mesh_2d(mesh: TriangleMesh, figure: go.Figure) -> None:
    assert mesh.object_bounds is not None

    for i, j in mesh.object_bounds:
        sub_mesh = mesh[i:j]

        (xs, ys, (_, z_max)) = sub_mesh.bounding_box.T

        layer = "below" if z_max < 1e-6 else None

        assert sub_mesh.face_colors is not None
        color = convert_to_RGB_255(sub_mesh.face_colors[0, :])

        figure.add_shape(
            type="rect",
            x0=xs[0],
            y0=ys[0],
            x1=xs[1],
            y1=ys[1],
            fillcolor=f"rgb{color!s}",
            layer=layer,
        )


def random_rgb(cell_hash: bytes) -> str:
    rng = random.Random(cell_hash)  # noqa: S311
    r = rng.randint(0, 255)
    g = rng.randint(0, 255)
    b = rng.randint(0, 255)
    return f"rgb({r},{g},{b})"


def create_discrete_colorscale(
    cell_ids: Int[Array, " *batch"],
    cell_hashes: dict[int, bytes],
    first_is_multipath_cell: bool,
) -> list[list[float | str]]:
    unique_ids = jnp.unique(cell_ids).tolist()
    min_id = min(unique_ids)
    max_id = max(unique_ids)
    scale_factor = 1 + max_id - min_id

    def scale(id_: int) -> float:
        return (id_ - min_id) / scale_factor

    colorscale = [
        [scale(id_ + offset), random_rgb(cell_hashes[id_])]
        for id_ in unique_ids
        for offset in (0, 1)
    ]

    if first_is_multipath_cell:  # Let's hide the cell with no multipath
        colorscale[0][1] = colorscale[1][1] = "rgba(0,0,0,0)"

    return colorscale

Ray Tracing with a Grid of Receivers#

Because we want to detect multipath cells for a moving receiver, we will need to perform RT simulations for many receiving antenna locations. The easiest way to do so is to use the TriangleScene.with_receivers_grid method and rely on the global support of batched dimensions that this library offers to perform many RT simulations at the same time.

Warning

Beware that large batches or higher order reflection paths could rapidly cause out-of-memory issues!

If you have multiple GPUs or TPUs at your disposal, you can also use the parallel feature of TriangleScene.compute_paths.

In the cell below, we also estimate the coverage map of the received power.

# Let's put one transmitter and many receivers in our scene
scene = eqx.tree_at(
    lambda s: s.transmitters, scene, jnp.array([-33.0, 0.0, 32.0])
)
# Our scene can be simplified to quadrilaterals,
# so informing the code of that matter will make it run faster
scene = scene.set_assume_quads()
batch = (
    100,
    100,
)  # Warning: a too large batch could easily cause OOM issues,
#    or you may want to reduce the 'chunk_size' value below.
z0 = 1.5  # The z coordinate of the receivers
scene_grid = scene.with_receivers_grid(*batch, height=z0)

# And also keep track of multipath cells
cell_ids = jnp.zeros(batch, dtype=jnp.int32)  # Multipath cell indices
unique_cell_ids = jnp.empty(
    (), dtype=jnp.int32
)  # Contains unique values of 'cell_ids'
cell_hashes = {
    0: b""
}  # This is only used to generate constant random color per cell
has_multipath = jnp.zeros(
    batch, dtype=bool
)  # Will be true if a receiver has at least one valid ray path

# Only needed for plotting purposes
x, y, _ = jnp.unstack(scene_grid.receivers, axis=-1)

with reuse() as fig:
    scene.plot()

    ant = Dipole(
        2.4e9, center=scene.transmitters, look_at=[0.0, 0.0, z0]
    )  # 2.4 GHz
    ant.plot_radiation_pattern(distance=0.25, opacity=0.5, showscale=False)
    A_e = ant.aperture
    E = jnp.zeros((*batch, 3), dtype=jnp.complex64)
    B = jnp.zeros_like(E)

    eta_r = jnp.array([
        materials[mat_name].relative_permittivity(ant.frequency)
        for mat_name in scene.mesh.material_names
    ])
    n_r = jnp.sqrt(eta_r)

    for order in range(2):
        for paths in scene_grid.compute_paths(order=order, chunk_size=1_000):
            # 1 - Identify multipath cells
            # [*batch]
            new_cell_ids = paths.multipath_cells()
            # [num_unique_new_cell_ids]
            new_unique_cell_ids = jnp.unique(new_cell_ids)
            new_cell_hashes = get_cell_hashes(new_cell_ids, paths.mask)
            # [*batch]
            has_multipath |= paths.mask.any(axis=-1)
            cell_ids, cell_hashes = merge_cell_ids_and_hashes(
                cell_ids, new_cell_ids, cell_hashes, new_cell_hashes
            )
            # 2 - Compute EM fields (optional)

            E_i, B_i = ant.fields(paths.vertices[..., 1, :])

            if order > 0:
                # [*batch num_path_candidates order]
                obj_indices = paths.objects[..., 1:-1]
                # [*batch num_path_candidates order]
                mat_indices = jnp.take(
                    scene.mesh.face_materials, obj_indices, axis=0
                )
                # [*batch num_path_candidates order 3]
                obj_normals = jnp.take(scene.mesh.normals, obj_indices, axis=0)
                # [*batch num_path_candidates order]
                obj_n_r = jnp.take(n_r, mat_indices, axis=0)
                # [*batch num_path_candidates order+1 3]
                path_segments = jnp.diff(paths.vertices, axis=-2)
                # [*batch num_path_candidates order+1 3],
                # [*batch num_path_candidates order+1 1]
                k, s = normalize(path_segments, keepdims=True)
                # [*batch num_path_candidates order 3]
                k_i = k[..., :-1, :]
                k_r = k[..., +1:, :]
                # [*batch num_path_candidates order 3]
                (e_i_s, e_i_p), (e_r_s, e_r_p) = sp_directions(
                    k_i, k_r, obj_normals
                )
                # [*batch num_path_candidates order 1]
                cos_theta = jnp.sum(obj_normals * -k_i, axis=-1, keepdims=True)
                # [*batch num_path_candidates order 1]
                r_s, r_p = reflection_coefficients(
                    obj_n_r[..., None], cos_theta
                )
                # [*batch num_path_candidates 1]
                r_s = jnp.prod(r_s, axis=-2)
                r_p = jnp.prod(r_p, axis=-2)
                # [*batch num_path_candidates order 3]
                (e_i_s, e_i_p), (e_r_s, e_r_p) = sp_directions(
                    k_i, k_r, obj_normals
                )
                # [*batch num_path_candidates 1]
                E_i_s = jnp.sum(E_i * e_i_s[..., 0, :], axis=-1, keepdims=True)
                E_i_p = jnp.sum(E_i * e_i_p[..., 0, :], axis=-1, keepdims=True)
                B_i_s = jnp.sum(B_i * e_i_s[..., 0, :], axis=-1, keepdims=True)
                B_i_p = jnp.sum(B_i * e_i_p[..., 0, :], axis=-1, keepdims=True)
                # [*batch num_path_candidates 1]
                E_r_s = r_s * E_i_s
                E_r_p = r_p * E_i_p
                B_r_s = r_s * B_i_s
                B_r_p = r_p * B_i_p
                # [*batch num_path_candidates 3]
                E_r = E_r_s * e_r_s[..., -1, :] + E_r_p * e_r_p[..., -1, :]
                B_r = B_r_s * e_r_s[..., -1, :] + B_r_p * e_r_p[..., -1, :]
                # [*batch num_path_candidates 1]
                s_tot = s.sum(axis=-2)
                spreading_factor = safe_divide(
                    s[..., 0, :], s_tot
                )  # Far-field approximation
                phase_shift = jnp.exp(1j * s_tot * ant.wavenumber)
                # [*batch num_path_candidates 3]
                E_r *= spreading_factor * phase_shift
                B_r *= spreading_factor * phase_shift
            else:
                # [*batch num_path_candidates 3]
                E_r = E_i
                B_r = B_i

            # [*batch 3]
            E += jnp.sum(E_r, axis=-2, where=paths.mask[..., None])
            B += jnp.sum(B_r, axis=-2, where=paths.mask[..., None])

    S = poynting_vector(E, B)
    P = A_e * jnp.linalg.norm(S, axis=-1)
    G_dB = 10 * jnp.log10(P / ant.reference_power)

    draw_image(
        G_dB,
        x=x[0, :],
        y=y[:, 0],
        z0=z0,
        colorbar={"title": "Gain (dB)"},
    )

# We set cell ids with no multipath to -1 for easier identification

if not has_multipath.all():
    # Simple way to retrieve the cell index that has no multipath
    cell_id = jnp.max(cell_ids, initial=0, where=~has_multipath)
    cell_hashes[-1] = cell_hashes.pop(int(cell_id))

cell_ids = jnp.where(has_multipath, cell_ids, -1)

fig

Plotting a Multipath Lifetime Map#

From the previous simulation, and thanks to your utility functions, plotting an MCM can be performed in a few lines.

# We renumber unique indices to be between 0 and num_unique_cell_ids (excluded)
# Because `jax.numpy.unique` sorts entries, the first cell id will always refer to
# the 'no multipath' cell, if it exists.
unique_ids, renumbered_cell_ids = jnp.unique(cell_ids, return_inverse=True)
renumbered_cell_ids = renumbered_cell_ids.reshape(cell_ids.shape)
renumbered_cell_hashes = {
    i: cell_hashes[int(id_)] for i, id_ in enumerate(unique_ids)
}
# We create a discrete colorscale
colorscale = create_discrete_colorscale(
    renumbered_cell_ids,
    renumbered_cell_hashes,
    first_is_multipath_cell=bool(~has_multipath.all()),
)

with reuse() as fig:
    tx_x, tx_y, _ = scene_grid.transmitters.reshape(3, 1)
    draw_mesh_2d(scene.mesh, fig)
    fig.add_scatter(
        x=tx_x,
        y=tx_y,
        mode="markers+text",
        text=["tx"],
        marker={"color": "#EF553B", "size": 15},
        showlegend=False,
    )
    fig.add_heatmap(
        z=np.asarray(renumbered_cell_ids),
        x=np.asarray(x[0, :]),
        y=np.asarray(y[:, 0]),
        colorscale=colorscale,
        showscale=False,
    )

    fig.update_layout(
        height=600,
        xaxis={"range": [x.min(), x.max()]},
        yaxis={
            "range": [y.min(), y.max()],
            "scaleanchor": "x",
            "scaleratio": 1,
        },
    )

fig

Metrics#

In the two next sections, we show how to compute the two metrics presented in the paper.

Area per cell#

To estimate the area per cell, we simply count the number of receivers (RXs) in each cell, and multiply it by the area per point (i.e., per RX) to obtain the area covered by each cell.

length_x = x.max() - x.min()
length_y = y.max() - y.min()
surface = length_x * length_y
num_points = cell_ids.size
surface_per_point = (
    surface / num_points
)  # ~ Roughly, because RXs are not placed at centers of tiles

unique_ids, points_per_cell = jnp.unique(cell_ids, return_counts=True)
points_per_cell = points_per_cell[
    unique_ids != -1
]  # We remove cell with no multipath
points_per_cell
Array([163, 189, 569,  67,  29, 122, 156,  46,  10,   6,   4, 192, 354,
       197, 228,  71,  17,  72,  12,  92, 170,   4,  36,  15,   1, 205],      dtype=int32)
surface_per_cell = points_per_cell * surface_per_point

labels = {
    "x": "Surface",
    "y": "Normalized cells occupying a given surface",
}
counts, bins = np.histogram(surface_per_cell, bins=30)
bins = 0.5 * (bins[:-1] + bins[1:])

px.bar(
    x=bins,
    y=counts,
    labels=labels,
)

Average Minimal Inter-cell Distance#

For each RX in each cell, we compute the minimal distance that it must travel to be in a different cell. We then compute the average of the value, per cell.

min_dist = min_distance_between_cells(scene_grid.receivers, cell_ids)

for cell_id in jnp.unique(cell_ids):
    same_cell = cell_ids == cell_id
    mean_min_dist = jnp.mean(min_dist, where=same_cell)
    std_min_dist = jnp.std(min_dist, where=same_cell)

    print(
        f"cell id = {int(cell_id):5d} has an average minimal distance "
        f"to next cell of {float(mean_min_dist):5.2f} "
        f"(std: {float(std_min_dist):4.2f})"
    )

Hide code cell output

cell id =    -1 has an average minimal distance to next cell of 17.38 (std: 12.12)
cell id =    22 has an average minimal distance to next cell of  4.29 (std: 2.60)
cell id =    36 has an average minimal distance to next cell of  3.05 (std: 1.94)
cell id =    40 has an average minimal distance to next cell of  3.37 (std: 2.00)
cell id =   236 has an average minimal distance to next cell of  1.99 (std: 0.59)
cell id =  1035 has an average minimal distance to next cell of  1.79 (std: 0.23)
cell id =  1139 has an average minimal distance to next cell of  2.81 (std: 1.23)
cell id =  1800 has an average minimal distance to next cell of  3.49 (std: 2.01)
cell id =  3698 has an average minimal distance to next cell of  1.67 (std: 0.64)
cell id =  3859 has an average minimal distance to next cell of  1.39 (std: 0.35)
cell id =  4063 has an average minimal distance to next cell of  1.22 (std: 0.00)
cell id =  4164 has an average minimal distance to next cell of  1.22 (std: 0.00)
cell id =  4193 has an average minimal distance to next cell of  1.83 (std: 0.86)
cell id =  4316 has an average minimal distance to next cell of  4.18 (std: 2.32)
cell id =  4343 has an average minimal distance to next cell of  1.78 (std: 0.69)
cell id =  4360 has an average minimal distance to next cell of  3.22 (std: 1.78)
cell id =  4371 has an average minimal distance to next cell of  1.60 (std: 0.55)
cell id =  4599 has an average minimal distance to next cell of  1.56 (std: 0.52)
cell id =  5000 has an average minimal distance to next cell of  2.12 (std: 1.00)
cell id =  5099 has an average minimal distance to next cell of  1.39 (std: 0.38)
cell id =  5291 has an average minimal distance to next cell of  1.77 (std: 0.68)
cell id =  5841 has an average minimal distance to next cell of  3.28 (std: 1.57)
cell id =  5867 has an average minimal distance to next cell of  1.22 (std: 0.00)
cell id =  5885 has an average minimal distance to next cell of  1.54 (std: 0.51)
cell id =  5962 has an average minimal distance to next cell of  1.34 (std: 0.29)
cell id =  6066 has an average minimal distance to next cell of  1.22 (std: 0.00)
cell id =  6434 has an average minimal distance to next cell of  4.32 (std: 2.59)

Animating Over Multiple Transmitter Positions#

Of course, one may be interested to study the evolution of the MLM with respect to the transmitter (TX) position.

The code below shows how to produce a nice interactive plot handling multiple TX positions, and you may recognize the three positions we used in our paper[1].

Hide code cell source

fig = make_subplots(
    rows=2,
    cols=1,
    specs=[[{"type": "scene"}], [{"type": "heatmap"}]],
)

with reuse(figure=fig) as fig:
    scene.plot(tx_kwargs={"visible": False}, row=1, col=1)
    draw_mesh_2d(scene.mesh, fig)

    offset = len(fig.data)

    x_positions = jnp.linspace(x.min(), x.max())

    for x_pos in x_positions:
        scene_grid = eqx.tree_at(
            lambda s: s.transmitters,
            scene_grid,
            scene_grid.transmitters.at[0].set(x_pos),
        )
        cell_ids = jnp.zeros(batch, dtype=jnp.int32)
        cell_hashes = {0: b""}
        has_multipath = jnp.zeros(batch, dtype=bool)

        for order in range(2):
            for paths in scene_grid.compute_paths(
                order=order, chunk_size=1_000
            ):
                new_cell_ids = paths.multipath_cells()
                new_cell_hashes = get_cell_hashes(new_cell_ids, paths.mask)
                has_multipath |= paths.mask.any(axis=-1)
                cell_ids, cell_hashes = merge_cell_ids_and_hashes(
                    cell_ids,
                    new_cell_ids,
                    cell_hashes,
                    new_cell_hashes,
                )

        if not has_multipath.all():
            cell_id = jnp.max(cell_ids, initial=0, where=~has_multipath)
            cell_hashes[-1] = cell_hashes.pop(int(cell_id))

        cell_ids = jnp.where(has_multipath, cell_ids, -1)
        unique_ids, renumbered_cell_ids = jnp.unique(
            cell_ids, return_inverse=True
        )
        renumbered_cell_ids = renumbered_cell_ids.reshape(cell_ids.shape)
        renumbered_cell_hashes = {
            i: cell_hashes[int(id_)] for i, id_ in enumerate(unique_ids)
        }
        colorscale = create_discrete_colorscale(
            renumbered_cell_ids,
            renumbered_cell_hashes,
            first_is_multipath_cell=bool(~has_multipath.all()),
        )

        draw_markers(
            np.asarray(scene_grid.transmitters.reshape(-1, 3)),
            labels=["tx"],
            showlegend=False,
            visible=False,
            row=1,
            col=1,
        )

        tx_x, tx_y, _ = scene_grid.transmitters.reshape(3, 1)

        fig.add_scatter(
            x=tx_x,
            y=tx_y,
            mode="markers+text",
            text=["tx"],
            marker={"color": "#EF553B", "size": 15},
            showlegend=False,
            visible=False,
            row=2,
            col=1,
        )

        fig.add_heatmap(
            x=np.asarray(x[0, :]),
            y=np.asarray(y[:, 0]),
            z=np.asarray(renumbered_cell_ids),
            colorscale=colorscale,
            hovertemplate="cell id: %{z}",
            showscale=False,
            visible=False,
            row=2,
            col=1,
        )

    steps = []

    for i, _ in enumerate(x_positions):
        step = {
            "method": "update",
            "args": [
                {"visible": [False, True] + [False] * len(x_positions) * 3},
            ],
        }
        step["args"][0]["visible"][offset + 3 * i + 0] = (
            True  # Show TX position on scene
        )
        step["args"][0]["visible"][offset + 3 * i + 1] = (
            True  # Show TX position on MLM
        )
        step["args"][0]["visible"][offset + 3 * i + 2] = True  # Show MLM
        steps.append(step)

    sliders = [
        {
            "active": 0,
            "currentvalue": {"prefix": "TX index: "},
            "pad": {"t": 50},
            "steps": steps,
        }
    ]

    fig.data[offset + 0].visible = True
    fig.data[offset + 1].visible = True
    fig.data[offset + 2].visible = True

    fig.update_layout(
        height=800,
        sliders=sliders,
        xaxis={"range": [x.min(), x.max()]},
        yaxis={
            "range": [y.min(), y.max()],
            "scaleanchor": "x",
            "scaleratio": 1,
        },
    )

fig

Conclusion#

Voilà! That’s it!

We hope this short tutorial helped you understand the procedure required to reproduce the results presented in our paper.

Should you have any questions, recommendations, or remarks, do not hesitate to reach us via e-mail or via GitHub issues and discussions!