"""GPU compute shader particle simulation.
Dispatches a compute shader to update particle positions, velocities, lifetimes,
and visual properties entirely on the GPU — avoiding per-frame CPU-to-GPU uploads
for particle data.
"""
from __future__ import annotations
import logging
from typing import Any
import numpy as np
import vulkan as vk
from ..gpu.memory import create_buffer, upload_numpy
from ..gpu.pipeline import create_shader_module
from ..materials.shader_compiler import compile_shader
__all__ = ["ParticleCompute"]
log = logging.getLogger(__name__)
# Must match PARTICLE_DTYPE from core/particles.py (16 floats x 4 bytes = 64 bytes)
_PARTICLE_GPU_STRIDE = 16 * 4
_WORKGROUP_SIZE = 256
# VK_WHOLE_SIZE as unsigned uint64 — the vulkan Python package exposes it as -1 (signed),
# which triggers OverflowError when assigned to a cffi unsigned field.
_VK_WHOLE_SIZE_U64 = 0xFFFFFFFFFFFFFFFF
# Push constant layout (must match particle_sim.comp):
# vec3 emitter_pos (12) + float dt (4) = 16
# vec3 gravity (12) + float damping (4) = 16
# vec3 initial_velocity (12) + float vel_spread (4) = 16
# vec4 start_colour (16) = 16
# vec4 end_colour (16) = 16
# float start_scale (4) + float end_scale (4)
# + float emission_radius(4) + uint max_particles (4) = 16
# uint frame_seed (4) + 3x uint pad (12) = 16
# Total = 112 bytes
_PUSH_CONSTANT_SIZE = 112
[docs]
class ParticleCompute:
"""GPU-based particle simulation via Vulkan compute shader.
Creates a compute pipeline that updates particle state (position, velocity,
colour, scale, lifetime) in an SSBO. The same SSBO can be bound by the
graphics particle pass for zero-copy rendering.
"""
def __init__(self, engine: Any):
self._engine = engine
self._max_particles: int = 0
self._frame_counter: int = 0
# GPU resources
self._particle_buf: Any = None
self._particle_mem: Any = None
self._compute_pipeline: Any = None
self._compute_layout: Any = None
self._compute_module: Any = None
self._desc_layout: Any = None
self._desc_pool: Any = None
self._desc_set: Any = None
self._ready = False
[docs]
def setup(self, max_particles: int = 65536) -> None:
"""Create compute pipeline, SSBO, and descriptor set.
Args:
max_particles: Maximum number of particles in the simulation buffer.
"""
self._max_particles = max_particles
e = self._engine
device = e.ctx.device
phys = e.ctx.physical_device
# Particle SSBO — device-local with host-visible for initial upload
buf_size = max_particles * _PARTICLE_GPU_STRIDE
self._particle_buf, self._particle_mem = create_buffer(
device,
phys,
buf_size,
vk.VK_BUFFER_USAGE_STORAGE_BUFFER_BIT,
vk.VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | vk.VK_MEMORY_PROPERTY_HOST_COHERENT_BIT,
)
# Descriptor set layout — single SSBO binding for compute
self._desc_layout = _create_compute_ssbo_layout(device)
# Descriptor pool and set
pool_size = vk.VkDescriptorPoolSize(
type=vk.VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
descriptorCount=1,
)
pool_info = vk.VkDescriptorPoolCreateInfo(
maxSets=1,
poolSizeCount=1,
pPoolSizes=[pool_size],
)
self._desc_pool = vk.vkCreateDescriptorPool(device, pool_info, None)
alloc_info = vk.VkDescriptorSetAllocateInfo(
descriptorPool=self._desc_pool,
descriptorSetCount=1,
pSetLayouts=[self._desc_layout],
)
self._desc_set = vk.vkAllocateDescriptorSets(device, alloc_info)[0]
# Write SSBO descriptor
buf_info = vk.VkDescriptorBufferInfo(buffer=self._particle_buf, offset=0, range=buf_size)
write = vk.VkWriteDescriptorSet(
dstSet=self._desc_set,
dstBinding=0,
dstArrayElement=0,
descriptorCount=1,
descriptorType=vk.VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
pBufferInfo=[buf_info],
)
vk.vkUpdateDescriptorSets(device, 1, [write], 0, None)
# Compile compute shader
shader_dir = e.shader_dir
comp_spv = compile_shader(shader_dir / "particle_sim.comp")
self._compute_module = create_shader_module(device, comp_spv)
# Compute pipeline
self._compute_pipeline, self._compute_layout = _create_compute_pipeline(
device,
self._compute_module,
self._desc_layout,
)
# Initialize buffer to zero (all particles dead)
zeros = np.zeros(max_particles * _PARTICLE_GPU_STRIDE // 4, dtype=np.float32)
# Set lifetime=1.0 and age=999.0 so all particles are "dead" and will respawn
particle_floats = zeros.reshape(max_particles, 16)
particle_floats[:, 13] = 1.0 # lifetime
particle_floats[:, 14] = 999.0 # age > lifetime → dead
upload_numpy(device, self._particle_mem, zeros)
self._ready = True
log.debug("Particle compute initialized (max %d particles)", max_particles)
[docs]
def dispatch(self, cmd: Any, dt: float, emitter_config: dict) -> None:
"""Dispatch the compute shader to simulate one step.
Args:
cmd: Active command buffer (must be outside a render pass).
dt: Delta time in seconds.
emitter_config: Dict with emitter parameters:
- emitter_pos: (x, y, z)
- gravity: (x, y, z)
- damping: float
- initial_velocity: (x, y, z)
- velocity_spread: float
- start_colour: (r, g, b, a)
- end_colour: (r, g, b, a)
- start_scale: float
- end_scale: float
- emission_radius: float
"""
if not self._ready:
return
self._frame_counter += 1
# Build push constants (112 bytes = 28 floats/uints)
pc = np.zeros(28, dtype=np.float32)
pos = emitter_config.get("emitter_pos", (0.0, 0.0, 0.0))
pc[0:3] = pos
pc[3] = dt
grav = emitter_config.get("gravity", (0.0, -9.8, 0.0))
pc[4:7] = grav
pc[7] = float(emitter_config.get("damping", 0.0))
vel = emitter_config.get("initial_velocity", (0.0, 5.0, 0.0))
pc[8:11] = vel
pc[11] = float(emitter_config.get("velocity_spread", 0.3))
sc = emitter_config.get("start_colour", (1.0, 1.0, 1.0, 1.0))
pc[12:16] = sc
ec = emitter_config.get("end_colour", (1.0, 1.0, 1.0, 0.0))
pc[16:20] = ec
pc[20] = float(emitter_config.get("start_scale", 1.0))
pc[21] = float(emitter_config.get("end_scale", 0.0))
pc[22] = float(emitter_config.get("emission_radius", 1.0))
# max_particles and frame_seed as uint32 — reinterpret float bits
uint_view = pc.view(np.uint32)
uint_view[23] = self._max_particles
uint_view[24] = self._frame_counter
# pad slots 25-27 are already zero
pc_bytes = pc.tobytes()
ffi = vk.ffi
cbuf = ffi.new("char[]", pc_bytes)
# Memory barrier: ensure previous frame's compute writes are visible
barrier = ffi.new("VkBufferMemoryBarrier*")
barrier.sType = vk.VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER
barrier.srcAccessMask = vk.VK_ACCESS_SHADER_WRITE_BIT
barrier.dstAccessMask = vk.VK_ACCESS_SHADER_READ_BIT | vk.VK_ACCESS_SHADER_WRITE_BIT
barrier.srcQueueFamilyIndex = vk.VK_QUEUE_FAMILY_IGNORED
barrier.dstQueueFamilyIndex = vk.VK_QUEUE_FAMILY_IGNORED
barrier.buffer = self._particle_buf
barrier.offset = 0
barrier.size = _VK_WHOLE_SIZE_U64
vk.vkCmdPipelineBarrier(
cmd,
vk.VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
vk.VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
0,
0,
None,
1,
[barrier[0]],
0,
None,
)
# Bind compute pipeline
vk.vkCmdBindPipeline(cmd, vk.VK_PIPELINE_BIND_POINT_COMPUTE, self._compute_pipeline)
vk.vkCmdBindDescriptorSets(
cmd,
vk.VK_PIPELINE_BIND_POINT_COMPUTE,
self._compute_layout,
0,
1,
[self._desc_set],
0,
None,
)
vk._vulkan.lib.vkCmdPushConstants(
cmd,
self._compute_layout,
vk.VK_SHADER_STAGE_COMPUTE_BIT,
0,
len(pc_bytes),
cbuf,
)
# Dispatch enough workgroups to cover all particles
group_count = (self._max_particles + _WORKGROUP_SIZE - 1) // _WORKGROUP_SIZE
vk.vkCmdDispatch(cmd, group_count, 1, 1)
# Barrier: compute writes → vertex shader reads
barrier.srcAccessMask = vk.VK_ACCESS_SHADER_WRITE_BIT
barrier.dstAccessMask = vk.VK_ACCESS_SHADER_READ_BIT
vk.vkCmdPipelineBarrier(
cmd,
vk.VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
vk.VK_PIPELINE_STAGE_VERTEX_SHADER_BIT,
0,
0,
None,
1,
[barrier[0]],
0,
None,
)
[docs]
def get_particle_ssbo(self) -> Any:
"""Return the particle SSBO buffer handle for use by the rendering pass."""
return self._particle_buf
[docs]
def get_particle_memory(self) -> Any:
"""Return the particle SSBO memory handle."""
return self._particle_mem
@property
def max_particles(self) -> int:
return self._max_particles
@property
def ready(self) -> bool:
return self._ready
[docs]
def upload_initial_particles(self, particles: np.ndarray) -> None:
"""Seed the GPU buffer with CPU-generated particle data.
Args:
particles: Numpy array with dtype matching PARTICLE_DTYPE.
Length must not exceed max_particles.
"""
if not self._ready:
return
count = min(len(particles), self._max_particles)
if count == 0:
return
upload_numpy(self._engine.ctx.device, self._particle_mem, particles[:count])
log.debug("Uploaded %d initial particles to GPU", count)
[docs]
def cleanup(self) -> None:
"""Destroy all GPU resources."""
if not self._ready:
return
device = self._engine.ctx.device
for obj, fn in [
(self._compute_pipeline, vk.vkDestroyPipeline),
(self._compute_layout, vk.vkDestroyPipelineLayout),
(self._compute_module, vk.vkDestroyShaderModule),
(self._desc_layout, vk.vkDestroyDescriptorSetLayout),
(self._desc_pool, vk.vkDestroyDescriptorPool),
]:
if obj:
fn(device, obj, None)
if self._particle_buf:
vk.vkDestroyBuffer(device, self._particle_buf, None)
if self._particle_mem:
vk.vkFreeMemory(device, self._particle_mem, None)
self._ready = False
log.debug("Particle compute resources cleaned up")
def _create_compute_ssbo_layout(device: Any) -> Any:
"""Create a descriptor set layout with a single compute-stage SSBO binding."""
binding = vk.VkDescriptorSetLayoutBinding(
binding=0,
descriptorType=vk.VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
descriptorCount=1,
stageFlags=vk.VK_SHADER_STAGE_COMPUTE_BIT,
)
layout_info = vk.VkDescriptorSetLayoutCreateInfo(
bindingCount=1,
pBindings=[binding],
)
return vk.vkCreateDescriptorSetLayout(device, layout_info, None)
def _create_compute_pipeline(
device: Any,
compute_module: Any,
desc_layout: Any,
) -> tuple[Any, Any]:
"""Create a compute pipeline with push constants for particle simulation.
Returns: (pipeline, pipeline_layout)
"""
ffi = vk.ffi
# Push constant range
push_range = ffi.new("VkPushConstantRange*")
push_range.stageFlags = vk.VK_SHADER_STAGE_COMPUTE_BIT
push_range.offset = 0
push_range.size = _PUSH_CONSTANT_SIZE
# Pipeline layout
set_layouts = ffi.new("VkDescriptorSetLayout[1]", [desc_layout])
layout_ci = ffi.new("VkPipelineLayoutCreateInfo*")
layout_ci.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO
layout_ci.setLayoutCount = 1
layout_ci.pSetLayouts = set_layouts
layout_ci.pushConstantRangeCount = 1
layout_ci.pPushConstantRanges = push_range
layout_out = ffi.new("VkPipelineLayout*")
result = vk._vulkan._callApi(
vk._vulkan.lib.vkCreatePipelineLayout,
device,
layout_ci,
ffi.NULL,
layout_out,
)
if result != vk.VK_SUCCESS:
raise RuntimeError(f"vkCreatePipelineLayout failed: {result}")
pipeline_layout = layout_out[0]
# Compute pipeline create info
main_name = ffi.new("char[]", b"main")
stage = ffi.new("VkPipelineShaderStageCreateInfo*")
stage.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO
stage.stage = vk.VK_SHADER_STAGE_COMPUTE_BIT
stage.module = compute_module
stage.pName = main_name
ci = ffi.new("VkComputePipelineCreateInfo*")
ci.sType = vk.VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO
ci.stage = stage[0]
ci.layout = pipeline_layout
pipeline_out = ffi.new("VkPipeline*")
result = vk._vulkan._callApi(
vk._vulkan.lib.vkCreateComputePipelines,
device,
ffi.NULL,
1,
ci,
ffi.NULL,
pipeline_out,
)
if result != vk.VK_SUCCESS:
raise RuntimeError(f"vkCreateComputePipelines failed: {result}")
log.debug("Particle compute pipeline created")
return pipeline_out[0], pipeline_layout