Multipath Lifetime Map#
This notebook aims at being a tutorial to reproduce the results presented in the paper Comparing Differentiable and Dynamic Ray Tracing: Introducing the Multipath Lifetime Map [6], and assumes you are familiar with its content.
You can run it locally or with Google Colab by clicking on the rocket at the top of this page!
Tip
On Google Colab, make sure to select a GPU or TPU runtime for a faster experience.
If you find this tutorial useful and plan on using this tool for your publications, please cite our work, see Citing.
Summary#
In our work, we present the Multipath Lifetime Map (MLM), a visual tool, along with two metrics, to help determine the scope of application of the Dynamic Ray Tracing (RT) method. For further details, please refer to the paper.
Important
The below implementation is far from being the most efficient, as it first aims at providing a nice visual output, and it is tailored for users moving on a 2D grid.
While our method extends to any number and kind of dynamic objects, it may become extremely hard to provide a visual representation of the MLM, especially for higher dimensions.
Imports#
As our code includes some non-trivial plots, we need to import quite a few Python modules.
All of them should be installed with differt[all].
import hashlib
import random
import equinox as eqx
import jax.numpy as jnp
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from differt.em import (
Dipole,
materials,
poynting_vector,
reflection_coefficients,
sp_directions,
)
from differt.geometry import (
TriangleMesh,
merge_cell_ids,
min_distance_between_cells,
normalize,
)
from differt.plotting import draw_image, draw_markers, reuse, set_backend
from differt.scene import (
TriangleScene,
download_sionna_scenes,
get_sionna_scene,
)
from differt.utils import safe_divide
from jaxtyping import Array, Bool, Int
from plotly.colors import convert_to_RGB_255
from plotly.subplots import make_subplots
Simple Urban Street Canyon Scenario#
Street canyons are probably one of the most common types of scenarios studied in RT, due to their simplicity to model, but also important presence in big cities.
As we provide a compatibility layer with Sionna scenes [9],
we will simply load the 'simple_street_canyon' scene.
You can download all the scenes from the Sionna repository with
download_sionna_scenes.
By default, they will be placed in a subfolder of the differt Python module.
download_sionna_scenes() # Let's download Sionna scenes (from the main branch)
set_backend(
"plotly"
) # Our scene is simple, and Plotly is the best backend for online interactive plots :-)
file = get_sionna_scene("simple_street_canyon")
scene = TriangleScene.load_xml(file)
scene.plot()
In the cell below (hidden by default), we define quite a few utility functions to get nice plots, but this code is not needed if you are only interested in the metrics, and not the visual output.
Ray Tracing with a Grid of Receivers#
Because we want to detect multipath cells for a moving receiver,
we will need to perform RT simulations for many receiving antenna locations.
The easiest way to do so is to use the
TriangleScene.with_receivers_grid
method and rely on the global support of batched dimensions that this library offers to perform
many RT simulations at the same time.
Warning
Beware that large batches or higher order reflection paths could rapidly cause out-of-memory issues!
If you have multiple GPUs or TPUs at your disposal, you can
also use the parallel feature of TriangleScene.compute_paths.
In the cell below, we also estimate the coverage map of the received power.
# Let's put one transmitter and many receivers in our scene
scene = eqx.tree_at(
lambda s: s.transmitters, scene, jnp.array([-33.0, 0.0, 32.0])
)
# Our scene can be simplified to quadrilaterals,
# so informing the code of that matter will make it run faster
scene = scene.set_assume_quads()
batch = (
100,
100,
) # Warning: a too large batch could easily cause OOM issues,
# or you may want to reduce the 'chunk_size' value below.
z0 = 1.5 # The z coordinate of the receivers
scene_grid = scene.with_receivers_grid(*batch, height=z0)
# And also keep track of multipath cells
cell_ids = jnp.zeros(batch, dtype=jnp.int32) # Multipath cell indices
unique_cell_ids = jnp.empty(
(), dtype=jnp.int32
) # Contains unique values of 'cell_ids'
cell_hashes = {
0: b""
} # This is only used to generate constant random color per cell
has_multipath = jnp.zeros(
batch, dtype=bool
) # Will be true if a receiver has at least one valid ray path
# Only needed for plotting purposes
x, y, _ = jnp.unstack(scene_grid.receivers, axis=-1)
with reuse() as fig:
scene.plot()
ant = Dipole(
2.4e9, center=scene.transmitters, look_at=[0.0, 0.0, z0]
) # 2.4 GHz
ant.plot_radiation_pattern(distance=0.25, opacity=0.5, showscale=False)
A_e = ant.aperture
E = jnp.zeros((*batch, 3), dtype=jnp.complex64)
B = jnp.zeros_like(E)
eta_r = jnp.array([
materials[mat_name].relative_permittivity(ant.frequency)
for mat_name in scene.mesh.material_names
])
n_r = jnp.sqrt(eta_r)
for order in range(2):
for paths in scene_grid.compute_paths(order=order, chunk_size=1_000):
# 1 - Identify multipath cells
# [*batch]
new_cell_ids = paths.multipath_cells()
# [num_unique_new_cell_ids]
new_unique_cell_ids = jnp.unique(new_cell_ids)
new_cell_hashes = get_cell_hashes(new_cell_ids, paths.mask)
# [*batch]
has_multipath |= paths.mask.any(axis=-1)
cell_ids, cell_hashes = merge_cell_ids_and_hashes(
cell_ids, new_cell_ids, cell_hashes, new_cell_hashes
)
# 2 - Compute EM fields (optional)
E_i, B_i = ant.fields(paths.vertices[..., 1, :])
if order > 0:
# [*batch num_path_candidates order]
obj_indices = paths.objects[..., 1:-1]
# [*batch num_path_candidates order]
mat_indices = jnp.take(
scene.mesh.face_materials, obj_indices, axis=0
)
# [*batch num_path_candidates order 3]
obj_normals = jnp.take(scene.mesh.normals, obj_indices, axis=0)
# [*batch num_path_candidates order]
obj_n_r = jnp.take(n_r, mat_indices, axis=0)
# [*batch num_path_candidates order+1 3]
path_segments = jnp.diff(paths.vertices, axis=-2)
# [*batch num_path_candidates order+1 3],
# [*batch num_path_candidates order+1 1]
k, s = normalize(path_segments, keepdims=True)
# [*batch num_path_candidates order 3]
k_i = k[..., :-1, :]
k_r = k[..., +1:, :]
# [*batch num_path_candidates order 3]
(e_i_s, e_i_p), (e_r_s, e_r_p) = sp_directions(
k_i, k_r, obj_normals
)
# [*batch num_path_candidates order 1]
cos_theta = jnp.sum(obj_normals * -k_i, axis=-1, keepdims=True)
# [*batch num_path_candidates order 1]
r_s, r_p = reflection_coefficients(
obj_n_r[..., None], cos_theta
)
# [*batch num_path_candidates 1]
r_s = jnp.prod(r_s, axis=-2)
r_p = jnp.prod(r_p, axis=-2)
# [*batch num_path_candidates order 3]
(e_i_s, e_i_p), (e_r_s, e_r_p) = sp_directions(
k_i, k_r, obj_normals
)
# [*batch num_path_candidates 1]
E_i_s = jnp.sum(E_i * e_i_s[..., 0, :], axis=-1, keepdims=True)
E_i_p = jnp.sum(E_i * e_i_p[..., 0, :], axis=-1, keepdims=True)
B_i_s = jnp.sum(B_i * e_i_s[..., 0, :], axis=-1, keepdims=True)
B_i_p = jnp.sum(B_i * e_i_p[..., 0, :], axis=-1, keepdims=True)
# [*batch num_path_candidates 1]
E_r_s = r_s * E_i_s
E_r_p = r_p * E_i_p
B_r_s = r_s * B_i_s
B_r_p = r_p * B_i_p
# [*batch num_path_candidates 3]
E_r = E_r_s * e_r_s[..., -1, :] + E_r_p * e_r_p[..., -1, :]
B_r = B_r_s * e_r_s[..., -1, :] + B_r_p * e_r_p[..., -1, :]
# [*batch num_path_candidates 1]
s_tot = s.sum(axis=-2)
spreading_factor = safe_divide(
s[..., 0, :], s_tot
) # Far-field approximation
phase_shift = jnp.exp(1j * s_tot * ant.wavenumber)
# [*batch num_path_candidates 3]
E_r *= spreading_factor * phase_shift
B_r *= spreading_factor * phase_shift
else:
# [*batch num_path_candidates 3]
E_r = E_i
B_r = B_i
# [*batch 3]
E += jnp.sum(E_r, axis=-2, where=paths.mask[..., None])
B += jnp.sum(B_r, axis=-2, where=paths.mask[..., None])
S = poynting_vector(E, B)
P = A_e * jnp.linalg.norm(S, axis=-1)
G_dB = 10 * jnp.log10(P / ant.reference_power)
draw_image(
G_dB,
x=x[0, :],
y=y[:, 0],
z0=z0,
colorbar={"title": "Gain (dB)"},
)
# We set cell ids with no multipath to -1 for easier identification
if not has_multipath.all():
# Simple way to retrieve the cell index that has no multipath
cell_id = jnp.max(cell_ids, initial=0, where=~has_multipath)
cell_hashes[-1] = cell_hashes.pop(int(cell_id))
cell_ids = jnp.where(has_multipath, cell_ids, -1)
fig