Source code for simvx.graphics.assets.mesh_loader

"""Full glTF scene loading — meshes, materials, textures, node hierarchy."""


from __future__ import annotations

import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any

import numpy as np
from pygltflib import GLTF2

from simvx.graphics._types import SKINNED_VERTEX_DTYPE, VERTEX_DTYPE

__all__ = ["load_gltf", "GLTFScene", "GLTFMaterial", "GLTFNode"]

log = logging.getLogger(__name__)

# glTF component type → numpy dtype
_COMPONENT_DTYPES = {
    5120: np.int8,
    5121: np.uint8,
    5122: np.int16,
    5123: np.uint16,
    5125: np.uint32,
    5126: np.float32,
}

# glTF accessor type → component count
_TYPE_SIZES = {"SCALAR": 1, "VEC2": 2, "VEC3": 3, "VEC4": 4, "MAT4": 16}


[docs] @dataclass class GLTFMaterial: """Extracted PBR metallic-roughness material.""" name: str = "" albedo: tuple[float, float, float, float] = (1.0, 1.0, 1.0, 1.0) metallic: float = 1.0 roughness: float = 1.0 albedo_texture: str | bytes | None = None normal_texture: str | bytes | None = None metallic_roughness_texture: str | bytes | None = None emissive_texture: str | bytes | None = None ao_texture: str | bytes | None = None double_sided: bool = False alpha_mode: str = "OPAQUE"
[docs] @dataclass class GLTFNode: """Scene graph node with optional mesh reference.""" name: str = "" mesh_index: int | None = None material_indices: list[int] = field(default_factory=list) transform: np.ndarray = field(default_factory=lambda: np.eye(4, dtype=np.float32)) children: list[int] = field(default_factory=list) # Skinning (extracted but not applied until Phase 6) skin_index: int | None = None
[docs] @dataclass class GLTFScene: """Complete loaded glTF scene.""" meshes: list[tuple[np.ndarray, np.ndarray]] = field(default_factory=list) # (vertices, indices) materials: list[GLTFMaterial] = field(default_factory=list) nodes: list[GLTFNode] = field(default_factory=list) root_nodes: list[int] = field(default_factory=list) # Skinning data skins: list[dict[str, Any]] = field(default_factory=list) # Animation data animations: list[dict[str, Any]] = field(default_factory=list)
[docs] def load_gltf(file_path: str) -> GLTFScene: """Load complete glTF scene with all meshes, materials, textures, and hierarchy. Returns a GLTFScene with all data extracted and ready for import. """ path = Path(file_path) gltf = GLTF2().load(str(path)) base_dir = path.parent scene = GLTFScene() # --- Materials --- for gmat in gltf.materials or []: mat = GLTFMaterial(name=gmat.name or "") pbr = gmat.pbrMetallicRoughness if pbr: bc = pbr.baseColorFactor or [1, 1, 1, 1] mat.albedo = tuple(bc[:4]) mat.metallic = pbr.metallicFactor if pbr.metallicFactor is not None else 1.0 mat.roughness = pbr.roughnessFactor if pbr.roughnessFactor is not None else 1.0 if pbr.baseColorTexture is not None: mat.albedo_texture = _resolve_texture(gltf, pbr.baseColorTexture.index, base_dir) if pbr.metallicRoughnessTexture is not None: mat.metallic_roughness_texture = _resolve_texture( gltf, pbr.metallicRoughnessTexture.index, base_dir, ) if gmat.normalTexture is not None: mat.normal_texture = _resolve_texture(gltf, gmat.normalTexture.index, base_dir) if gmat.emissiveTexture is not None: mat.emissive_texture = _resolve_texture(gltf, gmat.emissiveTexture.index, base_dir) if gmat.occlusionTexture is not None: mat.ao_texture = _resolve_texture(gltf, gmat.occlusionTexture.index, base_dir) mat.double_sided = gmat.doubleSided or False mat.alpha_mode = gmat.alphaMode or "OPAQUE" scene.materials.append(mat) # --- Meshes (each primitive becomes a separate entry) --- # mesh_prim_map[mesh_idx] = list of (scene_mesh_idx, material_idx) mesh_prim_map: dict[int, list[tuple[int, int]]] = {} for mesh_idx, gmesh in enumerate(gltf.meshes or []): prims = [] for prim in gmesh.primitives: vertices, indices = _extract_primitive(gltf, prim) scene_mesh_idx = len(scene.meshes) scene.meshes.append((vertices, indices)) mat_idx = prim.material if prim.material is not None else -1 prims.append((scene_mesh_idx, mat_idx)) mesh_prim_map[mesh_idx] = prims # --- Nodes --- for gnode in gltf.nodes or []: node = GLTFNode(name=gnode.name or "") node.transform = _node_transform(gnode) node.children = list(gnode.children) if gnode.children else [] if gnode.mesh is not None: prims = mesh_prim_map.get(gnode.mesh, []) if prims: node.mesh_index = prims[0][0] # First primitive node.material_indices = [p[1] for p in prims] if gnode.skin is not None: node.skin_index = gnode.skin scene.nodes.append(node) # --- Root nodes --- gltf_scene = gltf.scenes[gltf.scene or 0] if gltf.scenes else None if gltf_scene and gltf_scene.nodes: scene.root_nodes = list(gltf_scene.nodes) elif scene.nodes: scene.root_nodes = [0] # --- Skins (extract for Phase 6) --- for gskin in gltf.skins or []: skin_data: dict[str, Any] = { "name": gskin.name or "", "joints": list(gskin.joints) if gskin.joints else [], } if gskin.inverseBindMatrices is not None: skin_data["inverse_bind_matrices"] = _read_accessor(gltf, gskin.inverseBindMatrices) scene.skins.append(skin_data) # --- Animations --- # Build mapping of glTF node index → bone index per skin skin_joint_maps: list[dict[int, int]] = [] for gskin_data in scene.skins: jmap = {} for bone_idx, node_idx in enumerate(gskin_data.get("joints", [])): jmap[node_idx] = bone_idx skin_joint_maps.append(jmap) for ganim in gltf.animations or []: anim_data: dict[str, Any] = { "name": ganim.name or "", "duration": 0.0, "tracks": [], } for channel in ganim.channels or []: sampler = ganim.samplers[channel.sampler] target_node = channel.target.node target_path = channel.target.path # translation/rotation/scale # Find which bone this node maps to bone_index = -1 for jmap in skin_joint_maps: if target_node in jmap: bone_index = jmap[target_node] break if bone_index < 0: continue # Read keyframe times and values times = _read_accessor(gltf, sampler.input) values = _read_accessor(gltf, sampler.output) if len(times) > 0: anim_data["duration"] = max(anim_data["duration"], float(times[-1])) # Find or create track for this bone track = None for t in anim_data["tracks"]: if t["bone_index"] == bone_index: track = t break if track is None: track = {"bone_index": bone_index, "position_keys": [], "rotation_keys": [], "scale_keys": []} anim_data["tracks"].append(track) # Convert to keyframe lists if target_path == "translation": track["position_keys"] = [(float(t), v.astype(np.float32)) for t, v in zip(times, values, strict=True)] elif target_path == "rotation": track["rotation_keys"] = [(float(t), v.astype(np.float32)) for t, v in zip(times, values, strict=True)] elif target_path == "scale": track["scale_keys"] = [(float(t), v.astype(np.float32)) for t, v in zip(times, values, strict=True)] if anim_data["tracks"]: scene.animations.append(anim_data) log.debug( "Loaded glTF: %d meshes, %d materials, %d nodes, %d animations from %s", len(scene.meshes), len(scene.materials), len(scene.nodes), len(scene.animations), path.name, ) return scene
def _extract_primitive(gltf: GLTF2, prim: Any) -> tuple[np.ndarray, np.ndarray]: """Extract vertices and indices from a single glTF primitive. Returns skinned vertices (SKINNED_VERTEX_DTYPE) if JOINTS_0/WEIGHTS_0 are present, otherwise standard vertices (VERTEX_DTYPE). """ attrs = prim.attributes positions = _read_accessor(gltf, attrs.POSITION) normals = _read_accessor(gltf, attrs.NORMAL) if attrs.NORMAL is not None else None uvs = _read_accessor(gltf, attrs.TEXCOORD_0) if attrs.TEXCOORD_0 is not None else None has_skin = ( hasattr(attrs, "JOINTS_0") and attrs.JOINTS_0 is not None and hasattr(attrs, "WEIGHTS_0") and attrs.WEIGHTS_0 is not None ) count = len(positions) if has_skin: vertices = np.zeros(count, dtype=SKINNED_VERTEX_DTYPE) vertices["joints"] = _read_accessor(gltf, attrs.JOINTS_0).astype(np.uint16) vertices["weights"] = _read_accessor(gltf, attrs.WEIGHTS_0).astype(np.float32) else: vertices = np.zeros(count, dtype=VERTEX_DTYPE) vertices["position"] = positions if normals is not None: vertices["normal"] = normals if uvs is not None: vertices["uv"] = uvs if prim.indices is not None: indices = _read_accessor(gltf, prim.indices).astype(np.uint32) else: indices = np.arange(count, dtype=np.uint32) return vertices, indices def _resolve_texture(gltf: GLTF2, tex_index: int, base_dir: Path) -> str | bytes | None: """Resolve a glTF texture index to a file path or embedded image bytes.""" if tex_index is None or tex_index >= len(gltf.textures or []): return None tex = gltf.textures[tex_index] if tex.source is None or tex.source >= len(gltf.images or []): return None image = gltf.images[tex.source] if image.uri: return str(base_dir / image.uri) # Embedded texture via bufferView (common in .glb files) if image.bufferView is not None: bv = gltf.bufferViews[image.bufferView] buffer = gltf.buffers[bv.buffer] data = gltf.get_data_from_buffer_uri(buffer.uri) offset = bv.byteOffset or 0 return bytes(data[offset : offset + bv.byteLength]) return None def _node_transform(gnode: Any) -> np.ndarray: """Extract 4x4 transform from glTF node (TRS or matrix).""" if gnode.matrix: return np.array(gnode.matrix, dtype=np.float32).reshape(4, 4) mat = np.eye(4, dtype=np.float32) if gnode.scale: s = gnode.scale mat[0, 0], mat[1, 1], mat[2, 2] = s[0], s[1], s[2] if gnode.rotation: q = gnode.rotation # [x, y, z, w] rot = _quat_to_mat3(q[0], q[1], q[2], q[3]) scale_diag = np.diag(mat[:3, :3]).copy() mat[:3, :3] = rot * scale_diag[np.newaxis, :] if gnode.translation: t = gnode.translation mat[0, 3], mat[1, 3], mat[2, 3] = t[0], t[1], t[2] return mat def _quat_to_mat3(x: float, y: float, z: float, w: float) -> np.ndarray: """Convert quaternion to 3x3 rotation matrix.""" x2, y2, z2 = x + x, y + y, z + z xx, xy, xz = x * x2, x * y2, x * z2 yy, yz, zz = y * y2, y * z2, z * z2 wx, wy, wz = w * x2, w * y2, w * z2 return np.array( [ [1 - (yy + zz), xy - wz, xz + wy], [xy + wz, 1 - (xx + zz), yz - wx], [xz - wy, yz + wx, 1 - (xx + yy)], ], dtype=np.float32, ) def _read_accessor(gltf: GLTF2, accessor_index: int) -> np.ndarray: """Extract numpy array from a glTF accessor.""" accessor = gltf.accessors[accessor_index] bv = gltf.bufferViews[accessor.bufferView] buffer = gltf.buffers[bv.buffer] data = gltf.get_data_from_buffer_uri(buffer.uri) dtype = _COMPONENT_DTYPES[accessor.componentType] components = _TYPE_SIZES[accessor.type] offset = (bv.byteOffset or 0) + (accessor.byteOffset or 0) stride = bv.byteStride if stride and stride != components * np.dtype(dtype).itemsize: # Interleaved buffer — read with stride np.dtype(dtype).itemsize * components arr = np.zeros((accessor.count, components), dtype=dtype) for i in range(accessor.count): start = offset + i * stride chunk = np.frombuffer(data, dtype=dtype, count=components, offset=start) arr[i] = chunk return arr arr = np.frombuffer(data, dtype=dtype, count=accessor.count * components, offset=offset) if components > 1: arr = arr.reshape((accessor.count, components)) return arr