"""Animation state machine with parameter-based transitions."""
from __future__ import annotations
import logging
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
from ..node import Node
from ._interpolate import _interpolate
from .track import AnimationClip
log = logging.getLogger(__name__)
[docs]
@dataclass
class AnimationState:
"""Single state in animation state machine."""
name: str
clip: AnimationClip
speed_scale: float = 1.0
loop: bool = True
[docs]
@dataclass
class Transition:
"""Transition between animation states."""
from_state: str
to_state: str
condition: Callable[[], bool] # Returns True when transition should happen
blend_time: float = 0.2 # Time to blend between states
[docs]
class AnimationTree(Node):
"""Animation state machine with parameter-based transitions.
As a Node subclass, it participates in the scene tree and gets
``process(dt)`` called automatically. By default it animates its parent.
Example:
tree = AnimationTree(target=player)
tree.add_state("idle", idle_clip, loop=True)
tree.add_state("run", run_clip, loop=True)
tree.add_state("jump", jump_clip, loop=False)
tree.add_transition("idle", "run", lambda: tree.parameters["speed"] > 0.1)
tree.add_transition("run", "idle", lambda: tree.parameters["speed"] < 0.1)
tree.add_transition("idle", "jump", lambda: tree.parameters["jump_pressed"])
tree.set_parameter("speed", 0.0)
tree.start("idle")
"""
def __init__(self, target=None, **kwargs):
super().__init__(**kwargs)
self.target = target # Node to animate (defaults to parent)
self.parameters: dict[str, Any] = {} # User-defined parameters
self.states: dict[str, AnimationState] = {}
self.transitions: list[Transition] = []
self.current_state: str | None = None
self.current_time = 0.0
self.playing = False
# Blending state
self.blending = False
self.blend_from: str | None = None
self.blend_to: str | None = None
self.blend_time = 0.0
self.blend_duration = 0.2
[docs]
def add_state(self, name: str, clip: AnimationClip, speed_scale: float = 1.0, loop: bool = True):
"""Add animation state."""
self.states[name] = AnimationState(name, clip, speed_scale, loop)
[docs]
def add_transition(self, from_state: str, to_state: str, condition: Callable[[], bool], blend_time: float = 0.2):
"""Add transition with condition function."""
self.transitions.append(Transition(from_state, to_state, condition, blend_time))
[docs]
def set_parameter(self, name: str, value: Any):
"""Update animation parameter."""
self.parameters[name] = value
[docs]
def get_parameter(self, name: str, default: Any = None) -> Any:
"""Get animation parameter."""
return self.parameters.get(name, default)
[docs]
def start(self, state_name: str):
"""Start animation tree with initial state."""
if state_name not in self.states:
return
self.current_state = state_name
self.current_time = 0.0
self.playing = True
self.blending = False
[docs]
def stop(self):
"""Stop animation tree."""
self.playing = False
def _resolve_target(self):
"""Resolve target: use explicit target, else fall back to parent."""
return self.target if self.target is not None else self.parent
def _check_transitions(self):
"""Check if any transitions should trigger."""
if not self.current_state or self.blending:
return
for trans in self.transitions:
if trans.from_state == self.current_state:
try:
if trans.condition():
self._start_transition(trans.to_state, trans.blend_time)
break
except Exception as exc:
log.warning("Transition condition %s->%s failed: %s", trans.from_state, trans.to_state, exc)
def _start_transition(self, to_state: str, blend_time: float):
"""Start blending to new state."""
if to_state not in self.states:
return
self.blending = True
self.blend_from = self.current_state
self.blend_to = to_state
self.blend_time = 0.0
self.blend_duration = blend_time
[docs]
def process(self, dt: float):
"""Advance animation state machine each frame (called by SceneTree)."""
target = self._resolve_target()
if not self.playing or not target:
return
self._check_transitions()
if self.blending:
self._update_blend(dt, target)
else:
self._update_normal_playback(dt, target)
def _update_blend(self, dt: float, target):
"""Advance blend transition between two states."""
self.blend_time += dt
blend_factor = min(1.0, self.blend_time / self.blend_duration) if self.blend_duration > 0 else 1.0
if blend_factor >= 1.0:
self.current_state = self.blend_to
self.current_time = 0.0
self.blending = False
self.blend_from = None
self.blend_to = None
else:
if self.blend_from and self.blend_to:
from_state = self.states[self.blend_from]
to_state = self.states[self.blend_to]
from_values = from_state.clip.evaluate(self.current_time)
to_values = to_state.clip.evaluate(0.0)
for prop in set(from_values.keys()) | set(to_values.keys()):
v0 = from_values.get(prop)
v1 = to_values.get(prop)
if v0 is not None and v1 is not None:
blended = _interpolate(v0, v1, blend_factor)
if hasattr(target, prop):
setattr(target, prop, blended)
if self.blend_from:
from_state = self.states[self.blend_from]
self.current_time += dt * from_state.speed_scale
def _update_normal_playback(self, dt: float, target):
"""Advance normal (non-blending) playback for the current state."""
if not self.current_state:
return
state = self.states[self.current_state]
self.current_time += dt * state.speed_scale
if self.current_time >= state.clip.duration:
if state.loop:
self.current_time = self.current_time % state.clip.duration
else:
self.current_time = state.clip.duration
values = state.clip.evaluate(self.current_time)
for prop, value in values.items():
if hasattr(target, prop):
setattr(target, prop, value)
[docs]
def to_dict(self) -> dict:
"""Serialize tree state."""
return {
"parameters": dict(self.parameters),
"states": {
name: {
"clip": state.clip.to_dict(),
"speed_scale": state.speed_scale,
"loop": state.loop,
}
for name, state in self.states.items()
},
"current_state": self.current_state,
"current_time": self.current_time,
}
[docs]
@classmethod
def from_dict(cls, data: dict, target=None):
"""Deserialize tree."""
tree = cls(target=target)
tree.parameters = dict(data.get("parameters", {}))
for name, state_data in data.get("states", {}).items():
clip = AnimationClip.from_dict(state_data["clip"])
tree.add_state(name, clip, state_data.get("speed_scale", 1.0), state_data.get("loop", True))
tree.current_state = data.get("current_state")
tree.current_time = data.get("current_time", 0.0)
return tree