Sampling Path Candidates with Machine Learning#

This notebook aims to be a tutorial for reproducing the results presented in the paper Transform-Invariant Generative Ray Path Sampling for Efficient Radio Propagation Modeling, and assumes you are familiar with its content.

Important

This notebook presents version 2 of our work on the topic of sampling path candidates. You can access v1, as presented at the 2025 IEEE International Conference on Machine Learning for Communication and Networking (ICMLCN) in our paper Towards Generative Ray Path Sampling for Faster Point-to-Point Ray Tracing [3] by clicking here.

You can run it locally or with Google Colab by clicking on the rocket at the top of this page!

Tip

On Google Colab, make sure to select a GPU or TPU runtime for a faster experience.

If you find this tutorial useful and plan on using this tool for your publications, please cite our work; see Citing.

Warning

Training the models can take quite some time.

If you want to bypass the training and use a pre-trained model, you can download the weights from the releases page and load them with Model.load_weights.

Summary#

In our work, we present a machine learning model that aims to reduce the computational complexity of exhaustive point-to-point Ray Tracing by learning how to sample path candidates. For further details, please refer to the paper.

Setup#

Below are the important steps to properly set up the environment.

Imports#

Unlike for the previous version of our work, we have refactored the codebase to be more modular and reusable. This means that most of the code is located in separate files, available on GitHub, and installed as a Python package (sampling_paths).

This notebook is meant to be a tutorial on how to use the package to train and use the model, and thus does not contain the implementation details of the model itself.

We need to import quite a few Python modules, but all of them should be installed with pip install --group notebooks git+https://github.com/jeertmans/sampling-paths.git (assuming pip >= 25.1). The following cell is hidden by default for readability.

Hide code cell source

from collections import defaultdict
from functools import partial

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import treescope
from differt.geometry import path_lengths
from differt.plotting import draw_image, reuse, set_defaults
from differt.scene import (
    TriangleScene,
    download_sionna_scenes,
)
from differt.utils import safe_divide
from jaxtyping import Array, Float, Int, PRNGKeyArray
from plotly.subplots import make_subplots
from sampling_paths.agent import Agent
from sampling_paths.model import Model
from sampling_paths.utils import (
    BASE_SCENE,
    random_scene,
    train_dataloader,
    validation_scene_keys,
)
from tqdm.notebook import tqdm, trange

Pretty Printing#

Nested structures, like machine learning modules, or more generally PyTrees, do not render very nicely by default. To provide the user with an interactive pretty-printing experience, we use treescope:

treescope.basic_interactive_setup()

JAX Device#

While this notebook will run fine on all supported JAX devices (i.e., CPU, GPU, and TPU), using a GPU (or a TPU) will usually decrease the computational time by a significant factor.

To check the currently active JAX device, use jax.devices:

jax.devices()

Generating the training data#

For this tutorial, as in the paper, we restrict ourselves to a simple urban scenario, obtained from Sionna [9], from which we will derive random scenes.

download_sionna_scenes()  # Let's download Sionna scenes (from the main branch)
set_defaults(
    "plotly"
)  # Our scene is simple, and Plotly is the best backend for online interactive plots :-)


def make_transparent(fig: go.Figure) -> go.Figure:
    fig.update_scenes(
        xaxis_visible=False, yaxis_visible=False, zaxis_visible=False
    )
    return fig
fig = BASE_SCENE.plot()
make_transparent(fig)

From the BASE_SCENE object, we can generate random variations of it. Two main types of variations are considered here:

  1. the number of objects (either triangles or quadrilaterals);

  2. and the TX / RX positions.

Note

Unlike with NumPy and other common array libraries, JAX requires an explicit random key (jr.key) whenever you need to generate pseudo-random numbers. While this can lead to more verbose code, this has the major advantage of making random number generation easily reproducible, even across multiple devices. You want the same results? Then just pass the same key!

Below, we demonstrate how to generate (and plot) a random scene.

