"""Skeleton and bone hierarchy for skeletal animation.
Skeleton is a Node3D that manages a flat array of Bone data. Each frame it
walks the parent chain, computes world transforms, and produces GPU-ready
joint matrices (world * inverse_bind) suitable for SSBO upload.
SkeletonProfile provides a standard naming convention so retargeting,
animation libraries, and humanoid IK can agree on bone names.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
import numpy as np
from .descriptors import Signal
from .nodes_3d.node3d import Node3D
log = logging.getLogger(__name__)
__all__ = [
"Bone",
"Skeleton",
"SkeletonProfile",
"PROFILE_HUMANOID",
]
# ============================================================================
# Bone — data record (no scene-tree node, just structured data)
# ============================================================================
[docs]
@dataclass
class Bone:
"""Single bone in a skeleton hierarchy."""
name: str = ""
parent_index: int = -1
inverse_bind_matrix: np.ndarray = field(default_factory=lambda: np.eye(4, dtype=np.float32))
local_transform: np.ndarray = field(default_factory=lambda: np.eye(4, dtype=np.float32))
# ============================================================================
# Skeleton — Node3D managing a bone hierarchy
# ============================================================================
[docs]
class Skeleton(Node3D):
"""Bone hierarchy with GPU-ready joint matrix computation.
Joint matrices = parent_world * local_transform * inverse_bind_matrix
These are uploaded to an SSBO for vertex skinning in the shader.
As a ``Node3D`` it participates in the scene tree, inherits a 3D
transform, and can be animated/parented like any other spatial node.
Signals:
bone_pose_changed: Emitted after ``compute_pose`` updates joint
matrices. Connected ``MeshInstance3D`` nodes (via ``skin``)
can listen to know when to re-upload skinning data.
"""
def __init__(self, bones: list[Bone] | None = None, **kwargs):
super().__init__(**kwargs)
self.bones: list[Bone] = bones or []
self._joint_matrices: np.ndarray | None = None
self._world_transforms: np.ndarray | None = None
self._bone_overrides: dict[int, np.ndarray] = {}
self.bone_pose_changed = Signal()
# -- Bone count -----------------------------------------------------------
@property
def bone_count(self) -> int:
return len(self.bones)
# -- Joint matrices -------------------------------------------------------
@property
def joint_matrices(self) -> np.ndarray:
"""Get computed joint matrices (bone_count, 4, 4). Call compute_pose() first."""
if self._joint_matrices is None:
self._joint_matrices = np.zeros((self.bone_count, 4, 4), dtype=np.float32)
for i in range(self.bone_count):
self._joint_matrices[i] = np.eye(4, dtype=np.float32)
return self._joint_matrices
# -- Pose management ------------------------------------------------------
[docs]
def set_bone_pose(self, bone_index: int, transform: np.ndarray) -> None:
"""Override a single bone's local transform for the current pose.
The override persists until cleared via ``clear_bone_pose`` or another
``set_bone_pose`` call. Call ``compute_pose()`` afterwards (or let
``process()`` do it) to propagate the change.
"""
self._bone_overrides[bone_index] = np.asarray(transform, dtype=np.float32).reshape(4, 4)
[docs]
def get_bone_pose(self, bone_index: int) -> np.ndarray:
"""Return the current local transform for a bone (override or default)."""
if bone_index in self._bone_overrides:
return self._bone_overrides[bone_index].copy()
if 0 <= bone_index < self.bone_count:
return self.bones[bone_index].local_transform.copy()
raise IndexError(f"bone_index {bone_index} out of range (bone_count={self.bone_count})")
[docs]
def clear_bone_pose(self, bone_index: int) -> None:
"""Remove the per-bone pose override, reverting to the bone's default local_transform."""
self._bone_overrides.pop(bone_index, None)
[docs]
def clear_all_bone_poses(self) -> None:
"""Remove all per-bone pose overrides."""
self._bone_overrides.clear()
# -- Pose computation -----------------------------------------------------
[docs]
def compute_pose(self, bone_transforms: dict[int, np.ndarray] | None = None) -> None:
"""Compute final joint matrices from bone-local transforms.
Args:
bone_transforms: Optional *additional* overrides for bone-local
transforms (merged on top of ``set_bone_pose`` overrides).
Maps bone_index -> 4x4 local transform matrix.
Bones not in any override use their default ``local_transform``.
"""
n = self.bone_count
if n == 0:
return
if self._joint_matrices is None or len(self._joint_matrices) != n:
self._joint_matrices = np.zeros((n, 4, 4), dtype=np.float32)
# Merge overrides: set_bone_pose overrides first, then explicit arg on top
merged = dict(self._bone_overrides)
if bone_transforms:
merged.update(bone_transforms)
# Compute world transforms by walking hierarchy (parents before children)
world = np.zeros((n, 4, 4), dtype=np.float32)
for i, bone in enumerate(self.bones):
local = merged[i] if i in merged else bone.local_transform
if bone.parent_index < 0:
world[i] = local
elif bone.parent_index >= n:
log.warning("Bone %s has invalid parent_index %s (bone_count=%s)", i, bone.parent_index, n)
world[i] = local
else:
world[i] = world[bone.parent_index] @ local
# Joint matrix = world * inverse_bind
self._joint_matrices[i] = world[i] @ bone.inverse_bind_matrix
# Cache world transforms for get_bone_global_transform
self._world_transforms = world
# Notify listeners (e.g. skinned MeshInstance3D nodes)
self.bone_pose_changed()
# -- Scene-tree integration -----------------------------------------------
[docs]
def process(self, dt: float) -> None:
"""Auto-recompute pose each frame if there are any overrides."""
if self._bone_overrides:
self.compute_pose()
# -- Bone lookup ----------------------------------------------------------
[docs]
def find_bone(self, name: str) -> int:
"""Find bone index by name. Returns -1 if not found."""
for i, bone in enumerate(self.bones):
if bone.name == name:
return i
log.warning("Bone %r not found in skeleton (%s bones)", name, self.bone_count)
return -1
[docs]
def add_bone(self, bone: Bone) -> int:
"""Append a bone and return its index."""
idx = len(self.bones)
self.bones.append(bone)
# Invalidate cached matrices so they resize on next access
self._joint_matrices = None
return idx
# ============================================================================
# SkeletonProfile — standard bone naming convention
# ============================================================================
[docs]
@dataclass
class SkeletonProfile:
"""Standard bone naming convention for retargeting and humanoid IK.
A profile defines an ordered list of bone names, optional parent
relationships (by name), and bone groups for logical grouping.
Usage::
profile = PROFILE_HUMANOID
idx = skel.find_bone(profile.bone_names[0]) # "Hips"
# Validate a skeleton against the profile
missing = profile.validate(skel)
"""
name: str
bone_names: list[str] = field(default_factory=list)
bone_parents: dict[str, str] = field(default_factory=dict)
bone_groups: dict[str, list[str]] = field(default_factory=dict)
[docs]
def validate(self, skeleton: Skeleton) -> list[str]:
"""Return a list of profile bone names missing from *skeleton*."""
return [name for name in self.bone_names if skeleton.find_bone(name) == -1]
[docs]
def find_in_skeleton(self, skeleton: Skeleton, profile_bone_name: str) -> int:
"""Look up a profile bone name in a skeleton. Returns -1 if not found."""
return skeleton.find_bone(profile_bone_name)
[docs]
def get_parent(self, bone_name: str) -> str | None:
"""Return the profile-defined parent bone name, or None for root bones."""
return self.bone_parents.get(bone_name)
[docs]
def get_group(self, group_name: str) -> list[str]:
"""Return bone names belonging to a group, or empty list."""
return self.bone_groups.get(group_name, [])
# -- Built-in humanoid profile -----------------------------------------------
_HUMANOID_BONES = [
"Hips",
"Spine",
"Chest",
"UpperChest",
"Neck",
"Head",
"LeftShoulder",
"LeftUpperArm",
"LeftLowerArm",
"LeftHand",
"RightShoulder",
"RightUpperArm",
"RightLowerArm",
"RightHand",
"LeftUpperLeg",
"LeftLowerLeg",
"LeftFoot",
"RightUpperLeg",
"RightLowerLeg",
"RightFoot",
]
_HUMANOID_PARENTS = {
"Spine": "Hips",
"Chest": "Spine",
"UpperChest": "Chest",
"Neck": "UpperChest",
"Head": "Neck",
"LeftShoulder": "UpperChest",
"LeftUpperArm": "LeftShoulder",
"LeftLowerArm": "LeftUpperArm",
"LeftHand": "LeftLowerArm",
"RightShoulder": "UpperChest",
"RightUpperArm": "RightShoulder",
"RightLowerArm": "RightUpperArm",
"RightHand": "RightLowerArm",
"LeftUpperLeg": "Hips",
"LeftLowerLeg": "LeftUpperLeg",
"LeftFoot": "LeftLowerLeg",
"RightUpperLeg": "Hips",
"RightLowerLeg": "RightUpperLeg",
"RightFoot": "RightLowerLeg",
}
_HUMANOID_GROUPS = {
"torso": ["Hips", "Spine", "Chest", "UpperChest", "Neck", "Head"],
"left_arm": ["LeftShoulder", "LeftUpperArm", "LeftLowerArm", "LeftHand"],
"right_arm": ["RightShoulder", "RightUpperArm", "RightLowerArm", "RightHand"],
"left_leg": ["LeftUpperLeg", "LeftLowerLeg", "LeftFoot"],
"right_leg": ["RightUpperLeg", "RightLowerLeg", "RightFoot"],
}
PROFILE_HUMANOID = SkeletonProfile(
name="Humanoid",
bone_names=_HUMANOID_BONES,
bone_parents=_HUMANOID_PARENTS,
bone_groups=_HUMANOID_GROUPS,
)