differt.geometry.Paths#

class Paths(vertices, objects, mask=None, interaction_types=None, confidence_threshold=0.5)[source]#

Bases: Module, Generic[_M]

A convenient wrapper class around path vertices and object indices.

This class can hold arbitrary many paths, but they must share the same length, i.e., the same number of vertices per path.

The generic type parameter _M can be either None, indicating that all paths are deemed valid, or a JAX array, filled with either boolean or floating point values, see mask for further details.

Attributes

confidence_threshold

A threshold used to decide, e.g., when plotting, whether a given path is valid or not.

interaction_types

An optional array to indicate the type of each interaction.

mask

An optional mask to indicate which paths are valid and should be used.

masked_objects

The array of masked objects, with batched dimensions flattened into one.

masked_vertices

The array of masked vertices, with batched dimensions flattened into one.

num_valid_paths

The number of paths kept by mask.

order

The length (i.e., number of vertices) of each individual path, excluding start and end vertices.

path_length

The length (i.e., number of vertices) of each individual path.

shape

Return the batch shape of the paths.

vertices

The array of path vertices.

objects

The array of object indices.

Methods

group_by_objects()

Return an array of unique object groups.

mask_duplicate_objects([axis])

Return a new paths instance by masking duplicate objects along a given axis.

masked()

Return a flattened version of this object that only keeps valid paths.

multipath_cells([axis])

Return an array of same multipath cell indices.

plot(**kwargs)

Plot the (masked) paths on a 3D scene.

reduce(fun[, axis])

Apply a function on all path vertices and accumulate the result into a scalar value (or an array if axis is provided).

reshape(*batch)

Return a new paths instance with reshaped paths' batch dimensions to match a given shape.

squeeze([axis])

Return a new paths instance by squeezing one or more axes of paths' batch dimensions.

Detailed documentation

confidence_threshold: Float[ArrayLike, ''] = 0.5#

A threshold used to decide, e.g., when plotting, whether a given path is valid or not.

A path is considered valid if its confidence is greater than or equal to this threshold. Unused if mask is of type bool.

group_by_objects()[source]#

Return an array of unique object groups.

This function is useful to group paths that undergo the same types of interactions.

Internally, it uses the same logic as multipath_cells, but applied to object indices rather than on mask.

Return type:

Int[Array, '*batch']

Returns:

An array of group indices.

Examples

The following shows how one can group paths by object groups. There are two different objects, denoted by indices 0 and 1, and each path is made of three vertices.

>>> from differt.geometry import Paths
>>>
>>> objects = jnp.array([
...     [[1, 1, 0], [0, 0, 1], [1, 0, 1], [1, 0, 0], [1, 1, 1], [1, 1, 1]],
...     [[1, 0, 0], [1, 1, 1], [0, 0, 1], [1, 1, 0], [0, 0, 1], [1, 0, 0]],
... ])
>>> key = jax.random.key(1234)
>>> vertices = jax.random.uniform(key, (*objects.shape, 3))
>>> paths = Paths(vertices, objects)
>>> groups = paths.group_by_objects()
>>> groups
Array([[0, 1, 2, 3, 4, 4],
       [3, 4, 1, 0, 1, 3]], dtype=int32)
interaction_types: Int[Array, '*batch path_length-2'] | None = None#

An optional array to indicate the type of each interaction.

If not specified, InteractionType.REFLECTION is assumed.

mask: TypeVar(_M, bound= Bool[Array, '*batch'] | Float[Array, '*batch'] | None) = None#

An optional mask to indicate which paths are valid and should be used.

The mask is kept separately to vertices so that we can keep information about batch *batch dimensions, which would not be possible if we were to directly store valid paths.

If mask contains floating-point values, then they are interpreted as confidence values between 0 and 1, where values greater than or equal to confidence_threshold are considered valid.

mask_duplicate_objects(axis=-1)[source]#

Return a new paths instance by masking duplicate objects along a given axis.

E.g., when generating path candidates from a generative Machine Learning model, see Sampling Path Candidates with Machine Learning, it is possible that the model generates the same path candidate multiple times. This method allows to mask these duplicates, while maintaining the same batch dimensions and compatibility with jax.jit.

