"""Mesh registry — manages uploaded GPU mesh buffers and returns handles."""
from __future__ import annotations
import logging
from typing import Any
import numpy as np
import vulkan as vk
from .._types import MeshHandle
from ..gpu.memory import create_buffer, upload_numpy
log = logging.getLogger(__name__)
__all__ = ["MeshRegistry"]
[docs]
class MeshRegistry:
"""Upload meshes to GPU, return handles for efficient referencing."""
def __init__(self, device: Any, physical_device: Any):
self.device = device
self.physical_device = physical_device
# id -> (vb, vb_mem, ib, ib_mem, vertex_count, index_count)
self._meshes: dict[int, tuple[Any, Any, Any, Any, int, int]] = {}
self._next_id = 0
[docs]
def register(self, vertices: np.ndarray, indices: np.ndarray) -> MeshHandle:
"""Upload mesh to GPU, return handle."""
vb, vb_mem = create_buffer(
self.device,
self.physical_device,
vertices.nbytes,
vk.VK_BUFFER_USAGE_VERTEX_BUFFER_BIT,
vk.VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | vk.VK_MEMORY_PROPERTY_HOST_COHERENT_BIT,
)
upload_numpy(self.device, vb_mem, vertices)
ib, ib_mem = create_buffer(
self.device,
self.physical_device,
indices.nbytes,
vk.VK_BUFFER_USAGE_INDEX_BUFFER_BIT,
vk.VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | vk.VK_MEMORY_PROPERTY_HOST_COHERENT_BIT,
)
upload_numpy(self.device, ib_mem, indices)
# Compute bounding sphere from local origin.
# The frustum culler places the sphere at model_matrix[:3, 3] (the node's
# world position), which corresponds to local origin (0,0,0). The radius
# must therefore be the max distance from origin to any vertex — NOT from
# the mean vertex position, which would create a mismatched sphere center.
positions = vertices["position"] if vertices.dtype.names and "position" in vertices.dtype.names else vertices
if positions.ndim == 2 and positions.shape[1] >= 3:
radius = float(np.linalg.norm(positions[:, :3], axis=1).max())
else:
radius = 1.0
mesh_id = self._next_id
self._next_id += 1
self._meshes[mesh_id] = (vb, vb_mem, ib, ib_mem, len(vertices), len(indices))
return MeshHandle(id=mesh_id, vertex_count=len(vertices), index_count=len(indices), bounding_radius=radius)
[docs]
def get_buffers(self, handle: MeshHandle) -> tuple[Any, Any]:
"""Get (vertex_buffer, index_buffer) for a mesh handle."""
vb, _, ib, _, _, _ = self._meshes[handle.id]
return vb, ib
[docs]
def destroy(self) -> None:
"""Free all mesh buffers."""
for vb, vb_mem, ib, ib_mem, _, _ in self._meshes.values():
vk.vkDestroyBuffer(self.device, vb, None)
vk.vkFreeMemory(self.device, vb_mem, None)
vk.vkDestroyBuffer(self.device, ib, None)
vk.vkFreeMemory(self.device, ib_mem, None)
self._meshes.clear()