example_scene = random_scene(sample_in_canyon=True, key=jr.key(1234))
example_fig = example_scene.plot(
    showlegend=False,
    tx_kwargs={"labels": ["TX"]},
    rx_kwargs={"labels": ["RX"]},
    mesh_kwargs={"opacity": 0.5},
)
make_transparent(example_fig)

In the paper, we study other types of random variations, and they can be specified with optional parameters of the random_scene function.

help(random_scene)
Help on _JitWrapper in module sampling_paths.utils:

random_scene(*, min_fill_factor: float = 0.5, max_fill_factor: float = 1.0, tx_z_min: float = 2.0, tx_z_max: float = 50.0, rx_z_min: float = 1.0, rx_z_max: float = 2.0, sample_objects: bool = True, sample_in_canyon: bool = True, include_floor: bool = True, key: Union[jaxtyping.Key[jaxlib._jax.Array, ''], jaxtyping.UInt32[jaxlib._jax.Array, '2']]) -> differt.scene._triangle_scene.TriangleScene
    Return a random scene with one TX and one RX, at random positions, and a random number of objects.
    
    The number of objects is randomly sampled based on a random fill factor.
    
    Args:
        min_fill_factor: The minimum fill factor to be used.
        max_fill_factor: The maximum fill factor to be used.
        tx_z_min: Minimum height of the transmitter.
        tx_z_max: Maximum height of the transmitter.
        rx_z_min: Minimum height of the receiver.
        rx_z_max: Maximum height of the receiver.
        sample_objects: Whether to sample objects in the scene, instead of individual primitives.
        sample_in_canyon: Whether to sample the TX and RX positions within the canyon area.
        include_floor: Whether to always include the floor in the scene.
        key: The random key to be used.
    
    Returns:
        A new scene.

Afterward, we can trace ray paths in the generated scene and plot them.

with reuse(figure=example_fig) as fig:
    num_active_primitives = example_scene.mesh.num_active_primitives

    for order in [0, 1, 2, 3]:
        num_valid_paths = 0
        if order == 0:
            num_path_candidates = 1
        else:
            num_path_candidates = num_active_primitives * (
                num_active_primitives - 1
            ) ** (order - 1)
        for paths in tqdm(
            example_scene.compute_paths(order=order, chunk_size=1_000_000),
            desc="Processing path candidates",
            leave=False,
        ):
            num_valid_paths += paths.num_valid_paths
        paths.plot(showlegend=False)
        print(
            f"(order = {order}) Found {num_valid_paths:2d} valid paths out of {num_path_candidates:6d} path candidates."
        )

fig
(order = 0) Found  1 valid paths out of      1 path candidates.
(order = 1) Found  2 valid paths out of     62 path candidates.
(order = 2) Found  1 valid paths out of   3782 path candidates.
(order = 3) Found  1 valid paths out of 230702 path candidates.

We then collect some statistics about the scenes we will be training on.

Note

Unlike with the example scene, the training set will include TX and RX positions outside the main street canyon.

You can override this by modifying the parameters passed to the train_dataloader and scene_fn functions when instantiating the Agent object.

# Keyword parameters passed 'train_dataloader' and the agent's 'scene_fn' function.
scene_fn_kwargs = {"sample_in_canyon": False}

scenes = train_dataloader(key=jr.key(1234), **scene_fn_kwargs)

num_valid_paths = defaultdict(list)  # Number of valid paths
num_total_paths = defaultdict(list)  # Total number of paths (candidates)

colors = {0: "black", 1: "blue", 2: "red", 3: "green"}

for _ in trange(
    10_000,
    desc="Collecting statistics over many realizations",
    leave=False,
):
    scene = next(scenes)
    num_active_primitives = scene.mesh.num_active_primitives
    for order in colors:
        num_valid_paths[order].append(0)
        if order == 0:
            num_path_candidates = 1
        else:
            num_path_candidates = num_active_primitives * (
                num_active_primitives - 1
            ) ** (order - 1)
        num_total_paths[order].append(num_path_candidates)
        for paths in scene.compute_paths(order=order, chunk_size=100_000):
            num_valid_paths[order][-1] += paths.mask.sum()