Parameters:

axis (int) –

The batch axis along which the unique values are computed.

It defaults to the last axis, which is the axis where different path candidates are stored when generating paths with TriangleScene.compute_paths.

Return type:

Self

Returns:

A new paths instance with masked duplicate objects.

Raises:

ValueError – If the provided axis is out-of-bounds.

masked()[source]#

Return a flattened version of this object that only keeps valid paths.

The returned object has all batch dimensions flattened into one, keeping only the paths where mask is True (or where mask is greater than or equal to confidence_threshold), and the mask attribute is then set to None.

Return type:

Paths[None]

Returns:

A new paths instance with flattened batch dimensions and only valid paths.

property masked_objects: Int[Array, 'num_valid_paths path_length'][source]#

The array of masked objects, with batched dimensions flattened into one.

Similar to masked_vertices, but for objects.

property masked_vertices: Float[Array, 'num_valid_paths path_length 3'][source]#

The array of masked vertices, with batched dimensions flattened into one.

If mask is None, then the returned array is simply vertices with the batch dimensions flattened.

multipath_cells(axis=-1)[source]#

Return an array of same multipath cell indices.

Let the returned array be cell_ids, then cell_ids[i] == cell_ids[j] for all i, j indices if self.mask[i, :] == self.mask[j, :], granted that each array has been reshaped to a two-dimensional array and that axis is the last dimension. Of course, this method handles multiple dimensions and will reshape the output array to match initial shape, except for dimension axis that is removed.

The purpose of this method is to easily identify similar multipath structures, when a group of paths all have the same path candidates that are valid.

If the different path candidates are not all on the same axis, e.g., as a result of masking invalid paths, then you can still use group_by_objects to identify similar paths. Note that group_by_objects will possibly return different indices for different transmitter / receiver pairs, as they have different indices. To avoid this, you should probably slice the objects array to exclude first and last objects, i.e., with self.objects[..., 1:-1].

Parameters:

axis (int) –

The axis along to compare paths.

By default, the last axis is used to match the num_path_candidates axis as returned by TriangleScene.compute_paths.

Return type:

Int[Array, '*partial_batch']

Returns:

The array of group indices.

Raises:

ValueError – If mask is None.

property num_valid_paths: int | Int[Array, ''][source]#

The number of paths kept by mask.

If mask is not None, then the output value can be traced by JAX.

objects: Int[Array, '*batch path_length']#

The array of object indices.

To every path vertex corresponds one object (e.g., a triangle). A placeholder value of -1 can be used in specific cases, like for transmitter and receiver positions.

property order: int[source]#

The length (i.e., number of vertices) of each individual path, excluding start and end vertices.

property path_length: int[source]#

The length (i.e., number of vertices) of each individual path.

plot(**kwargs)[source]#

Plot the (masked) paths on a 3D scene.

Parameters:

kwargs (Any) – Keyword arguments passed to draw_paths.

Return type:

Any

Returns:

The resulting plot output.

reduce(fun, axis=None)[source]#

Apply a function on all path vertices and accumulate the result into a scalar value (or an array if axis is provided).

Parameters:
Return type:

Num[Array, ''] | Num[Array, '*reduced_batch']

Returns:

The sum of the results, with contributions from invalid paths that are set to zero.

reshape(*batch)[source]#

Return a new paths instance with reshaped paths’ batch dimensions to match a given shape.

Parameters:

batch (int) – The new batch shapes.

Return type:

Self

Returns:

A new paths instance with specified batch dimensions.

property shape: tuple[int, ...][source]#

Return the batch shape of the paths.

Returns:

The shape of paths’ batch dimensions.

squeeze(axis=None)[source]#

Return a new paths instance by squeezing one or more axes of paths’ batch dimensions.

Parameters:

axis (int | Sequence[int] | None) – See jax.numpy.squeeze for allowed values.

Return type:

Self

Returns:

A new paths instance with squeezed batch dimensions.

Raises:

ValueError – If one of the provided axes is out-of-bounds, or if trying to squeeze a 0-dimensional batch.

vertices: Float[Array, '*batch path_length 3']#

The array of path vertices.