"""Custom shader material system — user-facing API for custom GLSL shaders.
Provides ShaderMaterial for per-object custom shaders, UniformBuffer for GPU-side
uniform data, and ShaderMaterialManager for pipeline caching and hot-reload.
"""
from __future__ import annotations
import logging
import struct
import tempfile
from pathlib import Path
from typing import Any
import numpy as np
import vulkan as vk
from .shader_compiler import compile_shader, resolve_includes
__all__ = ["ShaderMaterial", "UniformBuffer", "ShaderMaterialManager"]
log = logging.getLogger(__name__)
# Mapping from Python/numpy types to GLSL uniform metadata
_UNIFORM_FORMATS: dict[str, tuple[str, int]] = {
"float": ("f", 4),
"int": ("i", 4),
"uint": ("I", 4),
"vec2": ("2f", 8),
"vec3": ("3f", 12),
"vec4": ("4f", 16),
"ivec2": ("2i", 8),
"ivec3": ("3i", 12),
"ivec4": ("4i", 16),
"mat4": ("16f", 64),
}
def _infer_uniform_type(value: Any) -> str:
"""Infer the GLSL uniform type from a Python value."""
if isinstance(value, int | np.integer):
return "int"
if isinstance(value, float | np.floating):
return "float"
if isinstance(value, np.ndarray):
if value.shape == (4, 4):
return "mat4"
size = value.size
return {2: "vec2", 3: "vec3", 4: "vec4"}.get(size, "float")
if isinstance(value, tuple | list):
n = len(value)
return {2: "vec2", 3: "vec3", 4: "vec4"}.get(n, "float")
return "float"
def _pack_uniform(value: Any, utype: str) -> bytes:
"""Pack a uniform value into bytes matching the GLSL layout."""
fmt, expected_size = _UNIFORM_FORMATS.get(utype, ("f", 4))
if isinstance(value, np.ndarray):
flat = value.astype(np.float32).ravel()
return flat.tobytes()[:expected_size]
if isinstance(value, tuple | list):
return struct.pack(fmt, *value)
if isinstance(value, int | np.integer):
return struct.pack("i", int(value))
if isinstance(value, float | np.floating):
return struct.pack("f", float(value))
return struct.pack("f", float(value))
def _align_to(offset: int, alignment: int) -> int:
"""Round offset up to the next multiple of alignment (std140 rules)."""
return (offset + alignment - 1) & ~(alignment - 1)
def _std140_alignment(utype: str) -> int:
"""Return std140 base alignment for a given GLSL type."""
if utype in ("float", "int", "uint"):
return 4
if utype in ("vec2", "ivec2"):
return 8
if utype in ("vec3", "ivec3", "vec4", "ivec4"):
return 16
if utype == "mat4":
return 16
return 4
[docs]
class ShaderMaterial:
"""User-facing custom shader material.
Allows using custom GLSL vertex/fragment shaders with user-defined uniforms.
Works alongside the engine's existing uber-shader pipeline by creating its own
separate Vulkan pipeline.
Example::
mat = ShaderMaterial(
vertex_path="shaders/wave.vert",
fragment_path="shaders/gradient.frag",
)
mat.set_uniform("time", 0.0)
mat.set_uniform("colour", (1.0, 0.5, 0.2, 1.0))
Or with inline source::
mat = ShaderMaterial(
vertex_source=\"\"\"
#version 450
layout(location=0) in vec3 pos;
void main() { gl_Position = vec4(pos, 1.0); }
\"\"\",
fragment_source=\"\"\"
#version 450
layout(location=0) out vec4 out_color;
void main() { out_color = vec4(1.0, 0.0, 0.0, 1.0); }
\"\"\",
)
"""
def __init__(
self,
vertex_path: str | Path | None = None,
fragment_path: str | Path | None = None,
*,
vertex_source: str | None = None,
fragment_source: str | None = None,
language: str = "glsl",
) -> None:
self._vertex_path = Path(vertex_path) if vertex_path else None
self._fragment_path = Path(fragment_path) if fragment_path else None
self.vertex_source = vertex_source
self.fragment_source = fragment_source
self.language = language
self._uniforms: dict[str, Any] = {}
self._uniform_types: dict[str, str] = {}
self._vert_module: Any = None
self._frag_module: Any = None
self._is_compiled = False
self._vert_mtime: float = 0.0
self._frag_mtime: float = 0.0
@property
def is_compiled(self) -> bool:
"""Whether shaders have been compiled to SPIR-V and loaded."""
return self._is_compiled
@property
def uniforms(self) -> dict[str, Any]:
"""All current uniform values."""
return dict(self._uniforms)
[docs]
def compile(self, device: Any, shader_dir: Path | None = None) -> None:
"""Compile shaders to SPIR-V and create Vulkan shader modules.
Uses file paths if provided, otherwise writes inline source to temp files
for compilation via glslc.
Args:
device: Vulkan logical device handle.
shader_dir: Base directory for resolving relative shader paths and includes.
"""
from ..gpu.pipeline import create_shader_module
base_dir = shader_dir or Path.cwd()
# Compile vertex shader
vert_spv = self._compile_stage("vertex", base_dir)
self._vert_module = create_shader_module(device, vert_spv)
# Compile fragment shader
frag_spv = self._compile_stage("fragment", base_dir)
self._frag_module = create_shader_module(device, frag_spv)
self._is_compiled = True
log.debug("ShaderMaterial compiled successfully")
def _compile_stage(self, stage: str, base_dir: Path) -> Path:
"""Compile a single shader stage, handling paths vs inline source."""
is_vertex = stage == "vertex"
path = self._vertex_path if is_vertex else self._fragment_path
source = self.vertex_source if is_vertex else self.fragment_source
ext = ".vert" if is_vertex else ".frag"
if path is not None:
resolved = path if path.is_absolute() else base_dir / path
if not resolved.exists():
raise FileNotFoundError(f"Shader file not found: {resolved}")
# Process includes
raw_source = resolved.read_text()
processed = resolve_includes(raw_source, resolved.parent)
# Write processed source to temp file for glslc
tmp = Path(tempfile.mktemp(suffix=ext))
tmp.write_text(processed)
try:
spv = compile_shader(tmp)
finally:
tmp.unlink(missing_ok=True)
# Track mtime for hot-reload
if is_vertex:
self._vert_mtime = resolved.stat().st_mtime
else:
self._frag_mtime = resolved.stat().st_mtime
return spv
if source is not None:
processed = resolve_includes(source, base_dir)
tmp = Path(tempfile.mktemp(suffix=ext))
tmp.write_text(processed)
try:
spv = compile_shader(tmp)
finally:
tmp.unlink(missing_ok=True)
return spv
raise ValueError(f"No {stage} shader source or path provided")
[docs]
def get_pipeline_key(self) -> tuple:
"""Return a hashable key unique to this shader combination.
Used for pipeline caching in ShaderMaterialManager.
"""
vert_key: str | None = None
frag_key: str | None = None
if self._vertex_path:
vert_key = str(self._vertex_path.resolve())
elif self.vertex_source:
vert_key = self.vertex_source
if self._fragment_path:
frag_key = str(self._fragment_path.resolve())
elif self.fragment_source:
frag_key = self.fragment_source
return (vert_key, frag_key)
[docs]
def has_source_changed(self) -> bool:
"""Check if shader source files have been modified since last compile."""
if self._vertex_path and self._vertex_path.exists():
if self._vertex_path.stat().st_mtime > self._vert_mtime:
return True
if self._fragment_path and self._fragment_path.exists():
if self._fragment_path.stat().st_mtime > self._frag_mtime:
return True
return False
[docs]
def cleanup(self, device: Any) -> None:
"""Destroy Vulkan shader modules."""
if self._vert_module:
vk.vkDestroyShaderModule(device, self._vert_module, None)
self._vert_module = None
if self._frag_module:
vk.vkDestroyShaderModule(device, self._frag_module, None)
self._frag_module = None
self._is_compiled = False
def _create_custom_pipeline(
device: Any,
vert_module: Any,
frag_module: Any,
render_pass: Any,
extent: tuple[int, int],
ssbo_layout: Any,
uniform_layout: Any | None = None,
texture_layout: Any | None = None,
) -> tuple[Any, Any]:
"""Create a Vulkan graphics pipeline for a custom shader.
Vertex format: position(vec3) + normal(vec3) + uv(vec2) = 32 bytes stride.
Push constants: view (mat4) + proj (mat4) = 128 bytes.
Descriptor sets: 0=SSBOs, 1=custom uniforms (optional), 2=textures (optional).
Returns (pipeline, pipeline_layout).
"""
ffi = vk.ffi
# Push constant range: 2x mat4 = 128 bytes
push_range = ffi.new("VkPushConstantRange*")
push_range.stageFlags = vk.VK_SHADER_STAGE_VERTEX_BIT | vk.VK_SHADER_STAGE_FRAGMENT_BIT
push_range.offset = 0
push_range.size = 128
# Collect descriptor set layouts
layouts = [ssbo_layout]
if uniform_layout:
layouts.append(uniform_layout)
if texture_layout:
layouts.append(texture_layout)
set_layouts = ffi.new(f"VkDescriptorSetLayout[{len(layouts)}]", layouts)
layout_ci = ffi.new("VkPipelineLayoutCreateInfo*")
layout_ci.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO
layout_ci.setLayoutCount = len(layouts)
layout_ci.pSetLayouts = set_layouts
layout_ci.pushConstantRangeCount = 1
layout_ci.pPushConstantRanges = push_range
layout_out = ffi.new("VkPipelineLayout*")
result = vk._vulkan._callApi(
vk._vulkan.lib.vkCreatePipelineLayout,
device,
layout_ci,
ffi.NULL,
layout_out,
)
if result != vk.VK_SUCCESS:
raise RuntimeError(f"vkCreatePipelineLayout failed: {result}")
pipeline_layout = layout_out[0]
# Build pipeline create info
pi = ffi.new("VkGraphicsPipelineCreateInfo*")
pi.sType = vk.VK_STRUCTURE_TYPE_GRAPHICS_PIPELINE_CREATE_INFO
# Shader stages
stages = ffi.new("VkPipelineShaderStageCreateInfo[2]")
main_name = ffi.new("char[]", b"main")
stages[0].sType = vk.VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO
stages[0].stage = vk.VK_SHADER_STAGE_VERTEX_BIT
stages[0].module = vert_module
stages[0].pName = main_name
stages[1].sType = vk.VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO
stages[1].stage = vk.VK_SHADER_STAGE_FRAGMENT_BIT
stages[1].module = frag_module
stages[1].pName = main_name
pi.stageCount = 2
pi.pStages = stages
# Vertex input: position(vec3) + normal(vec3) + uv(vec2) = 32 bytes
binding_desc = ffi.new("VkVertexInputBindingDescription*")
binding_desc.binding = 0
binding_desc.stride = 32
binding_desc.inputRate = vk.VK_VERTEX_INPUT_RATE_VERTEX
attr_descs = ffi.new("VkVertexInputAttributeDescription[3]")
attr_descs[0].location = 0
attr_descs[0].binding = 0
attr_descs[0].format = vk.VK_FORMAT_R32G32B32_SFLOAT
attr_descs[0].offset = 0
attr_descs[1].location = 1
attr_descs[1].binding = 0
attr_descs[1].format = vk.VK_FORMAT_R32G32B32_SFLOAT
attr_descs[1].offset = 12
attr_descs[2].location = 2
attr_descs[2].binding = 0
attr_descs[2].format = vk.VK_FORMAT_R32G32_SFLOAT
attr_descs[2].offset = 24
vi = ffi.new("VkPipelineVertexInputStateCreateInfo*")
vi.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_VERTEX_INPUT_STATE_CREATE_INFO
vi.vertexBindingDescriptionCount = 1
vi.pVertexBindingDescriptions = binding_desc
vi.vertexAttributeDescriptionCount = 3
vi.pVertexAttributeDescriptions = attr_descs
pi.pVertexInputState = vi
# Input assembly
ia = ffi.new("VkPipelineInputAssemblyStateCreateInfo*")
ia.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_INPUT_ASSEMBLY_STATE_CREATE_INFO
ia.topology = vk.VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST
pi.pInputAssemblyState = ia
# Viewport state
vps = ffi.new("VkPipelineViewportStateCreateInfo*")
vps.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_VIEWPORT_STATE_CREATE_INFO
vps.viewportCount = 1
viewport = ffi.new("VkViewport*")
viewport.width = float(extent[0])
viewport.height = float(extent[1])
viewport.maxDepth = 1.0
vps.pViewports = viewport
scissor = ffi.new("VkRect2D*")
scissor.extent.width = extent[0]
scissor.extent.height = extent[1]
vps.scissorCount = 1
vps.pScissors = scissor
pi.pViewportState = vps
# Rasterization
rs = ffi.new("VkPipelineRasterizationStateCreateInfo*")
rs.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_RASTERIZATION_STATE_CREATE_INFO
rs.polygonMode = vk.VK_POLYGON_MODE_FILL
rs.lineWidth = 1.0
rs.cullMode = vk.VK_CULL_MODE_BACK_BIT
rs.frontFace = vk.VK_FRONT_FACE_CLOCKWISE
pi.pRasterizationState = rs
# Multisample
ms = ffi.new("VkPipelineMultisampleStateCreateInfo*")
ms.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_MULTISAMPLE_STATE_CREATE_INFO
ms.rasterizationSamples = vk.VK_SAMPLE_COUNT_1_BIT
pi.pMultisampleState = ms
# Depth stencil
dss = ffi.new("VkPipelineDepthStencilStateCreateInfo*")
dss.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_DEPTH_STENCIL_STATE_CREATE_INFO
dss.depthTestEnable = 1
dss.depthWriteEnable = 1
dss.depthCompareOp = vk.VK_COMPARE_OP_LESS
pi.pDepthStencilState = dss
# Colour blend
cba = ffi.new("VkPipelineColorBlendAttachmentState*")
cba.colorWriteMask = (
vk.VK_COLOR_COMPONENT_R_BIT
| vk.VK_COLOR_COMPONENT_G_BIT
| vk.VK_COLOR_COMPONENT_B_BIT
| vk.VK_COLOR_COMPONENT_A_BIT
)
cb = ffi.new("VkPipelineColorBlendStateCreateInfo*")
cb.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_COLOR_BLEND_STATE_CREATE_INFO
cb.attachmentCount = 1
cb.pAttachments = cba
pi.pColorBlendState = cb
# Dynamic state
dyn_states = ffi.new("VkDynamicState[2]", [vk.VK_DYNAMIC_STATE_VIEWPORT, vk.VK_DYNAMIC_STATE_SCISSOR])
ds = ffi.new("VkPipelineDynamicStateCreateInfo*")
ds.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_DYNAMIC_STATE_CREATE_INFO
ds.dynamicStateCount = 2
ds.pDynamicStates = dyn_states
pi.pDynamicState = ds
pi.layout = pipeline_layout
pi.renderPass = render_pass
pipeline_out = ffi.new("VkPipeline*")
result = vk._vulkan._callApi(
vk._vulkan.lib.vkCreateGraphicsPipelines,
device,
ffi.NULL,
1,
pi,
ffi.NULL,
pipeline_out,
)
if result != vk.VK_SUCCESS:
raise RuntimeError(f"vkCreateGraphicsPipelines failed: {result}")
pipeline = pipeline_out[0]
log.debug("Custom shader pipeline created")
return pipeline, pipeline_layout
[docs]
class ShaderMaterialManager:
"""Caches compiled pipelines by shader combination and manages hot-reload.
Tracks all registered ShaderMaterial instances and their compiled pipelines.
Pipelines are cached by the shader source/path combination so that multiple
objects sharing the same shaders reuse one pipeline.
Example::
manager = ShaderMaterialManager()
pipeline, layout = manager.get_or_create_pipeline(
material, device, physical_device, render_pass, extent, ssbo_layout,
)
"""
def __init__(self) -> None:
self._pipeline_cache: dict[tuple, tuple[Any, Any]] = {} # key -> (pipeline, layout)
self._materials: list[ShaderMaterial] = []
self._uniform_buffers: dict[int, UniformBuffer] = {} # id(material) -> UniformBuffer
self._device: Any = None
self._physical_device: Any = None
[docs]
def register_material(self, material: ShaderMaterial) -> None:
"""Track a ShaderMaterial for hot-reload monitoring."""
if material not in self._materials:
self._materials.append(material)
[docs]
def get_or_create_pipeline(
self,
material: ShaderMaterial,
device: Any,
physical_device: Any,
render_pass: Any,
extent: tuple[int, int],
ssbo_layout: Any,
texture_layout: Any | None = None,
shader_dir: Path | None = None,
) -> tuple[Any, Any]:
"""Get a cached pipeline for this material, or compile and create one.
Args:
material: The ShaderMaterial to get/create a pipeline for.
device: Vulkan logical device.
physical_device: Vulkan physical device.
render_pass: Vulkan render pass.
extent: Swapchain extent (width, height).
ssbo_layout: Descriptor set layout for SSBOs (set 0).
texture_layout: Optional texture descriptor layout.
shader_dir: Base directory for shader file resolution.
Returns:
Tuple of (VkPipeline, VkPipelineLayout).
"""
self._device = device
self._physical_device = physical_device
key = material.get_pipeline_key()
if key in self._pipeline_cache:
return self._pipeline_cache[key]
# Compile if needed
if not material.is_compiled:
material.compile(device, shader_dir)
self.register_material(material)
# Create uniform buffer if material has uniforms
uniform_layout = None
if material.uniforms:
ubo = UniformBuffer()
ubo.create(device, physical_device)
self._uniform_buffers[id(material)] = ubo
uniform_layout = ubo.get_descriptor_layout()
# Create pipeline
pipeline, layout = _create_custom_pipeline(
device,
material._vert_module,
material._frag_module,
render_pass,
extent,
ssbo_layout,
uniform_layout=uniform_layout,
texture_layout=texture_layout,
)
self._pipeline_cache[key] = (pipeline, layout)
log.debug("Cached custom pipeline for key=%s", key)
return pipeline, layout
[docs]
def check_hot_reload(
self,
device: Any,
physical_device: Any,
render_pass: Any,
extent: tuple[int, int],
ssbo_layout: Any,
texture_layout: Any | None = None,
shader_dir: Path | None = None,
) -> list[ShaderMaterial]:
"""Check all registered materials for source file changes and recompile.
Returns a list of materials that were recompiled.
"""
recompiled = []
for material in self._materials:
if not material.has_source_changed():
continue
key = material.get_pipeline_key()
log.info("Hot-reloading shader: %s", key)
# Destroy old pipeline
old = self._pipeline_cache.pop(key, None)
if old:
vk.vkDestroyPipeline(device, old[0], None)
vk.vkDestroyPipelineLayout(device, old[1], None)
# Destroy old shader modules and recompile
material.cleanup(device)
try:
material.compile(device, shader_dir)
except Exception:
log.exception("Hot-reload compilation failed for %s", key)
continue
# Recreate pipeline
uniform_layout = None
ubo = self._uniform_buffers.get(id(material))
if ubo:
uniform_layout = ubo.get_descriptor_layout()
pipeline, layout = _create_custom_pipeline(
device,
material._vert_module,
material._frag_module,
render_pass,
extent,
ssbo_layout,
uniform_layout=uniform_layout,
texture_layout=texture_layout,
)
self._pipeline_cache[key] = (pipeline, layout)
recompiled.append(material)
return recompiled
[docs]
def cleanup(self, device: Any) -> None:
"""Destroy all cached pipelines, uniform buffers, and shader modules."""
for pipeline, layout in self._pipeline_cache.values():
vk.vkDestroyPipeline(device, pipeline, None)
vk.vkDestroyPipelineLayout(device, layout, None)
self._pipeline_cache.clear()
for ubo in self._uniform_buffers.values():
ubo.cleanup(device)
self._uniform_buffers.clear()
for material in self._materials:
material.cleanup(device)
self._materials.clear()