Source code for simvx.graphics.renderer.gpu_batch

"""GPU-driven batch rendering with multi-draw indirect (MDI) and fallback."""


from __future__ import annotations

import logging
from typing import Any

import numpy as np
import vulkan as vk

from simvx.graphics._types import INDIRECT_DRAW_DTYPE
from simvx.graphics.gpu.memory import create_indirect_buffer

log = logging.getLogger(__name__)

__all__ = ["GPUBatch"]


[docs] class GPUBatch: """Manages batched draw commands. When ``use_mdi=True`` (default), uses ``vkCmdDrawIndexedIndirect`` to issue all draw commands in a single GPU call. When ``use_mdi=False``, falls back to a loop of ``vkCmdDrawIndexed`` calls — functionally identical but slower on GPUs that support MDI. Usage:: batch = GPUBatch(device, physical_device, max_draws=100) batch.add_draw(index_count=36, first_instance=0) batch.upload() batch.draw(cmd) """ def __init__(self, device: Any, physical_device: Any, max_draws: int = 1000, *, use_mdi: bool = True): self.device = device self.max_draws = max_draws self.draw_count = 0 self._use_mdi = use_mdi self.indirect_buffer, self.indirect_memory = create_indirect_buffer(device, physical_device, max_draws) self._commands = np.zeros(max_draws, dtype=INDIRECT_DRAW_DTYPE)
[docs] def add_draw( self, index_count: int, instance_count: int = 1, first_index: int = 0, vertex_offset: int = 0, first_instance: int = 0, ) -> int: """Add a draw command. Returns the draw index.""" if self.draw_count >= self.max_draws: raise RuntimeError(f"Batch full (max {self.max_draws} draws)") idx = self.draw_count cmd = self._commands[idx] cmd["index_count"] = index_count cmd["instance_count"] = instance_count cmd["first_index"] = first_index cmd["vertex_offset"] = vertex_offset cmd["first_instance"] = first_instance self.draw_count += 1 return idx
[docs] def add_draws( self, index_count: int, first_instances: np.ndarray | list[int], ) -> int: """Bulk-add draw commands sharing the same mesh — avoids per-instance Python loop. Args: index_count: Index count for the mesh (same for all draws). first_instances: (N,) array of SSBO instance indices. Returns: Batch offset of the first added draw command. """ arr = np.asarray(first_instances, dtype=np.uint32) n = len(arr) if self.draw_count + n > self.max_draws: raise RuntimeError(f"Batch full (max {self.max_draws} draws, need {self.draw_count + n})") start = self.draw_count sl = self._commands[start : start + n] sl["index_count"] = index_count sl["instance_count"] = 1 sl["first_index"] = 0 sl["vertex_offset"] = 0 sl["first_instance"] = arr self.draw_count += n return start
[docs] def upload(self) -> None: """Upload draw commands to GPU indirect buffer.""" if self.draw_count == 0: return data = self._commands[: self.draw_count] ffi = vk.ffi size = data.nbytes dst = vk.vkMapMemory(self.device, self.indirect_memory, 0, size, 0) ffi.memmove(dst, ffi.cast("void*", data.ctypes.data), size) vk.vkUnmapMemory(self.device, self.indirect_memory)
[docs] def draw(self, cmd: Any) -> None: """Record draw commands for the entire batch.""" if self.draw_count == 0: return if self._use_mdi: vk.vkCmdDrawIndexedIndirect(cmd, self.indirect_buffer, 0, self.draw_count, INDIRECT_DRAW_DTYPE.itemsize) else: self._draw_individual(cmd, 0, self.draw_count)
[docs] def draw_range(self, cmd: Any, offset: int, count: int) -> None: """Draw a sub-range of commands. Args: cmd: Vulkan command buffer offset: First draw command index (not byte offset) count: Number of draw commands to execute """ if count == 0: return if self._use_mdi: byte_offset = offset * INDIRECT_DRAW_DTYPE.itemsize vk.vkCmdDrawIndexedIndirect(cmd, self.indirect_buffer, byte_offset, count, INDIRECT_DRAW_DTYPE.itemsize) else: self._draw_individual(cmd, offset, count)
def _draw_individual(self, cmd: Any, offset: int, count: int) -> None: """Fallback: issue individual vkCmdDrawIndexed calls from the CPU-side command array.""" cmds = self._commands[offset : offset + count] for c in cmds: vk.vkCmdDrawIndexed( cmd, int(c["index_count"]), int(c["instance_count"]), int(c["first_index"]), int(c["vertex_offset"]), int(c["first_instance"]), )
[docs] def reset(self) -> None: """Clear batch for next frame.""" self.draw_count = 0
[docs] def destroy(self) -> None: """Free GPU resources.""" vk.vkDestroyBuffer(self.device, self.indirect_buffer, None) vk.vkFreeMemory(self.device, self.indirect_memory, None)