"""Forward+ tiled light culling via compute shader.
Dispatches a compute shader that assigns lights to 16x16 screen-space tiles
using depth-aware frustum culling. The output buffers (light index list and
per-tile offset/count) are bound to the fragment shader so it can loop over
only the lights relevant to each pixel's tile.
"""
from __future__ import annotations
import logging
import math
from typing import Any
import numpy as np
import vulkan as vk
from ..gpu.memory import create_buffer, upload_numpy
from ..gpu.pipeline import create_shader_module
from ..materials.shader_compiler import compile_shader
from .render_pass import RenderPass, FrameContext
__all__ = ["LightCullPass"]
log = logging.getLogger(__name__)
_TILE_SIZE = 16
# CullParams SSBO layout (must match light_cull.comp binding 0):
# uint grid_dims_x (4)
# uint grid_dims_y (4)
# uint light_count (4)
# float near_plane (4)
# float far_plane (4)
# uint _pad[3] (12)
# uint global_index (4) — atomic counter
# Total = 36 bytes, padded to 48 for alignment
_CULL_PARAMS_SIZE = 48
_VK_WHOLE_SIZE_U64 = 0xFFFFFFFFFFFFFFFF
[docs]
class LightCullPass(RenderPass):
"""GPU tiled light culling for Forward+ rendering.
Creates a compute pipeline that reads a depth texture and light SSBO,
then outputs per-tile light index lists consumed by the fragment shader.
"""
name = "light_cull"
stage = "pre_render"
inputs = ()
outputs = ("light_tiles",)
def __init__(self, engine: Any):
super().__init__()
self._engine = engine
self._ready = False
# Pipeline
self._pipeline: Any = None
self._layout: Any = None
self._module: Any = None
# Descriptors
self._desc_pool: Any = None
self._desc_layout: Any = None
self._desc_set: Any = None
self._depth_sampler: Any = None
# Buffers
self._cull_params_buf: Any = None
self._cull_params_mem: Any = None
self._light_index_buf: Any = None
self._light_index_mem: Any = None
self._tile_buf: Any = None
self._tile_mem: Any = None
# Dimensions
self._grid_x: int = 0
self._grid_y: int = 0
self._max_tiles: int = 0
self._max_light_indices: int = 0
[docs]
def setup(self, width: int, height: int, max_lights: int = 256) -> None:
"""Create compute pipeline, SSBOs, and descriptor set.
Args:
width: Viewport width in pixels.
height: Viewport height in pixels.
max_lights: Maximum number of lights in the scene.
"""
e = self._engine
device = e._device
phys = e._physical_device
self._grid_x = math.ceil(width / _TILE_SIZE)
self._grid_y = math.ceil(height / _TILE_SIZE)
self._max_tiles = self._grid_x * self._grid_y
# Worst case: every tile has every light
self._max_light_indices = self._max_tiles * min(max_lights, 256)
# CullParams SSBO (binding 0) — includes atomic counter
self._cull_params_buf, self._cull_params_mem = create_buffer(
device, phys, _CULL_PARAMS_SIZE,
vk.VK_BUFFER_USAGE_STORAGE_BUFFER_BIT,
vk.VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | vk.VK_MEMORY_PROPERTY_HOST_COHERENT_BIT,
)
# Light index SSBO (binding 3) — output
index_buf_size = max(self._max_light_indices * 4, 4) # uint32 per entry
self._light_index_buf, self._light_index_mem = create_buffer(
device, phys, index_buf_size,
vk.VK_BUFFER_USAGE_STORAGE_BUFFER_BIT,
vk.VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT,
)
# Tile SSBO (binding 4) — output: uvec2 per tile
tile_buf_size = max(self._max_tiles * 8, 8) # uvec2 = 8 bytes
self._tile_buf, self._tile_mem = create_buffer(
device, phys, tile_buf_size,
vk.VK_BUFFER_USAGE_STORAGE_BUFFER_BIT,
vk.VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT,
)
# Depth sampler
sampler_ci = vk.VkSamplerCreateInfo(
magFilter=vk.VK_FILTER_NEAREST,
minFilter=vk.VK_FILTER_NEAREST,
addressModeU=vk.VK_SAMPLER_ADDRESS_MODE_CLAMP_TO_EDGE,
addressModeV=vk.VK_SAMPLER_ADDRESS_MODE_CLAMP_TO_EDGE,
addressModeW=vk.VK_SAMPLER_ADDRESS_MODE_CLAMP_TO_EDGE,
)
self._depth_sampler = vk.vkCreateSampler(device, sampler_ci, None)
# Descriptor set layout: 5 bindings
# 0: CullParams SSBO (compute)
# 1: depth texture sampler (compute)
# 2: LightBuffer SSBO (compute, read-only)
# 3: LightIndexBuffer SSBO (compute, write)
# 4: TileBuffer SSBO (compute, write)
bindings = [
vk.VkDescriptorSetLayoutBinding(
binding=0,
descriptorType=vk.VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
descriptorCount=1,
stageFlags=vk.VK_SHADER_STAGE_COMPUTE_BIT,
),
vk.VkDescriptorSetLayoutBinding(
binding=1,
descriptorType=vk.VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
descriptorCount=1,
stageFlags=vk.VK_SHADER_STAGE_COMPUTE_BIT,
),
vk.VkDescriptorSetLayoutBinding(
binding=2,
descriptorType=vk.VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
descriptorCount=1,
stageFlags=vk.VK_SHADER_STAGE_COMPUTE_BIT,
),
vk.VkDescriptorSetLayoutBinding(
binding=3,
descriptorType=vk.VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
descriptorCount=1,
stageFlags=vk.VK_SHADER_STAGE_COMPUTE_BIT,
),
vk.VkDescriptorSetLayoutBinding(
binding=4,
descriptorType=vk.VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
descriptorCount=1,
stageFlags=vk.VK_SHADER_STAGE_COMPUTE_BIT,
),
]
layout_ci = vk.VkDescriptorSetLayoutCreateInfo(
bindingCount=len(bindings),
pBindings=bindings,
)
self._desc_layout = vk.vkCreateDescriptorSetLayout(device, layout_ci, None)
# Descriptor pool
pool_sizes = [
vk.VkDescriptorPoolSize(type=vk.VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, descriptorCount=4),
vk.VkDescriptorPoolSize(type=vk.VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, descriptorCount=1),
]
pool_ci = vk.VkDescriptorPoolCreateInfo(maxSets=1, poolSizeCount=len(pool_sizes), pPoolSizes=pool_sizes)
self._desc_pool = vk.vkCreateDescriptorPool(device, pool_ci, None)
alloc_info = vk.VkDescriptorSetAllocateInfo(
descriptorPool=self._desc_pool,
descriptorSetCount=1,
pSetLayouts=[self._desc_layout],
)
self._desc_set = vk.vkAllocateDescriptorSets(device, alloc_info)[0]
# Write static buffer descriptors (bindings 0, 3, 4)
from ..gpu.descriptors import DescriptorWriteBatch
with DescriptorWriteBatch(device) as batch:
batch.ssbo(self._desc_set, 0, self._cull_params_buf, _CULL_PARAMS_SIZE)
batch.ssbo(self._desc_set, 3, self._light_index_buf, index_buf_size)
batch.ssbo(self._desc_set, 4, self._tile_buf, tile_buf_size)
# Compile compute shader and create pipeline
shader_dir = e.shader_dir
comp_spv = compile_shader(shader_dir / "light_cull.comp")
self._module = create_shader_module(device, comp_spv)
self._pipeline, self._layout = _create_cull_pipeline(device, self._module, self._desc_layout)
self._ready = True
log.debug("Light cull pass initialised (%dx%d tiles)", self._grid_x, self._grid_y)
[docs]
def update_descriptors(self, depth_view: Any, light_ssbo: Any, light_ssbo_size: int) -> None:
"""Update depth texture and light SSBO descriptors.
Called when the depth view or light buffer changes (e.g. on resize).
"""
device = self._engine._device
# Binding 1: depth texture
img_info = vk.VkDescriptorImageInfo(
sampler=self._depth_sampler,
imageView=depth_view,
imageLayout=vk.VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL,
)
write_depth = vk.VkWriteDescriptorSet(
dstSet=self._desc_set,
dstBinding=1,
dstArrayElement=0,
descriptorCount=1,
descriptorType=vk.VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
pImageInfo=[img_info],
)
# Binding 2: light data SSBO
buf_info = vk.VkDescriptorBufferInfo(buffer=light_ssbo, offset=0, range=light_ssbo_size)
write_lights = vk.VkWriteDescriptorSet(
dstSet=self._desc_set,
dstBinding=2,
dstArrayElement=0,
descriptorCount=1,
descriptorType=vk.VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
pBufferInfo=[buf_info],
)
vk.vkUpdateDescriptorSets(device, 2, [write_depth, write_lights], 0, None)
[docs]
def record(self, cmd: Any, frame: FrameContext) -> None:
"""RenderPass interface — delegates to _dispatch_impl with frame data."""
if self._ready:
self._dispatch_impl(cmd, frame.camera_view, frame.camera_proj, frame.light_count, frame.near, frame.far)
def _dispatch_impl(
self,
cmd: Any,
view_mat: np.ndarray,
proj_mat: np.ndarray,
light_count: int,
near: float,
far: float,
) -> None:
"""Record the light culling compute dispatch into a command buffer.
Must be called outside a render pass, after depth is available and
before the main geometry pass that reads the tile data.
"""
if not self._ready or light_count == 0:
return
device = self._engine._device
# Upload cull parameters (includes resetting the atomic counter to 0)
params = np.zeros(12, dtype=np.uint32) # 48 bytes = 12 uint32
params[0] = self._grid_x
params[1] = self._grid_y
params[2] = light_count
params.view(np.float32)[3] = near
params.view(np.float32)[4] = far
# params[5..7] = padding, params[8] = atomic counter = 0
upload_numpy(device, self._cull_params_mem, params)
# Memory barrier: ensure cull params upload is visible to compute
host_barrier = vk.VkMemoryBarrier(
srcAccessMask=vk.VK_ACCESS_HOST_WRITE_BIT,
dstAccessMask=vk.VK_ACCESS_SHADER_READ_BIT | vk.VK_ACCESS_SHADER_WRITE_BIT,
)
vk.vkCmdPipelineBarrier(
cmd,
vk.VK_PIPELINE_STAGE_HOST_BIT,
vk.VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
0, 1, [host_barrier], 0, None, 0, None,
)
# Bind pipeline and descriptors
vk.vkCmdBindPipeline(cmd, vk.VK_PIPELINE_BIND_POINT_COMPUTE, self._pipeline)
vk.vkCmdBindDescriptorSets(
cmd, vk.VK_PIPELINE_BIND_POINT_COMPUTE, self._layout,
0, 1, [self._desc_set], 0, None,
)
# Push constants: view + proj (transposed for column-major GLSL)
ffi = vk.ffi
view_t = np.ascontiguousarray(view_mat.T, dtype=np.float32)
proj_t = np.ascontiguousarray(proj_mat.T, dtype=np.float32)
pc_bytes = view_t.tobytes() + proj_t.tobytes()
cbuf = ffi.new("char[]", pc_bytes)
vk._vulkan.lib.vkCmdPushConstants(
cmd, self._layout, vk.VK_SHADER_STAGE_COMPUTE_BIT,
0, len(pc_bytes), cbuf,
)
# Dispatch: one workgroup per tile
vk.vkCmdDispatch(cmd, self._grid_x, self._grid_y, 1)
# Barrier: compute writes → fragment reads
compute_barrier = vk.VkMemoryBarrier(
srcAccessMask=vk.VK_ACCESS_SHADER_WRITE_BIT,
dstAccessMask=vk.VK_ACCESS_SHADER_READ_BIT,
)
vk.vkCmdPipelineBarrier(
cmd,
vk.VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
vk.VK_PIPELINE_STAGE_FRAGMENT_SHADER_BIT,
0, 1, [compute_barrier], 0, None, 0, None,
)
@property
def ready(self) -> bool:
return self._ready
@property
def tile_buffer(self) -> Any:
"""Tile SSBO handle for binding in fragment shader descriptors."""
return self._tile_buf
@property
def tile_buffer_size(self) -> int:
return self._max_tiles * 8
@property
def light_index_buffer(self) -> Any:
"""Light index SSBO handle for binding in fragment shader descriptors."""
return self._light_index_buf
@property
def light_index_buffer_size(self) -> int:
return self._max_light_indices * 4
@property
def grid_dims(self) -> tuple[int, int]:
return (self._grid_x, self._grid_y)
[docs]
def resize(self, width: int, height: int) -> None:
"""Handle viewport resize — recreate tile/index buffers if grid changed."""
new_gx = math.ceil(width / _TILE_SIZE)
new_gy = math.ceil(height / _TILE_SIZE)
if new_gx == self._grid_x and new_gy == self._grid_y:
return
# Tear down and recreate with new dimensions
self.destroy()
self.setup(width, height)
[docs]
def destroy(self) -> None:
"""Destroy all GPU resources."""
if not self._ready:
return
device = self._engine._device
for obj, fn in [
(self._pipeline, vk.vkDestroyPipeline),
(self._layout, vk.vkDestroyPipelineLayout),
(self._module, vk.vkDestroyShaderModule),
(self._desc_layout, vk.vkDestroyDescriptorSetLayout),
(self._desc_pool, vk.vkDestroyDescriptorPool),
(self._depth_sampler, vk.vkDestroySampler),
]:
if obj:
fn(device, obj, None)
for buf, mem in [
(self._cull_params_buf, self._cull_params_mem),
(self._light_index_buf, self._light_index_mem),
(self._tile_buf, self._tile_mem),
]:
if buf:
vk.vkDestroyBuffer(device, buf, None)
if mem:
vk.vkFreeMemory(device, mem, None)
self._ready = False
log.debug("Light cull pass resources cleaned up")
def _create_cull_pipeline(device: Any, compute_module: Any, desc_layout: Any) -> tuple[Any, Any]:
"""Create the light culling compute pipeline.
Returns: (pipeline, pipeline_layout)
"""
ffi = vk.ffi
# Push constant: 2x mat4 = 128 bytes
push_range = ffi.new("VkPushConstantRange*")
push_range.stageFlags = vk.VK_SHADER_STAGE_COMPUTE_BIT
push_range.offset = 0
push_range.size = 128
set_layouts = ffi.new("VkDescriptorSetLayout[1]", [desc_layout])
layout_ci = ffi.new("VkPipelineLayoutCreateInfo*")
layout_ci.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO
layout_ci.setLayoutCount = 1
layout_ci.pSetLayouts = set_layouts
layout_ci.pushConstantRangeCount = 1
layout_ci.pPushConstantRanges = push_range
layout_out = ffi.new("VkPipelineLayout*")
result = vk._vulkan._callApi(
vk._vulkan.lib.vkCreatePipelineLayout,
device, layout_ci, ffi.NULL, layout_out,
)
if result != vk.VK_SUCCESS:
raise RuntimeError(f"vkCreatePipelineLayout failed: {result}")
pipeline_layout = layout_out[0]
main_name = ffi.new("char[]", b"main")
stage = ffi.new("VkPipelineShaderStageCreateInfo*")
stage.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO
stage.stage = vk.VK_SHADER_STAGE_COMPUTE_BIT
stage.module = compute_module
stage.pName = main_name
ci = ffi.new("VkComputePipelineCreateInfo*")
ci.sType = vk.VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO
ci.stage = stage[0]
ci.layout = pipeline_layout
pipeline_out = ffi.new("VkPipeline*")
result = vk._vulkan._callApi(
vk._vulkan.lib.vkCreateComputePipelines,
device, ffi.NULL, 1, ci, ffi.NULL, pipeline_out,
)
if result != vk.VK_SUCCESS:
raise RuntimeError(f"vkCreateComputePipelines failed: {result}")
log.debug("Light cull compute pipeline created")
return pipeline_out[0], pipeline_layout