Source code for simvx.core.skeleton

"""Skeleton and bone hierarchy for skeletal animation.

Skeleton is a Node3D that manages a flat array of Bone data.  Each frame it
walks the parent chain, computes world transforms, and produces GPU-ready
joint matrices (world * inverse_bind) suitable for SSBO upload.

SkeletonProfile provides a standard naming convention so retargeting,
animation libraries, and humanoid IK can agree on bone names.
"""


from __future__ import annotations

import logging
from dataclasses import dataclass, field

import numpy as np

from .descriptors import Signal
from .nodes_3d.node3d import Node3D

log = logging.getLogger(__name__)

__all__ = [
    "Bone",
    "Skeleton",
    "SkeletonProfile",
    "PROFILE_HUMANOID",
]


# ============================================================================
# Bone — data record (no scene-tree node, just structured data)
# ============================================================================


[docs] @dataclass class Bone: """Single bone in a skeleton hierarchy.""" name: str = "" parent_index: int = -1 inverse_bind_matrix: np.ndarray = field(default_factory=lambda: np.eye(4, dtype=np.float32)) local_transform: np.ndarray = field(default_factory=lambda: np.eye(4, dtype=np.float32))
# ============================================================================ # Skeleton — Node3D managing a bone hierarchy # ============================================================================
[docs] class Skeleton(Node3D): """Bone hierarchy with GPU-ready joint matrix computation. Joint matrices = parent_world * local_transform * inverse_bind_matrix These are uploaded to an SSBO for vertex skinning in the shader. As a ``Node3D`` it participates in the scene tree, inherits a 3D transform, and can be animated/parented like any other spatial node. Signals: bone_pose_changed: Emitted after ``compute_pose`` updates joint matrices. Connected ``MeshInstance3D`` nodes (via ``skin``) can listen to know when to re-upload skinning data. """ def __init__(self, bones: list[Bone] | None = None, **kwargs): super().__init__(**kwargs) self.bones: list[Bone] = bones or [] self._joint_matrices: np.ndarray | None = None self._world_transforms: np.ndarray | None = None self._bone_overrides: dict[int, np.ndarray] = {} self.bone_pose_changed = Signal() # -- Bone count ----------------------------------------------------------- @property def bone_count(self) -> int: return len(self.bones) # -- Joint matrices ------------------------------------------------------- @property def joint_matrices(self) -> np.ndarray: """Get computed joint matrices (bone_count, 4, 4). Call compute_pose() first.""" if self._joint_matrices is None: self._joint_matrices = np.zeros((self.bone_count, 4, 4), dtype=np.float32) for i in range(self.bone_count): self._joint_matrices[i] = np.eye(4, dtype=np.float32) return self._joint_matrices # -- Pose management ------------------------------------------------------
[docs] def set_bone_pose(self, bone_index: int, transform: np.ndarray) -> None: """Override a single bone's local transform for the current pose. The override persists until cleared via ``clear_bone_pose`` or another ``set_bone_pose`` call. Call ``compute_pose()`` afterwards (or let ``process()`` do it) to propagate the change. """ self._bone_overrides[bone_index] = np.asarray(transform, dtype=np.float32).reshape(4, 4)
[docs] def get_bone_pose(self, bone_index: int) -> np.ndarray: """Return the current local transform for a bone (override or default).""" if bone_index in self._bone_overrides: return self._bone_overrides[bone_index].copy() if 0 <= bone_index < self.bone_count: return self.bones[bone_index].local_transform.copy() raise IndexError(f"bone_index {bone_index} out of range (bone_count={self.bone_count})")
[docs] def clear_bone_pose(self, bone_index: int) -> None: """Remove the per-bone pose override, reverting to the bone's default local_transform.""" self._bone_overrides.pop(bone_index, None)
[docs] def clear_all_bone_poses(self) -> None: """Remove all per-bone pose overrides.""" self._bone_overrides.clear()
[docs] def get_bone_global_transform(self, bone_index: int) -> np.ndarray: """Return the world-space transform of a bone after the last ``compute_pose``. This is *not* the joint matrix (which includes inverse-bind); it is the raw world transform useful for attachment points, IK targets, etc. """ if self._world_transforms is not None and 0 <= bone_index < len(self._world_transforms): return self._world_transforms[bone_index].copy() # Fallback: return identity if pose not yet computed return np.eye(4, dtype=np.float32)
# -- Pose computation -----------------------------------------------------
[docs] def compute_pose(self, bone_transforms: dict[int, np.ndarray] | None = None) -> None: """Compute final joint matrices from bone-local transforms. Args: bone_transforms: Optional *additional* overrides for bone-local transforms (merged on top of ``set_bone_pose`` overrides). Maps bone_index -> 4x4 local transform matrix. Bones not in any override use their default ``local_transform``. """ n = self.bone_count if n == 0: return if self._joint_matrices is None or len(self._joint_matrices) != n: self._joint_matrices = np.zeros((n, 4, 4), dtype=np.float32) # Merge overrides: set_bone_pose overrides first, then explicit arg on top merged = dict(self._bone_overrides) if bone_transforms: merged.update(bone_transforms) # Compute world transforms by walking hierarchy (parents before children) world = np.zeros((n, 4, 4), dtype=np.float32) for i, bone in enumerate(self.bones): local = merged[i] if i in merged else bone.local_transform if bone.parent_index < 0: world[i] = local elif bone.parent_index >= n: log.warning("Bone %s has invalid parent_index %s (bone_count=%s)", i, bone.parent_index, n) world[i] = local else: world[i] = world[bone.parent_index] @ local # Joint matrix = world * inverse_bind self._joint_matrices[i] = world[i] @ bone.inverse_bind_matrix # Cache world transforms for get_bone_global_transform self._world_transforms = world # Notify listeners (e.g. skinned MeshInstance3D nodes) self.bone_pose_changed()
# -- Scene-tree integration -----------------------------------------------
[docs] def process(self, dt: float) -> None: """Auto-recompute pose each frame if there are any overrides.""" if self._bone_overrides: self.compute_pose()
# -- Bone lookup ----------------------------------------------------------
[docs] def find_bone(self, name: str) -> int: """Find bone index by name. Returns -1 if not found.""" for i, bone in enumerate(self.bones): if bone.name == name: return i log.warning("Bone %r not found in skeleton (%s bones)", name, self.bone_count) return -1
[docs] def add_bone(self, bone: Bone) -> int: """Append a bone and return its index.""" idx = len(self.bones) self.bones.append(bone) # Invalidate cached matrices so they resize on next access self._joint_matrices = None return idx
# ============================================================================ # SkeletonProfile — standard bone naming convention # ============================================================================
[docs] @dataclass class SkeletonProfile: """Standard bone naming convention for retargeting and humanoid IK. A profile defines an ordered list of bone names, optional parent relationships (by name), and bone groups for logical grouping. Usage:: profile = PROFILE_HUMANOID idx = skel.find_bone(profile.bone_names[0]) # "Hips" # Validate a skeleton against the profile missing = profile.validate(skel) """ name: str bone_names: list[str] = field(default_factory=list) bone_parents: dict[str, str] = field(default_factory=dict) bone_groups: dict[str, list[str]] = field(default_factory=dict)
[docs] def validate(self, skeleton: Skeleton) -> list[str]: """Return a list of profile bone names missing from *skeleton*.""" return [name for name in self.bone_names if skeleton.find_bone(name) == -1]
[docs] def find_in_skeleton(self, skeleton: Skeleton, profile_bone_name: str) -> int: """Look up a profile bone name in a skeleton. Returns -1 if not found.""" return skeleton.find_bone(profile_bone_name)
[docs] def get_parent(self, bone_name: str) -> str | None: """Return the profile-defined parent bone name, or None for root bones.""" return self.bone_parents.get(bone_name)
[docs] def get_group(self, group_name: str) -> list[str]: """Return bone names belonging to a group, or empty list.""" return self.bone_groups.get(group_name, [])
# -- Built-in humanoid profile ----------------------------------------------- _HUMANOID_BONES = [ "Hips", "Spine", "Chest", "UpperChest", "Neck", "Head", "LeftShoulder", "LeftUpperArm", "LeftLowerArm", "LeftHand", "RightShoulder", "RightUpperArm", "RightLowerArm", "RightHand", "LeftUpperLeg", "LeftLowerLeg", "LeftFoot", "RightUpperLeg", "RightLowerLeg", "RightFoot", ] _HUMANOID_PARENTS = { "Spine": "Hips", "Chest": "Spine", "UpperChest": "Chest", "Neck": "UpperChest", "Head": "Neck", "LeftShoulder": "UpperChest", "LeftUpperArm": "LeftShoulder", "LeftLowerArm": "LeftUpperArm", "LeftHand": "LeftLowerArm", "RightShoulder": "UpperChest", "RightUpperArm": "RightShoulder", "RightLowerArm": "RightUpperArm", "RightHand": "RightLowerArm", "LeftUpperLeg": "Hips", "LeftLowerLeg": "LeftUpperLeg", "LeftFoot": "LeftLowerLeg", "RightUpperLeg": "Hips", "RightLowerLeg": "RightUpperLeg", "RightFoot": "RightLowerLeg", } _HUMANOID_GROUPS = { "torso": ["Hips", "Spine", "Chest", "UpperChest", "Neck", "Head"], "left_arm": ["LeftShoulder", "LeftUpperArm", "LeftLowerArm", "LeftHand"], "right_arm": ["RightShoulder", "RightUpperArm", "RightLowerArm", "RightHand"], "left_leg": ["LeftUpperLeg", "LeftLowerLeg", "LeftFoot"], "right_leg": ["RightUpperLeg", "RightLowerLeg", "RightFoot"], } PROFILE_HUMANOID = SkeletonProfile( name="Humanoid", bone_names=_HUMANOID_BONES, bone_parents=_HUMANOID_PARENTS, bone_groups=_HUMANOID_GROUPS, )