fig = go.Figure()
for order, color in colors.items():
    print(f"\tStatistics for {order = }:")
    num_valid = jnp.array(num_valid_paths[order])
    num_total = jnp.array(num_total_paths[order])
    where = num_total != 0  # We discard rare cases without any candidate
    avg_total = int(num_total.mean(where=where))
    frac = float((num_valid / num_total).mean(where=where))
    frac_one = float((num_valid > 0).sum(where=where) / where.sum())
    print(
        f"\t- an average of {avg_total} path candidates exist;\n"
        f"\t- out of which {frac:.8%} of the paths are valid\n"
        f"\t- and {frac_one:.8%}% of the scenes contained at least one valid path."
    )
    fig.add_histogram(
        x=num_valid,
        histnorm="percent",
        name=f"{order = }",
        marker_color=color,
    )

fig.update_layout(
    title="Distribution of the number of valid ray paths per scene",
    xaxis_title="Number of valid paths",
    yaxis_title="Percentage of scenes (%)",
)
	Statistics for order = 0:
	- an average of 1 path candidates exist;
	- out of which 34.15000141% of the paths are valid
	- and 34.15000141%% of the scenes contained at least one valid path.
	Statistics for order = 1:
	- an average of 56 path candidates exist;
	- out of which 1.46392491% of the paths are valid
	- and 37.62000203%% of the scenes contained at least one valid path.
	Statistics for order = 2:
	- an average of 3365 path candidates exist;
	- out of which 0.02076994% of the paths are valid
	- and 24.73000139%% of the scenes contained at least one valid path.
	Statistics for order = 3:
	- an average of 211723 path candidates exist;
	- out of which 0.00033600% of the paths are valid
	- and 12.03000024%% of the scenes contained at least one valid path.

Reward function#

As for every reinforcement learning problem, we need to define a reward function that will guide the training of our model.

Here, we use a binary reward function that indicates if a generated path candidate leads to a valid ray path or not.

When computing paths with TriangleScene.compute_paths, assuming default parameters, the Paths.mask attribute is a boolean array where each entry indicates if the corresponding ray path is valid or not. As we only have one pair of TX and RX, the sum of all entries in this mask, when converted to floating point values, is either 1 or 0. This will be our reward.

The default reward function can be imported with:

from sampling_paths.metrics import reward_fn

Differentiable reward#

As differt is a differentiable library, thanks to JAX, the reward function will be differentiable with respect to its arguments.

Simulating non-differentiable reward#

In our previous work, we already observed that using a non-differentiable reward leads to worse training performance.

To observe this effect, you can use jax.lax.stop_gradient to prevent gradients from flowing through the reward function. To do so, when creating the model (see below), pass the following reward function:

Model(
    ...,
    reward_fn=lambda path_candidate, scene: jax.lax.stop_gradient(
        reward_fn(path_candidate, scene)
    ),
)

Machine Learning model#

Our model is made of four modules: an outer flow model that returns the flows between a parent state and its child states, and three inner models, each encoding different aspects of the scene.

The model’s internals are detailed in the paper and will not be discussed here.

Below, we instantiate the model with default parameters.

order = 2  # Change this to try different path orders
num_embeddings = 128
model = Model(
    order=order,
    num_embeddings=num_embeddings,
    width_size=2 * num_embeddings,
    depth=2,
    key=jr.key(1234),
)
model

Training with an agent#

To train the model, we define an agent that will handle the training loop for us.

This is mainly syntactic sugar to avoid writing boilerplate code, as the agent does not add any new functionality to the model itself.

Instead, it provides a convenient interface to train the model over multiple epochs, track metrics, and alternate between the standard training step and the replay buffer training step, which are automatically handled by the agent’s train method.

To optimize the model’s parameters, we rely on the optax module that provides convenient optimizers, from which we use the optax.contrib.muon optimizer.

Every few training steps, we will ask the agent to evaluate the model’s performance to obtain the curves presented in the paper.

