"""Forward+ clustered lighting — CPU-side data model and light-to-cluster assignment.
Divides the camera view frustum into a 3D grid of clusters (tiles_x * tiles_y * depth_slices).
Each cluster stores the indices of lights that overlap it. The graphics backend dispatches
compute shaders using the output buffers; this module provides the data model and assignment logic.
Depth slices use exponential distribution (log-space) matching the non-linear depth buffer,
giving more resolution near the camera where it matters most.
"""
from __future__ import annotations
import logging
import math
from dataclasses import dataclass, field
import numpy as np
log = logging.getLogger(__name__)
__all__ = [
"ClusterGrid",
"LightCluster",
"assign_lights",
"ClusterConfig",
]
[docs]
@dataclass(slots=True)
class ClusterConfig:
"""Configuration for the cluster grid dimensions."""
tiles_x: int = 16
tiles_y: int = 9
depth_slices: int = 24
max_lights_per_cluster: int = 256
[docs]
@dataclass(slots=True)
class LightCluster:
"""A single cluster cell in the 3D grid. Holds indices of overlapping lights."""
light_indices: list[int] = field(default_factory=list)
@property
def light_count(self) -> int:
return len(self.light_indices)
[docs]
def clear(self) -> None:
self.light_indices.clear()
[docs]
class ClusterGrid:
"""3D grid of clusters dividing the view frustum.
The grid is addressed as ``(tile_x, tile_y, slice_z)`` where:
- ``tile_x`` / ``tile_y`` subdivide the screen in pixel-space
- ``slice_z`` subdivides the depth range in log-space
Attributes:
config: Grid dimensions and limits.
clusters: Flat array of ``LightCluster`` objects (x-major order).
slice_boundaries: Precomputed view-space Z boundaries per depth slice.
"""
__slots__ = ("config", "clusters", "slice_boundaries", "_total")
def __init__(self, config: ClusterConfig | None = None) -> None:
self.config = config or ClusterConfig()
self._total = self.config.tiles_x * self.config.tiles_y * self.config.depth_slices
self.clusters: list[LightCluster] = [LightCluster() for _ in range(self._total)]
self.slice_boundaries: np.ndarray = np.empty(0, dtype=np.float32)
# -- Indexing --
[docs]
def index(self, tx: int, ty: int, sz: int) -> int:
"""Flat index from 3D grid coordinates (x-major: x + y * tiles_x + z * tiles_x * tiles_y)."""
return tx + ty * self.config.tiles_x + sz * self.config.tiles_x * self.config.tiles_y
[docs]
def cluster_at(self, tx: int, ty: int, sz: int) -> LightCluster:
"""Return the cluster at grid coordinates."""
return self.clusters[self.index(tx, ty, sz)]
# -- Lifecycle --
[docs]
def clear(self) -> None:
"""Clear all light assignments from every cluster."""
for c in self.clusters:
c.clear()
[docs]
def rebuild(self, near: float, far: float) -> None:
"""Recompute depth slice boundaries for the given clip planes.
Uses exponential (log-space) distribution::
depth[i] = near * (far / near) ** (i / num_slices)
This concentrates slices near the camera, matching perspective depth precision.
"""
n = self.config.depth_slices
ratio = far / near if near > 0 else 1.0
self.slice_boundaries = np.empty(n + 1, dtype=np.float32)
for i in range(n + 1):
self.slice_boundaries[i] = near * (ratio ** (i / n))
[docs]
def depth_slice(self, view_z: float) -> int:
"""Return the depth slice index for a view-space Z value (positive = in front of camera).
Returns -1 if outside the frustum range.
"""
if len(self.slice_boundaries) < 2:
return -1
near = float(self.slice_boundaries[0])
far = float(self.slice_boundaries[-1])
if view_z < near or view_z > far:
return -1
# Inverse of: depth[i] = near * (far/near)^(i/n) => i = n * log(z/near) / log(far/near)
ratio = far / near
if ratio <= 1.0:
return 0
s = int(self.config.depth_slices * math.log(view_z / near) / math.log(ratio))
return min(s, self.config.depth_slices - 1)
# -- GPU output buffers --
[docs]
def to_light_index_buffer(self) -> np.ndarray:
"""Pack all cluster light indices into a flat uint32 array for GPU upload.
Returns concatenated light index lists for all clusters.
"""
parts: list[int] = []
for c in self.clusters:
parts.extend(c.light_indices)
return np.array(parts, dtype=np.uint32) if parts else np.zeros(0, dtype=np.uint32)
[docs]
def to_tile_buffer(self) -> np.ndarray:
"""Pack cluster metadata into an (N, 2) uint32 array: ``(offset, count)`` per cluster."""
buf = np.empty((self._total, 2), dtype=np.uint32)
offset = 0
for i, c in enumerate(self.clusters):
count = len(c.light_indices)
buf[i, 0] = offset
buf[i, 1] = count
offset += count
return buf
@property
def total_clusters(self) -> int:
return self._total
# ============================================================================
# Light-to-cluster assignment (CPU-side)
# ============================================================================
def _extract_light_data(lights: list) -> tuple[np.ndarray, np.ndarray, np.ndarray, list[str]]:
"""Extract position, range, direction, and type from Light3D nodes.
Returns:
positions: (N, 3) float32 — world-space positions
ranges: (N,) float32 — light range (inf for directional)
directions: (N, 3) float32 — forward vector (meaningful for spot/directional)
types: list of "directional", "point", or "spot"
"""
n = len(lights)
positions = np.zeros((n, 3), dtype=np.float32)
ranges = np.full(n, np.inf, dtype=np.float32)
directions = np.zeros((n, 3), dtype=np.float32)
types: list[str] = []
for i, light in enumerate(lights):
pos = light.world_position
positions[i] = [float(pos[0]), float(pos[1]), float(pos[2])]
fwd = light.forward
directions[i] = [float(fwd[0]), float(fwd[1]), float(fwd[2])]
cls_name = type(light).__name__
if cls_name == "DirectionalLight3D":
types.append("directional")
elif cls_name == "SpotLight3D":
types.append("spot")
ranges[i] = float(light.range)
else:
types.append("point")
ranges[i] = float(light.range)
return positions, ranges, directions, types
def _sphere_aabb_in_cluster(
cx_min: float, cx_max: float,
cy_min: float, cy_max: float,
cz_min: float, cz_max: float,
sx: float, sy: float, sz: float, sr: float,
) -> bool:
"""Test if a sphere (centre sx,sy,sz radius sr) overlaps an AABB."""
# Closest point on AABB to sphere centre
dx = max(cx_min - sx, 0.0, sx - cx_max)
dy = max(cy_min - sy, 0.0, sy - cy_max)
dz = max(cz_min - sz, 0.0, sz - cz_max)
return (dx * dx + dy * dy + dz * dz) <= sr * sr
[docs]
def assign_lights(
lights: list,
view_matrix: np.ndarray,
projection_matrix: np.ndarray,
viewport_size: tuple[int, int],
grid: ClusterGrid,
) -> ClusterGrid:
"""Assign lights to clusters based on camera view and projection.
This is the CPU-side culling pass. For each light, it determines which clusters
the light's bounding volume overlaps and records the light index in those clusters.
Args:
lights: List of ``Light3D`` nodes (PointLight3D, SpotLight3D, DirectionalLight3D).
view_matrix: 4x4 camera view matrix (row-major numpy array).
projection_matrix: 4x4 camera projection matrix (row-major numpy array).
viewport_size: ``(width, height)`` in pixels.
grid: ``ClusterGrid`` to populate (will be cleared first).
Returns:
The same ``grid`` instance, populated with light assignments.
"""
grid.clear()
if not lights:
return grid
cfg = grid.config
vw, vh = viewport_size
tile_w = vw / cfg.tiles_x
tile_h = vh / cfg.tiles_y
if len(grid.slice_boundaries) < 2:
return grid
near = float(grid.slice_boundaries[0])
far = float(grid.slice_boundaries[-1])
positions, ranges, directions, types = _extract_light_data(lights)
# Transform light positions to view space: view_pos = (view_matrix @ [x, y, z, 1])
n = len(lights)
ones = np.ones((n, 1), dtype=np.float32)
world_pos_h = np.hstack([positions, ones]) # (N, 4)
view_pos = (view_matrix @ world_pos_h.T).T[:, :3] # (N, 3)
# In our convention, view space has -Z into the screen. Convert to positive depth.
view_z = -view_pos[:, 2]
for li in range(n):
light_type = types[li]
light_range = float(ranges[li])
vz = float(view_z[li])
vx = float(view_pos[li, 0])
vy = float(view_pos[li, 1])
if light_type == "directional":
# Directional lights affect all clusters
for ci in range(grid.total_clusters):
if len(grid.clusters[ci].light_indices) < cfg.max_lights_per_cluster:
grid.clusters[ci].light_indices.append(li)
continue
# Bounding sphere in view space: centre (vx, vy, -vz_pos), radius = light_range
# Determine depth slice range
z_min = vz - light_range
z_max = vz + light_range
if z_max < near or z_min > far:
continue # Entirely outside frustum depth
slice_min = max(0, grid.depth_slice(max(z_min, near)))
slice_max = min(cfg.depth_slices - 1, grid.depth_slice(min(z_max, far)))
if slice_min < 0:
slice_min = 0
if slice_max < 0:
continue
# Project light bounding sphere to screen space to find tile range
# For each depth slice, compute the tile extents
for sz in range(slice_min, slice_max + 1):
sz_near = float(grid.slice_boundaries[sz])
# Conservative screen-space bounds: project the sphere at this depth
# Use the closest depth within this slice to the light for max screen extent
proj_z = max(sz_near, near)
if proj_z < 1e-6:
proj_z = 1e-6
# Screen-space radius of the bounding sphere at this depth
# Using projection matrix element [0,0] for X and [1,1] for Y
# proj[0,0] = f/aspect, proj[1,1] = f (or -f for Vulkan)
px = abs(float(projection_matrix[0, 0]))
py = abs(float(projection_matrix[1, 1]))
screen_rx = (light_range * px / proj_z) * 0.5 * vw
screen_ry = (light_range * py / proj_z) * 0.5 * vh
# Project light centre to screen
screen_cx = (vx * px / proj_z + 1.0) * 0.5 * vw if proj_z > 1e-6 else vw * 0.5
screen_cy = (vy * py / proj_z + 1.0) * 0.5 * vh if proj_z > 1e-6 else vh * 0.5
# Tile range
tx_min = max(0, int((screen_cx - screen_rx) / tile_w))
tx_max = min(cfg.tiles_x - 1, int((screen_cx + screen_rx) / tile_w))
ty_min = max(0, int((screen_cy - screen_ry) / tile_h))
ty_max = min(cfg.tiles_y - 1, int((screen_cy + screen_ry) / tile_h))
for ty in range(ty_min, ty_max + 1):
for tx in range(tx_min, tx_max + 1):
cluster = grid.cluster_at(tx, ty, sz)
if len(cluster.light_indices) < cfg.max_lights_per_cluster and li not in cluster.light_indices:
cluster.light_indices.append(li)
return grid