Source code for simvx.core.clustered_lighting

"""Forward+ clustered lighting — CPU-side data model and light-to-cluster assignment.

Divides the camera view frustum into a 3D grid of clusters (tiles_x * tiles_y * depth_slices).
Each cluster stores the indices of lights that overlap it. The graphics backend dispatches
compute shaders using the output buffers; this module provides the data model and assignment logic.

Depth slices use exponential distribution (log-space) matching the non-linear depth buffer,
giving more resolution near the camera where it matters most.
"""


from __future__ import annotations

import logging
import math
from dataclasses import dataclass, field

import numpy as np

log = logging.getLogger(__name__)

__all__ = [
    "ClusterGrid",
    "LightCluster",
    "assign_lights",
    "ClusterConfig",
]


[docs] @dataclass(slots=True) class ClusterConfig: """Configuration for the cluster grid dimensions.""" tiles_x: int = 16 tiles_y: int = 9 depth_slices: int = 24 max_lights_per_cluster: int = 256
[docs] @dataclass(slots=True) class LightCluster: """A single cluster cell in the 3D grid. Holds indices of overlapping lights.""" light_indices: list[int] = field(default_factory=list) @property def light_count(self) -> int: return len(self.light_indices)
[docs] def clear(self) -> None: self.light_indices.clear()
[docs] class ClusterGrid: """3D grid of clusters dividing the view frustum. The grid is addressed as ``(tile_x, tile_y, slice_z)`` where: - ``tile_x`` / ``tile_y`` subdivide the screen in pixel-space - ``slice_z`` subdivides the depth range in log-space Attributes: config: Grid dimensions and limits. clusters: Flat array of ``LightCluster`` objects (x-major order). slice_boundaries: Precomputed view-space Z boundaries per depth slice. """ __slots__ = ("config", "clusters", "slice_boundaries", "_total") def __init__(self, config: ClusterConfig | None = None) -> None: self.config = config or ClusterConfig() self._total = self.config.tiles_x * self.config.tiles_y * self.config.depth_slices self.clusters: list[LightCluster] = [LightCluster() for _ in range(self._total)] self.slice_boundaries: np.ndarray = np.empty(0, dtype=np.float32) # -- Indexing --
[docs] def index(self, tx: int, ty: int, sz: int) -> int: """Flat index from 3D grid coordinates (x-major: x + y * tiles_x + z * tiles_x * tiles_y).""" return tx + ty * self.config.tiles_x + sz * self.config.tiles_x * self.config.tiles_y
[docs] def cluster_at(self, tx: int, ty: int, sz: int) -> LightCluster: """Return the cluster at grid coordinates.""" return self.clusters[self.index(tx, ty, sz)]
# -- Lifecycle --
[docs] def clear(self) -> None: """Clear all light assignments from every cluster.""" for c in self.clusters: c.clear()
[docs] def rebuild(self, near: float, far: float) -> None: """Recompute depth slice boundaries for the given clip planes. Uses exponential (log-space) distribution:: depth[i] = near * (far / near) ** (i / num_slices) This concentrates slices near the camera, matching perspective depth precision. """ n = self.config.depth_slices ratio = far / near if near > 0 else 1.0 self.slice_boundaries = np.empty(n + 1, dtype=np.float32) for i in range(n + 1): self.slice_boundaries[i] = near * (ratio ** (i / n))
[docs] def depth_slice(self, view_z: float) -> int: """Return the depth slice index for a view-space Z value (positive = in front of camera). Returns -1 if outside the frustum range. """ if len(self.slice_boundaries) < 2: return -1 near = float(self.slice_boundaries[0]) far = float(self.slice_boundaries[-1]) if view_z < near or view_z > far: return -1 # Inverse of: depth[i] = near * (far/near)^(i/n) => i = n * log(z/near) / log(far/near) ratio = far / near if ratio <= 1.0: return 0 s = int(self.config.depth_slices * math.log(view_z / near) / math.log(ratio)) return min(s, self.config.depth_slices - 1)
# -- GPU output buffers --
[docs] def to_light_index_buffer(self) -> np.ndarray: """Pack all cluster light indices into a flat uint32 array for GPU upload. Returns concatenated light index lists for all clusters. """ parts: list[int] = [] for c in self.clusters: parts.extend(c.light_indices) return np.array(parts, dtype=np.uint32) if parts else np.zeros(0, dtype=np.uint32)
[docs] def to_tile_buffer(self) -> np.ndarray: """Pack cluster metadata into an (N, 2) uint32 array: ``(offset, count)`` per cluster.""" buf = np.empty((self._total, 2), dtype=np.uint32) offset = 0 for i, c in enumerate(self.clusters): count = len(c.light_indices) buf[i, 0] = offset buf[i, 1] = count offset += count return buf
@property def total_clusters(self) -> int: return self._total
# ============================================================================ # Light-to-cluster assignment (CPU-side) # ============================================================================ def _extract_light_data(lights: list) -> tuple[np.ndarray, np.ndarray, np.ndarray, list[str]]: """Extract position, range, direction, and type from Light3D nodes. Returns: positions: (N, 3) float32 — world-space positions ranges: (N,) float32 — light range (inf for directional) directions: (N, 3) float32 — forward vector (meaningful for spot/directional) types: list of "directional", "point", or "spot" """ n = len(lights) positions = np.zeros((n, 3), dtype=np.float32) ranges = np.full(n, np.inf, dtype=np.float32) directions = np.zeros((n, 3), dtype=np.float32) types: list[str] = [] for i, light in enumerate(lights): pos = light.world_position positions[i] = [float(pos[0]), float(pos[1]), float(pos[2])] fwd = light.forward directions[i] = [float(fwd[0]), float(fwd[1]), float(fwd[2])] cls_name = type(light).__name__ if cls_name == "DirectionalLight3D": types.append("directional") elif cls_name == "SpotLight3D": types.append("spot") ranges[i] = float(light.range) else: types.append("point") ranges[i] = float(light.range) return positions, ranges, directions, types def _sphere_aabb_in_cluster( cx_min: float, cx_max: float, cy_min: float, cy_max: float, cz_min: float, cz_max: float, sx: float, sy: float, sz: float, sr: float, ) -> bool: """Test if a sphere (centre sx,sy,sz radius sr) overlaps an AABB.""" # Closest point on AABB to sphere centre dx = max(cx_min - sx, 0.0, sx - cx_max) dy = max(cy_min - sy, 0.0, sy - cy_max) dz = max(cz_min - sz, 0.0, sz - cz_max) return (dx * dx + dy * dy + dz * dz) <= sr * sr
[docs] def assign_lights( lights: list, view_matrix: np.ndarray, projection_matrix: np.ndarray, viewport_size: tuple[int, int], grid: ClusterGrid, ) -> ClusterGrid: """Assign lights to clusters based on camera view and projection. This is the CPU-side culling pass. For each light, it determines which clusters the light's bounding volume overlaps and records the light index in those clusters. Args: lights: List of ``Light3D`` nodes (PointLight3D, SpotLight3D, DirectionalLight3D). view_matrix: 4x4 camera view matrix (row-major numpy array). projection_matrix: 4x4 camera projection matrix (row-major numpy array). viewport_size: ``(width, height)`` in pixels. grid: ``ClusterGrid`` to populate (will be cleared first). Returns: The same ``grid`` instance, populated with light assignments. """ grid.clear() if not lights: return grid cfg = grid.config vw, vh = viewport_size tile_w = vw / cfg.tiles_x tile_h = vh / cfg.tiles_y if len(grid.slice_boundaries) < 2: return grid near = float(grid.slice_boundaries[0]) far = float(grid.slice_boundaries[-1]) positions, ranges, directions, types = _extract_light_data(lights) # Transform light positions to view space: view_pos = (view_matrix @ [x, y, z, 1]) n = len(lights) ones = np.ones((n, 1), dtype=np.float32) world_pos_h = np.hstack([positions, ones]) # (N, 4) view_pos = (view_matrix @ world_pos_h.T).T[:, :3] # (N, 3) # In our convention, view space has -Z into the screen. Convert to positive depth. view_z = -view_pos[:, 2] for li in range(n): light_type = types[li] light_range = float(ranges[li]) vz = float(view_z[li]) vx = float(view_pos[li, 0]) vy = float(view_pos[li, 1]) if light_type == "directional": # Directional lights affect all clusters for ci in range(grid.total_clusters): if len(grid.clusters[ci].light_indices) < cfg.max_lights_per_cluster: grid.clusters[ci].light_indices.append(li) continue # Bounding sphere in view space: centre (vx, vy, -vz_pos), radius = light_range # Determine depth slice range z_min = vz - light_range z_max = vz + light_range if z_max < near or z_min > far: continue # Entirely outside frustum depth slice_min = max(0, grid.depth_slice(max(z_min, near))) slice_max = min(cfg.depth_slices - 1, grid.depth_slice(min(z_max, far))) if slice_min < 0: slice_min = 0 if slice_max < 0: continue # Project light bounding sphere to screen space to find tile range # For each depth slice, compute the tile extents for sz in range(slice_min, slice_max + 1): sz_near = float(grid.slice_boundaries[sz]) # Conservative screen-space bounds: project the sphere at this depth # Use the closest depth within this slice to the light for max screen extent proj_z = max(sz_near, near) if proj_z < 1e-6: proj_z = 1e-6 # Screen-space radius of the bounding sphere at this depth # Using projection matrix element [0,0] for X and [1,1] for Y # proj[0,0] = f/aspect, proj[1,1] = f (or -f for Vulkan) px = abs(float(projection_matrix[0, 0])) py = abs(float(projection_matrix[1, 1])) screen_rx = (light_range * px / proj_z) * 0.5 * vw screen_ry = (light_range * py / proj_z) * 0.5 * vh # Project light centre to screen screen_cx = (vx * px / proj_z + 1.0) * 0.5 * vw if proj_z > 1e-6 else vw * 0.5 screen_cy = (vy * py / proj_z + 1.0) * 0.5 * vh if proj_z > 1e-6 else vh * 0.5 # Tile range tx_min = max(0, int((screen_cx - screen_rx) / tile_w)) tx_max = min(cfg.tiles_x - 1, int((screen_cx + screen_rx) / tile_w)) ty_min = max(0, int((screen_cy - screen_ry) / tile_h)) ty_max = min(cfg.tiles_y - 1, int((screen_cy + screen_ry) / tile_h)) for ty in range(ty_min, ty_max + 1): for tx in range(tx_min, tx_max + 1): cluster = grid.cluster_at(tx, ty, sz) if len(cluster.light_indices) < cfg.max_lights_per_cluster and li not in cluster.light_indices: cluster.light_indices.append(li) return grid