"""Descriptor pool, layout, and set management."""
from __future__ import annotations
import logging
from typing import Any
import vulkan as vk
from .._types import MAX_TEXTURES
log = logging.getLogger(__name__)
__all__ = [
"DescriptorWriteBatch",
"create_descriptor_pool",
"create_ssbo_layout",
"allocate_descriptor_set",
"write_ssbo_descriptor",
"write_image_descriptor",
"create_texture_descriptor_pool",
"create_texture_descriptor_layout",
"write_texture_descriptor",
]
[docs]
def create_descriptor_pool(
device: Any,
max_sets: int = 4,
extra_samplers: int = 0,
ssbo_count: int = 0,
) -> Any:
"""Create a descriptor pool for SSBO descriptors (+ optional image samplers).
If *ssbo_count* is given it overrides the default ``max_sets * 4`` SSBO descriptor count.
"""
pool_sizes = [
vk.VkDescriptorPoolSize(
type=vk.VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
descriptorCount=ssbo_count if ssbo_count > 0 else max_sets * 4,
),
]
if extra_samplers > 0:
pool_sizes.append(
vk.VkDescriptorPoolSize(
type=vk.VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
descriptorCount=extra_samplers,
)
)
pool_info = vk.VkDescriptorPoolCreateInfo(
maxSets=max_sets,
poolSizeCount=len(pool_sizes),
pPoolSizes=pool_sizes,
)
pool = vk.vkCreateDescriptorPool(device, pool_info, None)
log.debug("Descriptor pool created (max_sets=%d)", max_sets)
return pool
[docs]
def create_ssbo_layout(
device: Any,
binding_count: int = 3,
extra_samplers: int = 0,
trailing_ssbos: int = 0,
) -> Any:
"""Create a descriptor set layout with N SSBO bindings + optional sampler bindings.
Binding order: ``binding_count`` SSBOs, then ``extra_samplers`` image samplers,
then ``trailing_ssbos`` additional SSBOs (fragment-only, for tile light data etc.).
"""
bindings = []
for i in range(binding_count):
bindings.append(
vk.VkDescriptorSetLayoutBinding(
binding=i,
descriptorType=vk.VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
descriptorCount=1,
stageFlags=vk.VK_SHADER_STAGE_VERTEX_BIT | vk.VK_SHADER_STAGE_FRAGMENT_BIT,
)
)
for i in range(extra_samplers):
bindings.append(
vk.VkDescriptorSetLayoutBinding(
binding=binding_count + i,
descriptorType=vk.VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
descriptorCount=1,
stageFlags=vk.VK_SHADER_STAGE_FRAGMENT_BIT,
)
)
trailing_start = binding_count + extra_samplers
for i in range(trailing_ssbos):
bindings.append(
vk.VkDescriptorSetLayoutBinding(
binding=trailing_start + i,
descriptorType=vk.VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
descriptorCount=1,
stageFlags=vk.VK_SHADER_STAGE_FRAGMENT_BIT,
)
)
layout_info = vk.VkDescriptorSetLayoutCreateInfo(
bindingCount=len(bindings),
pBindings=bindings,
)
layout = vk.vkCreateDescriptorSetLayout(device, layout_info, None)
log.debug("SSBO descriptor set layout created (%d bindings)", len(bindings))
return layout
[docs]
def allocate_descriptor_set(device: Any, pool: Any, layout: Any) -> Any:
"""Allocate a single descriptor set from the pool."""
alloc_info = vk.VkDescriptorSetAllocateInfo(
descriptorPool=pool,
descriptorSetCount=1,
pSetLayouts=[layout],
)
sets = vk.vkAllocateDescriptorSets(device, alloc_info)
return sets[0]
[docs]
def create_texture_descriptor_pool(device: Any, max_textures: int = MAX_TEXTURES) -> Any:
"""Create a descriptor pool for combined image samplers."""
pool_size = vk.VkDescriptorPoolSize(
type=vk.VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
descriptorCount=max_textures,
)
pool_info = vk.VkDescriptorPoolCreateInfo(
maxSets=1,
poolSizeCount=1,
pPoolSizes=[pool_size],
)
return vk.vkCreateDescriptorPool(device, pool_info, None)
[docs]
def create_texture_descriptor_layout(device: Any, max_textures: int = MAX_TEXTURES) -> Any:
"""Create set layout for texture array (fixed-size combined image samplers)."""
binding = vk.VkDescriptorSetLayoutBinding(
binding=0,
descriptorType=vk.VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
descriptorCount=max_textures,
stageFlags=vk.VK_SHADER_STAGE_FRAGMENT_BIT,
)
layout_info = vk.VkDescriptorSetLayoutCreateInfo(
bindingCount=1,
pBindings=[binding],
)
return vk.vkCreateDescriptorSetLayout(device, layout_info, None)
def _make_texture_write(
descriptor_set: Any, texture_index: int, image_view: Any, sampler: Any,
) -> vk.VkWriteDescriptorSet:
"""Build a VkWriteDescriptorSet for a texture array element without submitting it."""
image_info = vk.VkDescriptorImageInfo(
sampler=sampler, imageView=image_view,
imageLayout=vk.VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL,
)
return vk.VkWriteDescriptorSet(
dstSet=descriptor_set, dstBinding=0, dstArrayElement=texture_index,
descriptorCount=1, descriptorType=vk.VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
pImageInfo=[image_info],
)
def _make_image_write(
descriptor_set: Any, binding: int, image_view: Any, sampler: Any,
) -> vk.VkWriteDescriptorSet:
"""Build a VkWriteDescriptorSet for a combined image sampler without submitting it."""
image_info = vk.VkDescriptorImageInfo(
sampler=sampler, imageView=image_view,
imageLayout=vk.VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL,
)
return vk.VkWriteDescriptorSet(
dstSet=descriptor_set, dstBinding=binding, dstArrayElement=0,
descriptorCount=1, descriptorType=vk.VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
pImageInfo=[image_info],
)
def _make_ssbo_write(
descriptor_set: Any, binding: int, buffer: Any, size: int,
) -> vk.VkWriteDescriptorSet:
"""Build a VkWriteDescriptorSet for an SSBO binding without submitting it."""
buf_info = vk.VkDescriptorBufferInfo(buffer=buffer, offset=0, range=size)
return vk.VkWriteDescriptorSet(
dstSet=descriptor_set, dstBinding=binding, dstArrayElement=0,
descriptorCount=1, descriptorType=vk.VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
pBufferInfo=[buf_info],
)
[docs]
def write_texture_descriptor(
device: Any, descriptor_set: Any, texture_index: int, image_view: Any, sampler: Any,
) -> None:
"""Write a single texture to the texture array at the given index."""
w = _make_texture_write(descriptor_set, texture_index, image_view, sampler)
vk.vkUpdateDescriptorSets(device, 1, [w], 0, None)
[docs]
def write_image_descriptor(
device: Any, descriptor_set: Any, binding: int, image_view: Any, sampler: Any,
) -> None:
"""Write a combined image sampler to a descriptor set at the given binding."""
w = _make_image_write(descriptor_set, binding, image_view, sampler)
vk.vkUpdateDescriptorSets(device, 1, [w], 0, None)
[docs]
def write_ssbo_descriptor(
device: Any, descriptor_set: Any, binding: int, buffer: Any, size: int,
) -> None:
"""Write a single SSBO buffer binding to a descriptor set."""
w = _make_ssbo_write(descriptor_set, binding, buffer, size)
vk.vkUpdateDescriptorSets(device, 1, [w], 0, None)
[docs]
class DescriptorWriteBatch:
"""Collects VkWriteDescriptorSet structs and flushes them in a single Vulkan call.
Usage::
batch = DescriptorWriteBatch(device)
batch.ssbo(ds, 0, buf_a, size_a)
batch.ssbo(ds, 1, buf_b, size_b)
batch.image(ds, 2, view, sampler)
batch.flush()
Can also be used as a context manager -- ``flush()`` is called on exit::
with DescriptorWriteBatch(device) as batch:
batch.ssbo(ds, 0, buf, size)
"""
__slots__ = ("_device", "_writes")
def __init__(self, device: Any) -> None:
self._device = device
self._writes: list[vk.VkWriteDescriptorSet] = []
# -- Accumulate writes --------------------------------------------------
[docs]
def ssbo(self, descriptor_set: Any, binding: int, buffer: Any, size: int) -> DescriptorWriteBatch:
"""Queue an SSBO descriptor write."""
self._writes.append(_make_ssbo_write(descriptor_set, binding, buffer, size))
return self
[docs]
def image(self, descriptor_set: Any, binding: int, image_view: Any, sampler: Any) -> DescriptorWriteBatch:
"""Queue a combined image sampler descriptor write."""
self._writes.append(_make_image_write(descriptor_set, binding, image_view, sampler))
return self
[docs]
def texture(
self, descriptor_set: Any, texture_index: int, image_view: Any, sampler: Any,
) -> DescriptorWriteBatch:
"""Queue a texture array element descriptor write."""
self._writes.append(_make_texture_write(descriptor_set, texture_index, image_view, sampler))
return self
[docs]
def raw(self, write: vk.VkWriteDescriptorSet) -> DescriptorWriteBatch:
"""Queue a pre-built VkWriteDescriptorSet."""
self._writes.append(write)
return self
# -- Submit --------------------------------------------------------------
[docs]
def flush(self) -> int:
"""Submit all queued writes in a single ``vkUpdateDescriptorSets`` call.
Returns the number of writes submitted. The internal queue is cleared.
"""
count = len(self._writes)
if count:
vk.vkUpdateDescriptorSets(self._device, count, self._writes, 0, None)
self._writes.clear()
return count
# -- Context manager -----------------------------------------------------
[docs]
def __enter__(self) -> DescriptorWriteBatch:
return self
[docs]
def __exit__(self, *exc: object) -> None:
self.flush()
[docs]
def __len__(self) -> int:
return len(self._writes)