"""Cascaded Shadow Map (CSM) rendering pass."""
from __future__ import annotations
import logging
from typing import Any
import numpy as np
import vulkan as vk
from .gpu_batch import GPUBatch
from .pass_helpers import load_shader_modules
from .passes import create_shadow_pass
__all__ = ["ShadowPass"]
log = logging.getLogger(__name__)
CASCADE_COUNT = 3
SHADOW_MAP_SIZE = 2048
DEPTH_FORMAT = vk.VK_FORMAT_D32_SFLOAT
[docs]
class ShadowPass:
"""Renders depth from directional light's POV into a cascaded shadow map atlas.
Atlas layout: CASCADE_COUNT images side-by-side horizontally.
Total size: SHADOW_MAP_SIZE * CASCADE_COUNT × SHADOW_MAP_SIZE.
"""
def __init__(self, engine: Any):
self._engine = engine
self._render_pass: Any = None
self._framebuffer: Any = None
self._pipeline: Any = None
self._pipeline_layout: Any = None
self._vert_module: Any = None
self._frag_module: Any = None
self._depth_image: Any = None
self._depth_memory: Any = None
self._depth_view: Any = None
self._sampler: Any = None
self._texture_index: int = -1
self._ready = False
self._batch: GPUBatch | None = None
# Output: cascade VP matrices and split distances for forward pass
self.cascade_vps = np.zeros((CASCADE_COUNT, 4, 4), dtype=np.float32)
self.cascade_splits = np.zeros(CASCADE_COUNT + 1, dtype=np.float32)
[docs]
def setup(self, ssbo_layout: Any) -> None:
"""Initialize shadow map resources."""
e = self._engine
device = e.ctx.device
phys = e.ctx.physical_device
atlas_w = SHADOW_MAP_SIZE * CASCADE_COUNT
atlas_h = SHADOW_MAP_SIZE
# Render pass (depth-only)
self._render_pass = create_shadow_pass(device)
# Depth image (atlas)
img_info = vk.VkImageCreateInfo(
imageType=vk.VK_IMAGE_TYPE_2D,
format=DEPTH_FORMAT,
extent=vk.VkExtent3D(width=atlas_w, height=atlas_h, depth=1),
mipLevels=1, arrayLayers=1,
samples=vk.VK_SAMPLE_COUNT_1_BIT,
tiling=vk.VK_IMAGE_TILING_OPTIMAL,
usage=(vk.VK_IMAGE_USAGE_DEPTH_STENCIL_ATTACHMENT_BIT
| vk.VK_IMAGE_USAGE_SAMPLED_BIT),
sharingMode=vk.VK_SHARING_MODE_EXCLUSIVE,
initialLayout=vk.VK_IMAGE_LAYOUT_UNDEFINED,
)
self._depth_image = vk.vkCreateImage(device, img_info, None)
mem_reqs = vk.vkGetImageMemoryRequirements(device, self._depth_image)
from ..gpu.memory import _find_memory_type
alloc_info = vk.VkMemoryAllocateInfo(
allocationSize=mem_reqs.size,
memoryTypeIndex=_find_memory_type(
phys, mem_reqs.memoryTypeBits,
vk.VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT,
),
)
self._depth_memory = vk.vkAllocateMemory(device, alloc_info, None)
vk.vkBindImageMemory(device, self._depth_image, self._depth_memory, 0)
# Image view
view_ci = vk.VkImageViewCreateInfo(
image=self._depth_image,
viewType=vk.VK_IMAGE_VIEW_TYPE_2D,
format=DEPTH_FORMAT,
subresourceRange=vk.VkImageSubresourceRange(
aspectMask=vk.VK_IMAGE_ASPECT_DEPTH_BIT,
baseMipLevel=0, levelCount=1,
baseArrayLayer=0, layerCount=1,
),
)
self._depth_view = vk.vkCreateImageView(device, view_ci, None)
# Framebuffer
fb_ci = vk.VkFramebufferCreateInfo(
renderPass=self._render_pass,
attachmentCount=1,
pAttachments=[self._depth_view],
width=atlas_w, height=atlas_h, layers=1,
)
self._framebuffer = vk.vkCreateFramebuffer(device, fb_ci, None)
# Sampler (comparison sampler for PCF)
sampler_ci = vk.VkSamplerCreateInfo(
magFilter=vk.VK_FILTER_LINEAR,
minFilter=vk.VK_FILTER_LINEAR,
addressModeU=vk.VK_SAMPLER_ADDRESS_MODE_CLAMP_TO_BORDER,
addressModeV=vk.VK_SAMPLER_ADDRESS_MODE_CLAMP_TO_BORDER,
addressModeW=vk.VK_SAMPLER_ADDRESS_MODE_CLAMP_TO_BORDER,
borderColor=vk.VK_BORDER_COLOR_FLOAT_OPAQUE_WHITE,
compareEnable=vk.VK_FALSE,
anisotropyEnable=vk.VK_FALSE,
unnormalizedCoordinates=vk.VK_FALSE,
mipmapMode=vk.VK_SAMPLER_MIPMAP_MODE_NEAREST,
)
self._sampler = vk.vkCreateSampler(device, sampler_ci, None)
# Register shadow map in bindless texture array (with own sampler for depth)
from ..gpu.descriptors import write_texture_descriptor
if not e.texture_descriptor_set:
e._init_texture_system()
self._texture_index = e._next_texture_index
write_texture_descriptor(
device, e.texture_descriptor_set,
self._texture_index, self._depth_view, self._sampler,
)
e._next_texture_index += 1
# Shadow pipeline
self._vert_module, self._frag_module = load_shader_modules(
device, e.shader_dir, "shadow.vert", "shadow.frag",
)
self._pipeline, self._pipeline_layout = _create_shadow_pipeline(
device, self._vert_module, self._frag_module,
self._render_pass, ssbo_layout,
)
self._batch = GPUBatch(device, phys, max_draws=10_000)
self._ready = True
log.debug("Shadow pass initialized (%dx%d atlas, %d cascades)",
atlas_w, atlas_h, CASCADE_COUNT)
@property
def shadow_texture_index(self) -> int:
"""Bindless index of the shadow map atlas texture."""
return self._texture_index
[docs]
def compute_cascades(
self,
view: np.ndarray,
proj: np.ndarray,
light_dir: np.ndarray,
near: float = 0.0,
far: float = 0.0,
) -> None:
"""Compute cascade split distances and light-space VP matrices.
Uses practical split scheme (logarithmic + linear blend).
If near/far are 0, they are extracted from the projection matrix.
"""
# Extract near/far from projection matrix if not provided
if near <= 0 or far <= 0:
# Vulkan perspective: proj[2,2] = far/(near-far), proj[2,3] = near*far/(near-far)
p22, p23 = proj[2, 2], proj[2, 3]
if abs(p22) > 1e-6:
near = p23 / p22
far = p23 / (p22 + 1.0)
else:
near, far = 0.1, 100.0
# Clamp far to reasonable shadow distance (very far = poor resolution)
far = min(far, 300.0)
lambda_split = 0.5 # even blend — better far-cascade coverage
# Compute split distances
self.cascade_splits[0] = near
for i in range(CASCADE_COUNT):
p = (i + 1) / CASCADE_COUNT
log_split = near * (far / near) ** p
lin_split = near + (far - near) * p
self.cascade_splits[i + 1] = lambda_split * log_split + (1.0 - lambda_split) * lin_split
# Inverse VP for unprojecting frustum corners
inv_vp = np.linalg.inv(proj @ view)
for c in range(CASCADE_COUNT):
# NDC corners of this cascade's frustum slice
z_near = self.cascade_splits[c]
z_far = self.cascade_splits[c + 1]
# Linearize depth to NDC (assumes perspective projection)
# NDC z = (far * near) / (far - z * (far - near)) mapped to [0, 1]
# For Vulkan: z_ndc = proj[2][2] * z_eye + proj[2][3]) / (-z_eye)
ndc_near = _depth_to_ndc(z_near, proj)
ndc_far = _depth_to_ndc(z_far, proj)
# 8 frustum corners in NDC
corners_ndc = np.array([
[-1, -1, ndc_near, 1], [1, -1, ndc_near, 1],
[-1, 1, ndc_near, 1], [1, 1, ndc_near, 1],
[-1, -1, ndc_far, 1], [1, -1, ndc_far, 1],
[-1, 1, ndc_far, 1], [1, 1, ndc_far, 1],
], dtype=np.float32)
# Unproject to world space
corners_world = (inv_vp @ corners_ndc.T).T
corners_world /= corners_world[:, 3:4]
corners_xyz = corners_world[:, :3]
# Compute centroid
centroid = corners_xyz.mean(axis=0)
# Light view matrix (looking along light_dir from above centroid)
light_dir_n = light_dir / np.linalg.norm(light_dir)
up = np.array([0, 1, 0], dtype=np.float32)
if abs(np.dot(light_dir_n, up)) > 0.99:
up = np.array([1, 0, 0], dtype=np.float32)
right = np.cross(up, light_dir_n)
right /= np.linalg.norm(right)
up = np.cross(light_dir_n, right)
light_view = np.eye(4, dtype=np.float32)
light_view[0, :3] = right
light_view[1, :3] = up
light_view[2, :3] = light_dir_n
light_view[:3, 3] = -light_view[:3, :3] @ centroid
# Transform corners to light space
corners_ls = (light_view @ np.hstack([corners_xyz, np.ones((8, 1), dtype=np.float32)]).T).T
mins = corners_ls[:, :3].min(axis=0)
maxs = corners_ls[:, :3].max(axis=0)
# Orthographic projection around the bounding box
# Add some padding for shadow acne prevention
pad = (maxs - mins) * 0.1
mins -= pad
maxs += pad
light_proj = np.zeros((4, 4), dtype=np.float32)
light_proj[0, 0] = 2.0 / (maxs[0] - mins[0])
light_proj[1, 1] = 2.0 / (maxs[1] - mins[1])
light_proj[2, 2] = 1.0 / (maxs[2] - mins[2])
light_proj[0, 3] = -(maxs[0] + mins[0]) / (maxs[0] - mins[0])
light_proj[1, 3] = -(maxs[1] + mins[1]) / (maxs[1] - mins[1])
light_proj[2, 3] = -mins[2] / (maxs[2] - mins[2])
light_proj[3, 3] = 1.0
self.cascade_vps[c] = (light_proj @ light_view).T # Transpose for GLSL column-major
[docs]
def render(
self,
cmd: Any,
instances: list,
ssbo_set: Any,
mesh_registry: Any,
) -> None:
"""Record shadow depth rendering commands."""
if not self._ready or not instances:
return
atlas_w = SHADOW_MAP_SIZE * CASCADE_COUNT
atlas_h = SHADOW_MAP_SIZE
# Begin shadow render pass
clear = vk.VkClearValue(
depthStencil=vk.VkClearDepthStencilValue(depth=1.0, stencil=0),
)
rp_info = vk.VkRenderPassBeginInfo(
renderPass=self._render_pass,
framebuffer=self._framebuffer,
renderArea=vk.VkRect2D(
offset=vk.VkOffset2D(x=0, y=0),
extent=vk.VkExtent2D(width=atlas_w, height=atlas_h),
),
clearValueCount=1,
pClearValues=[clear],
)
vk.vkCmdBeginRenderPass(cmd, rp_info, vk.VK_SUBPASS_CONTENTS_INLINE)
vk.vkCmdBindPipeline(cmd, vk.VK_PIPELINE_BIND_POINT_GRAPHICS, self._pipeline)
# Bind transform SSBO
vk.vkCmdBindDescriptorSets(
cmd, vk.VK_PIPELINE_BIND_POINT_GRAPHICS, self._pipeline_layout,
0, 1, [ssbo_set], 0, None,
)
# Group instances by mesh
mesh_groups: dict[int, list[int]] = {}
for i, (mesh_handle, _, _, _) in enumerate(instances):
mesh_groups.setdefault(mesh_handle.id, []).append(i)
for cascade in range(CASCADE_COUNT):
# Set viewport to this cascade's region
vk_vp = vk.VkViewport(
x=float(cascade * SHADOW_MAP_SIZE), y=0.0,
width=float(SHADOW_MAP_SIZE), height=float(SHADOW_MAP_SIZE),
minDepth=0.0, maxDepth=1.0,
)
vk.vkCmdSetViewport(cmd, 0, 1, [vk_vp])
scissor = vk.VkRect2D(
offset=vk.VkOffset2D(x=cascade * SHADOW_MAP_SIZE, y=0),
extent=vk.VkExtent2D(width=SHADOW_MAP_SIZE, height=SHADOW_MAP_SIZE),
)
vk.vkCmdSetScissor(cmd, 0, 1, [scissor])
# Push light VP matrix for this cascade (column-major, already transposed)
pc_data = np.ascontiguousarray(self.cascade_vps[cascade]).tobytes()
ffi = vk.ffi
cbuf = ffi.new("char[]", pc_data)
vk._vulkan.lib.vkCmdPushConstants(
cmd, self._pipeline_layout,
vk.VK_SHADER_STAGE_VERTEX_BIT,
0, 64, cbuf,
)
# Build indirect draw commands for this cascade and execute via multi-draw
self._batch.reset()
group_ranges: list[tuple[Any, int, int]] = [] # (handle, batch_offset, count)
for _mesh_id, indices in mesh_groups.items():
mesh_handle = instances[indices[0]][0]
batch_offset = self._batch.add_draws(mesh_handle.index_count, indices)
group_ranges.append((mesh_handle, batch_offset, len(indices)))
self._batch.upload()
for mesh_handle, batch_offset, count in group_ranges:
vb, ib = mesh_registry.get_buffers(mesh_handle)
vk.vkCmdBindVertexBuffers(cmd, 0, 1, [vb], [0])
vk.vkCmdBindIndexBuffer(cmd, ib, 0, vk.VK_INDEX_TYPE_UINT32)
self._batch.draw_range(cmd, batch_offset, count)
vk.vkCmdEndRenderPass(cmd)
[docs]
def cleanup(self) -> None:
"""Release all GPU resources."""
if not self._ready:
return
device = self._engine.ctx.device
for obj, fn in [
(self._framebuffer, vk.vkDestroyFramebuffer),
(self._pipeline, vk.vkDestroyPipeline),
(self._pipeline_layout, vk.vkDestroyPipelineLayout),
(self._vert_module, vk.vkDestroyShaderModule),
(self._frag_module, vk.vkDestroyShaderModule),
(self._depth_view, vk.vkDestroyImageView),
(self._depth_image, vk.vkDestroyImage),
(self._sampler, vk.vkDestroySampler),
(self._render_pass, vk.vkDestroyRenderPass),
]:
if obj:
fn(device, obj, None)
if self._depth_memory:
vk.vkFreeMemory(device, self._depth_memory, None)
if self._batch:
self._batch.destroy()
self._ready = False
def _depth_to_ndc(z_eye: float, proj: np.ndarray) -> float:
"""Convert eye-space depth to Vulkan NDC z [0, 1]."""
# For Vulkan perspective: ndc_z = (proj[2,2] * (-z) + proj[2,3]) / z
clip_z = proj[2, 2] * (-z_eye) + proj[2, 3]
clip_w = z_eye # -(-z_eye)
return clip_z / clip_w
def _create_shadow_pipeline(
device: Any,
vert_module: Any,
frag_module: Any,
render_pass: Any,
ssbo_layout: Any,
) -> tuple[Any, Any]:
"""Create depth-only shadow pipeline.
Push constant: mat4 light_vp = 64 bytes.
Same vertex input as forward (32 bytes stride).
No colour attachment output.
"""
ffi = vk.ffi
# Push constant: 1x mat4 = 64 bytes
push_range = ffi.new("VkPushConstantRange*")
push_range.stageFlags = vk.VK_SHADER_STAGE_VERTEX_BIT
push_range.offset = 0
push_range.size = 64
# Pipeline layout with SSBO descriptor set
set_layouts = ffi.new("VkDescriptorSetLayout[1]", [ssbo_layout])
layout_ci = ffi.new("VkPipelineLayoutCreateInfo*")
layout_ci.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO
layout_ci.setLayoutCount = 1
layout_ci.pSetLayouts = set_layouts
layout_ci.pushConstantRangeCount = 1
layout_ci.pPushConstantRanges = push_range
layout_out = ffi.new("VkPipelineLayout*")
result = vk._vulkan._callApi(
vk._vulkan.lib.vkCreatePipelineLayout,
device, layout_ci, ffi.NULL, layout_out,
)
if result != vk.VK_SUCCESS:
raise RuntimeError(f"vkCreatePipelineLayout failed: {result}")
pipeline_layout = layout_out[0]
pi = ffi.new("VkGraphicsPipelineCreateInfo*")
pi.sType = vk.VK_STRUCTURE_TYPE_GRAPHICS_PIPELINE_CREATE_INFO
# Shader stages
stages = ffi.new("VkPipelineShaderStageCreateInfo[2]")
main_name = ffi.new("char[]", b"main")
stages[0].sType = vk.VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO
stages[0].stage = vk.VK_SHADER_STAGE_VERTEX_BIT
stages[0].module = vert_module
stages[0].pName = main_name
stages[1].sType = vk.VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO
stages[1].stage = vk.VK_SHADER_STAGE_FRAGMENT_BIT
stages[1].module = frag_module
stages[1].pName = main_name
pi.stageCount = 2
pi.pStages = stages
# Vertex input — same as forward (position + normal + uv = 32 bytes)
binding_desc = ffi.new("VkVertexInputBindingDescription*")
binding_desc.binding = 0
binding_desc.stride = 32
binding_desc.inputRate = vk.VK_VERTEX_INPUT_RATE_VERTEX
attr_descs = ffi.new("VkVertexInputAttributeDescription[3]")
attr_descs[0].location = 0
attr_descs[0].binding = 0
attr_descs[0].format = vk.VK_FORMAT_R32G32B32_SFLOAT
attr_descs[0].offset = 0
attr_descs[1].location = 1
attr_descs[1].binding = 0
attr_descs[1].format = vk.VK_FORMAT_R32G32B32_SFLOAT
attr_descs[1].offset = 12
attr_descs[2].location = 2
attr_descs[2].binding = 0
attr_descs[2].format = vk.VK_FORMAT_R32G32_SFLOAT
attr_descs[2].offset = 24
vi = ffi.new("VkPipelineVertexInputStateCreateInfo*")
vi.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_VERTEX_INPUT_STATE_CREATE_INFO
vi.vertexBindingDescriptionCount = 1
vi.pVertexBindingDescriptions = binding_desc
vi.vertexAttributeDescriptionCount = 3
vi.pVertexAttributeDescriptions = attr_descs
pi.pVertexInputState = vi
# Input assembly
ia = ffi.new("VkPipelineInputAssemblyStateCreateInfo*")
ia.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_INPUT_ASSEMBLY_STATE_CREATE_INFO
ia.topology = vk.VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST
pi.pInputAssemblyState = ia
# Viewport state (dynamic)
vps = ffi.new("VkPipelineViewportStateCreateInfo*")
vps.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_VIEWPORT_STATE_CREATE_INFO
vps.viewportCount = 1
viewport = ffi.new("VkViewport*")
viewport.width = float(SHADOW_MAP_SIZE)
viewport.height = float(SHADOW_MAP_SIZE)
viewport.maxDepth = 1.0
vps.pViewports = viewport
scissor = ffi.new("VkRect2D*")
scissor.extent.width = SHADOW_MAP_SIZE
scissor.extent.height = SHADOW_MAP_SIZE
vps.scissorCount = 1
vps.pScissors = scissor
pi.pViewportState = vps
# Rasterization — depth bias for shadow acne
rs = ffi.new("VkPipelineRasterizationStateCreateInfo*")
rs.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_RASTERIZATION_STATE_CREATE_INFO
rs.polygonMode = vk.VK_POLYGON_MODE_FILL
rs.lineWidth = 1.0
rs.cullMode = vk.VK_CULL_MODE_FRONT_BIT # Front-face culling reduces peter-panning
rs.frontFace = vk.VK_FRONT_FACE_CLOCKWISE
rs.depthBiasEnable = 1
rs.depthBiasConstantFactor = 1.25
rs.depthBiasSlopeFactor = 1.75
pi.pRasterizationState = rs
# Multisample
ms = ffi.new("VkPipelineMultisampleStateCreateInfo*")
ms.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_MULTISAMPLE_STATE_CREATE_INFO
ms.rasterizationSamples = vk.VK_SAMPLE_COUNT_1_BIT
pi.pMultisampleState = ms
# Depth stencil
dss = ffi.new("VkPipelineDepthStencilStateCreateInfo*")
dss.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_DEPTH_STENCIL_STATE_CREATE_INFO
dss.depthTestEnable = 1
dss.depthWriteEnable = 1
dss.depthCompareOp = vk.VK_COMPARE_OP_LESS_OR_EQUAL
pi.pDepthStencilState = dss
# No colour blend (depth-only, no colour attachment)
cb = ffi.new("VkPipelineColorBlendStateCreateInfo*")
cb.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_COLOR_BLEND_STATE_CREATE_INFO
cb.attachmentCount = 0
pi.pColorBlendState = cb
# Dynamic state
dyn_states = ffi.new("VkDynamicState[2]", [
vk.VK_DYNAMIC_STATE_VIEWPORT, vk.VK_DYNAMIC_STATE_SCISSOR,
])
ds = ffi.new("VkPipelineDynamicStateCreateInfo*")
ds.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_DYNAMIC_STATE_CREATE_INFO
ds.dynamicStateCount = 2
ds.pDynamicStates = dyn_states
pi.pDynamicState = ds
pi.layout = pipeline_layout
pi.renderPass = render_pass
pipeline_out = ffi.new("VkPipeline*")
result = vk._vulkan._callApi(
vk._vulkan.lib.vkCreateGraphicsPipelines,
device, ffi.NULL, 1, pi, ffi.NULL, pipeline_out,
)
if result != vk.VK_SUCCESS:
raise RuntimeError(f"vkCreateGraphicsPipelines failed: {result}")
log.debug("Shadow pipeline created")
return pipeline_out[0], pipeline_layout