"""Bloom post-processing — extract bright pixels, blur, composite."""
from __future__ import annotations
import logging
from typing import Any
import numpy as np
import vulkan as vk
from ..gpu.memory import create_sampler
from .pass_helpers import create_sampler_descriptor_pool, load_shader_modules
from .render_target import RenderTarget
__all__ = ["BloomPass"]
log = logging.getLogger(__name__)
[docs]
class BloomPass:
"""Two-pass bloom: bright-pixel extraction + separable Gaussian blur.
Operates at half resolution for performance. Ping-pongs between two
render targets for multi-pass blur.
"""
def __init__(self, engine: Any):
self._engine = engine
# Half-res render targets for ping-pong
self._rt_a: RenderTarget | None = None
self._rt_b: RenderTarget | None = None
# Pipelines
self._extract_pipeline: Any = None
self._extract_layout: Any = None
self._blur_pipeline: Any = None
self._blur_layout: Any = None
# Shader modules
self._vert_module: Any = None
self._extract_frag_module: Any = None
self._blur_frag_module: Any = None
# Descriptors (one per input texture: hdr, rt_a, rt_b)
self._sampler: Any = None
self._desc_pool: Any = None
self._desc_layout: Any = None
self._desc_hdr: Any = None # HDR input → extract
self._desc_a: Any = None # rt_a → blur
self._desc_b: Any = None # rt_b → blur
# Settings
self.threshold = 1.0
self.soft_knee = 0.5
self.blur_passes = 2 # number of H+V blur iterations
self._ready = False
@property
def bloom_image_view(self) -> Any:
"""Final bloom result image view (for tonemap compositing)."""
return self._rt_a.color_view if self._rt_a else None
[docs]
def setup(self, hdr_view: Any) -> None:
"""Initialize bloom targets, shaders, and pipelines."""
e = self._engine
device = e.ctx.device
w, h = e.extent
half_w, half_h = max(1, w // 2), max(1, h // 2)
# Half-res render targets (no depth needed)
self._rt_a = RenderTarget(
device, e.ctx.physical_device, half_w, half_h,
color_format=vk.VK_FORMAT_R16G16B16A16_SFLOAT,
use_depth=False,
)
self._rt_b = RenderTarget(
device, e.ctx.physical_device, half_w, half_h,
color_format=vk.VK_FORMAT_R16G16B16A16_SFLOAT,
use_depth=False,
)
# Sampler
self._sampler = create_sampler(device)
# Descriptor layout: single combined image sampler at binding 0
binding = vk.VkDescriptorSetLayoutBinding(
binding=0,
descriptorType=vk.VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
descriptorCount=1,
stageFlags=vk.VK_SHADER_STAGE_FRAGMENT_BIT,
)
self._desc_layout = vk.vkCreateDescriptorSetLayout(device,
vk.VkDescriptorSetLayoutCreateInfo(
bindingCount=1, pBindings=[binding],
), None)
# Pool + 3 descriptor sets
self._desc_pool, desc_sets = create_sampler_descriptor_pool(
device, self._desc_layout, max_sets=3,
)
self._desc_hdr = desc_sets[0]
self._desc_a = desc_sets[1]
self._desc_b = desc_sets[2]
# Write descriptors
self._write_descriptor(device, self._desc_hdr, hdr_view)
self._write_descriptor(device, self._desc_a, self._rt_a.color_view)
self._write_descriptor(device, self._desc_b, self._rt_b.color_view)
# Compile shaders (vert shared between extract and blur)
shader_dir = e.shader_dir
self._vert_module, self._extract_frag_module = load_shader_modules(
device, shader_dir, "tonemap.vert", "bloom_extract.frag",
)
# Blur frag loaded separately since vert module is shared
from ..gpu.pipeline import create_shader_module
from ..materials.shader_compiler import compile_shader
self._blur_frag_module = create_shader_module(device, compile_shader(shader_dir / "bloom_blur.frag"))
# Create pipelines
self._extract_pipeline, self._extract_layout = self._create_pipeline(
device, self._vert_module, self._extract_frag_module,
self._rt_a.render_pass, (half_w, half_h), pc_size=16, # vec2 + float + float
)
self._blur_pipeline, self._blur_layout = self._create_pipeline(
device, self._vert_module, self._blur_frag_module,
self._rt_a.render_pass, (half_w, half_h), pc_size=16, # vec2 + vec2
)
self._ready = True
log.debug("Bloom pass initialized (%dx%d half-res)", half_w, half_h)
def _write_descriptor(self, device: Any, desc_set: Any, image_view: Any) -> None:
"""Write an image sampler to a descriptor set."""
image_info = vk.VkDescriptorImageInfo(
sampler=self._sampler,
imageView=image_view,
imageLayout=vk.VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL,
)
vk.vkUpdateDescriptorSets(device, 1, [vk.VkWriteDescriptorSet(
dstSet=desc_set,
dstBinding=0,
dstArrayElement=0,
descriptorCount=1,
descriptorType=vk.VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
pImageInfo=[image_info],
)], 0, None)
def _create_pipeline(
self, device: Any, vert_mod: Any, frag_mod: Any,
render_pass: Any, extent: tuple[int, int], pc_size: int,
) -> tuple[Any, Any]:
"""Create a fullscreen post-processing pipeline."""
ffi = vk.ffi
# Push constant range
push_range = ffi.new("VkPushConstantRange*")
push_range.stageFlags = vk.VK_SHADER_STAGE_FRAGMENT_BIT
push_range.offset = 0
push_range.size = pc_size
# Pipeline layout
layout_ci = ffi.new("VkPipelineLayoutCreateInfo*")
layout_ci.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO
set_layouts = ffi.new("VkDescriptorSetLayout[1]", [self._desc_layout])
layout_ci.setLayoutCount = 1
layout_ci.pSetLayouts = set_layouts
layout_ci.pushConstantRangeCount = 1
layout_ci.pPushConstantRanges = push_range
layout_out = ffi.new("VkPipelineLayout*")
result = vk._vulkan._callApi(
vk._vulkan.lib.vkCreatePipelineLayout,
device, layout_ci, ffi.NULL, layout_out,
)
if result != vk.VK_SUCCESS:
raise RuntimeError(f"vkCreatePipelineLayout failed: {result}")
pipeline_layout = layout_out[0]
# Pipeline create info
pi = ffi.new("VkGraphicsPipelineCreateInfo*")
pi.sType = vk.VK_STRUCTURE_TYPE_GRAPHICS_PIPELINE_CREATE_INFO
# Shader stages
stages = ffi.new("VkPipelineShaderStageCreateInfo[2]")
main_name = ffi.new("char[]", b"main")
stages[0].sType = vk.VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO
stages[0].stage = vk.VK_SHADER_STAGE_VERTEX_BIT
stages[0].module = vert_mod
stages[0].pName = main_name
stages[1].sType = vk.VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO
stages[1].stage = vk.VK_SHADER_STAGE_FRAGMENT_BIT
stages[1].module = frag_mod
stages[1].pName = main_name
pi.stageCount = 2
pi.pStages = stages
# No vertex input
vi = ffi.new("VkPipelineVertexInputStateCreateInfo*")
vi.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_VERTEX_INPUT_STATE_CREATE_INFO
pi.pVertexInputState = vi
# Input assembly
ia = ffi.new("VkPipelineInputAssemblyStateCreateInfo*")
ia.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_INPUT_ASSEMBLY_STATE_CREATE_INFO
ia.topology = vk.VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST
pi.pInputAssemblyState = ia
# Viewport state (dynamic)
vps = ffi.new("VkPipelineViewportStateCreateInfo*")
vps.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_VIEWPORT_STATE_CREATE_INFO
vps.viewportCount = 1
viewport = ffi.new("VkViewport*")
viewport.width = float(extent[0])
viewport.height = float(extent[1])
viewport.maxDepth = 1.0
vps.pViewports = viewport
scissor = ffi.new("VkRect2D*")
scissor.extent.width = extent[0]
scissor.extent.height = extent[1]
vps.scissorCount = 1
vps.pScissors = scissor
pi.pViewportState = vps
# Rasterization
rs = ffi.new("VkPipelineRasterizationStateCreateInfo*")
rs.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_RASTERIZATION_STATE_CREATE_INFO
rs.polygonMode = vk.VK_POLYGON_MODE_FILL
rs.lineWidth = 1.0
rs.cullMode = vk.VK_CULL_MODE_NONE
pi.pRasterizationState = rs
# Multisample
ms = ffi.new("VkPipelineMultisampleStateCreateInfo*")
ms.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_MULTISAMPLE_STATE_CREATE_INFO
ms.rasterizationSamples = vk.VK_SAMPLE_COUNT_1_BIT
pi.pMultisampleState = ms
# No depth
dss = ffi.new("VkPipelineDepthStencilStateCreateInfo*")
dss.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_DEPTH_STENCIL_STATE_CREATE_INFO
dss.depthTestEnable = 0
dss.depthWriteEnable = 0
pi.pDepthStencilState = dss
# Colour blend (no blending)
cba = ffi.new("VkPipelineColorBlendAttachmentState*")
cba.colorWriteMask = (
vk.VK_COLOR_COMPONENT_R_BIT | vk.VK_COLOR_COMPONENT_G_BIT
| vk.VK_COLOR_COMPONENT_B_BIT | vk.VK_COLOR_COMPONENT_A_BIT
)
cb = ffi.new("VkPipelineColorBlendStateCreateInfo*")
cb.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_COLOR_BLEND_STATE_CREATE_INFO
cb.attachmentCount = 1
cb.pAttachments = cba
pi.pColorBlendState = cb
# Dynamic state
dyn_states = ffi.new("VkDynamicState[2]",
[vk.VK_DYNAMIC_STATE_VIEWPORT, vk.VK_DYNAMIC_STATE_SCISSOR])
ds = ffi.new("VkPipelineDynamicStateCreateInfo*")
ds.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_DYNAMIC_STATE_CREATE_INFO
ds.dynamicStateCount = 2
ds.pDynamicStates = dyn_states
pi.pDynamicState = ds
pi.layout = pipeline_layout
pi.renderPass = render_pass
pipeline_out = ffi.new("VkPipeline*")
result = vk._vulkan._callApi(
vk._vulkan.lib.vkCreateGraphicsPipelines,
device, ffi.NULL, 1, pi, ffi.NULL, pipeline_out,
)
if result != vk.VK_SUCCESS:
raise RuntimeError(f"vkCreateGraphicsPipelines failed: {result}")
return pipeline_out[0], pipeline_layout
def _begin_pass(self, cmd: Any, rt: RenderTarget) -> None:
"""Begin a render pass on a bloom render target."""
clear = vk.VkClearValue(color=vk.VkClearColorValue(float32=[0, 0, 0, 0]))
rp_begin = vk.VkRenderPassBeginInfo(
renderPass=rt.render_pass,
framebuffer=rt.framebuffer,
renderArea=vk.VkRect2D(
offset=vk.VkOffset2D(x=0, y=0),
extent=vk.VkExtent2D(width=rt.width, height=rt.height),
),
clearValueCount=1,
pClearValues=[clear],
)
vk.vkCmdBeginRenderPass(cmd, rp_begin, vk.VK_SUBPASS_CONTENTS_INLINE)
def _set_viewport(self, cmd: Any, rt: RenderTarget) -> None:
"""Set viewport and scissor for a render target."""
vk_vp = vk.VkViewport(
x=0.0, y=0.0,
width=float(rt.width), height=float(rt.height),
minDepth=0.0, maxDepth=1.0,
)
vk.vkCmdSetViewport(cmd, 0, 1, [vk_vp])
sc = vk.VkRect2D(
offset=vk.VkOffset2D(x=0, y=0),
extent=vk.VkExtent2D(width=rt.width, height=rt.height),
)
vk.vkCmdSetScissor(cmd, 0, 1, [sc])
[docs]
def render(self, cmd: Any) -> None:
"""Execute bloom: extract → blur H → blur V (repeated)."""
if not self._ready:
return
ffi = vk.ffi
rt_a = self._rt_a
rt_b = self._rt_b
tw = 1.0 / rt_a.width
th = 1.0 / rt_a.height
# --- Pass 1: Extract bright pixels (HDR → rt_a) ---
self._begin_pass(cmd, rt_a)
self._set_viewport(cmd, rt_a)
vk.vkCmdBindPipeline(cmd, vk.VK_PIPELINE_BIND_POINT_GRAPHICS, self._extract_pipeline)
vk.vkCmdBindDescriptorSets(
cmd, vk.VK_PIPELINE_BIND_POINT_GRAPHICS, self._extract_layout,
0, 1, [self._desc_hdr], 0, None,
)
# Push constants: texel_size(vec2) + threshold(float) + soft_knee(float)
pc = np.array([tw, th, self.threshold, self.soft_knee], dtype=np.float32)
cbuf = ffi.new("char[]", pc.tobytes())
vk._vulkan.lib.vkCmdPushConstants(
cmd, self._extract_layout, vk.VK_SHADER_STAGE_FRAGMENT_BIT, 0, 16, cbuf,
)
vk.vkCmdDraw(cmd, 3, 1, 0, 0)
vk.vkCmdEndRenderPass(cmd)
# --- Pass 2+: Separable Gaussian blur (ping-pong) ---
for _ in range(self.blur_passes):
# Horizontal: rt_a → rt_b
self._begin_pass(cmd, rt_b)
self._set_viewport(cmd, rt_b)
vk.vkCmdBindPipeline(cmd, vk.VK_PIPELINE_BIND_POINT_GRAPHICS, self._blur_pipeline)
vk.vkCmdBindDescriptorSets(
cmd, vk.VK_PIPELINE_BIND_POINT_GRAPHICS, self._blur_layout,
0, 1, [self._desc_a], 0, None,
)
pc_h = np.array([tw, th, 1.0, 0.0], dtype=np.float32)
cbuf = ffi.new("char[]", pc_h.tobytes())
vk._vulkan.lib.vkCmdPushConstants(
cmd, self._blur_layout, vk.VK_SHADER_STAGE_FRAGMENT_BIT, 0, 16, cbuf,
)
vk.vkCmdDraw(cmd, 3, 1, 0, 0)
vk.vkCmdEndRenderPass(cmd)
# Vertical: rt_b → rt_a
self._begin_pass(cmd, rt_a)
self._set_viewport(cmd, rt_a)
vk.vkCmdBindPipeline(cmd, vk.VK_PIPELINE_BIND_POINT_GRAPHICS, self._blur_pipeline)
vk.vkCmdBindDescriptorSets(
cmd, vk.VK_PIPELINE_BIND_POINT_GRAPHICS, self._blur_layout,
0, 1, [self._desc_b], 0, None,
)
pc_v = np.array([tw, th, 0.0, 1.0], dtype=np.float32)
cbuf = ffi.new("char[]", pc_v.tobytes())
vk._vulkan.lib.vkCmdPushConstants(
cmd, self._blur_layout, vk.VK_SHADER_STAGE_FRAGMENT_BIT, 0, 16, cbuf,
)
vk.vkCmdDraw(cmd, 3, 1, 0, 0)
vk.vkCmdEndRenderPass(cmd)
# Result is in rt_a (bloom_image_view)
[docs]
def resize(self, width: int, height: int, hdr_view: Any) -> None:
"""Recreate bloom targets for new dimensions."""
if not self._ready:
return
device = self._engine.ctx.device
# Destroy old targets and pipelines
if self._extract_pipeline:
vk.vkDestroyPipeline(device, self._extract_pipeline, None)
if self._extract_layout:
vk.vkDestroyPipelineLayout(device, self._extract_layout, None)
if self._blur_pipeline:
vk.vkDestroyPipeline(device, self._blur_pipeline, None)
if self._blur_layout:
vk.vkDestroyPipelineLayout(device, self._blur_layout, None)
if self._rt_a:
self._rt_a.destroy()
if self._rt_b:
self._rt_b.destroy()
half_w, half_h = max(1, width // 2), max(1, height // 2)
self._rt_a = RenderTarget(
device, self._engine.ctx.physical_device, half_w, half_h,
color_format=vk.VK_FORMAT_R16G16B16A16_SFLOAT, use_depth=False,
)
self._rt_b = RenderTarget(
device, self._engine.ctx.physical_device, half_w, half_h,
color_format=vk.VK_FORMAT_R16G16B16A16_SFLOAT, use_depth=False,
)
# Update descriptors
self._write_descriptor(device, self._desc_hdr, hdr_view)
self._write_descriptor(device, self._desc_a, self._rt_a.color_view)
self._write_descriptor(device, self._desc_b, self._rt_b.color_view)
# Recreate pipelines
self._extract_pipeline, self._extract_layout = self._create_pipeline(
device, self._vert_module, self._extract_frag_module,
self._rt_a.render_pass, (half_w, half_h), pc_size=16,
)
self._blur_pipeline, self._blur_layout = self._create_pipeline(
device, self._vert_module, self._blur_frag_module,
self._rt_a.render_pass, (half_w, half_h), pc_size=16,
)
[docs]
def cleanup(self) -> None:
"""Release all GPU resources."""
if not self._ready:
return
device = self._engine.ctx.device
if self._extract_pipeline:
vk.vkDestroyPipeline(device, self._extract_pipeline, None)
if self._extract_layout:
vk.vkDestroyPipelineLayout(device, self._extract_layout, None)
if self._blur_pipeline:
vk.vkDestroyPipeline(device, self._blur_pipeline, None)
if self._blur_layout:
vk.vkDestroyPipelineLayout(device, self._blur_layout, None)
if self._vert_module:
vk.vkDestroyShaderModule(device, self._vert_module, None)
if self._extract_frag_module:
vk.vkDestroyShaderModule(device, self._extract_frag_module, None)
if self._blur_frag_module:
vk.vkDestroyShaderModule(device, self._blur_frag_module, None)
if self._desc_pool:
vk.vkDestroyDescriptorPool(device, self._desc_pool, None)
if self._desc_layout:
vk.vkDestroyDescriptorSetLayout(device, self._desc_layout, None)
if self._sampler:
vk.vkDestroySampler(device, self._sampler, None)
if self._rt_a:
self._rt_a.destroy()
if self._rt_b:
self._rt_b.destroy()
self._ready = False