"""Full glTF scene loading — meshes, materials, textures, node hierarchy."""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import numpy as np
from pygltflib import GLTF2
from simvx.graphics._types import SKINNED_VERTEX_DTYPE, VERTEX_DTYPE
__all__ = ["load_gltf", "GLTFScene", "GLTFMaterial", "GLTFNode"]
log = logging.getLogger(__name__)
# glTF component type → numpy dtype
_COMPONENT_DTYPES = {
5120: np.int8,
5121: np.uint8,
5122: np.int16,
5123: np.uint16,
5125: np.uint32,
5126: np.float32,
}
# glTF accessor type → component count
_TYPE_SIZES = {"SCALAR": 1, "VEC2": 2, "VEC3": 3, "VEC4": 4, "MAT4": 16}
[docs]
@dataclass
class GLTFMaterial:
"""Extracted PBR metallic-roughness material."""
name: str = ""
albedo: tuple[float, float, float, float] = (1.0, 1.0, 1.0, 1.0)
metallic: float = 1.0
roughness: float = 1.0
albedo_texture: str | bytes | None = None
normal_texture: str | bytes | None = None
metallic_roughness_texture: str | bytes | None = None
emissive_texture: str | bytes | None = None
ao_texture: str | bytes | None = None
double_sided: bool = False
alpha_mode: str = "OPAQUE"
[docs]
@dataclass
class GLTFNode:
"""Scene graph node with optional mesh reference."""
name: str = ""
mesh_index: int | None = None
material_indices: list[int] = field(default_factory=list)
transform: np.ndarray = field(default_factory=lambda: np.eye(4, dtype=np.float32))
children: list[int] = field(default_factory=list)
# Skinning (extracted but not applied until Phase 6)
skin_index: int | None = None
[docs]
@dataclass
class GLTFScene:
"""Complete loaded glTF scene."""
meshes: list[tuple[np.ndarray, np.ndarray]] = field(default_factory=list) # (vertices, indices)
materials: list[GLTFMaterial] = field(default_factory=list)
nodes: list[GLTFNode] = field(default_factory=list)
root_nodes: list[int] = field(default_factory=list)
# Skinning data
skins: list[dict[str, Any]] = field(default_factory=list)
# Animation data
animations: list[dict[str, Any]] = field(default_factory=list)
[docs]
def load_gltf(file_path: str) -> GLTFScene:
"""Load complete glTF scene with all meshes, materials, textures, and hierarchy.
Returns a GLTFScene with all data extracted and ready for import.
"""
path = Path(file_path)
gltf = GLTF2().load(str(path))
base_dir = path.parent
scene = GLTFScene()
# --- Materials ---
for gmat in gltf.materials or []:
mat = GLTFMaterial(name=gmat.name or "")
pbr = gmat.pbrMetallicRoughness
if pbr:
bc = pbr.baseColorFactor or [1, 1, 1, 1]
mat.albedo = tuple(bc[:4])
mat.metallic = pbr.metallicFactor if pbr.metallicFactor is not None else 1.0
mat.roughness = pbr.roughnessFactor if pbr.roughnessFactor is not None else 1.0
if pbr.baseColorTexture is not None:
mat.albedo_texture = _resolve_texture(gltf, pbr.baseColorTexture.index, base_dir)
if pbr.metallicRoughnessTexture is not None:
mat.metallic_roughness_texture = _resolve_texture(
gltf,
pbr.metallicRoughnessTexture.index,
base_dir,
)
if gmat.normalTexture is not None:
mat.normal_texture = _resolve_texture(gltf, gmat.normalTexture.index, base_dir)
if gmat.emissiveTexture is not None:
mat.emissive_texture = _resolve_texture(gltf, gmat.emissiveTexture.index, base_dir)
if gmat.occlusionTexture is not None:
mat.ao_texture = _resolve_texture(gltf, gmat.occlusionTexture.index, base_dir)
mat.double_sided = gmat.doubleSided or False
mat.alpha_mode = gmat.alphaMode or "OPAQUE"
scene.materials.append(mat)
# --- Meshes (each primitive becomes a separate entry) ---
# mesh_prim_map[mesh_idx] = list of (scene_mesh_idx, material_idx)
mesh_prim_map: dict[int, list[tuple[int, int]]] = {}
for mesh_idx, gmesh in enumerate(gltf.meshes or []):
prims = []
for prim in gmesh.primitives:
vertices, indices = _extract_primitive(gltf, prim)
scene_mesh_idx = len(scene.meshes)
scene.meshes.append((vertices, indices))
mat_idx = prim.material if prim.material is not None else -1
prims.append((scene_mesh_idx, mat_idx))
mesh_prim_map[mesh_idx] = prims
# --- Nodes ---
for gnode in gltf.nodes or []:
node = GLTFNode(name=gnode.name or "")
node.transform = _node_transform(gnode)
node.children = list(gnode.children) if gnode.children else []
if gnode.mesh is not None:
prims = mesh_prim_map.get(gnode.mesh, [])
if prims:
node.mesh_index = prims[0][0] # First primitive
node.material_indices = [p[1] for p in prims]
if gnode.skin is not None:
node.skin_index = gnode.skin
scene.nodes.append(node)
# --- Root nodes ---
gltf_scene = gltf.scenes[gltf.scene or 0] if gltf.scenes else None
if gltf_scene and gltf_scene.nodes:
scene.root_nodes = list(gltf_scene.nodes)
elif scene.nodes:
scene.root_nodes = [0]
# --- Skins (extract for Phase 6) ---
for gskin in gltf.skins or []:
skin_data: dict[str, Any] = {
"name": gskin.name or "",
"joints": list(gskin.joints) if gskin.joints else [],
}
if gskin.inverseBindMatrices is not None:
skin_data["inverse_bind_matrices"] = _read_accessor(gltf, gskin.inverseBindMatrices)
scene.skins.append(skin_data)
# --- Animations ---
# Build mapping of glTF node index → bone index per skin
skin_joint_maps: list[dict[int, int]] = []
for gskin_data in scene.skins:
jmap = {}
for bone_idx, node_idx in enumerate(gskin_data.get("joints", [])):
jmap[node_idx] = bone_idx
skin_joint_maps.append(jmap)
for ganim in gltf.animations or []:
anim_data: dict[str, Any] = {
"name": ganim.name or "",
"duration": 0.0,
"tracks": [],
}
for channel in ganim.channels or []:
sampler = ganim.samplers[channel.sampler]
target_node = channel.target.node
target_path = channel.target.path # translation/rotation/scale
# Find which bone this node maps to
bone_index = -1
for jmap in skin_joint_maps:
if target_node in jmap:
bone_index = jmap[target_node]
break
if bone_index < 0:
continue
# Read keyframe times and values
times = _read_accessor(gltf, sampler.input)
values = _read_accessor(gltf, sampler.output)
if len(times) > 0:
anim_data["duration"] = max(anim_data["duration"], float(times[-1]))
# Find or create track for this bone
track = None
for t in anim_data["tracks"]:
if t["bone_index"] == bone_index:
track = t
break
if track is None:
track = {"bone_index": bone_index, "position_keys": [], "rotation_keys": [], "scale_keys": []}
anim_data["tracks"].append(track)
# Convert to keyframe lists
if target_path == "translation":
track["position_keys"] = [(float(t), v.astype(np.float32)) for t, v in zip(times, values, strict=True)]
elif target_path == "rotation":
track["rotation_keys"] = [(float(t), v.astype(np.float32)) for t, v in zip(times, values, strict=True)]
elif target_path == "scale":
track["scale_keys"] = [(float(t), v.astype(np.float32)) for t, v in zip(times, values, strict=True)]
if anim_data["tracks"]:
scene.animations.append(anim_data)
log.debug(
"Loaded glTF: %d meshes, %d materials, %d nodes, %d animations from %s",
len(scene.meshes),
len(scene.materials),
len(scene.nodes),
len(scene.animations),
path.name,
)
return scene
def _extract_primitive(gltf: GLTF2, prim: Any) -> tuple[np.ndarray, np.ndarray]:
"""Extract vertices and indices from a single glTF primitive.
Returns skinned vertices (SKINNED_VERTEX_DTYPE) if JOINTS_0/WEIGHTS_0
are present, otherwise standard vertices (VERTEX_DTYPE).
"""
attrs = prim.attributes
positions = _read_accessor(gltf, attrs.POSITION)
normals = _read_accessor(gltf, attrs.NORMAL) if attrs.NORMAL is not None else None
uvs = _read_accessor(gltf, attrs.TEXCOORD_0) if attrs.TEXCOORD_0 is not None else None
has_skin = (
hasattr(attrs, "JOINTS_0")
and attrs.JOINTS_0 is not None
and hasattr(attrs, "WEIGHTS_0")
and attrs.WEIGHTS_0 is not None
)
count = len(positions)
if has_skin:
vertices = np.zeros(count, dtype=SKINNED_VERTEX_DTYPE)
vertices["joints"] = _read_accessor(gltf, attrs.JOINTS_0).astype(np.uint16)
vertices["weights"] = _read_accessor(gltf, attrs.WEIGHTS_0).astype(np.float32)
else:
vertices = np.zeros(count, dtype=VERTEX_DTYPE)
vertices["position"] = positions
if normals is not None:
vertices["normal"] = normals
if uvs is not None:
vertices["uv"] = uvs
if prim.indices is not None:
indices = _read_accessor(gltf, prim.indices).astype(np.uint32)
else:
indices = np.arange(count, dtype=np.uint32)
return vertices, indices
def _resolve_texture(gltf: GLTF2, tex_index: int, base_dir: Path) -> str | bytes | None:
"""Resolve a glTF texture index to a file path or embedded image bytes."""
if tex_index is None or tex_index >= len(gltf.textures or []):
return None
tex = gltf.textures[tex_index]
if tex.source is None or tex.source >= len(gltf.images or []):
return None
image = gltf.images[tex.source]
if image.uri:
return str(base_dir / image.uri)
# Embedded texture via bufferView (common in .glb files)
if image.bufferView is not None:
bv = gltf.bufferViews[image.bufferView]
buffer = gltf.buffers[bv.buffer]
data = gltf.get_data_from_buffer_uri(buffer.uri)
offset = bv.byteOffset or 0
return bytes(data[offset : offset + bv.byteLength])
return None
def _node_transform(gnode: Any) -> np.ndarray:
"""Extract 4x4 transform from glTF node (TRS or matrix)."""
if gnode.matrix:
return np.array(gnode.matrix, dtype=np.float32).reshape(4, 4)
mat = np.eye(4, dtype=np.float32)
if gnode.scale:
s = gnode.scale
mat[0, 0], mat[1, 1], mat[2, 2] = s[0], s[1], s[2]
if gnode.rotation:
q = gnode.rotation # [x, y, z, w]
rot = _quat_to_mat3(q[0], q[1], q[2], q[3])
scale_diag = np.diag(mat[:3, :3]).copy()
mat[:3, :3] = rot * scale_diag[np.newaxis, :]
if gnode.translation:
t = gnode.translation
mat[0, 3], mat[1, 3], mat[2, 3] = t[0], t[1], t[2]
return mat
def _quat_to_mat3(x: float, y: float, z: float, w: float) -> np.ndarray:
"""Convert quaternion to 3x3 rotation matrix."""
x2, y2, z2 = x + x, y + y, z + z
xx, xy, xz = x * x2, x * y2, x * z2
yy, yz, zz = y * y2, y * z2, z * z2
wx, wy, wz = w * x2, w * y2, w * z2
return np.array(
[
[1 - (yy + zz), xy - wz, xz + wy],
[xy + wz, 1 - (xx + zz), yz - wx],
[xz - wy, yz + wx, 1 - (xx + yy)],
],
dtype=np.float32,
)
def _read_accessor(gltf: GLTF2, accessor_index: int) -> np.ndarray:
"""Extract numpy array from a glTF accessor."""
accessor = gltf.accessors[accessor_index]
bv = gltf.bufferViews[accessor.bufferView]
buffer = gltf.buffers[bv.buffer]
data = gltf.get_data_from_buffer_uri(buffer.uri)
dtype = _COMPONENT_DTYPES[accessor.componentType]
components = _TYPE_SIZES[accessor.type]
offset = (bv.byteOffset or 0) + (accessor.byteOffset or 0)
stride = bv.byteStride
if stride and stride != components * np.dtype(dtype).itemsize:
# Interleaved buffer — read with stride
np.dtype(dtype).itemsize * components
arr = np.zeros((accessor.count, components), dtype=dtype)
for i in range(accessor.count):
start = offset + i * stride
chunk = np.frombuffer(data, dtype=dtype, count=components, offset=start)
arr[i] = chunk
return arr
arr = np.frombuffer(data, dtype=dtype, count=accessor.count * components, offset=offset)
if components > 1:
arr = arr.reshape((accessor.count, components))
return arr