"""GPU-driven batch rendering with multi-draw indirect (MDI) and fallback."""
from __future__ import annotations
import logging
from typing import Any
import numpy as np
import vulkan as vk
from simvx.graphics._types import INDIRECT_DRAW_DTYPE
from simvx.graphics.gpu.memory import create_indirect_buffer
log = logging.getLogger(__name__)
__all__ = ["GPUBatch"]
[docs]
class GPUBatch:
"""Manages batched draw commands.
When ``use_mdi=True`` (default), uses ``vkCmdDrawIndexedIndirect`` to issue
all draw commands in a single GPU call. When ``use_mdi=False``, falls back
to a loop of ``vkCmdDrawIndexed`` calls — functionally identical but slower
on GPUs that support MDI.
Usage::
batch = GPUBatch(device, physical_device, max_draws=100)
batch.add_draw(index_count=36, first_instance=0)
batch.upload()
batch.draw(cmd)
"""
def __init__(self, device: Any, physical_device: Any, max_draws: int = 1000, *, use_mdi: bool = True):
self.device = device
self.max_draws = max_draws
self.draw_count = 0
self._use_mdi = use_mdi
self.indirect_buffer, self.indirect_memory = create_indirect_buffer(device, physical_device, max_draws)
self._commands = np.zeros(max_draws, dtype=INDIRECT_DRAW_DTYPE)
[docs]
def add_draw(
self,
index_count: int,
instance_count: int = 1,
first_index: int = 0,
vertex_offset: int = 0,
first_instance: int = 0,
) -> int:
"""Add a draw command. Returns the draw index."""
if self.draw_count >= self.max_draws:
raise RuntimeError(f"Batch full (max {self.max_draws} draws)")
idx = self.draw_count
cmd = self._commands[idx]
cmd["index_count"] = index_count
cmd["instance_count"] = instance_count
cmd["first_index"] = first_index
cmd["vertex_offset"] = vertex_offset
cmd["first_instance"] = first_instance
self.draw_count += 1
return idx
[docs]
def add_draws(
self,
index_count: int,
first_instances: np.ndarray | list[int],
) -> int:
"""Bulk-add draw commands sharing the same mesh — avoids per-instance Python loop.
Args:
index_count: Index count for the mesh (same for all draws).
first_instances: (N,) array of SSBO instance indices.
Returns:
Batch offset of the first added draw command.
"""
arr = np.asarray(first_instances, dtype=np.uint32)
n = len(arr)
if self.draw_count + n > self.max_draws:
raise RuntimeError(f"Batch full (max {self.max_draws} draws, need {self.draw_count + n})")
start = self.draw_count
sl = self._commands[start : start + n]
sl["index_count"] = index_count
sl["instance_count"] = 1
sl["first_index"] = 0
sl["vertex_offset"] = 0
sl["first_instance"] = arr
self.draw_count += n
return start
[docs]
def upload(self) -> None:
"""Upload draw commands to GPU indirect buffer."""
if self.draw_count == 0:
return
data = self._commands[: self.draw_count]
ffi = vk.ffi
size = data.nbytes
dst = vk.vkMapMemory(self.device, self.indirect_memory, 0, size, 0)
ffi.memmove(dst, ffi.cast("void*", data.ctypes.data), size)
vk.vkUnmapMemory(self.device, self.indirect_memory)
[docs]
def draw(self, cmd: Any) -> None:
"""Record draw commands for the entire batch."""
if self.draw_count == 0:
return
if self._use_mdi:
vk.vkCmdDrawIndexedIndirect(cmd, self.indirect_buffer, 0, self.draw_count, INDIRECT_DRAW_DTYPE.itemsize)
else:
self._draw_individual(cmd, 0, self.draw_count)
[docs]
def draw_range(self, cmd: Any, offset: int, count: int) -> None:
"""Draw a sub-range of commands.
Args:
cmd: Vulkan command buffer
offset: First draw command index (not byte offset)
count: Number of draw commands to execute
"""
if count == 0:
return
if self._use_mdi:
byte_offset = offset * INDIRECT_DRAW_DTYPE.itemsize
vk.vkCmdDrawIndexedIndirect(cmd, self.indirect_buffer, byte_offset, count, INDIRECT_DRAW_DTYPE.itemsize)
else:
self._draw_individual(cmd, offset, count)
def _draw_individual(self, cmd: Any, offset: int, count: int) -> None:
"""Fallback: issue individual vkCmdDrawIndexed calls from the CPU-side command array."""
cmds = self._commands[offset : offset + count]
for c in cmds:
vk.vkCmdDrawIndexed(
cmd,
int(c["index_count"]),
int(c["instance_count"]),
int(c["first_index"]),
int(c["vertex_offset"]),
int(c["first_instance"]),
)
[docs]
def reset(self) -> None:
"""Clear batch for next frame."""
self.draw_count = 0
[docs]
def destroy(self) -> None:
"""Free GPU resources."""
vk.vkDestroyBuffer(self.device, self.indirect_buffer, None)
vk.vkFreeMemory(self.device, self.indirect_memory, None)