Source code for simvx.graphics.renderer.colour_grading

"""Colour grading pass — LUT-based colour correction via compute shader."""

from __future__ import annotations

import logging
from typing import Any

import numpy as np
import vulkan as vk

from ..gpu.memory import create_buffer
from ..gpu.pipeline import create_shader_module
from ..materials.shader_compiler import compile_shader

__all__ = ["ColourGradingPass"]

log = logging.getLogger(__name__)

# Push constant: vec4 adjustments(16) + vec4 temperature(16) + vec4 resolution(16) = 48 bytes
_PC_SIZE = 48


[docs] def generate_neutral_lut(size: int = 32) -> np.ndarray: """Generate an identity 3D LUT (no colour change).""" lut = np.zeros((size, size, size, 4), dtype=np.uint8) for b in range(size): for g in range(size): for r in range(size): lut[b, g, r] = [ int(r / (size - 1) * 255), int(g / (size - 1) * 255), int(b / (size - 1) * 255), 255, ] return lut
[docs] def generate_warm_lut(size: int = 32) -> np.ndarray: """Generate a warm-toned 3D LUT (shifted toward orange/amber).""" lut = generate_neutral_lut(size) for b in range(size): for g in range(size): for r in range(size): rf = r / (size - 1) gf = g / (size - 1) bf = b / (size - 1) # Warm shift: boost reds, slight green reduction, reduce blues rf = min(1.0, rf * 1.1 + 0.02) gf = gf * 0.95 + 0.01 bf = bf * 0.8 lut[b, g, r] = [int(rf * 255), int(gf * 255), int(bf * 255), 255] return lut
[docs] def generate_cool_lut(size: int = 32) -> np.ndarray: """Generate a cool-toned 3D LUT (shifted toward blue/teal).""" lut = generate_neutral_lut(size) for b in range(size): for g in range(size): for r in range(size): rf = r / (size - 1) gf = g / (size - 1) bf = b / (size - 1) # Cool shift: reduce reds, boost greens slightly, boost blues rf = rf * 0.85 gf = min(1.0, gf * 1.02 + 0.01) bf = min(1.0, bf * 1.15 + 0.02) lut[b, g, r] = [int(rf * 255), int(gf * 255), int(bf * 255), 255] return lut
[docs] def generate_vintage_lut(size: int = 32) -> np.ndarray: """Generate a vintage/desaturated warm 3D LUT.""" lut = generate_neutral_lut(size) for b in range(size): for g in range(size): for r in range(size): rf = r / (size - 1) gf = g / (size - 1) bf = b / (size - 1) # Desaturate toward luminance luma = 0.2126 * rf + 0.7152 * gf + 0.0722 * bf sat = 0.6 # reduced saturation rf = luma + (rf - luma) * sat gf = luma + (gf - luma) * sat bf = luma + (bf - luma) * sat # Warm tint rf = min(1.0, rf * 1.05 + 0.03) gf = gf * 0.95 bf = bf * 0.75 # Lifted blacks (fade effect) rf = rf * 0.9 + 0.05 gf = gf * 0.9 + 0.04 bf = bf * 0.9 + 0.06 lut[b, g, r] = [ int(min(1.0, rf) * 255), int(min(1.0, gf) * 255), int(min(1.0, bf) * 255), 255, ] return lut
def _kelvin_to_rgb_multipliers(kelvin: float) -> tuple[float, float]: """Approximate colour temperature as R and B multipliers (G stays at 1.0). Returns (r_mult, b_mult) relative to 6500K neutral. Uses simplified Planckian locus approximation. """ # Normalize around 6500K t = (kelvin - 6500.0) / 6500.0 if t > 0: # Warmer: boost red, reduce blue r = 1.0 + t * 0.15 b = 1.0 - t * 0.25 else: # Cooler: reduce red, boost blue r = 1.0 + t * 0.25 b = 1.0 - t * 0.15 return (max(0.5, min(1.5, r)), max(0.5, min(1.5, b)))
[docs] class ColourGradingPass: """Compute-based colour grading: LUT lookup + brightness/contrast/saturation/temperature. Operates in-place on the HDR colour image. Apply after fog, before tone mapping. """ def __init__(self, engine: Any): self._engine = engine self._ready = False # Compute pipeline self._pipeline: Any = None self._layout: Any = None self._module: Any = None # Descriptors self._desc_pool: Any = None self._desc_layout: Any = None self._desc_set: Any = None # LUT 3D texture self._lut_image: Any = None self._lut_memory: Any = None self._lut_view: Any = None self._lut_sampler: Any = None self._lut_size: int = 0 # Dimensions self._width: int = 0 self._height: int = 0 # Public settings self.enabled: bool = False self.brightness: float = 0.0 # -1 to 1 self.contrast: float = 1.0 # 0 to 2 self.saturation: float = 1.0 # 0 to 2 self.colour_temperature: float = 6500.0 # Kelvin, 1000-12000
[docs] def setup(self, width: int, height: int, color_view: Any) -> None: """Initialize colour grading pipeline and upload default neutral LUT.""" self._width = width self._height = height self._create_lut_sampler() self._upload_lut(generate_neutral_lut(32)) self._create_descriptors(color_view) self._create_pipeline() self._ready = True log.debug("Colour grading pass initialized (%dx%d)", width, height)
[docs] def set_lut(self, lut_data: np.ndarray) -> None: """Upload a new 3D LUT texture. Shape should be (size, size, size, 4) uint8.""" if not self._ready: return # Destroy old LUT device = self._engine.ctx.device if self._lut_view: vk.vkDestroyImageView(device, self._lut_view, None) if self._lut_image: vk.vkDestroyImage(device, self._lut_image, None) if self._lut_memory: vk.vkFreeMemory(device, self._lut_memory, None) self._upload_lut(lut_data) # Re-write LUT descriptor (binding 1) self._update_lut_descriptor()
def _create_lut_sampler(self) -> None: """Create trilinear sampler for 3D LUT texture.""" self._lut_sampler = vk.vkCreateSampler( self._engine.ctx.device, vk.VkSamplerCreateInfo( magFilter=vk.VK_FILTER_LINEAR, minFilter=vk.VK_FILTER_LINEAR, addressModeU=vk.VK_SAMPLER_ADDRESS_MODE_CLAMP_TO_EDGE, addressModeV=vk.VK_SAMPLER_ADDRESS_MODE_CLAMP_TO_EDGE, addressModeW=vk.VK_SAMPLER_ADDRESS_MODE_CLAMP_TO_EDGE, anisotropyEnable=vk.VK_FALSE, unnormalizedCoordinates=vk.VK_FALSE, mipmapMode=vk.VK_SAMPLER_MIPMAP_MODE_LINEAR, ), None, ) def _upload_lut(self, lut_data: np.ndarray) -> None: """Upload a 3D LUT as a VK_IMAGE_TYPE_3D texture.""" e = self._engine device = e.ctx.device size = lut_data.shape[0] self._lut_size = size pixel_data = np.ascontiguousarray(lut_data) # Staging buffer staging_buf, staging_mem = _create_staging_buffer(device, e.ctx.physical_device, pixel_data) # Create 3D image img_ci = vk.VkImageCreateInfo( imageType=vk.VK_IMAGE_TYPE_3D, format=vk.VK_FORMAT_R8G8B8A8_UNORM, extent=vk.VkExtent3D(width=size, height=size, depth=size), mipLevels=1, arrayLayers=1, samples=vk.VK_SAMPLE_COUNT_1_BIT, tiling=vk.VK_IMAGE_TILING_OPTIMAL, usage=vk.VK_IMAGE_USAGE_TRANSFER_DST_BIT | vk.VK_IMAGE_USAGE_SAMPLED_BIT, sharingMode=vk.VK_SHARING_MODE_EXCLUSIVE, initialLayout=vk.VK_IMAGE_LAYOUT_UNDEFINED, ) self._lut_image = vk.vkCreateImage(device, img_ci, None) mem_reqs = vk.vkGetImageMemoryRequirements(device, self._lut_image) mem_props = vk.vkGetPhysicalDeviceMemoryProperties(e.ctx.physical_device) mem_type = _find_memory_type(mem_props, mem_reqs.memoryTypeBits, vk.VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT) self._lut_memory = vk.vkAllocateMemory( device, vk.VkMemoryAllocateInfo( allocationSize=mem_reqs.size, memoryTypeIndex=mem_type, ), None, ) vk.vkBindImageMemory(device, self._lut_image, self._lut_memory, 0) # Transition UNDEFINED -> TRANSFER_DST from ..gpu.memory import begin_single_time_commands, end_single_time_commands cmd = begin_single_time_commands(device, e.ctx.command_pool) barrier = vk.VkImageMemoryBarrier( srcAccessMask=0, dstAccessMask=vk.VK_ACCESS_TRANSFER_WRITE_BIT, oldLayout=vk.VK_IMAGE_LAYOUT_UNDEFINED, newLayout=vk.VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, srcQueueFamilyIndex=vk.VK_QUEUE_FAMILY_IGNORED, dstQueueFamilyIndex=vk.VK_QUEUE_FAMILY_IGNORED, image=self._lut_image, subresourceRange=vk.VkImageSubresourceRange( aspectMask=vk.VK_IMAGE_ASPECT_COLOR_BIT, baseMipLevel=0, levelCount=1, baseArrayLayer=0, layerCount=1, ), ) vk.vkCmdPipelineBarrier( cmd, vk.VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT, vk.VK_PIPELINE_STAGE_TRANSFER_BIT, 0, 0, None, 0, None, 1, [barrier], ) # Copy staging -> image region = vk.VkBufferImageCopy( bufferOffset=0, bufferRowLength=0, bufferImageHeight=0, imageSubresource=vk.VkImageSubresourceLayers( aspectMask=vk.VK_IMAGE_ASPECT_COLOR_BIT, mipLevel=0, baseArrayLayer=0, layerCount=1, ), imageOffset=vk.VkOffset3D(x=0, y=0, z=0), imageExtent=vk.VkExtent3D(width=size, height=size, depth=size), ) vk.vkCmdCopyBufferToImage( cmd, staging_buf, self._lut_image, vk.VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, 1, [region], ) # Transition TRANSFER_DST -> SHADER_READ_ONLY barrier2 = vk.VkImageMemoryBarrier( srcAccessMask=vk.VK_ACCESS_TRANSFER_WRITE_BIT, dstAccessMask=vk.VK_ACCESS_SHADER_READ_BIT, oldLayout=vk.VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, newLayout=vk.VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL, srcQueueFamilyIndex=vk.VK_QUEUE_FAMILY_IGNORED, dstQueueFamilyIndex=vk.VK_QUEUE_FAMILY_IGNORED, image=self._lut_image, subresourceRange=vk.VkImageSubresourceRange( aspectMask=vk.VK_IMAGE_ASPECT_COLOR_BIT, baseMipLevel=0, levelCount=1, baseArrayLayer=0, layerCount=1, ), ) vk.vkCmdPipelineBarrier( cmd, vk.VK_PIPELINE_STAGE_TRANSFER_BIT, vk.VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, 0, None, 0, None, 1, [barrier2], ) end_single_time_commands(device, e.ctx.graphics_queue, e.ctx.command_pool, cmd) # Cleanup staging vk.vkDestroyBuffer(device, staging_buf, None) vk.vkFreeMemory(device, staging_mem, None) # Image view (3D) self._lut_view = vk.vkCreateImageView( device, vk.VkImageViewCreateInfo( image=self._lut_image, viewType=vk.VK_IMAGE_VIEW_TYPE_3D, format=vk.VK_FORMAT_R8G8B8A8_UNORM, subresourceRange=vk.VkImageSubresourceRange( aspectMask=vk.VK_IMAGE_ASPECT_COLOR_BIT, baseMipLevel=0, levelCount=1, baseArrayLayer=0, layerCount=1, ), ), None, ) def _create_descriptors(self, color_view: Any) -> None: """Create descriptor set: colour storage image (binding 0) + LUT sampler (binding 1).""" device = self._engine.ctx.device bindings = [ vk.VkDescriptorSetLayoutBinding( binding=0, descriptorType=vk.VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, descriptorCount=1, stageFlags=vk.VK_SHADER_STAGE_COMPUTE_BIT, ), vk.VkDescriptorSetLayoutBinding( binding=1, descriptorType=vk.VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, descriptorCount=1, stageFlags=vk.VK_SHADER_STAGE_COMPUTE_BIT, ), ] self._desc_layout = vk.vkCreateDescriptorSetLayout( device, vk.VkDescriptorSetLayoutCreateInfo(bindingCount=2, pBindings=bindings), None ) pool_sizes = [ vk.VkDescriptorPoolSize(type=vk.VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, descriptorCount=1), vk.VkDescriptorPoolSize(type=vk.VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, descriptorCount=1), ] self._desc_pool = vk.vkCreateDescriptorPool( device, vk.VkDescriptorPoolCreateInfo(maxSets=1, poolSizeCount=2, pPoolSizes=pool_sizes), None ) sets = vk.vkAllocateDescriptorSets( device, vk.VkDescriptorSetAllocateInfo( descriptorPool=self._desc_pool, descriptorSetCount=1, pSetLayouts=[self._desc_layout], ), ) self._desc_set = sets[0] self._write_descriptors(color_view) def _write_descriptors(self, color_view: Any) -> None: """Write colour storage image and LUT sampler to descriptor set.""" device = self._engine.ctx.device color_info = vk.VkDescriptorImageInfo( imageView=color_view, imageLayout=vk.VK_IMAGE_LAYOUT_GENERAL, ) lut_info = vk.VkDescriptorImageInfo( sampler=self._lut_sampler, imageView=self._lut_view, imageLayout=vk.VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL, ) vk.vkUpdateDescriptorSets( device, 2, [ vk.VkWriteDescriptorSet( dstSet=self._desc_set, dstBinding=0, dstArrayElement=0, descriptorCount=1, descriptorType=vk.VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, pImageInfo=[color_info], ), vk.VkWriteDescriptorSet( dstSet=self._desc_set, dstBinding=1, dstArrayElement=0, descriptorCount=1, descriptorType=vk.VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, pImageInfo=[lut_info], ), ], 0, None, ) def _update_lut_descriptor(self) -> None: """Update LUT descriptor (binding 1) after LUT replacement.""" lut_info = vk.VkDescriptorImageInfo( sampler=self._lut_sampler, imageView=self._lut_view, imageLayout=vk.VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL, ) vk.vkUpdateDescriptorSets( self._engine.ctx.device, 1, [ vk.VkWriteDescriptorSet( dstSet=self._desc_set, dstBinding=1, dstArrayElement=0, descriptorCount=1, descriptorType=vk.VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, pImageInfo=[lut_info], ) ], 0, None, ) def _create_pipeline(self) -> None: """Create colour grading compute pipeline.""" e = self._engine device = e.ctx.device ffi = vk.ffi spv = compile_shader(e.shader_dir / "color_grade.comp") self._module = create_shader_module(device, spv) push_range = ffi.new("VkPushConstantRange*") push_range.stageFlags = vk.VK_SHADER_STAGE_COMPUTE_BIT push_range.offset = 0 push_range.size = _PC_SIZE layout_ci = ffi.new("VkPipelineLayoutCreateInfo*") layout_ci.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO set_layouts = ffi.new("VkDescriptorSetLayout[1]", [self._desc_layout]) layout_ci.setLayoutCount = 1 layout_ci.pSetLayouts = set_layouts layout_ci.pushConstantRangeCount = 1 layout_ci.pPushConstantRanges = push_range layout_out = ffi.new("VkPipelineLayout*") result = vk._vulkan._callApi( vk._vulkan.lib.vkCreatePipelineLayout, device, layout_ci, ffi.NULL, layout_out, ) if result != vk.VK_SUCCESS: raise RuntimeError(f"vkCreatePipelineLayout (colour grading) failed: {result}") self._layout = layout_out[0] main_name = ffi.new("char[]", b"main") ci = ffi.new("VkComputePipelineCreateInfo*") ci.sType = vk.VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO ci.layout = self._layout stage = ffi.new("VkPipelineShaderStageCreateInfo*") stage.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO stage.stage = vk.VK_SHADER_STAGE_COMPUTE_BIT stage.module = self._module stage.pName = main_name ci.stage = stage[0] pipeline_out = ffi.new("VkPipeline*") result = vk._vulkan._callApi( vk._vulkan.lib.vkCreateComputePipelines, device, ffi.NULL, 1, ci, ffi.NULL, pipeline_out, ) if result != vk.VK_SUCCESS: raise RuntimeError(f"vkCreateComputePipelines (colour grading) failed: {result}") self._pipeline = pipeline_out[0]
[docs] def render(self, cmd: Any) -> None: """Dispatch colour grading compute shader. Call after fog, before tonemap. Args: cmd: Active command buffer (outside any render pass). """ if not self._ready or not self.enabled: return ffi = vk.ffi groups_x = (self._width + 7) // 8 groups_y = (self._height + 7) // 8 vk.vkCmdBindPipeline(cmd, vk.VK_PIPELINE_BIND_POINT_COMPUTE, self._pipeline) vk.vkCmdBindDescriptorSets( cmd, vk.VK_PIPELINE_BIND_POINT_COMPUTE, self._layout, 0, 1, [self._desc_set], 0, None, ) # Compute temperature multipliers r_mult, b_mult = _kelvin_to_rgb_multipliers(self.colour_temperature) # Pack push constants: 3 * vec4 = 48 bytes adjustments = np.array( [ self.brightness, self.contrast, self.saturation, float(self._lut_size), ], dtype=np.float32, ) temperature = np.array([r_mult, b_mult, 0.0, 0.0], dtype=np.float32) resolution = np.array( [ float(self._width), float(self._height), 1.0 / self._width, 1.0 / self._height, ], dtype=np.float32, ) pc_data = adjustments.tobytes() + temperature.tobytes() + resolution.tobytes() cbuf = ffi.new("char[]", pc_data) vk._vulkan.lib.vkCmdPushConstants( cmd, self._layout, vk.VK_SHADER_STAGE_COMPUTE_BIT, 0, _PC_SIZE, cbuf, ) vk.vkCmdDispatch(cmd, groups_x, groups_y, 1) # Barrier: colour grading write -> next pass read barrier = vk.VkMemoryBarrier( srcAccessMask=vk.VK_ACCESS_SHADER_WRITE_BIT, dstAccessMask=vk.VK_ACCESS_SHADER_READ_BIT | vk.VK_ACCESS_SHADER_WRITE_BIT, ) vk.vkCmdPipelineBarrier( cmd, vk.VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, vk.VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT | vk.VK_PIPELINE_STAGE_FRAGMENT_SHADER_BIT, 0, 1, [barrier], 0, None, 0, None, )
[docs] def resize(self, width: int, height: int, color_view: Any) -> None: """Update descriptors for new dimensions.""" if not self._ready: return self._width = width self._height = height self._write_descriptors(color_view)
[docs] def cleanup(self) -> None: """Release all GPU resources.""" if not self._ready: return device = self._engine.ctx.device if self._pipeline: vk.vkDestroyPipeline(device, self._pipeline, None) if self._layout: vk.vkDestroyPipelineLayout(device, self._layout, None) if self._module: vk.vkDestroyShaderModule(device, self._module, None) if self._desc_pool: vk.vkDestroyDescriptorPool(device, self._desc_pool, None) if self._desc_layout: vk.vkDestroyDescriptorSetLayout(device, self._desc_layout, None) if self._lut_view: vk.vkDestroyImageView(device, self._lut_view, None) if self._lut_image: vk.vkDestroyImage(device, self._lut_image, None) if self._lut_memory: vk.vkFreeMemory(device, self._lut_memory, None) if self._lut_sampler: vk.vkDestroySampler(device, self._lut_sampler, None) self._ready = False
def _find_memory_type(mem_props: Any, type_filter: int, properties: int) -> int: """Find a suitable memory type index.""" for i in range(mem_props.memoryTypeCount): if (type_filter & (1 << i)) and (mem_props.memoryTypes[i].propertyFlags & properties) == properties: return i raise RuntimeError("Failed to find suitable memory type") def _create_staging_buffer(device: Any, physical_device: Any, data: np.ndarray) -> tuple[Any, Any]: """Create and fill a host-visible staging buffer from numpy data.""" ffi = vk.ffi size = data.nbytes buf, mem = create_buffer( device, physical_device, size, vk.VK_BUFFER_USAGE_TRANSFER_SRC_BIT, vk.VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | vk.VK_MEMORY_PROPERTY_HOST_COHERENT_BIT, ) dst = vk.vkMapMemory(device, mem, 0, size, 0) ffi.memmove(dst, ffi.cast("void*", data.ctypes.data), size) vk.vkUnmapMemory(device, mem) return buf, mem