class Results(eqx.Module):
    episodes: Int[Array, " n"]
    loss_values: Float[Array, " n"]
    success_rates: Float[Array, " n"]
    hit_rates: Float[Array, " n"]
    fill_rates: Float[Array, " n"] | None


def train(
    model: Model,
    *,
    num_validation_scenes: int = 100,
    num_episodes: int = 500_000,
    evaluate_every: int = 1_000,
    key: PRNGKeyArray,
) -> tuple[Model, Results]:
    key_episodes, key_valid_samples = jr.split(key)
    valid_keys = validation_scene_keys(
        order=model.order,
        num_scenes=num_validation_scenes,
        key=key_valid_samples,
        **scene_fn_kwargs,
    )

    agent = Agent(
        model=model,
        scene_fn=partial(random_scene, **scene_fn_kwargs),
    )

    episodes = []
    loss_values = []
    success_rates = []
    hit_rates = []
    fill_rates = []

    progress_bar = tqdm(
        jr.split(key_episodes, num_episodes), desc="Training model"
    )

    for episode, key_episode in enumerate(progress_bar):
        scene_key, train_key, eval_key = jr.split(key_episode, 3)

        # Train
        agent, loss_value = agent.train(scene_key, key=train_key)

        # Evaluate
        if episode % evaluate_every == 0:
            accuracy, hit_rate = agent.evaluate(valid_keys, key=eval_key)

            progress_bar.set_description(
                "Training model - "
                f"loss: {loss_value:.1e}, "
                f"success rate: {accuracy:.2%}, "
                f"hit rate: {hit_rate:.2%}"
                + (
                    f", buffer filled: {agent.replay_buffer.fill_ratio:.2%}"
                    if agent.replay_buffer is not None
                    else ""
                ),
            )

            episodes.append(episode)
            loss_values.append(loss_value)
            success_rates.append(100 * accuracy)
            hit_rates.append(100 * hit_rate)
            if agent.replay_buffer is not None:
                fill_rates.append(100 * agent.replay_buffer.fill_ratio)

    results = Results(
        episodes=jnp.asarray(episodes),
        loss_values=jnp.asarray(loss_values),
        success_rates=jnp.asarray(success_rates),
        hit_rates=jnp.asarray(hit_rates),
        fill_rates=jnp.asarray(fill_rates) if fill_rates else None,
    )

    return eqx.nn.inference_mode(agent.model), results
trained_model, results = train(model, key=jr.key(1234))
plt.figure()
plt.title(f"Train losses (K = {order})")
plt.semilogy(results.episodes, results.loss_values)
plt.xlabel("Training episodes")
plt.ylabel("Loss")
plt.show()
../_images/89c4aa63db3747a0107a1d5212e9a62e83f0a16ac7eafd2704a99ac2c66095b9.png
_, ax1 = plt.subplots()
ax1.set_title(f"Train accuracy and hit rate (K = {order})")
ax1.set_xlabel("Training episodes")
ax1.set_ylabel("Accuracy (%)")
ax1.plot(results.episodes, results.success_rates, label="Accuracy")
ax2 = ax1.twinx()
ax2.set_ylabel("Hit rate (%)")
ax2.plot(results.episodes, results.hit_rates, "k--", label="Hit Rate")
plt.show()
../_images/dfd05d19c3736a71e210a57fb91d757b6b3cf2ccc325a8aad8423f600bfd6259.png
if results.fill_rates is not None:
    plt.figure()
    plt.title(f"Replay buffer fill ratio (K = {order})")
    plt.semilogx(results.episodes, results.fill_rates, label="Fill Ratio")
    plt.xlabel("Training episodes")
    plt.ylabel("Fill Ratio (%)")
    plt.show()
../_images/668b81c76b6ce90b4b3afb8f667f06c9aba4b7a0e5d887fa4db9ad12433342ae.png
trained_models = {trained_model.order: trained_model}
save_weights = True

