"""Transform gizmos for 3D scene editing (translate, rotate, scale).
Handles the interaction math for picking and dragging gizmo handles.
Actual rendering of gizmo visuals is the graphics backend's responsibility.
"""
from __future__ import annotations
import logging
import math
from enum import Enum, auto
from .math.types import Vec3
log = logging.getLogger(__name__)
__all__ = ["GizmoMode", "GizmoAxis", "Gizmo"]
# ---------------------------------------------------------------------------
# Enums
# ---------------------------------------------------------------------------
[docs]
class GizmoMode(Enum):
"""Active manipulation mode for the gizmo."""
TRANSLATE = auto()
ROTATE = auto()
SCALE = auto()
[docs]
class GizmoAxis(Enum):
"""Which axis or axis-pair the user is interacting with."""
X = auto()
Y = auto()
Z = auto()
XY = auto()
XZ = auto()
YZ = auto()
ALL = auto()
# ---------------------------------------------------------------------------
# Unit direction vectors (module constants)
# ---------------------------------------------------------------------------
_AXIS_DIRS: dict[GizmoAxis, Vec3] = {
GizmoAxis.X: Vec3(1, 0, 0),
GizmoAxis.Y: Vec3(0, 1, 0),
GizmoAxis.Z: Vec3(0, 0, 1),
}
_PLANE_NORMALS: dict[GizmoAxis, Vec3] = {
GizmoAxis.XY: Vec3(0, 0, 1),
GizmoAxis.XZ: Vec3(0, 1, 0),
GizmoAxis.YZ: Vec3(1, 0, 0),
}
# ---------------------------------------------------------------------------
# Helper functions
# ---------------------------------------------------------------------------
def _ray_axis_closest_point(
ray_origin: Vec3,
ray_dir: Vec3,
axis_origin: Vec3,
axis_dir: Vec3,
) -> float:
"""Return parameter *t* along *axis_dir* of the closest point to the ray.
Uses the standard closest-point-between-two-lines formula. Returns 0.0
when the lines are nearly parallel.
"""
w = ray_origin - axis_origin
a = ray_dir.dot(ray_dir) # always > 0
b = ray_dir.dot(axis_dir)
c = axis_dir.dot(axis_dir) # always > 0
d = ray_dir.dot(w)
e = axis_dir.dot(w)
denom = a * c - b * b
if abs(denom) < 1e-10:
return 0.0
t_axis = (a * e - b * d) / denom
return t_axis
def _ray_plane_intersect(
ray_origin: Vec3,
ray_dir: Vec3,
plane_point: Vec3,
plane_normal: Vec3,
) -> Vec3 | None:
"""Intersect a ray with a plane. Returns the hit point, or *None*."""
denom = plane_normal.dot(ray_dir)
if abs(denom) < 1e-10:
return None
t = plane_normal.dot(plane_point - ray_origin) / denom
if t < 0:
return None
return ray_origin + ray_dir * t
# ---------------------------------------------------------------------------
# Gizmo
# ---------------------------------------------------------------------------
[docs]
class Gizmo:
"""Interactive transform gizmo operating in world space.
Supports three manipulation modes (translate, rotate, scale) and provides
ray-based picking / dragging logic. The graphics backend is responsible
for drawing the visual handles; this class only does the math.
"""
def __init__(self) -> None:
self._mode: GizmoMode = GizmoMode.TRANSLATE
self.active: bool = False
self.position: Vec3 = Vec3(0, 0, 0)
self.axis_length: float = 1.0
self.hover_axis: GizmoAxis | None = None
self.dragging: bool = False
self._drag_start: Vec3 | None = None
self._drag_axis: GizmoAxis | None = None
self._last_point: Vec3 | None = None
self._last_angle: float = 0.0
# -- mode property ------------------------------------------------------
@property
def mode(self) -> GizmoMode:
return self._mode
@mode.setter
def mode(self, value: GizmoMode) -> None:
if self.dragging:
self.end_drag()
self._mode = value
# -- picking ------------------------------------------------------------
_SHAFT_THRESHOLD = 0.1 # distance from shaft centre for a hit
_PLANE_HANDLE_SIZE = 0.25 # fraction of axis_length for plane handles
_RING_THRESHOLD = 0.12 # angular tolerance for rotate rings
[docs]
def pick_axis(self, ray_origin: Vec3, ray_dir: Vec3) -> GizmoAxis | None:
"""Ray-test the gizmo handles and return the closest hit axis.
For translate/scale modes the handles are axis shafts (thin cylinders)
and small plane-pair quads. For rotate mode the handles are three
circles (one per principal axis).
Returns ``None`` when nothing is hit.
"""
ray_dir = ray_dir.normalized()
if self._mode in (GizmoMode.TRANSLATE, GizmoMode.SCALE):
return self._pick_linear(ray_origin, ray_dir)
return self._pick_rotate(ray_origin, ray_dir)
def _pick_linear(self, ray_origin: Vec3, ray_dir: Vec3) -> GizmoAxis | None:
best_axis: GizmoAxis | None = None
best_dist = float("inf")
# Test axis shafts (X, Y, Z)
for axis_enum, axis_dir in _AXIS_DIRS.items():
t = _ray_axis_closest_point(
ray_origin,
ray_dir,
self.position,
axis_dir,
)
if t < 0 or t > self.axis_length:
continue
# Point on axis and closest point on ray
p_axis = self.position + axis_dir * t
# Closest point on ray to p_axis
to_axis = p_axis - ray_origin
t_ray = to_axis.dot(ray_dir)
if t_ray < 0:
continue
p_ray = ray_origin + ray_dir * t_ray
dist = (p_ray - p_axis).length()
if dist < self._SHAFT_THRESHOLD and dist < best_dist:
best_dist = dist
best_axis = axis_enum
# Test plane handles (XY, XZ, YZ) -- small quads near the origin
handle_size = self.axis_length * self._PLANE_HANDLE_SIZE
for plane_axis, normal in _PLANE_NORMALS.items():
hit = _ray_plane_intersect(
ray_origin,
ray_dir,
self.position,
normal,
)
if hit is None:
continue
local = hit - self.position
# Determine which two axes define this plane
if plane_axis is GizmoAxis.XY:
u, v = local.x, local.y
elif plane_axis is GizmoAxis.XZ:
u, v = local.x, local.z
else: # YZ
u, v = local.y, local.z
if 0 < u < handle_size and 0 < v < handle_size:
dist = abs(normal.dot(local))
if dist < best_dist:
best_dist = dist
best_axis = plane_axis
return best_axis
def _pick_rotate(self, ray_origin: Vec3, ray_dir: Vec3) -> GizmoAxis | None:
best_axis: GizmoAxis | None = None
best_dist = float("inf")
radius = self.axis_length
for axis_enum, normal in _AXIS_DIRS.items():
hit = _ray_plane_intersect(
ray_origin,
ray_dir,
self.position,
normal,
)
if hit is None:
continue
# Distance from gizmo centre to the hit, projected onto circle
dist_to_center = (hit - self.position).length()
ring_error = abs(dist_to_center - radius)
if ring_error < self._RING_THRESHOLD and ring_error < best_dist:
best_dist = ring_error
best_axis = axis_enum
return best_axis
# -- dragging -----------------------------------------------------------
[docs]
def begin_drag(
self,
axis: GizmoAxis,
ray_origin: Vec3,
ray_dir: Vec3,
) -> None:
"""Start a drag interaction on *axis*."""
ray_dir = ray_dir.normalized()
self.dragging = True
self._drag_axis = axis
self._drag_start = self._project_ray(axis, ray_origin, ray_dir)
self._last_point = self._drag_start
self._last_angle = self._compute_angle(axis, ray_origin, ray_dir)
[docs]
def update_drag(self, ray_origin: Vec3, ray_dir: Vec3) -> Vec3:
"""Compute the incremental delta since the last update.
Returns a ``Vec3`` whose meaning depends on the current mode:
* **TRANSLATE** -- world-space translation delta.
* **ROTATE** -- (angle_x, angle_y, angle_z) in *radians*.
* **SCALE** -- per-axis scale factor delta (centred on 1.0 = no change,
returned as additive offset so caller does ``current_scale += delta``).
"""
if not self.dragging or self._drag_axis is None:
return Vec3(0, 0, 0)
ray_dir = ray_dir.normalized()
axis = self._drag_axis
if self._mode is GizmoMode.TRANSLATE:
return self._update_translate(axis, ray_origin, ray_dir)
if self._mode is GizmoMode.ROTATE:
return self._update_rotate(axis, ray_origin, ray_dir)
return self._update_scale(axis, ray_origin, ray_dir)
[docs]
def end_drag(self) -> None:
"""Finish dragging and reset internal drag state."""
self.dragging = False
self._drag_axis = None
self._drag_start = None
self._last_point = None
self._last_angle = 0.0
# -- mode cycling -------------------------------------------------------
[docs]
def cycle_mode(self) -> None:
"""Cycle through TRANSLATE -> ROTATE -> SCALE -> TRANSLATE."""
order = [GizmoMode.TRANSLATE, GizmoMode.ROTATE, GizmoMode.SCALE]
idx = order.index(self._mode)
self.mode = order[(idx + 1) % len(order)]
# -- private helpers ----------------------------------------------------
def _project_ray(
self,
axis: GizmoAxis,
ray_origin: Vec3,
ray_dir: Vec3,
) -> Vec3:
"""Project the ray onto the constraint defined by *axis*.
For single-axis constraints (X/Y/Z) the closest point on that axis
line is returned. For plane constraints (XY/XZ/YZ) the ray-plane
intersection is returned. For ALL, the ray-plane intersection with
a camera-facing plane through ``self.position`` is used (approximated
by the plane most perpendicular to the ray).
"""
if axis in _AXIS_DIRS:
axis_dir = _AXIS_DIRS[axis]
t = _ray_axis_closest_point(
ray_origin,
ray_dir,
self.position,
axis_dir,
)
return self.position + axis_dir * t
if axis in _PLANE_NORMALS:
normal = _PLANE_NORMALS[axis]
hit = _ray_plane_intersect(
ray_origin,
ray_dir,
self.position,
normal,
)
return hit if hit is not None else self.position
# GizmoAxis.ALL -- pick the most perpendicular principal plane
best_normal = Vec3(0, 1, 0)
best_dot = 1.0
for n in (Vec3(1, 0, 0), Vec3(0, 1, 0), Vec3(0, 0, 1)):
d = abs(ray_dir.dot(n))
if d > best_dot or best_dot == 1.0:
best_dot = d
best_normal = n
hit = _ray_plane_intersect(
ray_origin,
ray_dir,
self.position,
best_normal,
)
return hit if hit is not None else self.position
def _update_translate(self, axis: GizmoAxis, ray_origin: Vec3, ray_dir: Vec3) -> Vec3:
current = self._project_ray(axis, ray_origin, ray_dir)
if self._last_point is None:
self._last_point = current
return Vec3(0, 0, 0)
delta = current - self._last_point
self._last_point = current
return delta
def _compute_angle(self, axis: GizmoAxis, ray_origin: Vec3, ray_dir: Vec3) -> float:
"""Compute the angle (radians) of the ray hit around the given axis."""
if axis not in _AXIS_DIRS:
return 0.0
normal = _AXIS_DIRS[axis]
hit = _ray_plane_intersect(
ray_origin,
ray_dir,
self.position,
normal,
)
if hit is None:
return self._last_angle
local = hit - self.position
# Build a 2D basis on the plane perpendicular to *normal*
if axis is GizmoAxis.X:
return math.atan2(local.z, local.y)
if axis is GizmoAxis.Y:
return math.atan2(local.x, local.z)
# Z
return math.atan2(local.y, local.x)
def _update_rotate(self, axis: GizmoAxis, ray_origin: Vec3, ray_dir: Vec3) -> Vec3:
angle = self._compute_angle(axis, ray_origin, ray_dir)
delta_rad = angle - self._last_angle
# Handle wrap-around (-pi .. pi)
if delta_rad > math.pi:
delta_rad -= 2 * math.pi
elif delta_rad < -math.pi:
delta_rad += 2 * math.pi
self._last_angle = angle
if axis is GizmoAxis.X:
return Vec3(delta_rad, 0, 0)
if axis is GizmoAxis.Y:
return Vec3(0, delta_rad, 0)
if axis is GizmoAxis.Z:
return Vec3(0, 0, delta_rad)
return Vec3(0, 0, 0)
def _update_scale(self, axis: GizmoAxis, ray_origin: Vec3, ray_dir: Vec3) -> Vec3:
current = self._project_ray(axis, ray_origin, ray_dir)
if self._last_point is None or self._drag_start is None:
self._last_point = current
return Vec3(0, 0, 0)
# Scale factor is the signed distance change from gizmo centre
def _signed_dist(p: Vec3) -> float:
d = p - self.position
if axis in _AXIS_DIRS:
return d.dot(_AXIS_DIRS[axis])
return d.length()
prev_d = _signed_dist(self._last_point)
curr_d = _signed_dist(current)
# Avoid division by zero; use axis_length as reference distance
ref = self.axis_length if self.axis_length > 1e-6 else 1.0
factor = (curr_d - prev_d) / ref
self._last_point = current
if axis is GizmoAxis.X:
return Vec3(factor, 0, 0)
if axis is GizmoAxis.Y:
return Vec3(0, factor, 0)
if axis is GizmoAxis.Z:
return Vec3(0, 0, factor)
# Plane or ALL -- uniform along involved axes
if axis is GizmoAxis.XY:
return Vec3(factor, factor, 0)
if axis is GizmoAxis.XZ:
return Vec3(factor, 0, factor)
if axis is GizmoAxis.YZ:
return Vec3(0, factor, factor)
# ALL
return Vec3(factor, factor, factor)