"""Collision system — shapes, broadphase AABB grid, narrowphase GJK/SAT, raycasting."""
from __future__ import annotations
import logging
import math
from dataclasses import dataclass, field
from typing import Any
import numpy as np
log = logging.getLogger(__name__)
__all__ = [
"CollisionShape",
"SphereShape",
"BoxShape",
"ConvexShape",
"CapsuleShape",
"SpatialHashGrid",
"RayHit",
"ShapeCastResult",
"CollisionWorld",
]
# ============================================================================
# Collision Shapes
# ============================================================================
[docs]
class CollisionShape:
"""Abstract base for collision shapes.
Subclasses must implement get_aabb() for broadphase and support()
for GJK narrowphase intersection testing.
"""
[docs]
def get_aabb(self, position: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""Return (min_corner, max_corner) in world space."""
raise NotImplementedError
[docs]
def support(self, position: np.ndarray, direction: np.ndarray) -> np.ndarray:
"""GJK support function: furthest point in given direction."""
raise NotImplementedError
[docs]
class SphereShape(CollisionShape):
"""Sphere collision shape defined by a radius."""
__slots__ = ("radius",)
def __init__(self, radius: float = 1.0):
self.radius = radius
[docs]
def get_aabb(self, position: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""Returns AABB for a sphere at the given position."""
r = np.array([self.radius, self.radius, self.radius], dtype=np.float32)
return position - r, position + r
[docs]
def support(self, position: np.ndarray, direction: np.ndarray) -> np.ndarray:
"""Returns the furthest point on the sphere in the given direction."""
d = np.linalg.norm(direction)
if d < 1e-10:
return position.copy()
return position + direction * (self.radius / d)
[docs]
class BoxShape(CollisionShape):
"""Axis-aligned box collision shape defined by half-extents."""
__slots__ = ("half_extents",)
def __init__(self, half_extents: tuple[float, float, float] = (0.5, 0.5, 0.5)):
self.half_extents = np.array(half_extents, dtype=np.float32)
[docs]
def get_aabb(self, position: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""Returns AABB for a box at the given position."""
return position - self.half_extents, position + self.half_extents
[docs]
def support(self, position: np.ndarray, direction: np.ndarray) -> np.ndarray:
"""Returns the furthest vertex of the box in the given direction."""
signs = np.sign(direction)
signs[signs == 0] = 1.0
return position + self.half_extents * signs
[docs]
class ConvexShape(CollisionShape):
"""Convex hull collision shape built from a set of vertices.
Vertices are stored in local space and translated to world space
during AABB and support queries.
"""
__slots__ = ("vertices",)
def __init__(self, vertices: np.ndarray):
self.vertices = np.asarray(vertices, dtype=np.float32)
[docs]
def get_aabb(self, position: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""Returns AABB bounding the convex hull at the given position."""
world = self.vertices + position
return world.min(axis=0), world.max(axis=0)
[docs]
def support(self, position: np.ndarray, direction: np.ndarray) -> np.ndarray:
"""Returns the vertex with maximum projection onto the direction."""
world = self.vertices + position
dots = world @ direction
return world[np.argmax(dots)].copy()
[docs]
class CapsuleShape(CollisionShape):
"""Capsule collision shape — cylinder with hemispherical caps along the Y axis."""
__slots__ = ("radius", "height")
def __init__(self, radius: float = 0.5, height: float = 2.0):
self.radius = radius
self.height = height # total height including caps
[docs]
def get_aabb(self, position: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""Returns AABB for a capsule at the given position.
The capsule is oriented along the Y axis. The half-height of the
cylindrical segment is ``(height / 2 - radius)`` (clamped to 0),
plus the radius of the hemispherical caps gives a total Y extent
of ``height / 2``.
"""
half_y = self.height / 2.0
r = self.radius
ext = np.array([r, half_y, r], dtype=np.float32)
return position - ext, position + ext
[docs]
def support(self, position: np.ndarray, direction: np.ndarray) -> np.ndarray:
"""Returns the furthest point on the capsule in the given direction.
The capsule is the Minkowski sum of a line segment (along Y) and a
sphere of the given radius. The support point is the center of
whichever hemisphere is more aligned with *direction*, offset by
``radius`` in the normalised direction.
"""
half_seg = max(0.0, self.height / 2.0 - self.radius)
# Pick the hemisphere center that projects furthest along direction
if direction[1] >= 0:
center = position + np.array([0.0, half_seg, 0.0], dtype=np.float32)
else:
center = position - np.array([0.0, half_seg, 0.0], dtype=np.float32)
d = np.linalg.norm(direction)
if d < 1e-10:
return center.copy()
return center + direction * (self.radius / d)
# ============================================================================
# Spatial Hash Grid (Broadphase)
# ============================================================================
[docs]
class SpatialHashGrid:
"""Uniform-grid spatial hash for broadphase collision culling.
Maps axis-aligned bounding boxes to grid cells so that overlap queries
only need to check bodies in neighbouring cells rather than every body
in the world.
"""
__slots__ = ("cell_size", "_inv_cell", "_grid")
def __init__(self, cell_size: float = 2.0):
self.cell_size = cell_size
self._inv_cell = 1.0 / cell_size
self._grid: dict[tuple[int, int, int], list[int]] = {}
def _cell_key(self, x: float, y: float, z: float) -> tuple[int, int, int]:
return (
int(math.floor(x * self._inv_cell)),
int(math.floor(y * self._inv_cell)),
int(math.floor(z * self._inv_cell)),
)
[docs]
def insert(self, body_id: int, aabb_min: np.ndarray, aabb_max: np.ndarray) -> None:
"""Insert a body into all grid cells its AABB overlaps."""
lo = self._cell_key(float(aabb_min[0]), float(aabb_min[1]), float(aabb_min[2]))
hi = self._cell_key(float(aabb_max[0]), float(aabb_max[1]), float(aabb_max[2]))
for ix in range(lo[0], hi[0] + 1):
for iy in range(lo[1], hi[1] + 1):
for iz in range(lo[2], hi[2] + 1):
key = (ix, iy, iz)
bucket = self._grid.get(key)
if bucket is None:
self._grid[key] = [body_id]
else:
bucket.append(body_id)
[docs]
def query(self, aabb_min: np.ndarray, aabb_max: np.ndarray) -> set[int]:
"""Return body IDs in cells that overlap the given AABB."""
lo = self._cell_key(float(aabb_min[0]), float(aabb_min[1]), float(aabb_min[2]))
hi = self._cell_key(float(aabb_max[0]), float(aabb_max[1]), float(aabb_max[2]))
result: set[int] = set()
for ix in range(lo[0], hi[0] + 1):
for iy in range(lo[1], hi[1] + 1):
for iz in range(lo[2], hi[2] + 1):
bucket = self._grid.get((ix, iy, iz))
if bucket:
result.update(bucket)
return result
[docs]
def clear(self) -> None:
"""Reset the grid for the next frame."""
self._grid.clear()
# ============================================================================
# Ray Hit Result
# ============================================================================
[docs]
@dataclass
class RayHit:
"""Result of a raycast query.
Attributes:
body: The collider reference that was hit.
point: World-space hit position.
distance: Distance from ray origin to the hit point.
"""
body: Any
point: np.ndarray
distance: float
[docs]
@dataclass
class ShapeCastResult:
"""Result of a shape cast (swept shape) query."""
body: Any
point: np.ndarray
normal: np.ndarray
distance: float
fraction: float
# ============================================================================
# GJK Intersection Test
# ============================================================================
def _gjk_intersect(
shape_a: CollisionShape,
pos_a: np.ndarray,
shape_b: CollisionShape,
pos_b: np.ndarray,
) -> bool:
"""GJK algorithm for convex shape intersection."""
direction = pos_b - pos_a
if np.dot(direction, direction) < 1e-10:
direction = np.array([1.0, 0.0, 0.0], dtype=np.float32)
def support(d: np.ndarray) -> np.ndarray:
return shape_a.support(pos_a, d) - shape_b.support(pos_b, -d)
simplex = [support(direction)]
direction = -simplex[0]
for _ in range(32):
a = support(direction)
if np.dot(a, direction) < 0:
return False
simplex.append(a)
if len(simplex) == 2:
# Line case
ab = simplex[0] - simplex[1]
ao = -simplex[1]
if np.dot(ab, ao) > 0:
direction = np.cross(np.cross(ab, ao), ab)
if np.dot(direction, direction) < 1e-10:
return True
else:
simplex = [simplex[1]]
direction = ao
elif len(simplex) == 3:
# Triangle case
ab = simplex[1] - simplex[2]
ac = simplex[0] - simplex[2]
ao = -simplex[2]
abc = np.cross(ab, ac)
if np.dot(np.cross(abc, ac), ao) > 0:
if np.dot(ac, ao) > 0:
simplex = [simplex[0], simplex[2]]
direction = np.cross(np.cross(ac, ao), ac)
else:
simplex = [simplex[2]]
direction = ao
elif np.dot(np.cross(ab, abc), ao) > 0:
if np.dot(ab, ao) > 0:
simplex = [simplex[1], simplex[2]]
direction = np.cross(np.cross(ab, ao), ab)
else:
simplex = [simplex[2]]
direction = ao
else:
if np.dot(abc, ao) > 0:
direction = abc
else:
simplex = [simplex[1], simplex[0], simplex[2]]
direction = -abc
elif len(simplex) == 4:
# Tetrahedron — check if origin is inside
a, b, c, d = simplex[3], simplex[2], simplex[1], simplex[0]
ao = -a
ab, ac, ad = b - a, c - a, d - a
abc = np.cross(ab, ac)
acd = np.cross(ac, ad)
adb = np.cross(ad, ab)
if np.dot(abc, ao) > 0:
simplex = [simplex[1], simplex[2], simplex[3]]
direction = abc
elif np.dot(acd, ao) > 0:
simplex = [simplex[0], simplex[1], simplex[3]]
direction = acd
elif np.dot(adb, ao) > 0:
simplex = [simplex[2], simplex[0], simplex[3]]
direction = adb
else:
return True # Origin inside tetrahedron
if np.dot(direction, direction) < 1e-10:
return True
return False
# ============================================================================
# AABB Overlap (Broadphase)
# ============================================================================
def _aabb_overlap(
min_a: np.ndarray,
max_a: np.ndarray,
min_b: np.ndarray,
max_b: np.ndarray,
) -> bool:
"""Fast AABB overlap check."""
return bool(np.all(max_a >= min_b) and np.all(max_b >= min_a))
# ============================================================================
# Ray-Shape Intersection
# ============================================================================
def _ray_sphere(
origin: np.ndarray,
direction: np.ndarray,
center: np.ndarray,
radius: float,
max_dist: float,
) -> float | None:
"""Ray-sphere intersection. Returns distance or None."""
oc = origin - center
a = np.dot(direction, direction)
b = 2.0 * np.dot(oc, direction)
c = np.dot(oc, oc) - radius * radius
disc = b * b - 4 * a * c
if disc < 0:
return None
t = (-b - math.sqrt(disc)) / (2 * a)
if t < 0:
t = (-b + math.sqrt(disc)) / (2 * a)
return t if 0 <= t <= max_dist else None
def _ray_aabb(
origin: np.ndarray,
direction: np.ndarray,
box_min: np.ndarray,
box_max: np.ndarray,
max_dist: float,
) -> float | None:
"""Ray-AABB intersection (slab method). Returns distance or None."""
with np.errstate(divide="ignore", invalid="ignore"):
inv_dir = np.where(np.abs(direction) > 1e-10, 1.0 / direction, np.copysign(1e30, direction))
t1 = (box_min - origin) * inv_dir
t2 = (box_max - origin) * inv_dir
tmin = float(np.max(np.minimum(t1, t2)))
tmax = float(np.min(np.maximum(t1, t2)))
if tmax < 0 or tmin > tmax or tmin > max_dist:
return None
t = tmin if tmin >= 0 else tmax
return t if t <= max_dist else None
# ============================================================================
# Collision World
# ============================================================================
@dataclass
class _Body:
"""Internal body wrapper."""
ref: Any
shape: CollisionShape
position: np.ndarray
layer: int = 1
mask: int = 1
aabb_min: np.ndarray = field(default_factory=lambda: np.zeros(3, dtype=np.float32))
aabb_max: np.ndarray = field(default_factory=lambda: np.zeros(3, dtype=np.float32))
[docs]
class CollisionWorld:
"""Broadphase (AABB) + narrowphase (GJK) collision world.
Bodies are registered with a shape, position, and collision layer/mask.
Supports overlap queries between bodies and raycasting against all bodies.
Args:
use_spatial_hash: When True, uses a :class:`SpatialHashGrid` for
broadphase culling instead of brute-force AABB checks.
cell_size: Grid cell size for the spatial hash (ignored when
*use_spatial_hash* is False).
"""
def __init__(self, use_spatial_hash: bool = False, cell_size: float = 2.0):
self._bodies: dict[int, _Body] = {}
self._use_spatial_hash = use_spatial_hash
self._spatial_hash: SpatialHashGrid | None = SpatialHashGrid(cell_size) if use_spatial_hash else None
self._spatial_hash_dirty = False
[docs]
def add_body(
self,
body: Any,
shape: CollisionShape,
position: np.ndarray | None = None,
layer: int = 1,
mask: int = 1,
) -> None:
"""Register a body with a collision shape."""
pos = np.asarray(position if position is not None else [0, 0, 0], dtype=np.float32)
b = _Body(ref=body, shape=shape, position=pos, layer=layer, mask=mask)
b.aabb_min, b.aabb_max = shape.get_aabb(pos)
bid = id(body)
self._bodies[bid] = b
if self._spatial_hash is not None:
self._spatial_hash.insert(bid, b.aabb_min, b.aabb_max)
[docs]
def remove_body(self, body: Any) -> None:
"""Unregister a body."""
if self._bodies.pop(id(body), None) is not None and self._spatial_hash is not None:
self._spatial_hash_dirty = True
[docs]
def update_position(self, body: Any, position: np.ndarray) -> None:
"""Update a body's position and recompute its AABB.
When the spatial hash is enabled the grid is rebuilt lazily before
the next :meth:`query_overlaps` call via :meth:`rebuild_spatial_hash`.
"""
b = self._bodies.get(id(body))
if b:
b.position = np.asarray(position, dtype=np.float32)
b.aabb_min, b.aabb_max = b.shape.get_aabb(b.position)
if self._spatial_hash is not None:
self._spatial_hash_dirty = True
[docs]
def rebuild_spatial_hash(self) -> None:
"""Rebuild the spatial hash grid from all current bodies."""
if self._spatial_hash is None:
return
self._spatial_hash.clear()
for bid, b in self._bodies.items():
self._spatial_hash.insert(bid, b.aabb_min, b.aabb_max)
self._spatial_hash_dirty = False
[docs]
def query_overlaps(self, body: Any) -> list[Any]:
"""Find all bodies overlapping with the given body.
When the spatial hash is enabled, only candidate bodies from
neighbouring grid cells are checked instead of the entire world.
"""
bid = id(body)
b = self._bodies.get(bid)
if not b:
return []
# Determine candidate set
if self._spatial_hash is not None:
if self._spatial_hash_dirty:
self.rebuild_spatial_hash()
candidate_ids = self._spatial_hash.query(b.aabb_min, b.aabb_max)
candidate_ids.discard(bid)
candidates = (self._bodies[cid] for cid in candidate_ids if cid in self._bodies)
else:
candidates = (other for other in self._bodies.values() if other is not b)
results = []
for other in candidates:
if not (b.mask & other.layer) and not (other.mask & b.layer):
continue
if not _aabb_overlap(b.aabb_min, b.aabb_max, other.aabb_min, other.aabb_max):
continue
if _gjk_intersect(b.shape, b.position, other.shape, other.position):
results.append(other.ref)
return results
[docs]
def test_overlap(self, body_a: Any, body_b: Any) -> bool:
"""Test if two specific bodies overlap."""
a = self._bodies.get(id(body_a))
b = self._bodies.get(id(body_b))
if not a or not b:
return False
if not _aabb_overlap(a.aabb_min, a.aabb_max, b.aabb_min, b.aabb_max):
return False
return _gjk_intersect(a.shape, a.position, b.shape, b.position)
[docs]
def raycast(
self,
origin: np.ndarray,
direction: np.ndarray,
max_dist: float = 1000.0,
layer_mask: int = 0xFFFFFFFF,
) -> list[RayHit]:
"""Cast a ray and return all hits sorted by ascending distance.
Uses AABB broadphase to cull, then shape-specific narrowphase for
spheres and boxes. ConvexShape falls back to AABB hit distance.
"""
origin = np.asarray(origin, dtype=np.float32)
direction = np.asarray(direction, dtype=np.float32)
d_len = np.linalg.norm(direction)
if d_len < 1e-10:
return []
direction = direction / d_len
hits: list[RayHit] = []
for b in self._bodies.values():
if not (layer_mask & b.layer):
continue
# First test AABB
t_aabb = _ray_aabb(origin, direction, b.aabb_min, b.aabb_max, max_dist)
if t_aabb is None:
continue
# Narrowphase: shape-specific ray test
if isinstance(b.shape, SphereShape):
t = _ray_sphere(origin, direction, b.position, b.shape.radius, max_dist)
elif isinstance(b.shape, BoxShape):
bmin = b.position - b.shape.half_extents
bmax = b.position + b.shape.half_extents
t = _ray_aabb(origin, direction, bmin, bmax, max_dist)
else:
# For convex shapes, use AABB hit as approximation
t = t_aabb
if t is not None:
hits.append(RayHit(body=b.ref, point=origin + direction * t, distance=t))
hits.sort(key=lambda h: h.distance)
return hits
[docs]
def shape_cast(
self, shape: CollisionShape, from_pos: np.ndarray, to_pos: np.ndarray,
layer_mask: int = 0xFFFFFFFF, exclude: set | None = None, max_results: int = 32,
) -> list[ShapeCastResult]:
"""Sweep a shape from from_pos to to_pos and return all intersecting bodies.
Uses binary search along the sweep path to find the first contact fraction
for each candidate body. Results sorted by distance.
"""
from_pos = np.asarray(from_pos, dtype=np.float32)
to_pos = np.asarray(to_pos, dtype=np.float32)
sweep = to_pos - from_pos
sweep_len = float(np.linalg.norm(sweep))
exclude_ids = {id(e) for e in exclude} if exclude else set()
# Compute swept AABB (union of shape AABB at start and end)
aabb_start_min, aabb_start_max = shape.get_aabb(from_pos)
aabb_end_min, aabb_end_max = shape.get_aabb(to_pos)
swept_min = np.minimum(aabb_start_min, aabb_end_min)
swept_max = np.maximum(aabb_start_max, aabb_end_max)
results: list[ShapeCastResult] = []
for b in self._bodies.values():
if not (layer_mask & b.layer):
continue
if id(b.ref) in exclude_ids:
continue
if not _aabb_overlap(swept_min, swept_max, b.aabb_min, b.aabb_max):
continue
# Binary search for first contact fraction
lo, hi = 0.0, 1.0
hit = False
if _gjk_intersect(shape, from_pos, b.shape, b.position):
hit = True
lo = 0.0
else:
if not _gjk_intersect(shape, to_pos, b.shape, b.position):
found = False
for f in (0.25, 0.5, 0.75):
test_pos = from_pos + sweep * f
if _gjk_intersect(shape, test_pos, b.shape, b.position):
found = True
hi = f
break
if not found:
continue
hit = True
for _ in range(16):
mid = (lo + hi) * 0.5
test_pos = from_pos + sweep * mid
if _gjk_intersect(shape, test_pos, b.shape, b.position):
hi = mid
else:
lo = mid
lo = hi
if hit:
contact_pos = from_pos + sweep * lo
diff = contact_pos - b.position
diff_len = float(np.linalg.norm(diff))
normal = diff / diff_len if diff_len > 1e-10 else np.array([0, 1, 0], dtype=np.float32)
contact_point = b.shape.support(b.position, diff) if diff_len > 1e-10 else b.position.copy()
distance = sweep_len * lo
results.append(ShapeCastResult(
body=b.ref, point=contact_point, normal=normal, distance=distance, fraction=lo,
))
if len(results) >= max_results:
break
results.sort(key=lambda r: r.distance)
return results
@property
def body_count(self) -> int:
return len(self._bodies)