for order in tqdm([0, 1, 2, 3], desc="Orders"):
    if order in trained_models:
        continue

    model = Model(
        order=order,
        num_embeddings=num_embeddings,
        width_size=2 * num_embeddings,
        depth=2,
        key=jr.key(1234),
    )
    if order > 0:
        trained_model, _ = train(model, key=jr.key(1234))
    else:
        # No need to train LOS model
        trained_model = eqx.nn.inference_mode(model)
    trained_models[order] = trained_model

if save_weights:
    for order, trained_model in trained_models.items():
        if order == 0:
            continue  # Don't save line-of-sight model
        with open(f"model_{order}.eqx", "wb") as f:
            eqx.tree_serialise_leaves(f, trained_model)

Using the model for coverage map estimation#

def compute_cm(
    receivers: Float[Array, "dim_x dim_y 3"],
    base_scene: TriangleScene,
    model: Model,
    batch_size: int,
    key: PRNGKeyArray | None,
) -> Float[Array, "dim_x dim_y"]:
    # N.B.: we use sequential_vmap to avoid out-of-memory errors in the exhaustive case
    @jax.custom_batching.sequential_vmap
    def compute_cm_for_one_receiver(
        receiver: Float[Array, "3"],
    ) -> Float[Array, ""]:
        scene = eqx.tree_at(lambda s: s.receivers, base_scene, receiver)
        if key is not None and model.order > 0:
            path_candidates = jax.vmap(lambda key: model(scene, key=key))(
                jr.split(key, batch_size)
            )
            paths = scene.compute_paths(path_candidates=path_candidates)
            paths = (
                paths.mask_duplicate_objects()
            )  # Important to remove duplicate path candidates
        else:
            # Exhaustive method does not contain duplicate path candidates
            paths = scene.compute_paths(order=model.order)
        r = path_lengths(paths.vertices)
        e = safe_divide(1.0, r)
        return (e * e).sum(where=paths.mask)

    return jax.vmap(jax.vmap(compute_cm_for_one_receiver))(receivers)


M = 10
batch = (300, 300)
z0 = 1.5
base_scene = eqx.tree_at(
    lambda s: s.transmitters, BASE_SCENE, jnp.array([[0.0, 0.0, 32.0]])
)
receivers = base_scene.with_receivers_grid(*batch, height=z0).receivers

# Only needed for plotting purposes
x, y, _ = jnp.unstack(receivers, axis=-1)

fig = make_subplots(
    rows=1,
    cols=2,
    specs=[[{"type": "scene"}, {"type": "scene"}]],
)

with reuse(figure=fig) as fig:
    base_scene.plot(
        tx_kwargs={"marker_color": "#636EFA"}, showlegend=False, row=1, col=1
    )
    base_scene.plot(
        tx_kwargs={"marker_color": "#636EFA"}, showlegend=False, row=1, col=2
    )

    G_model = G_exhau = jnp.zeros(batch)

    for order in [0, 1, 2, 3]:
        G_model += compute_cm(
            receivers, base_scene, trained_models[order], M, jr.key(0)
        )
        G_exhau += compute_cm(
            receivers, base_scene, trained_models[order], 0, None
        )

    G_min = jnp.minimum(G_model.min(), G_exhau.min())
    G_max = jnp.maximum(G_model.max(), G_exhau.max())

    draw_image(
        10 * jnp.log10(G_model),
        x=x[0, :],
        y=y[:, 0],
        z0=z0,
        colorbar={"title": "Gain (dB)"},
        colorscale="viridis",
        cmin=float(10 * jnp.log10(G_min)),
        cmax=float(10 * jnp.log10(G_max)),
        row=1,
        col=1,
        showscale=False,
    )
    draw_image(
        10 * jnp.log10(G_exhau),
        x=x[0, :],
        y=y[:, 0],
        z0=z0,
        colorbar={"title": "Gain (dB)"},
        colorscale="viridis",
        cmin=float(10 * jnp.log10(G_min)),
        cmax=float(10 * jnp.log10(G_max)),
        row=1,
        col=2,
    )
fig