"""Colour grading pass — LUT-based colour correction via compute shader."""
from __future__ import annotations
import logging
from typing import Any
import numpy as np
import vulkan as vk
from ..gpu.memory import create_buffer
from ..gpu.pipeline import create_shader_module
from ..materials.shader_compiler import compile_shader
__all__ = ["ColourGradingPass"]
log = logging.getLogger(__name__)
# Push constant: vec4 adjustments(16) + vec4 temperature(16) + vec4 resolution(16) = 48 bytes
_PC_SIZE = 48
[docs]
def generate_neutral_lut(size: int = 32) -> np.ndarray:
"""Generate an identity 3D LUT (no colour change)."""
lut = np.zeros((size, size, size, 4), dtype=np.uint8)
for b in range(size):
for g in range(size):
for r in range(size):
lut[b, g, r] = [
int(r / (size - 1) * 255),
int(g / (size - 1) * 255),
int(b / (size - 1) * 255),
255,
]
return lut
[docs]
def generate_warm_lut(size: int = 32) -> np.ndarray:
"""Generate a warm-toned 3D LUT (shifted toward orange/amber)."""
lut = generate_neutral_lut(size)
for b in range(size):
for g in range(size):
for r in range(size):
rf = r / (size - 1)
gf = g / (size - 1)
bf = b / (size - 1)
# Warm shift: boost reds, slight green reduction, reduce blues
rf = min(1.0, rf * 1.1 + 0.02)
gf = gf * 0.95 + 0.01
bf = bf * 0.8
lut[b, g, r] = [int(rf * 255), int(gf * 255), int(bf * 255), 255]
return lut
[docs]
def generate_cool_lut(size: int = 32) -> np.ndarray:
"""Generate a cool-toned 3D LUT (shifted toward blue/teal)."""
lut = generate_neutral_lut(size)
for b in range(size):
for g in range(size):
for r in range(size):
rf = r / (size - 1)
gf = g / (size - 1)
bf = b / (size - 1)
# Cool shift: reduce reds, boost greens slightly, boost blues
rf = rf * 0.85
gf = min(1.0, gf * 1.02 + 0.01)
bf = min(1.0, bf * 1.15 + 0.02)
lut[b, g, r] = [int(rf * 255), int(gf * 255), int(bf * 255), 255]
return lut
[docs]
def generate_vintage_lut(size: int = 32) -> np.ndarray:
"""Generate a vintage/desaturated warm 3D LUT."""
lut = generate_neutral_lut(size)
for b in range(size):
for g in range(size):
for r in range(size):
rf = r / (size - 1)
gf = g / (size - 1)
bf = b / (size - 1)
# Desaturate toward luminance
luma = 0.2126 * rf + 0.7152 * gf + 0.0722 * bf
sat = 0.6 # reduced saturation
rf = luma + (rf - luma) * sat
gf = luma + (gf - luma) * sat
bf = luma + (bf - luma) * sat
# Warm tint
rf = min(1.0, rf * 1.05 + 0.03)
gf = gf * 0.95
bf = bf * 0.75
# Lifted blacks (fade effect)
rf = rf * 0.9 + 0.05
gf = gf * 0.9 + 0.04
bf = bf * 0.9 + 0.06
lut[b, g, r] = [
int(min(1.0, rf) * 255),
int(min(1.0, gf) * 255),
int(min(1.0, bf) * 255),
255,
]
return lut
def _kelvin_to_rgb_multipliers(kelvin: float) -> tuple[float, float]:
"""Approximate colour temperature as R and B multipliers (G stays at 1.0).
Returns (r_mult, b_mult) relative to 6500K neutral.
Uses simplified Planckian locus approximation.
"""
# Normalize around 6500K
t = (kelvin - 6500.0) / 6500.0
if t > 0:
# Warmer: boost red, reduce blue
r = 1.0 + t * 0.15
b = 1.0 - t * 0.25
else:
# Cooler: reduce red, boost blue
r = 1.0 + t * 0.25
b = 1.0 - t * 0.15
return (max(0.5, min(1.5, r)), max(0.5, min(1.5, b)))
[docs]
class ColourGradingPass:
"""Compute-based colour grading: LUT lookup + brightness/contrast/saturation/temperature.
Operates in-place on the HDR colour image. Apply after fog, before tone mapping.
"""
def __init__(self, engine: Any):
self._engine = engine
self._ready = False
# Compute pipeline
self._pipeline: Any = None
self._layout: Any = None
self._module: Any = None
# Descriptors
self._desc_pool: Any = None
self._desc_layout: Any = None
self._desc_set: Any = None
# LUT 3D texture
self._lut_image: Any = None
self._lut_memory: Any = None
self._lut_view: Any = None
self._lut_sampler: Any = None
self._lut_size: int = 0
# Dimensions
self._width: int = 0
self._height: int = 0
# Public settings
self.enabled: bool = False
self.brightness: float = 0.0 # -1 to 1
self.contrast: float = 1.0 # 0 to 2
self.saturation: float = 1.0 # 0 to 2
self.colour_temperature: float = 6500.0 # Kelvin, 1000-12000
[docs]
def setup(self, width: int, height: int, color_view: Any) -> None:
"""Initialize colour grading pipeline and upload default neutral LUT."""
self._width = width
self._height = height
self._create_lut_sampler()
self._upload_lut(generate_neutral_lut(32))
self._create_descriptors(color_view)
self._create_pipeline()
self._ready = True
log.debug("Colour grading pass initialized (%dx%d)", width, height)
[docs]
def set_lut(self, lut_data: np.ndarray) -> None:
"""Upload a new 3D LUT texture. Shape should be (size, size, size, 4) uint8."""
if not self._ready:
return
# Destroy old LUT
device = self._engine.ctx.device
if self._lut_view:
vk.vkDestroyImageView(device, self._lut_view, None)
if self._lut_image:
vk.vkDestroyImage(device, self._lut_image, None)
if self._lut_memory:
vk.vkFreeMemory(device, self._lut_memory, None)
self._upload_lut(lut_data)
# Re-write LUT descriptor (binding 1)
self._update_lut_descriptor()
def _create_lut_sampler(self) -> None:
"""Create trilinear sampler for 3D LUT texture."""
self._lut_sampler = vk.vkCreateSampler(
self._engine.ctx.device,
vk.VkSamplerCreateInfo(
magFilter=vk.VK_FILTER_LINEAR,
minFilter=vk.VK_FILTER_LINEAR,
addressModeU=vk.VK_SAMPLER_ADDRESS_MODE_CLAMP_TO_EDGE,
addressModeV=vk.VK_SAMPLER_ADDRESS_MODE_CLAMP_TO_EDGE,
addressModeW=vk.VK_SAMPLER_ADDRESS_MODE_CLAMP_TO_EDGE,
anisotropyEnable=vk.VK_FALSE,
unnormalizedCoordinates=vk.VK_FALSE,
mipmapMode=vk.VK_SAMPLER_MIPMAP_MODE_LINEAR,
),
None,
)
def _upload_lut(self, lut_data: np.ndarray) -> None:
"""Upload a 3D LUT as a VK_IMAGE_TYPE_3D texture."""
e = self._engine
device = e.ctx.device
size = lut_data.shape[0]
self._lut_size = size
pixel_data = np.ascontiguousarray(lut_data)
# Staging buffer
staging_buf, staging_mem = _create_staging_buffer(device, e.ctx.physical_device, pixel_data)
# Create 3D image
img_ci = vk.VkImageCreateInfo(
imageType=vk.VK_IMAGE_TYPE_3D,
format=vk.VK_FORMAT_R8G8B8A8_UNORM,
extent=vk.VkExtent3D(width=size, height=size, depth=size),
mipLevels=1,
arrayLayers=1,
samples=vk.VK_SAMPLE_COUNT_1_BIT,
tiling=vk.VK_IMAGE_TILING_OPTIMAL,
usage=vk.VK_IMAGE_USAGE_TRANSFER_DST_BIT | vk.VK_IMAGE_USAGE_SAMPLED_BIT,
sharingMode=vk.VK_SHARING_MODE_EXCLUSIVE,
initialLayout=vk.VK_IMAGE_LAYOUT_UNDEFINED,
)
self._lut_image = vk.vkCreateImage(device, img_ci, None)
mem_reqs = vk.vkGetImageMemoryRequirements(device, self._lut_image)
mem_props = vk.vkGetPhysicalDeviceMemoryProperties(e.ctx.physical_device)
mem_type = _find_memory_type(mem_props, mem_reqs.memoryTypeBits, vk.VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT)
self._lut_memory = vk.vkAllocateMemory(
device,
vk.VkMemoryAllocateInfo(
allocationSize=mem_reqs.size,
memoryTypeIndex=mem_type,
),
None,
)
vk.vkBindImageMemory(device, self._lut_image, self._lut_memory, 0)
# Transition UNDEFINED -> TRANSFER_DST
from ..gpu.memory import begin_single_time_commands, end_single_time_commands
cmd = begin_single_time_commands(device, e.ctx.command_pool)
barrier = vk.VkImageMemoryBarrier(
srcAccessMask=0,
dstAccessMask=vk.VK_ACCESS_TRANSFER_WRITE_BIT,
oldLayout=vk.VK_IMAGE_LAYOUT_UNDEFINED,
newLayout=vk.VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL,
srcQueueFamilyIndex=vk.VK_QUEUE_FAMILY_IGNORED,
dstQueueFamilyIndex=vk.VK_QUEUE_FAMILY_IGNORED,
image=self._lut_image,
subresourceRange=vk.VkImageSubresourceRange(
aspectMask=vk.VK_IMAGE_ASPECT_COLOR_BIT,
baseMipLevel=0,
levelCount=1,
baseArrayLayer=0,
layerCount=1,
),
)
vk.vkCmdPipelineBarrier(
cmd,
vk.VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT,
vk.VK_PIPELINE_STAGE_TRANSFER_BIT,
0,
0,
None,
0,
None,
1,
[barrier],
)
# Copy staging -> image
region = vk.VkBufferImageCopy(
bufferOffset=0,
bufferRowLength=0,
bufferImageHeight=0,
imageSubresource=vk.VkImageSubresourceLayers(
aspectMask=vk.VK_IMAGE_ASPECT_COLOR_BIT,
mipLevel=0,
baseArrayLayer=0,
layerCount=1,
),
imageOffset=vk.VkOffset3D(x=0, y=0, z=0),
imageExtent=vk.VkExtent3D(width=size, height=size, depth=size),
)
vk.vkCmdCopyBufferToImage(
cmd,
staging_buf,
self._lut_image,
vk.VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL,
1,
[region],
)
# Transition TRANSFER_DST -> SHADER_READ_ONLY
barrier2 = vk.VkImageMemoryBarrier(
srcAccessMask=vk.VK_ACCESS_TRANSFER_WRITE_BIT,
dstAccessMask=vk.VK_ACCESS_SHADER_READ_BIT,
oldLayout=vk.VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL,
newLayout=vk.VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL,
srcQueueFamilyIndex=vk.VK_QUEUE_FAMILY_IGNORED,
dstQueueFamilyIndex=vk.VK_QUEUE_FAMILY_IGNORED,
image=self._lut_image,
subresourceRange=vk.VkImageSubresourceRange(
aspectMask=vk.VK_IMAGE_ASPECT_COLOR_BIT,
baseMipLevel=0,
levelCount=1,
baseArrayLayer=0,
layerCount=1,
),
)
vk.vkCmdPipelineBarrier(
cmd,
vk.VK_PIPELINE_STAGE_TRANSFER_BIT,
vk.VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
0,
0,
None,
0,
None,
1,
[barrier2],
)
end_single_time_commands(device, e.ctx.graphics_queue, e.ctx.command_pool, cmd)
# Cleanup staging
vk.vkDestroyBuffer(device, staging_buf, None)
vk.vkFreeMemory(device, staging_mem, None)
# Image view (3D)
self._lut_view = vk.vkCreateImageView(
device,
vk.VkImageViewCreateInfo(
image=self._lut_image,
viewType=vk.VK_IMAGE_VIEW_TYPE_3D,
format=vk.VK_FORMAT_R8G8B8A8_UNORM,
subresourceRange=vk.VkImageSubresourceRange(
aspectMask=vk.VK_IMAGE_ASPECT_COLOR_BIT,
baseMipLevel=0,
levelCount=1,
baseArrayLayer=0,
layerCount=1,
),
),
None,
)
def _create_descriptors(self, color_view: Any) -> None:
"""Create descriptor set: colour storage image (binding 0) + LUT sampler (binding 1)."""
device = self._engine.ctx.device
bindings = [
vk.VkDescriptorSetLayoutBinding(
binding=0,
descriptorType=vk.VK_DESCRIPTOR_TYPE_STORAGE_IMAGE,
descriptorCount=1,
stageFlags=vk.VK_SHADER_STAGE_COMPUTE_BIT,
),
vk.VkDescriptorSetLayoutBinding(
binding=1,
descriptorType=vk.VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
descriptorCount=1,
stageFlags=vk.VK_SHADER_STAGE_COMPUTE_BIT,
),
]
self._desc_layout = vk.vkCreateDescriptorSetLayout(
device, vk.VkDescriptorSetLayoutCreateInfo(bindingCount=2, pBindings=bindings), None
)
pool_sizes = [
vk.VkDescriptorPoolSize(type=vk.VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, descriptorCount=1),
vk.VkDescriptorPoolSize(type=vk.VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, descriptorCount=1),
]
self._desc_pool = vk.vkCreateDescriptorPool(
device, vk.VkDescriptorPoolCreateInfo(maxSets=1, poolSizeCount=2, pPoolSizes=pool_sizes), None
)
sets = vk.vkAllocateDescriptorSets(
device,
vk.VkDescriptorSetAllocateInfo(
descriptorPool=self._desc_pool,
descriptorSetCount=1,
pSetLayouts=[self._desc_layout],
),
)
self._desc_set = sets[0]
self._write_descriptors(color_view)
def _write_descriptors(self, color_view: Any) -> None:
"""Write colour storage image and LUT sampler to descriptor set."""
device = self._engine.ctx.device
color_info = vk.VkDescriptorImageInfo(
imageView=color_view,
imageLayout=vk.VK_IMAGE_LAYOUT_GENERAL,
)
lut_info = vk.VkDescriptorImageInfo(
sampler=self._lut_sampler,
imageView=self._lut_view,
imageLayout=vk.VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL,
)
vk.vkUpdateDescriptorSets(
device,
2,
[
vk.VkWriteDescriptorSet(
dstSet=self._desc_set,
dstBinding=0,
dstArrayElement=0,
descriptorCount=1,
descriptorType=vk.VK_DESCRIPTOR_TYPE_STORAGE_IMAGE,
pImageInfo=[color_info],
),
vk.VkWriteDescriptorSet(
dstSet=self._desc_set,
dstBinding=1,
dstArrayElement=0,
descriptorCount=1,
descriptorType=vk.VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
pImageInfo=[lut_info],
),
],
0,
None,
)
def _update_lut_descriptor(self) -> None:
"""Update LUT descriptor (binding 1) after LUT replacement."""
lut_info = vk.VkDescriptorImageInfo(
sampler=self._lut_sampler,
imageView=self._lut_view,
imageLayout=vk.VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL,
)
vk.vkUpdateDescriptorSets(
self._engine.ctx.device,
1,
[
vk.VkWriteDescriptorSet(
dstSet=self._desc_set,
dstBinding=1,
dstArrayElement=0,
descriptorCount=1,
descriptorType=vk.VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
pImageInfo=[lut_info],
)
],
0,
None,
)
def _create_pipeline(self) -> None:
"""Create colour grading compute pipeline."""
e = self._engine
device = e.ctx.device
ffi = vk.ffi
spv = compile_shader(e.shader_dir / "color_grade.comp")
self._module = create_shader_module(device, spv)
push_range = ffi.new("VkPushConstantRange*")
push_range.stageFlags = vk.VK_SHADER_STAGE_COMPUTE_BIT
push_range.offset = 0
push_range.size = _PC_SIZE
layout_ci = ffi.new("VkPipelineLayoutCreateInfo*")
layout_ci.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO
set_layouts = ffi.new("VkDescriptorSetLayout[1]", [self._desc_layout])
layout_ci.setLayoutCount = 1
layout_ci.pSetLayouts = set_layouts
layout_ci.pushConstantRangeCount = 1
layout_ci.pPushConstantRanges = push_range
layout_out = ffi.new("VkPipelineLayout*")
result = vk._vulkan._callApi(
vk._vulkan.lib.vkCreatePipelineLayout,
device,
layout_ci,
ffi.NULL,
layout_out,
)
if result != vk.VK_SUCCESS:
raise RuntimeError(f"vkCreatePipelineLayout (colour grading) failed: {result}")
self._layout = layout_out[0]
main_name = ffi.new("char[]", b"main")
ci = ffi.new("VkComputePipelineCreateInfo*")
ci.sType = vk.VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO
ci.layout = self._layout
stage = ffi.new("VkPipelineShaderStageCreateInfo*")
stage.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO
stage.stage = vk.VK_SHADER_STAGE_COMPUTE_BIT
stage.module = self._module
stage.pName = main_name
ci.stage = stage[0]
pipeline_out = ffi.new("VkPipeline*")
result = vk._vulkan._callApi(
vk._vulkan.lib.vkCreateComputePipelines,
device,
ffi.NULL,
1,
ci,
ffi.NULL,
pipeline_out,
)
if result != vk.VK_SUCCESS:
raise RuntimeError(f"vkCreateComputePipelines (colour grading) failed: {result}")
self._pipeline = pipeline_out[0]
[docs]
def render(self, cmd: Any) -> None:
"""Dispatch colour grading compute shader. Call after fog, before tonemap.
Args:
cmd: Active command buffer (outside any render pass).
"""
if not self._ready or not self.enabled:
return
ffi = vk.ffi
groups_x = (self._width + 7) // 8
groups_y = (self._height + 7) // 8
vk.vkCmdBindPipeline(cmd, vk.VK_PIPELINE_BIND_POINT_COMPUTE, self._pipeline)
vk.vkCmdBindDescriptorSets(
cmd,
vk.VK_PIPELINE_BIND_POINT_COMPUTE,
self._layout,
0,
1,
[self._desc_set],
0,
None,
)
# Compute temperature multipliers
r_mult, b_mult = _kelvin_to_rgb_multipliers(self.colour_temperature)
# Pack push constants: 3 * vec4 = 48 bytes
adjustments = np.array(
[
self.brightness,
self.contrast,
self.saturation,
float(self._lut_size),
],
dtype=np.float32,
)
temperature = np.array([r_mult, b_mult, 0.0, 0.0], dtype=np.float32)
resolution = np.array(
[
float(self._width),
float(self._height),
1.0 / self._width,
1.0 / self._height,
],
dtype=np.float32,
)
pc_data = adjustments.tobytes() + temperature.tobytes() + resolution.tobytes()
cbuf = ffi.new("char[]", pc_data)
vk._vulkan.lib.vkCmdPushConstants(
cmd,
self._layout,
vk.VK_SHADER_STAGE_COMPUTE_BIT,
0,
_PC_SIZE,
cbuf,
)
vk.vkCmdDispatch(cmd, groups_x, groups_y, 1)
# Barrier: colour grading write -> next pass read
barrier = vk.VkMemoryBarrier(
srcAccessMask=vk.VK_ACCESS_SHADER_WRITE_BIT,
dstAccessMask=vk.VK_ACCESS_SHADER_READ_BIT | vk.VK_ACCESS_SHADER_WRITE_BIT,
)
vk.vkCmdPipelineBarrier(
cmd,
vk.VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
vk.VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT | vk.VK_PIPELINE_STAGE_FRAGMENT_SHADER_BIT,
0,
1,
[barrier],
0,
None,
0,
None,
)
[docs]
def resize(self, width: int, height: int, color_view: Any) -> None:
"""Update descriptors for new dimensions."""
if not self._ready:
return
self._width = width
self._height = height
self._write_descriptors(color_view)
[docs]
def cleanup(self) -> None:
"""Release all GPU resources."""
if not self._ready:
return
device = self._engine.ctx.device
if self._pipeline:
vk.vkDestroyPipeline(device, self._pipeline, None)
if self._layout:
vk.vkDestroyPipelineLayout(device, self._layout, None)
if self._module:
vk.vkDestroyShaderModule(device, self._module, None)
if self._desc_pool:
vk.vkDestroyDescriptorPool(device, self._desc_pool, None)
if self._desc_layout:
vk.vkDestroyDescriptorSetLayout(device, self._desc_layout, None)
if self._lut_view:
vk.vkDestroyImageView(device, self._lut_view, None)
if self._lut_image:
vk.vkDestroyImage(device, self._lut_image, None)
if self._lut_memory:
vk.vkFreeMemory(device, self._lut_memory, None)
if self._lut_sampler:
vk.vkDestroySampler(device, self._lut_sampler, None)
self._ready = False
def _find_memory_type(mem_props: Any, type_filter: int, properties: int) -> int:
"""Find a suitable memory type index."""
for i in range(mem_props.memoryTypeCount):
if (type_filter & (1 << i)) and (mem_props.memoryTypes[i].propertyFlags & properties) == properties:
return i
raise RuntimeError("Failed to find suitable memory type")
def _create_staging_buffer(device: Any, physical_device: Any, data: np.ndarray) -> tuple[Any, Any]:
"""Create and fill a host-visible staging buffer from numpy data."""
ffi = vk.ffi
size = data.nbytes
buf, mem = create_buffer(
device,
physical_device,
size,
vk.VK_BUFFER_USAGE_TRANSFER_SRC_BIT,
vk.VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | vk.VK_MEMORY_PROPERTY_HOST_COHERENT_BIT,
)
dst = vk.vkMapMemory(device, mem, 0, size, 0)
ffi.memmove(dst, ffi.cast("void*", data.ctypes.data), size)
vk.vkUnmapMemory(device, mem)
return buf, mem