"""Sprite nodes with frame-based animation support."""
from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
from ..descriptors import Property
from ..math.types import Vec2
from ..nodes_2d.node2d import Node2D
from ..properties import Colour
from ..signals import Signal
# ============================================================================
# Sprite2D
# ============================================================================
[docs]
class Sprite2D(Node2D):
"""2D sprite node -- renders a texture via Draw2D.draw_texture().
The ``texture`` property holds a file path; the graphics backend loads it via
TextureManager and stores the GPU index in ``_texture_id``. The ``draw()``
callback emits a textured quad through the renderer (Draw2D).
Attributes:
texture: Path to the image file (PNG/JPG).
colour: RGBA tint (0.0-1.0 floats).
width: Display width in pixels (0 = use texture native size).
height: Display height in pixels (0 = use texture native size).
"""
texture = Property(
None,
hint="Texture source: file path, PNG bytes, or RGBA uint8 ndarray",
on_change="_invalidate_texture_id",
)
colour = Colour((1.0, 1.0, 1.0, 1.0))
width = Property(0, hint="Display width (0 = native)")
height = Property(0, hint="Display height (0 = native)")
# ``"linear"``: bilinear filter (default, smooth scaling).
# ``"nearest"``: nearest-neighbour, the right choice for pixel-art ports
# so up-scaled sprites stay crisp instead of going to mush. Backends that
# don't honour the flag fall back to linear silently; the property still
# round-trips through serialisation so the authoring intent is preserved.
filter = Property("linear", enum=("linear", "nearest"),
hint="Texture sampler filter mode")
flip_h = Property(False, hint="Flip the sprite horizontally (UV-based, pivot preserved)")
flip_v = Property(False, hint="Flip the sprite vertically (UV-based, pivot preserved)")
def __init__(
self,
texture: Any = None,
position=None,
rotation: float = 0.0,
scale=None,
colour: tuple = (1.0, 1.0, 1.0, 1.0),
width: int = 0,
height: int = 0,
filter: str = "linear",
flip_h: bool = False,
flip_v: bool = False,
**kwargs,
):
super().__init__(position=position, rotation=rotation, scale=scale, **kwargs)
if texture is not None:
self.texture = texture
self.colour = colour
self.width = width
self.height = height
self.filter = filter
self.flip_h = flip_h
self.flip_v = flip_v
# Set by SceneAdapter when the texture is loaded on GPU
self._texture_id: int = -1
def _invalidate_texture_id(self) -> None:
"""Force the GPU texture to reload on next frame after `texture` is reassigned."""
self._texture_id = -1
[docs]
def preload(self) -> None:
"""Synchronously upload the texture and resolve ``width`` / ``height``.
Call after ``enter_tree`` (so ``self.app`` is available) and before
the first frame to avoid the size-flash that happens when SceneAdapter
lazily resolves the texture: until that resolve runs, ``_texture_id``
is ``-1`` and the renderer skips the draw, so a sprite that relies on
the texture's native dimensions briefly renders at the default 64×64
fallback once the resolve catches up.
After ``preload()``:
* ``_texture_id`` is non-negative.
* ``width`` / ``height`` are set from the texture's native pixels
if they were both 0 (i.e. "native size" was requested).
No-ops gracefully when the node is not in a tree, when no App is
attached (headless / unit tests), or when the texture is already
loaded. Logs a WARNING on resolve failure but does not raise: the
renderer will fall back to its draw-time path.
"""
if self.texture is None or self._texture_id >= 0:
return
app = self.app
engine = getattr(app, "engine", None) if app is not None else None
tm = getattr(engine, "texture_manager", None) if engine is not None else None
if tm is None:
return
tex_id = tm.resolve(self.texture)
if tex_id < 0:
return
self._texture_id = tex_id
# Adopt the native pixel size when the caller didn't pick a size.
if self.width == 0 and self.height == 0:
tw, th = tm.get_texture_size(tex_id)
if tw > 0 and th > 0:
self.width = tw
self.height = th
[docs]
def on_draw(self, renderer) -> None:
"""Emit a textured quad via the renderer (Draw2D).
``flip_h`` / ``flip_v`` flip the source UV rectangle rather than the
node's scale, so the sprite pivot stays at ``world_position`` (the
Godot semantics ports rely on).
"""
if self._texture_id < 0 or not self.visible:
return
pos, s, rot = self.world_transform
w = self.width if self.width > 0 else 64
h = self.height if self.height > 0 else 64
u0, u1 = (1.0, 0.0) if self.flip_h else (0.0, 1.0)
v0, v1 = (1.0, 0.0) if self.flip_v else (0.0, 1.0)
renderer.draw_texture_region(
self._texture_id,
pos.x - w * s.x * 0.5,
pos.y - h * s.y * 0.5,
w * s.x,
h * s.y,
u0, v0, u1, v1,
colour=self.colour,
rotation=rot,
)
[docs]
def to_dict(self) -> dict:
"""Serialize sprite state (used by clipboard + PlayMode snapshot)."""
return {
"texture": self.texture,
"position": [self.position.x, self.position.y],
"scale": [self.scale.x, self.scale.y],
"rotation": self.rotation,
"colour": list(self.colour),
"visible": self.visible,
"width": self.width,
"height": self.height,
"filter": self.filter,
"flip_h": self.flip_h,
"flip_v": self.flip_v,
}
[docs]
@classmethod
def from_dict(cls, data: dict):
"""Deserialize sprite state."""
sprite = cls(
texture=data.get("texture"),
position=Vec2(*data.get("position", [0, 0])),
scale=Vec2(*data.get("scale", [1, 1])),
rotation=data.get("rotation", 0.0),
colour=tuple(data.get("colour", [1, 1, 1, 1])),
width=data.get("width", 0),
height=data.get("height", 0),
filter=data.get("filter", "linear"),
flip_h=data.get("flip_h", False),
flip_v=data.get("flip_v", False),
)
sprite.visible = data.get("visible", True)
return sprite
# ============================================================================
# SpriteAnimation / AnimatedSprite2D
# ============================================================================
[docs]
@dataclass
class SpriteAnimation:
"""Named sprite animation with frame range."""
name: str
frames: list[int] # Frame indices
fps: float = 10.0
loop: bool = True
[docs]
class AnimatedSprite2D(Sprite2D):
"""Sprite with frame-based animation from sprite sheets.
Inherits from Sprite2D (Node2D), so it participates in the scene tree and
gets ``on_process(dt)`` and ``on_draw(renderer)`` called automatically.
Example:
sprite = AnimatedSprite2D(
texture="player.png",
frames_horizontal=4,
frames_vertical=4
)
sprite.add_animation("walk", frames=[0, 1, 2, 3], fps=10, loop=True)
sprite.add_animation("jump", frames=[4, 5, 6], fps=15, loop=False)
sprite.play("walk")
"""
def __init__(
self,
texture: str = None,
frames_horizontal: int = 1,
frames_vertical: int = 1,
frame_width: int | None = None,
frame_height: int | None = None,
**kwargs,
):
super().__init__(texture=texture, **kwargs)
self.frames_h = frames_horizontal
self.frames_v = frames_vertical
self.frame_width = frame_width # Manual frame size (optional)
self.frame_height = frame_height
# Animation state
self.animations: dict[str, SpriteAnimation] = {}
self.current_animation: str | None = None
self.frame = 0 # Current frame index
self.frame_time = 0.0 # Accumulated time for current frame
self.playing = False
self.animation_finished = False
# Signals
self.animation_finished_signal = Signal()
self.on_frame_changed: Callable[[int], None] | None = None
[docs]
def add_animation(self, name: str, frames: list[int], fps: float = 10.0, loop: bool = True):
"""Register a named animation."""
self.animations[name] = SpriteAnimation(name, frames, fps, loop)
[docs]
def play(self, animation_name: str = "default"):
"""Play named animation."""
if animation_name not in self.animations:
# Fallback: play all frames
total_frames = self.frames_h * self.frames_v
self.add_animation(animation_name, list(range(total_frames)))
self.current_animation = animation_name
self.frame = 0
self.frame_time = 0.0
self.playing = True
self.animation_finished = False
[docs]
def stop(self):
"""Stop animation and reset to the start of the current animation.
``playing`` becomes ``False`` and the frame counter resets so a
subsequent ``play()`` or ``resume()`` begins from frame 0.
"""
self.playing = False
self.frame = 0
self.frame_time = 0.0
[docs]
def pause(self):
"""Pause animation, preserving the current frame and frame time.
``playing`` becomes ``False`` but no state is reset; ``resume()``
continues from where playback left off.
"""
self.playing = False
[docs]
def resume(self):
"""Resume animation from the current frame."""
self.playing = True
[docs]
def on_process(self, dt: float):
"""Advance sprite animation each frame."""
if not self.playing or not self.current_animation:
return
anim = self.animations[self.current_animation]
self.frame_time += dt
frame_duration = 1.0 / anim.fps if anim.fps > 0 else 0.0
while self.frame_time >= frame_duration and frame_duration > 0:
self.frame_time -= frame_duration
old_frame = self.frame
self.frame += 1
# Loop or finish
if self.frame >= len(anim.frames):
if anim.loop:
self.frame = 0
else:
self.frame = len(anim.frames) - 1
self.playing = False
self.animation_finished = True
self.animation_finished_signal()
if self.frame != old_frame and self.on_frame_changed:
self.on_frame_changed(self.frame)
[docs]
def on_draw(self, renderer) -> None:
"""Draw the current animation frame as a textured quad with proper UVs.
``flip_h`` / ``flip_v`` (inherited from ``Sprite2D``) swap the UV
endpoints: pivot remains at ``world_position`` regardless.
"""
if self._texture_id < 0 or not self.visible:
return
pos, s, rot = self.world_transform
w = self.width if self.width > 0 else 64
h = self.height if self.height > 0 else 64
uv0, uv1 = self.frame_uv
u0, u1 = (uv1.x, uv0.x) if self.flip_h else (uv0.x, uv1.x)
v0, v1 = (uv1.y, uv0.y) if self.flip_v else (uv0.y, uv1.y)
renderer.draw_texture_region(
self._texture_id,
pos.x - w * s.x * 0.5,
pos.y - h * s.y * 0.5,
w * s.x,
h * s.y,
u0, v0, u1, v1,
colour=self.colour,
rotation=rot,
)
[docs]
@property
def current_frame_index(self) -> int:
"""Absolute frame index in the sprite sheet."""
if not self.current_animation or self.current_animation not in self.animations:
return 0
anim = self.animations[self.current_animation]
return anim.frames[self.frame] if self.frame < len(anim.frames) else 0
[docs]
@property
def frame_uv(self) -> tuple[Vec2, Vec2]:
"""UV coordinates for the current frame (top-left, bottom-right)."""
idx = self.current_frame_index
row = idx // self.frames_h
col = idx % self.frames_h
u0 = col / self.frames_h
v0 = row / self.frames_v
u1 = (col + 1) / self.frames_h
v1 = (row + 1) / self.frames_v
return (Vec2(u0, v0), Vec2(u1, v1))
[docs]
def to_dict(self) -> dict:
"""Serialize animated sprite."""
data = super().to_dict()
data.update(
{
"frames_h": self.frames_h,
"frames_v": self.frames_v,
"frame_width": self.frame_width,
"frame_height": self.frame_height,
"animations": {
name: {
"frames": anim.frames,
"fps": anim.fps,
"loop": anim.loop,
}
for name, anim in self.animations.items()
},
"current_animation": self.current_animation,
"frame": self.frame,
}
)
return data
[docs]
@classmethod
def from_dict(cls, data: dict):
"""Deserialize animated sprite."""
sprite = cls(
texture=data.get("texture"),
frames_horizontal=data.get("frames_h", 1),
frames_vertical=data.get("frames_v", 1),
frame_width=data.get("frame_width"),
frame_height=data.get("frame_height"),
position=Vec2(*data.get("position", [0, 0])),
scale=Vec2(*data.get("scale", [1, 1])),
rotation=data.get("rotation", 0.0),
colour=tuple(data.get("colour", [1, 1, 1, 1])),
)
sprite.visible = data.get("visible", True)
# Restore animations
for name, anim_data in data.get("animations", {}).items():
sprite.add_animation(name, anim_data["frames"], anim_data.get("fps", 10.0), anim_data.get("loop", True))
sprite.current_animation = data.get("current_animation")
sprite.frame = data.get("frame", 0)
return sprite
[docs]
@classmethod
def from_frames(
cls,
frames: list[Any] | str | Path,
fps: float = 10.0,
*,
name: str = "default",
loop: bool = True,
play: bool = True,
**kwargs,
) -> AnimatedSprite2D:
"""Build a flipbook AnimatedSprite2D from a list of frame textures or a folder.
``frames`` accepts any of:
* ``list``: each element is a per-frame texture source (file path,
PNG bytes, or ``H×W×4`` uint8 ndarray). Frames are stitched into
a single horizontal strip atlas.
* ``str`` / ``Path`` to a directory: every ``*.png`` (recursive: no,
top-level only) is sorted alphabetically and treated as one frame.
An empty directory or no PNG files raises ``FileNotFoundError``.
The resulting sprite uses a single sheet texture (so it follows the
same fast GPU path as a hand-authored atlas: no per-frame upload at
runtime) with ``frames_horizontal = N``, ``frames_vertical = 1``. The
animation named ``name`` is registered with all N frames; ``play=True``
starts playback immediately.
All frames must have the same pixel dimensions; mismatched sizes
raise ``ValueError``.
"""
frame_list = _resolve_frame_inputs(frames)
if not frame_list:
raise FileNotFoundError(f"from_frames: no frames provided ({frames!r})")
# Decode each frame to an RGBA ndarray, then check sizes.
pixels = [_decode_frame(src) for src in frame_list]
first_h, first_w = pixels[0].shape[:2]
for i, p in enumerate(pixels):
if p.shape[:2] != (first_h, first_w):
raise ValueError(
f"from_frames: frame {i} size {p.shape[:2]} != frame 0 size {(first_h, first_w)}"
)
# Horizontal strip atlas.
atlas = np.concatenate(pixels, axis=1) # shape (H, W*N, 4)
n = len(pixels)
sprite = cls(
texture=atlas,
frames_horizontal=n,
frames_vertical=1,
width=first_w,
height=first_h,
**kwargs,
)
sprite.add_animation(name, frames=list(range(n)), fps=fps, loop=loop)
if play:
sprite.play(name)
return sprite
def _resolve_frame_inputs(frames: list[Any] | str | Path) -> list[Any]:
"""Normalise ``frames`` to a concrete list of per-frame sources."""
if isinstance(frames, (str, Path)):
folder = Path(frames)
if not folder.is_dir():
raise FileNotFoundError(f"from_frames: not a directory: {folder}")
return sorted(folder.glob("*.png"))
return list(frames)
def _decode_frame(source: Any) -> np.ndarray:
"""Return an RGBA uint8 ``(H, W, 4)`` ndarray for any frame source."""
if isinstance(source, np.ndarray):
if source.ndim != 3 or source.shape[2] != 4 or source.dtype != np.uint8:
raise ValueError(
f"from_frames: ndarray frames must be RGBA uint8 (H, W, 4); got {source.shape}/{source.dtype}"
)
return source
# File path or bytes: defer to PIL (matches TextureManager's loader).
try:
from PIL import Image # type: ignore[import-not-found]
except ImportError as exc:
raise ImportError("from_frames: PIL is required to decode file/bytes frames") from exc
if isinstance(source, (bytes, bytearray, memoryview)):
import io
img = Image.open(io.BytesIO(bytes(source))).convert("RGBA")
else:
img = Image.open(str(source)).convert("RGBA")
return np.ascontiguousarray(np.array(img, dtype=np.uint8))