"""Physical/logical device selection and queue management."""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Any
import vulkan as vk
__all__ = ["select_physical_device", "create_logical_device", "QueueFamilies"]
log = logging.getLogger(__name__)
[docs]
@dataclass
class QueueFamilies:
graphics: int
present: int
def _find_queue_families(physical_device: Any, surface: Any, vk_surface_support: Any) -> QueueFamilies | None:
props = vk.vkGetPhysicalDeviceQueueFamilyProperties(physical_device)
graphics = present = -1
for i, p in enumerate(props):
if p.queueFlags & vk.VK_QUEUE_GRAPHICS_BIT:
graphics = i
if vk_surface_support(physical_device, i, surface):
present = i
if graphics >= 0 and present >= 0:
return QueueFamilies(graphics, present)
return None
def _query_device_features(physical_device: Any) -> dict[str, bool]:
"""Query which optional Vulkan features a physical device supports."""
features = vk.vkGetPhysicalDeviceFeatures(physical_device)
return {
"multi_draw_indirect": bool(features.multiDrawIndirect),
}
[docs]
def select_physical_device(instance: Any, surface: Any) -> tuple[Any, QueueFamilies]:
"""Pick a suitable VkPhysicalDevice. Returns (physical_device, queue_families)."""
vk_surface_support = vk.vkGetInstanceProcAddr(instance, "vkGetPhysicalDeviceSurfaceSupportKHR")
devices = vk.vkEnumeratePhysicalDevices(instance)
if not devices:
raise RuntimeError("No Vulkan-capable GPU found")
for dev in devices:
props = vk.vkGetPhysicalDeviceProperties(dev)
qf = _find_queue_families(dev, surface, vk_surface_support)
if qf is None:
continue
name = props.deviceName if isinstance(props.deviceName, str) else props.deviceName.decode("utf-8")
log.debug("Selected GPU: %s", name)
return dev, qf
raise RuntimeError("No suitable GPU with graphics+present queues found")
[docs]
def create_logical_device(
physical_device: Any,
queue_families: QueueFamilies,
) -> tuple[Any, Any, Any]:
"""Create a VkDevice. Returns (device, graphics_queue, present_queue)."""
unique_families = {queue_families.graphics, queue_families.present}
queue_create_infos = [
vk.VkDeviceQueueCreateInfo(
queueFamilyIndex=family,
queueCount=1,
pQueuePriorities=[1.0],
)
for family in unique_families
]
device_extensions = [vk.VK_KHR_SWAPCHAIN_EXTENSION_NAME]
# Only request multiDrawIndirect if the device supports it
caps = _query_device_features(physical_device)
features = vk.VkPhysicalDeviceFeatures(
multiDrawIndirect=caps["multi_draw_indirect"],
)
# Enable Vulkan 1.2 features required by shaders using nonuniformEXT
features12 = vk.VkPhysicalDeviceVulkan12Features(
sType=vk.VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES,
runtimeDescriptorArray=True,
shaderSampledImageArrayNonUniformIndexing=True,
descriptorBindingPartiallyBound=True,
)
create_info = vk.VkDeviceCreateInfo(
pNext=features12,
queueCreateInfoCount=len(queue_create_infos),
pQueueCreateInfos=queue_create_infos,
enabledExtensionCount=len(device_extensions),
ppEnabledExtensionNames=device_extensions,
pEnabledFeatures=features,
)
device = vk.vkCreateDevice(physical_device, create_info, None)
graphics_queue = vk.vkGetDeviceQueue(device, queue_families.graphics, 0)
present_queue = vk.vkGetDeviceQueue(device, queue_families.present, 0)
log.debug("Logical device created")
return device, graphics_queue, present_queue