Ray Tracing at City Scale#
import jax.numpy as jnp
import numpy as np
from tqdm.notebook import tqdm, trange
from vispy.scene.visuals import Image
from vispy.visuals.transforms import STTransform
import differt.plotting as dplt
from differt.geometry import TriangleMesh
from differt.geometry.triangle_mesh import (
triangles_contain_vertices_assuming_inside_same_plane,
)
from differt.rt.image_method import (
consecutive_vertices_are_on_same_side_of_mirrors,
image_method,
)
from differt.rt.utils import (
generate_all_path_candidates_chunks_iter,
rays_intersect_any_triangle,
rays_intersect_triangles,
)
mesh_file = "bruxelles.obj"
mesh = TriangleMesh.load_obj(mesh_file)
canvas = mesh.plot()
tx = jnp.array([-40.0, 75, 30.0])
rx = jnp.array([+20.0, 108.034, 1.50])
dplt.draw_markers(np.array([tx, rx]), ["tx", "rx"], canvas=canvas)
color = ["black", "green", "orange", "yellow"]
all_triangle_vertices = jnp.take(mesh.vertices, mesh.triangles, axis=0)
image_np = np.random.rand(1000, 1000).astype(np.float32)
x = all_triangle_vertices[..., 0]
y = all_triangle_vertices[..., 1]
image = Image(image_np, interpolation="nearest", method="subdivide", cmap="jet")
image.transform = STTransform(
scale=(abs(np.max(x) - np.min(x)) / 1000, abs(np.max(y) - np.min(y)) / 1000),
translate=(np.min(x), np.min(y), 1.5),
)
num_triangles = mesh.triangles.shape[0]
for order in trange(
0, 2, leave=False
): # You probably don't want to try order > 1 (too slow if testing all paths)
# Prepare input arrays
for path_candidates in tqdm(
generate_all_path_candidates_chunks_iter(
num_triangles, order, chunk_size=2_000_000
),
leave=False,
):
# print(f"{path_candidates.shape = }")
num_path_candidates = path_candidates.shape[0]
from_vertices = jnp.tile(tx, (num_path_candidates, 1))
to_vertices = jnp.tile(rx, (num_path_candidates, 1))
triangles = jnp.take(mesh.triangles, path_candidates, axis=0)
triangle_vertices = jnp.take(mesh.vertices, triangles, axis=0)
mirror_vertices = triangle_vertices[..., 0, :]
mirror_normals = jnp.take(mesh.normals, path_candidates, axis=0)
paths = image_method(
from_vertices, to_vertices, mirror_vertices, mirror_normals
)
# print(f"{paths.shape = }, {triangle_vertices.shape = }")
mask = triangles_contain_vertices_assuming_inside_same_plane(
triangle_vertices,
paths,
)
mask = jnp.all(mask, axis=-1)
full_paths = jnp.concatenate(
(
jnp.expand_dims(from_vertices[mask, ...], axis=-2),
paths[mask, ...],
jnp.expand_dims(to_vertices[mask, ...], axis=-2),
),
axis=-2,
)
mask = consecutive_vertices_are_on_same_side_of_mirrors(
full_paths, mirror_vertices[mask, ...], mirror_normals[mask, ...]
)
mask = jnp.all(mask, axis=-1)
ray_origins = full_paths[..., :-1, :]
ray_directions = jnp.diff(full_paths, axis=-2)
ray_origins = jnp.repeat(
jnp.expand_dims(ray_origins, axis=-2), num_triangles, axis=-2
)
ray_directions = jnp.repeat(
jnp.expand_dims(ray_directions, axis=-2), num_triangles, axis=-2
)
t, hit = rays_intersect_triangles(
ray_origins,
ray_directions,
jnp.broadcast_to(all_triangle_vertices, (*ray_origins.shape, 3)),
)
intersect = (t < 0.999) & hit
intersect = jnp.any(intersect, axis=(-1, -2))
mask = mask & ~intersect
full_paths = full_paths[mask, ...]
dplt.draw_paths(full_paths, canvas=canvas)
view = dplt.view_from_canvas(canvas)
view.add(image)
view.camera.set_state(
{
"scale_factor": 138.81554751457762,
"center": (20.0, 108.034, 46.0),
"fov": 45.0,
"elevation": 13.0,
"azimuth": 39.0,
"roll": 0.0,
}
)
canvas
WARNING: QOpenGLWidget is not supported on this platform.
snapshot
mesh_file = "manhattan.obj"
mesh = TriangleMesh.load_obj(mesh_file)
canvas = mesh.plot()
tx = jnp.array([-40.0, 75, 30.0])
rx = jnp.array([+20.0, 108.034, 1.50])
dplt.draw_markers(np.array([tx, rx]), ["tx", "rx"], canvas=canvas)
all_triangle_vertices = jnp.take(mesh.vertices, mesh.triangles, axis=0)
x_tri = all_triangle_vertices[..., 0]
y_tri = all_triangle_vertices[..., 1]
N = 125
x = jnp.linspace(x_tri.min(), x_tri.max(), N)
y = jnp.linspace(y_tri.min(), y_tri.max(), N)
X, Y = jnp.meshgrid(x, y)
Z = 1.5 * jnp.ones_like(X)
RX = jnp.stack((X, Y, Z), axis=-1)
power = jnp.zeros_like(X)
num_triangles = mesh.triangles.shape[0]
for order in range(
0, 0
): # You probably don't want to try order > 1 (too slow if testing all paths)
# Prepare input arrays
for path_candidates in generate_all_path_candidates_chunks_iter(
num_triangles, order, chunk_size=1000
):
num_path_candidates = path_candidates.shape[0]
from_vertices = jnp.tile(tx, (N, N, num_path_candidates, 1))
to_vertices = jnp.tile(
jnp.expand_dims(RX, axis=-2), (1, 1, num_path_candidates, 1)
)
path_candidates = jnp.tile(path_candidates, (N, N, 1, 1))
triangles = jnp.take(mesh.triangles, path_candidates, axis=0)
triangle_vertices = jnp.take(mesh.vertices, triangles, axis=0)
mirror_vertices = triangle_vertices[..., 0, :]
mirror_normals = jnp.take(mesh.normals, path_candidates, axis=0)
paths = image_method(
from_vertices, to_vertices, mirror_vertices, mirror_normals
)
mask = triangles_contain_vertices_assuming_inside_same_plane(
triangle_vertices,
paths,
)
mask_1 = jnp.all(mask, axis=-1)
full_paths = jnp.concatenate(
(
jnp.expand_dims(from_vertices, axis=-2),
paths,
jnp.expand_dims(to_vertices, axis=-2),
),
axis=-2,
)
mask_2 = consecutive_vertices_are_on_same_side_of_mirrors(
full_paths, mirror_vertices, mirror_normals
)
mask_2 = jnp.all(mask_2, axis=-1)
ray_origins = full_paths[..., :-1, :]
ray_directions = jnp.diff(full_paths, axis=-2)
intersect_any = rays_intersect_any_triangle(
ray_origins,
ray_directions,
all_triangle_vertices,
)
mask_2 = mask_2 & ~jnp.any(intersect_any, axis=-1)
mask = mask_1 & mask_2
lengths = jnp.linalg.norm(full_paths, axis=-1).sum(axis=-1)
power_per_path = 1.0 / (lengths * lengths)
power_per_path *= mask.astype(power_per_path.dtype)
# print(f"{power_per_path.shape = }")
# print(f"{mask_1.shape = }")
# print(f"{mask_2.shape = }")
# print(lengths.shape)
power += power_per_path.sum(axis=-1)
image = Image(power, interpolation="nearest", method="subdivide", cmap="jet")
image.transform = STTransform(
scale=(abs(np.max(x) - np.min(x)) / N, abs(np.max(y) - np.min(y)) / N),
translate=(np.min(x), np.min(y), 1.5),
)
view = dplt.view_from_canvas(canvas)
view.add(image)
canvas
WARNING: QOpenGLWidget is not supported on this platform.
snapshot
mesh_file = "manhattan_small.obj"
mesh = TriangleMesh.load_obj(mesh_file)
canvas = mesh.plot()
tx = jnp.array([-40.0, 75, 30.0])
rx = jnp.array([+20.0, 108.034, 1.50])
dplt.draw_markers(np.array([tx, rx]), ["tx", "rx"], canvas=canvas)
all_triangle_vertices = jnp.take(mesh.vertices, mesh.triangles, axis=0)
x_tri = all_triangle_vertices[..., 0]
y_tri = all_triangle_vertices[..., 1]
N = 125
x = jnp.linspace(x_tri.min(), x_tri.max(), N)
y = jnp.linspace(y_tri.min(), y_tri.max(), N)
X, Y = jnp.meshgrid(x, y)
Z = 1.5 * jnp.ones_like(X)
RX = jnp.stack((X, Y, Z), axis=-1)
power = jnp.zeros_like(X)
num_triangles = mesh.triangles.shape[0]
for order in trange(
0, 1, leave=False
): # You probably don't want to try order > 1 (too slow if testing all paths)
# Prepare input arrays
for path_candidates in tqdm(
generate_all_path_candidates_chunks_iter(num_triangles, order, chunk_size=1000),
leave=False,
):
num_path_candidates = path_candidates.shape[0]
from_vertices = jnp.tile(tx, (N, N, num_path_candidates, 1))
to_vertices = jnp.tile(
jnp.expand_dims(RX, axis=-2), (1, 1, num_path_candidates, 1)
)
path_candidates = jnp.tile(path_candidates, (N, N, 1, 1))
triangles = jnp.take(mesh.triangles, path_candidates, axis=0)
triangle_vertices = jnp.take(mesh.vertices, triangles, axis=0)
mirror_vertices = triangle_vertices[..., 0, :]
mirror_normals = jnp.take(mesh.normals, path_candidates, axis=0)
paths = image_method(
from_vertices, to_vertices, mirror_vertices, mirror_normals
)
mask = triangles_contain_vertices_assuming_inside_same_plane(
triangle_vertices,
paths,
)
mask_1 = jnp.all(mask, axis=-1)
full_paths = jnp.concatenate(
(
jnp.expand_dims(from_vertices, axis=-2),
paths,
jnp.expand_dims(to_vertices, axis=-2),
),
axis=-2,
)
mask_2 = consecutive_vertices_are_on_same_side_of_mirrors(
full_paths, mirror_vertices, mirror_normals
)
mask_2 = jnp.all(mask_2, axis=-1)
ray_origins = full_paths[..., :-1, :]
ray_directions = jnp.diff(full_paths, axis=-2)
intersect_any = rays_intersect_any_triangle(
ray_origins,
ray_directions,
all_triangle_vertices,
)
mask_2 = mask_2 & ~jnp.any(intersect_any, axis=-1)
mask = mask_1 & mask_2
lengths = jnp.linalg.norm(full_paths, axis=-1).sum(axis=-1)
power_per_path = 1.0 / (lengths * lengths)
power_per_path *= mask.astype(power_per_path.dtype)
# print(f"{power_per_path.shape = }")
# print(f"{mask_1.shape = }")
# print(f"{mask_2.shape = }")
# print(lengths.shape)
power += power_per_path.sum(axis=-1)
image = Image(power, interpolation="nearest", method="subdivide", cmap="jet")
image.transform = STTransform(
scale=(abs(np.max(x) - np.min(x)) / N, abs(np.max(y) - np.min(y)) / N),
translate=(np.min(x), np.min(y), 1.5),
)
view = dplt.view_from_canvas(canvas)
view.add(image)
canvas
WARNING: QOpenGLWidget is not supported on this platform.
snapshot