Source code for simvx.core.collision

"""Collision system — shapes, broadphase AABB grid, narrowphase GJK/SAT, raycasting."""


from __future__ import annotations

import logging
import math
from dataclasses import dataclass, field
from typing import Any

import numpy as np

log = logging.getLogger(__name__)

__all__ = [
    "CollisionShape",
    "SphereShape",
    "BoxShape",
    "ConvexShape",
    "CapsuleShape",
    "SpatialHashGrid",
    "RayHit",
    "ShapeCastResult",
    "CollisionWorld",
]


# ============================================================================
# Collision Shapes
# ============================================================================


[docs] class CollisionShape: """Abstract base for collision shapes. Subclasses must implement get_aabb() for broadphase and support() for GJK narrowphase intersection testing. """
[docs] def get_aabb(self, position: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """Return (min_corner, max_corner) in world space.""" raise NotImplementedError
[docs] def support(self, position: np.ndarray, direction: np.ndarray) -> np.ndarray: """GJK support function: furthest point in given direction.""" raise NotImplementedError
[docs] class SphereShape(CollisionShape): """Sphere collision shape defined by a radius.""" __slots__ = ("radius",) def __init__(self, radius: float = 1.0): self.radius = radius
[docs] def get_aabb(self, position: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """Returns AABB for a sphere at the given position.""" r = np.array([self.radius, self.radius, self.radius], dtype=np.float32) return position - r, position + r
[docs] def support(self, position: np.ndarray, direction: np.ndarray) -> np.ndarray: """Returns the furthest point on the sphere in the given direction.""" d = np.linalg.norm(direction) if d < 1e-10: return position.copy() return position + direction * (self.radius / d)
[docs] class BoxShape(CollisionShape): """Axis-aligned box collision shape defined by half-extents.""" __slots__ = ("half_extents",) def __init__(self, half_extents: tuple[float, float, float] = (0.5, 0.5, 0.5)): self.half_extents = np.array(half_extents, dtype=np.float32)
[docs] def get_aabb(self, position: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """Returns AABB for a box at the given position.""" return position - self.half_extents, position + self.half_extents
[docs] def support(self, position: np.ndarray, direction: np.ndarray) -> np.ndarray: """Returns the furthest vertex of the box in the given direction.""" signs = np.sign(direction) signs[signs == 0] = 1.0 return position + self.half_extents * signs
[docs] class ConvexShape(CollisionShape): """Convex hull collision shape built from a set of vertices. Vertices are stored in local space and translated to world space during AABB and support queries. """ __slots__ = ("vertices",) def __init__(self, vertices: np.ndarray): self.vertices = np.asarray(vertices, dtype=np.float32)
[docs] def get_aabb(self, position: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """Returns AABB bounding the convex hull at the given position.""" world = self.vertices + position return world.min(axis=0), world.max(axis=0)
[docs] def support(self, position: np.ndarray, direction: np.ndarray) -> np.ndarray: """Returns the vertex with maximum projection onto the direction.""" world = self.vertices + position dots = world @ direction return world[np.argmax(dots)].copy()
[docs] class CapsuleShape(CollisionShape): """Capsule collision shape — cylinder with hemispherical caps along the Y axis.""" __slots__ = ("radius", "height") def __init__(self, radius: float = 0.5, height: float = 2.0): self.radius = radius self.height = height # total height including caps
[docs] def get_aabb(self, position: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """Returns AABB for a capsule at the given position. The capsule is oriented along the Y axis. The half-height of the cylindrical segment is ``(height / 2 - radius)`` (clamped to 0), plus the radius of the hemispherical caps gives a total Y extent of ``height / 2``. """ half_y = self.height / 2.0 r = self.radius ext = np.array([r, half_y, r], dtype=np.float32) return position - ext, position + ext
[docs] def support(self, position: np.ndarray, direction: np.ndarray) -> np.ndarray: """Returns the furthest point on the capsule in the given direction. The capsule is the Minkowski sum of a line segment (along Y) and a sphere of the given radius. The support point is the center of whichever hemisphere is more aligned with *direction*, offset by ``radius`` in the normalised direction. """ half_seg = max(0.0, self.height / 2.0 - self.radius) # Pick the hemisphere center that projects furthest along direction if direction[1] >= 0: center = position + np.array([0.0, half_seg, 0.0], dtype=np.float32) else: center = position - np.array([0.0, half_seg, 0.0], dtype=np.float32) d = np.linalg.norm(direction) if d < 1e-10: return center.copy() return center + direction * (self.radius / d)
# ============================================================================ # Spatial Hash Grid (Broadphase) # ============================================================================
[docs] class SpatialHashGrid: """Uniform-grid spatial hash for broadphase collision culling. Maps axis-aligned bounding boxes to grid cells so that overlap queries only need to check bodies in neighbouring cells rather than every body in the world. """ __slots__ = ("cell_size", "_inv_cell", "_grid") def __init__(self, cell_size: float = 2.0): self.cell_size = cell_size self._inv_cell = 1.0 / cell_size self._grid: dict[tuple[int, int, int], list[int]] = {} def _cell_key(self, x: float, y: float, z: float) -> tuple[int, int, int]: return ( int(math.floor(x * self._inv_cell)), int(math.floor(y * self._inv_cell)), int(math.floor(z * self._inv_cell)), )
[docs] def insert(self, body_id: int, aabb_min: np.ndarray, aabb_max: np.ndarray) -> None: """Insert a body into all grid cells its AABB overlaps.""" lo = self._cell_key(float(aabb_min[0]), float(aabb_min[1]), float(aabb_min[2])) hi = self._cell_key(float(aabb_max[0]), float(aabb_max[1]), float(aabb_max[2])) for ix in range(lo[0], hi[0] + 1): for iy in range(lo[1], hi[1] + 1): for iz in range(lo[2], hi[2] + 1): key = (ix, iy, iz) bucket = self._grid.get(key) if bucket is None: self._grid[key] = [body_id] else: bucket.append(body_id)
[docs] def query(self, aabb_min: np.ndarray, aabb_max: np.ndarray) -> set[int]: """Return body IDs in cells that overlap the given AABB.""" lo = self._cell_key(float(aabb_min[0]), float(aabb_min[1]), float(aabb_min[2])) hi = self._cell_key(float(aabb_max[0]), float(aabb_max[1]), float(aabb_max[2])) result: set[int] = set() for ix in range(lo[0], hi[0] + 1): for iy in range(lo[1], hi[1] + 1): for iz in range(lo[2], hi[2] + 1): bucket = self._grid.get((ix, iy, iz)) if bucket: result.update(bucket) return result
[docs] def clear(self) -> None: """Reset the grid for the next frame.""" self._grid.clear()
# ============================================================================ # Ray Hit Result # ============================================================================
[docs] @dataclass class RayHit: """Result of a raycast query. Attributes: body: The collider reference that was hit. point: World-space hit position. distance: Distance from ray origin to the hit point. """ body: Any point: np.ndarray distance: float
[docs] @dataclass class ShapeCastResult: """Result of a shape cast (swept shape) query.""" body: Any point: np.ndarray normal: np.ndarray distance: float fraction: float
# ============================================================================ # GJK Intersection Test # ============================================================================ def _gjk_intersect( shape_a: CollisionShape, pos_a: np.ndarray, shape_b: CollisionShape, pos_b: np.ndarray, ) -> bool: """GJK algorithm for convex shape intersection.""" direction = pos_b - pos_a if np.dot(direction, direction) < 1e-10: direction = np.array([1.0, 0.0, 0.0], dtype=np.float32) def support(d: np.ndarray) -> np.ndarray: return shape_a.support(pos_a, d) - shape_b.support(pos_b, -d) simplex = [support(direction)] direction = -simplex[0] for _ in range(32): a = support(direction) if np.dot(a, direction) < 0: return False simplex.append(a) if len(simplex) == 2: # Line case ab = simplex[0] - simplex[1] ao = -simplex[1] if np.dot(ab, ao) > 0: direction = np.cross(np.cross(ab, ao), ab) if np.dot(direction, direction) < 1e-10: return True else: simplex = [simplex[1]] direction = ao elif len(simplex) == 3: # Triangle case ab = simplex[1] - simplex[2] ac = simplex[0] - simplex[2] ao = -simplex[2] abc = np.cross(ab, ac) if np.dot(np.cross(abc, ac), ao) > 0: if np.dot(ac, ao) > 0: simplex = [simplex[0], simplex[2]] direction = np.cross(np.cross(ac, ao), ac) else: simplex = [simplex[2]] direction = ao elif np.dot(np.cross(ab, abc), ao) > 0: if np.dot(ab, ao) > 0: simplex = [simplex[1], simplex[2]] direction = np.cross(np.cross(ab, ao), ab) else: simplex = [simplex[2]] direction = ao else: if np.dot(abc, ao) > 0: direction = abc else: simplex = [simplex[1], simplex[0], simplex[2]] direction = -abc elif len(simplex) == 4: # Tetrahedron — check if origin is inside a, b, c, d = simplex[3], simplex[2], simplex[1], simplex[0] ao = -a ab, ac, ad = b - a, c - a, d - a abc = np.cross(ab, ac) acd = np.cross(ac, ad) adb = np.cross(ad, ab) if np.dot(abc, ao) > 0: simplex = [simplex[1], simplex[2], simplex[3]] direction = abc elif np.dot(acd, ao) > 0: simplex = [simplex[0], simplex[1], simplex[3]] direction = acd elif np.dot(adb, ao) > 0: simplex = [simplex[2], simplex[0], simplex[3]] direction = adb else: return True # Origin inside tetrahedron if np.dot(direction, direction) < 1e-10: return True return False # ============================================================================ # AABB Overlap (Broadphase) # ============================================================================ def _aabb_overlap( min_a: np.ndarray, max_a: np.ndarray, min_b: np.ndarray, max_b: np.ndarray, ) -> bool: """Fast AABB overlap check.""" return bool(np.all(max_a >= min_b) and np.all(max_b >= min_a)) # ============================================================================ # Ray-Shape Intersection # ============================================================================ def _ray_sphere( origin: np.ndarray, direction: np.ndarray, center: np.ndarray, radius: float, max_dist: float, ) -> float | None: """Ray-sphere intersection. Returns distance or None.""" oc = origin - center a = np.dot(direction, direction) b = 2.0 * np.dot(oc, direction) c = np.dot(oc, oc) - radius * radius disc = b * b - 4 * a * c if disc < 0: return None t = (-b - math.sqrt(disc)) / (2 * a) if t < 0: t = (-b + math.sqrt(disc)) / (2 * a) return t if 0 <= t <= max_dist else None def _ray_aabb( origin: np.ndarray, direction: np.ndarray, box_min: np.ndarray, box_max: np.ndarray, max_dist: float, ) -> float | None: """Ray-AABB intersection (slab method). Returns distance or None.""" with np.errstate(divide="ignore", invalid="ignore"): inv_dir = np.where(np.abs(direction) > 1e-10, 1.0 / direction, np.copysign(1e30, direction)) t1 = (box_min - origin) * inv_dir t2 = (box_max - origin) * inv_dir tmin = float(np.max(np.minimum(t1, t2))) tmax = float(np.min(np.maximum(t1, t2))) if tmax < 0 or tmin > tmax or tmin > max_dist: return None t = tmin if tmin >= 0 else tmax return t if t <= max_dist else None # ============================================================================ # Collision World # ============================================================================ @dataclass class _Body: """Internal body wrapper.""" ref: Any shape: CollisionShape position: np.ndarray layer: int = 1 mask: int = 1 aabb_min: np.ndarray = field(default_factory=lambda: np.zeros(3, dtype=np.float32)) aabb_max: np.ndarray = field(default_factory=lambda: np.zeros(3, dtype=np.float32))
[docs] class CollisionWorld: """Broadphase (AABB) + narrowphase (GJK) collision world. Bodies are registered with a shape, position, and collision layer/mask. Supports overlap queries between bodies and raycasting against all bodies. Args: use_spatial_hash: When True, uses a :class:`SpatialHashGrid` for broadphase culling instead of brute-force AABB checks. cell_size: Grid cell size for the spatial hash (ignored when *use_spatial_hash* is False). """ def __init__(self, use_spatial_hash: bool = False, cell_size: float = 2.0): self._bodies: dict[int, _Body] = {} self._use_spatial_hash = use_spatial_hash self._spatial_hash: SpatialHashGrid | None = SpatialHashGrid(cell_size) if use_spatial_hash else None self._spatial_hash_dirty = False
[docs] def add_body( self, body: Any, shape: CollisionShape, position: np.ndarray | None = None, layer: int = 1, mask: int = 1, ) -> None: """Register a body with a collision shape.""" pos = np.asarray(position if position is not None else [0, 0, 0], dtype=np.float32) b = _Body(ref=body, shape=shape, position=pos, layer=layer, mask=mask) b.aabb_min, b.aabb_max = shape.get_aabb(pos) bid = id(body) self._bodies[bid] = b if self._spatial_hash is not None: self._spatial_hash.insert(bid, b.aabb_min, b.aabb_max)
[docs] def remove_body(self, body: Any) -> None: """Unregister a body.""" if self._bodies.pop(id(body), None) is not None and self._spatial_hash is not None: self._spatial_hash_dirty = True
[docs] def update_position(self, body: Any, position: np.ndarray) -> None: """Update a body's position and recompute its AABB. When the spatial hash is enabled the grid is rebuilt lazily before the next :meth:`query_overlaps` call via :meth:`rebuild_spatial_hash`. """ b = self._bodies.get(id(body)) if b: b.position = np.asarray(position, dtype=np.float32) b.aabb_min, b.aabb_max = b.shape.get_aabb(b.position) if self._spatial_hash is not None: self._spatial_hash_dirty = True
[docs] def rebuild_spatial_hash(self) -> None: """Rebuild the spatial hash grid from all current bodies.""" if self._spatial_hash is None: return self._spatial_hash.clear() for bid, b in self._bodies.items(): self._spatial_hash.insert(bid, b.aabb_min, b.aabb_max) self._spatial_hash_dirty = False
[docs] def query_overlaps(self, body: Any) -> list[Any]: """Find all bodies overlapping with the given body. When the spatial hash is enabled, only candidate bodies from neighbouring grid cells are checked instead of the entire world. """ bid = id(body) b = self._bodies.get(bid) if not b: return [] # Determine candidate set if self._spatial_hash is not None: if self._spatial_hash_dirty: self.rebuild_spatial_hash() candidate_ids = self._spatial_hash.query(b.aabb_min, b.aabb_max) candidate_ids.discard(bid) candidates = (self._bodies[cid] for cid in candidate_ids if cid in self._bodies) else: candidates = (other for other in self._bodies.values() if other is not b) results = [] for other in candidates: if not (b.mask & other.layer) and not (other.mask & b.layer): continue if not _aabb_overlap(b.aabb_min, b.aabb_max, other.aabb_min, other.aabb_max): continue if _gjk_intersect(b.shape, b.position, other.shape, other.position): results.append(other.ref) return results
[docs] def test_overlap(self, body_a: Any, body_b: Any) -> bool: """Test if two specific bodies overlap.""" a = self._bodies.get(id(body_a)) b = self._bodies.get(id(body_b)) if not a or not b: return False if not _aabb_overlap(a.aabb_min, a.aabb_max, b.aabb_min, b.aabb_max): return False return _gjk_intersect(a.shape, a.position, b.shape, b.position)
[docs] def raycast( self, origin: np.ndarray, direction: np.ndarray, max_dist: float = 1000.0, layer_mask: int = 0xFFFFFFFF, ) -> list[RayHit]: """Cast a ray and return all hits sorted by ascending distance. Uses AABB broadphase to cull, then shape-specific narrowphase for spheres and boxes. ConvexShape falls back to AABB hit distance. """ origin = np.asarray(origin, dtype=np.float32) direction = np.asarray(direction, dtype=np.float32) d_len = np.linalg.norm(direction) if d_len < 1e-10: return [] direction = direction / d_len hits: list[RayHit] = [] for b in self._bodies.values(): if not (layer_mask & b.layer): continue # First test AABB t_aabb = _ray_aabb(origin, direction, b.aabb_min, b.aabb_max, max_dist) if t_aabb is None: continue # Narrowphase: shape-specific ray test if isinstance(b.shape, SphereShape): t = _ray_sphere(origin, direction, b.position, b.shape.radius, max_dist) elif isinstance(b.shape, BoxShape): bmin = b.position - b.shape.half_extents bmax = b.position + b.shape.half_extents t = _ray_aabb(origin, direction, bmin, bmax, max_dist) else: # For convex shapes, use AABB hit as approximation t = t_aabb if t is not None: hits.append(RayHit(body=b.ref, point=origin + direction * t, distance=t)) hits.sort(key=lambda h: h.distance) return hits
[docs] def shape_cast( self, shape: CollisionShape, from_pos: np.ndarray, to_pos: np.ndarray, layer_mask: int = 0xFFFFFFFF, exclude: set | None = None, max_results: int = 32, ) -> list[ShapeCastResult]: """Sweep a shape from from_pos to to_pos and return all intersecting bodies. Uses binary search along the sweep path to find the first contact fraction for each candidate body. Results sorted by distance. """ from_pos = np.asarray(from_pos, dtype=np.float32) to_pos = np.asarray(to_pos, dtype=np.float32) sweep = to_pos - from_pos sweep_len = float(np.linalg.norm(sweep)) exclude_ids = {id(e) for e in exclude} if exclude else set() # Compute swept AABB (union of shape AABB at start and end) aabb_start_min, aabb_start_max = shape.get_aabb(from_pos) aabb_end_min, aabb_end_max = shape.get_aabb(to_pos) swept_min = np.minimum(aabb_start_min, aabb_end_min) swept_max = np.maximum(aabb_start_max, aabb_end_max) results: list[ShapeCastResult] = [] for b in self._bodies.values(): if not (layer_mask & b.layer): continue if id(b.ref) in exclude_ids: continue if not _aabb_overlap(swept_min, swept_max, b.aabb_min, b.aabb_max): continue # Binary search for first contact fraction lo, hi = 0.0, 1.0 hit = False if _gjk_intersect(shape, from_pos, b.shape, b.position): hit = True lo = 0.0 else: if not _gjk_intersect(shape, to_pos, b.shape, b.position): found = False for f in (0.25, 0.5, 0.75): test_pos = from_pos + sweep * f if _gjk_intersect(shape, test_pos, b.shape, b.position): found = True hi = f break if not found: continue hit = True for _ in range(16): mid = (lo + hi) * 0.5 test_pos = from_pos + sweep * mid if _gjk_intersect(shape, test_pos, b.shape, b.position): hi = mid else: lo = mid lo = hi if hit: contact_pos = from_pos + sweep * lo diff = contact_pos - b.position diff_len = float(np.linalg.norm(diff)) normal = diff / diff_len if diff_len > 1e-10 else np.array([0, 1, 0], dtype=np.float32) contact_point = b.shape.support(b.position, diff) if diff_len > 1e-10 else b.position.copy() distance = sweep_len * lo results.append(ShapeCastResult( body=b.ref, point=contact_point, normal=normal, distance=distance, fraction=lo, )) if len(results) >= max_results: break results.sort(key=lambda r: r.distance) return results
@property def body_count(self) -> int: return len(self._bodies)