"""GPU render pass for transform gizmo overlays (translate, rotate, scale).
Builds per-frame vertex arrays for gizmo visuals and renders them with
depth testing DISABLED so they always appear on top of the scene. Uses
the same vertex format as debug lines: position (vec3) + colour (vec4).
Visual feedback:
- Default axis colours: X=red, Y=green, Z=blue
- Hover: hovered axis turns yellow
- Active drag: active axis turns bright white, others dim to 30% opacity
"""
from __future__ import annotations
import logging
import math
from dataclasses import dataclass, field
from typing import Any
import numpy as np
import vulkan as vk
from ..gpu.memory import create_buffer, upload_numpy
from ..gpu.pipeline import create_gizmo_pipeline, create_shader_module
from ..materials.shader_compiler import compile_shader
__all__ = ["GizmoPass", "GizmoRenderData"]
log = logging.getLogger(__name__)
# Same vertex dtype as debug lines (position vec3 + colour vec4 = 28 bytes)
_VERTEX_DTYPE = np.dtype(
[
("position", np.float32, 3),
("colour", np.float32, 4),
]
)
# Axis index encoding (matches GizmoAxis enum values)
_AXIS_X = 0
_AXIS_Y = 1
_AXIS_Z = 2
_AXIS_XY = 3
_AXIS_XZ = 4
_AXIS_YZ = 5
_AXIS_ALL = 6
# Default axis colours (Godot-style palette)
_COLORS = {
_AXIS_X: np.array([0.96, 0.26, 0.28, 1.0], dtype=np.float32), # Warm red
_AXIS_Y: np.array([0.40, 0.84, 0.36, 1.0], dtype=np.float32), # Vivid green
_AXIS_Z: np.array([0.26, 0.52, 0.96, 1.0], dtype=np.float32), # Clear blue
_AXIS_XY: np.array([0.96, 0.86, 0.26, 0.45], dtype=np.float32), # Yellow semi-transparent
_AXIS_XZ: np.array([0.86, 0.45, 0.96, 0.45], dtype=np.float32), # Magenta semi-transparent
_AXIS_YZ: np.array([0.26, 0.96, 0.86, 0.45], dtype=np.float32), # Cyan semi-transparent
_AXIS_ALL: np.array([0.90, 0.90, 0.90, 0.35], dtype=np.float32), # White semi-transparent
}
_ACTIVE_COLOR = np.array([1.0, 1.0, 1.0, 1.0], dtype=np.float32) # Bright white
_DIM_ALPHA = 0.3 # Opacity for non-active axes during drag
_CENTER_COLOR = np.array([0.92, 0.92, 0.92, 0.85], dtype=np.float32) # Origin dot
# Geometry constants
_CONE_SEGMENTS = 16 # Number of segments for arrowhead cones
_RING_SEGMENTS = 48 # Number of segments for rotation rings
_CONE_RADIUS = 0.08 # Radius of arrowhead cone
_CONE_HEIGHT = 0.20 # Height of arrowhead cone
_CUBE_SIZE = 0.07 # Half-size of scale cube endpoints
_PLANE_HANDLE_FRAC = 0.25 # Fraction of axis_length for plane handle quads
_CENTER_CUBE_SIZE = 0.05 # Half-size of center origin cube
[docs]
@dataclass
class GizmoRenderData:
"""Per-frame data describing what gizmo to render."""
position: np.ndarray = field(default_factory=lambda: np.zeros(3, dtype=np.float32))
mode: int = 0 # 0=translate, 1=rotate, 2=scale
hover_axis: int = -1 # -1=none, 0-6 per axis encoding
active_axis: int = -1 # -1=none, 0-6 per axis encoding
view_matrix: np.ndarray = field(default_factory=lambda: np.eye(4, dtype=np.float32))
proj_matrix: np.ndarray = field(default_factory=lambda: np.eye(4, dtype=np.float32))
axis_length: float = 1.0
def _axis_colour(axis_idx: int, hover: int, active: int) -> np.ndarray:
"""Compute the colour for an axis handle given hover/active state."""
if active >= 0:
if axis_idx == active:
return _ACTIVE_COLOR.copy()
c = _COLORS[axis_idx].copy()
c[3] *= _DIM_ALPHA
return c
if axis_idx == hover:
# Brighten the axis colour rather than switching to a flat yellow
c = _COLORS[axis_idx].copy()
c[:3] = np.minimum(c[:3] * 1.3 + 0.25, 1.0)
c[3] = 1.0
return c
return _COLORS[axis_idx].copy()
# ---------------------------------------------------------------------------
# Vertex builders
# ---------------------------------------------------------------------------
def _build_translate_vertices(
pos: np.ndarray,
length: float,
hover: int,
active: int,
) -> list[tuple[np.ndarray, np.ndarray]]:
"""Build line + triangle vertices for translate gizmo (arrows + plane handles).
Returns list of (line_verts, tri_verts) as structured numpy arrays.
"""
lines: list[tuple] = []
tris: list[tuple] = []
dirs = [
np.array([1, 0, 0], dtype=np.float32),
np.array([0, 1, 0], dtype=np.float32),
np.array([0, 0, 1], dtype=np.float32),
]
# Axis shafts (lines)
for i, d in enumerate(dirs):
c = _axis_colour(i, hover, active)
start = pos.copy()
end = pos + d * length
lines.append((start, c))
lines.append((end, c))
# Arrowhead cones (triangles) at end of each axis
for i, d in enumerate(dirs):
c = _axis_colour(i, hover, active)
tip = pos + d * (length + _CONE_HEIGHT)
base_center = pos + d * length
# Build two perpendicular vectors to d
perp1, perp2 = _perpendiculars(d)
for s in range(_CONE_SEGMENTS):
a0 = 2.0 * math.pi * s / _CONE_SEGMENTS
a1 = 2.0 * math.pi * (s + 1) / _CONE_SEGMENTS
p0 = base_center + (perp1 * math.cos(a0) + perp2 * math.sin(a0)) * _CONE_RADIUS
p1 = base_center + (perp1 * math.cos(a1) + perp2 * math.sin(a1)) * _CONE_RADIUS
# Side triangle
tris.append((tip, c))
tris.append((p0, c))
tris.append((p1, c))
# Base triangle
tris.append((base_center, c))
tris.append((p1, c))
tris.append((p0, c))
# Plane handles (two-triangle quads)
_add_plane_handles(pos, length, hover, active, tris)
# Center origin cube
_add_cube(pos, _CENTER_CUBE_SIZE, _CENTER_COLOR, tris)
return _pack_verts(lines), _pack_verts(tris)
def _build_rotate_vertices(
pos: np.ndarray,
length: float,
hover: int,
active: int,
) -> list[tuple[np.ndarray, np.ndarray]]:
"""Build line + triangle vertices for rotate gizmo (3 rings + center)."""
lines: list[tuple] = []
tris: list[tuple] = []
normals = [
np.array([1, 0, 0], dtype=np.float32),
np.array([0, 1, 0], dtype=np.float32),
np.array([0, 0, 1], dtype=np.float32),
]
for i, normal in enumerate(normals):
c = _axis_colour(i, hover, active)
perp1, perp2 = _perpendiculars(normal)
for s in range(_RING_SEGMENTS):
a0 = 2.0 * math.pi * s / _RING_SEGMENTS
a1 = 2.0 * math.pi * (s + 1) / _RING_SEGMENTS
p0 = pos + (perp1 * math.cos(a0) + perp2 * math.sin(a0)) * length
p1 = pos + (perp1 * math.cos(a1) + perp2 * math.sin(a1)) * length
lines.append((p0, c))
lines.append((p1, c))
# Center origin cube
_add_cube(pos, _CENTER_CUBE_SIZE, _CENTER_COLOR, tris)
return _pack_verts(lines), _pack_verts(tris)
def _build_scale_vertices(
pos: np.ndarray,
length: float,
hover: int,
active: int,
) -> list[tuple[np.ndarray, np.ndarray]]:
"""Build line + triangle vertices for scale gizmo (lines with cube endpoints)."""
lines: list[tuple] = []
tris: list[tuple] = []
dirs = [
np.array([1, 0, 0], dtype=np.float32),
np.array([0, 1, 0], dtype=np.float32),
np.array([0, 0, 1], dtype=np.float32),
]
for i, d in enumerate(dirs):
c = _axis_colour(i, hover, active)
start = pos.copy()
end = pos + d * length
lines.append((start, c))
lines.append((end, c))
# Cube at endpoint
_add_cube(end, _CUBE_SIZE, c, tris)
# Plane handles
_add_plane_handles(pos, length, hover, active, tris)
# Center origin cube
_add_cube(pos, _CENTER_CUBE_SIZE, _CENTER_COLOR, tris)
return _pack_verts(lines), _pack_verts(tris)
def _add_plane_handles(
pos: np.ndarray,
length: float,
hover: int,
active: int,
tris: list[tuple],
) -> None:
"""Add plane-handle quads (XY, XZ, YZ) near the origin."""
size = length * _PLANE_HANDLE_FRAC
# XY plane (normal = Z)
c = _axis_colour(_AXIS_XY, hover, active)
p00 = pos.copy()
p10 = pos + np.array([size, 0, 0], dtype=np.float32)
p11 = pos + np.array([size, size, 0], dtype=np.float32)
p01 = pos + np.array([0, size, 0], dtype=np.float32)
tris.extend([(p00, c), (p10, c), (p11, c), (p00, c), (p11, c), (p01, c)])
# XZ plane (normal = Y)
c = _axis_colour(_AXIS_XZ, hover, active)
p00 = pos.copy()
p10 = pos + np.array([size, 0, 0], dtype=np.float32)
p11 = pos + np.array([size, 0, size], dtype=np.float32)
p01 = pos + np.array([0, 0, size], dtype=np.float32)
tris.extend([(p00, c), (p10, c), (p11, c), (p00, c), (p11, c), (p01, c)])
# YZ plane (normal = X)
c = _axis_colour(_AXIS_YZ, hover, active)
p00 = pos.copy()
p10 = pos + np.array([0, size, 0], dtype=np.float32)
p11 = pos + np.array([0, size, size], dtype=np.float32)
p01 = pos + np.array([0, 0, size], dtype=np.float32)
tris.extend([(p00, c), (p10, c), (p11, c), (p00, c), (p11, c), (p01, c)])
def _add_cube(
center: np.ndarray,
half: float,
colour: np.ndarray,
tris: list[tuple],
) -> None:
"""Append triangles for a solid cube at *center* with *half* extent."""
h = half
# 8 corners
corners = [
center + np.array([dx * h, dy * h, dz * h], dtype=np.float32)
for dx in (-1, 1)
for dy in (-1, 1)
for dz in (-1, 1)
]
# 6 faces, 2 triangles each — index order: front/back/left/right/top/bottom
faces = [
(0, 1, 3, 2),
(4, 6, 7, 5), # -X, +X
(0, 4, 5, 1),
(2, 3, 7, 6), # -Y, +Y
(0, 2, 6, 4),
(1, 5, 7, 3), # -Z, +Z
]
for a, b, c_idx, d in faces:
tris.extend(
[
(corners[a], colour),
(corners[b], colour),
(corners[c_idx], colour),
(corners[a], colour),
(corners[c_idx], colour),
(corners[d], colour),
]
)
def _perpendiculars(d: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""Return two unit vectors perpendicular to *d*."""
if abs(d[0]) < 0.9:
up = np.array([1, 0, 0], dtype=np.float32)
else:
up = np.array([0, 1, 0], dtype=np.float32)
perp1 = np.cross(d, up)
perp1 /= np.linalg.norm(perp1) + 1e-12
perp2 = np.cross(d, perp1)
perp2 /= np.linalg.norm(perp2) + 1e-12
return perp1, perp2
def _pack_verts(raw: list[tuple]) -> np.ndarray:
"""Pack list of (position, colour) tuples into structured numpy array."""
if not raw:
return np.zeros(0, dtype=_VERTEX_DTYPE)
arr = np.zeros(len(raw), dtype=_VERTEX_DTYPE)
for i, (p, c) in enumerate(raw):
arr[i]["position"] = p
arr[i]["colour"] = c
return arr
# ---------------------------------------------------------------------------
# GizmoPass
# ---------------------------------------------------------------------------
[docs]
class GizmoPass:
"""GPU render pass that draws gizmo overlays on top of the 3D scene.
Uses two pipelines: one for line segments (axis shafts, rotation rings)
and one for triangles (arrowheads, cubes, plane handles). Both have
depth testing disabled so gizmos always render on top.
"""
def __init__(self, engine: Any) -> None:
self._engine = engine
# GPU resources
self._line_pipeline: Any = None
self._line_pipeline_layout: Any = None
self._tri_pipeline: Any = None
self._tri_pipeline_layout: Any = None
self._vert_module: Any = None
self._frag_module: Any = None
# Vertex buffers
self._line_vb: Any = None
self._line_vb_mem: Any = None
self._line_vb_cap: int = 0
self._tri_vb: Any = None
self._tri_vb_mem: Any = None
self._tri_vb_cap: int = 0
self._ready = False
[docs]
def setup(self) -> None:
"""Create GPU pipelines for gizmo rendering."""
e = self._engine
device = e.ctx.device
shader_dir = e.shader_dir
vert_spv = compile_shader(shader_dir / "gizmo.vert")
frag_spv = compile_shader(shader_dir / "gizmo.frag")
self._vert_module = create_shader_module(device, vert_spv)
self._frag_module = create_shader_module(device, frag_spv)
self._line_pipeline, self._line_pipeline_layout = create_gizmo_pipeline(
device,
self._vert_module,
self._frag_module,
e.render_pass,
e.extent,
topology=vk.VK_PRIMITIVE_TOPOLOGY_LINE_LIST,
)
self._tri_pipeline, self._tri_pipeline_layout = create_gizmo_pipeline(
device,
self._vert_module,
self._frag_module,
e.render_pass,
e.extent,
topology=vk.VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST,
)
self._ready = True
[docs]
def render(self, cmd: Any, data: GizmoRenderData, extent: tuple[int, int]) -> None:
"""Build gizmo geometry and record draw commands."""
if not self._ready:
return
pos = np.asarray(data.position, dtype=np.float32)
h = data.hover_axis
a = data.active_axis
# Build vertices for the active mode
if data.mode == 0:
line_verts, tri_verts = _build_translate_vertices(pos, data.axis_length, h, a)
elif data.mode == 1:
line_verts, tri_verts = _build_rotate_vertices(pos, data.axis_length, h, a)
else:
line_verts, tri_verts = _build_scale_vertices(pos, data.axis_length, h, a)
e = self._engine
device = e.ctx.device
# Push constants: view + proj (transposed for column-major GLSL)
view_t = np.ascontiguousarray(data.view_matrix.T)
proj_t = np.ascontiguousarray(data.proj_matrix.T)
pc_data = view_t.tobytes() + proj_t.tobytes()
# Viewport / scissor
vk_viewport = vk.VkViewport(
x=0.0,
y=0.0,
width=float(extent[0]),
height=float(extent[1]),
minDepth=0.0,
maxDepth=1.0,
)
scissor = vk.VkRect2D(
offset=vk.VkOffset2D(x=0, y=0),
extent=vk.VkExtent2D(width=extent[0], height=extent[1]),
)
# Draw lines
if len(line_verts) > 0:
self._ensure_vb(device, line_verts.nbytes, "line")
upload_numpy(device, self._line_vb_mem, line_verts)
vk.vkCmdBindPipeline(cmd, vk.VK_PIPELINE_BIND_POINT_GRAPHICS, self._line_pipeline)
vk.vkCmdSetViewport(cmd, 0, 1, [vk_viewport])
vk.vkCmdSetScissor(cmd, 0, 1, [scissor])
e.push_constants(cmd, self._line_pipeline_layout, pc_data)
vk.vkCmdBindVertexBuffers(cmd, 0, 1, [self._line_vb], [0])
vk.vkCmdDraw(cmd, len(line_verts), 1, 0, 0)
# Draw triangles
if len(tri_verts) > 0:
self._ensure_vb(device, tri_verts.nbytes, "tri")
upload_numpy(device, self._tri_vb_mem, tri_verts)
vk.vkCmdBindPipeline(cmd, vk.VK_PIPELINE_BIND_POINT_GRAPHICS, self._tri_pipeline)
vk.vkCmdSetViewport(cmd, 0, 1, [vk_viewport])
vk.vkCmdSetScissor(cmd, 0, 1, [scissor])
e.push_constants(cmd, self._tri_pipeline_layout, pc_data)
vk.vkCmdBindVertexBuffers(cmd, 0, 1, [self._tri_vb], [0])
vk.vkCmdDraw(cmd, len(tri_verts), 1, 0, 0)
def _ensure_vb(self, device: Any, needed: int, kind: str) -> None:
"""Grow vertex buffer if needed."""
if kind == "line":
if needed <= self._line_vb_cap:
return
if self._line_vb:
vk.vkDestroyBuffer(device, self._line_vb, None)
vk.vkFreeMemory(device, self._line_vb_mem, None)
cap = max(needed, 8192)
self._line_vb, self._line_vb_mem = create_buffer(
device,
self._engine.ctx.physical_device,
cap,
vk.VK_BUFFER_USAGE_VERTEX_BUFFER_BIT,
vk.VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | vk.VK_MEMORY_PROPERTY_HOST_COHERENT_BIT,
)
self._line_vb_cap = cap
else:
if needed <= self._tri_vb_cap:
return
if self._tri_vb:
vk.vkDestroyBuffer(device, self._tri_vb, None)
vk.vkFreeMemory(device, self._tri_vb_mem, None)
cap = max(needed, 32768)
self._tri_vb, self._tri_vb_mem = create_buffer(
device,
self._engine.ctx.physical_device,
cap,
vk.VK_BUFFER_USAGE_VERTEX_BUFFER_BIT,
vk.VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | vk.VK_MEMORY_PROPERTY_HOST_COHERENT_BIT,
)
self._tri_vb_cap = cap
[docs]
def cleanup(self) -> None:
"""Release all GPU resources."""
if not self._ready:
return
device = self._engine.ctx.device
for pipeline in (self._line_pipeline, self._tri_pipeline):
if pipeline:
vk.vkDestroyPipeline(device, pipeline, None)
for layout in (self._line_pipeline_layout, self._tri_pipeline_layout):
if layout:
vk.vkDestroyPipelineLayout(device, layout, None)
if self._vert_module:
vk.vkDestroyShaderModule(device, self._vert_module, None)
if self._frag_module:
vk.vkDestroyShaderModule(device, self._frag_module, None)
for vb, mem in [(self._line_vb, self._line_vb_mem), (self._tri_vb, self._tri_vb_mem)]:
if vb:
vk.vkDestroyBuffer(device, vb, None)
if mem:
vk.vkFreeMemory(device, mem, None)
self._ready = False