Source code for simvx.core.scene

"""
Scene serialization — save/load node trees as JSON or pickle.

Both backends serialize to the same plain-dict intermediate format.
Pickle stores the dict (not raw Node objects), avoiding class-path
brittleness and GPU state issues.

Public API:
    save_scene(node, path, format="json")
    load_scene(path, format=None)  # auto-detects from extension
"""


from __future__ import annotations

import json
import logging
import pickle
from collections.abc import Callable
from pathlib import Path

import numpy as np

from .audio import AudioStreamPlayer, AudioStreamPlayer2D, AudioStreamPlayer3D
from .node import Node
from .nodes_2d.node2d import Node2D
from .nodes_3d.mesh import MeshInstance3D
from .nodes_3d.node3d import Node3D
from .physics_nodes import CharacterBody2D, CharacterBody3D, CollisionShape2D, CollisionShape3D
from .math.types import Quat, Vec2, Vec3

log = logging.getLogger(__name__)


# ============================================================================
# Format versioning — incremental migrations for backward compatibility
# ============================================================================

CURRENT_FORMAT_VERSION = 1

# Migration functions: (from_ver, to_ver) -> callable that transforms the dict
MIGRATIONS: dict[tuple[int, int], Callable[[dict], dict]] = {
    (0, 1): lambda d: d,  # v0 -> v1: no structural changes, just adds version
}


def _migrate_scene(data: dict) -> dict:
    """Migrate scene data to current format version."""
    version = data.get("format_version", 0)
    while version < CURRENT_FORMAT_VERSION:
        next_version = version + 1
        migration = MIGRATIONS.get((version, next_version))
        if migration is None:
            raise ValueError(f"No migration path from format version {version} to {next_version}")
        data = migration(data)
        version = next_version
    data["format_version"] = CURRENT_FORMAT_VERSION
    return data


# ============================================================================
# Type codec — tagged dicts for JSON-safe encoding
# ============================================================================

_TYPE_TAGS = {
    Vec2: "__vec2__",
    Vec3: "__vec3__",
    Quat: "__quat__",
}

_TAG_TO_TYPE = {tag: cls for cls, tag in _TYPE_TAGS.items()}


def _encode_type(v):
    """Encode a Vec2/Vec3/Quat as a tagged dict."""
    if isinstance(v, Quat):
        return {"__quat__": [float(v.w), float(v.x), float(v.y), float(v.z)]}
    if isinstance(v, Vec3):
        return {"__vec3__": [float(c) for c in v]}
    if isinstance(v, Vec2):
        return {"__vec2__": [float(c) for c in v]}
    return v


def _decode_type(d):
    """Decode a tagged dict back to Vec2/Vec3/Quat."""
    if not isinstance(d, dict):
        return d
    for tag, cls in _TAG_TO_TYPE.items():
        if tag in d:
            return cls(*d[tag])
    return d


# ============================================================================
# Serialize — node tree → plain dict
# ============================================================================

# Spatial defaults per base class for skip-if-default logic
_SPATIAL_DEFAULTS_2D = {
    "position": Vec2(),
    "rotation": 0.0,
    "scale": Vec2(1, 1),
}

_SPATIAL_DEFAULTS_3D = {
    "position": Vec3(),
    "rotation": Quat(),
    "scale": Vec3(1, 1, 1),
}


def _type_eq(a, b) -> bool:
    """Compare two Vec/Quat values (or floats) for approximate equality."""
    if isinstance(a, int | float) and isinstance(b, int | float):
        return abs(a - b) < 1e-9
    if isinstance(a, Quat) and isinstance(b, Quat):
        return all(abs(getattr(a, c) - getattr(b, c)) < 1e-9 for c in "wxyz")
    if isinstance(a, Vec3) and isinstance(b, Vec3):
        return all(abs(av - bv) < 1e-9 for av, bv in zip(a, b, strict=True))
    if isinstance(a, Vec2) and isinstance(b, Vec2):
        return all(abs(av - bv) < 1e-9 for av, bv in zip(a, b, strict=True))
    return a == b


def _serialize_material(mat) -> dict:
    """Serialize a Material to a plain dict (including texture URIs)."""
    d = {
        "colour": list(mat.colour),
        "metallic": mat.metallic,
        "roughness": mat.roughness,
        "wireframe": mat.wireframe,
        "blend": mat.blend,
        "double_sided": mat.double_sided,
        "unlit": mat.unlit,
    }
    # Include texture URIs if present
    if mat.albedo_uri:
        d["albedo_uri"] = mat.albedo_uri
    if mat.normal_uri:
        d["normal_uri"] = mat.normal_uri
    if mat.metallic_roughness_uri:
        d["metallic_roughness_uri"] = mat.metallic_roughness_uri
    if mat.emissive_uri:
        d["emissive_uri"] = mat.emissive_uri
    if mat.ao_uri:
        d["ao_uri"] = mat.ao_uri
    return d


def _serialize_node(node: Node) -> dict:
    """Recursively serialize a node and its children."""
    # Sub-scene reference — emit compact form
    if node._packed_scene_path:
        d = {"__scene__": node._packed_scene_path, "name": node.name}
        if isinstance(node, Node3D):
            defaults = _SPATIAL_DEFAULTS_3D
            for attr in ("position", "rotation", "scale"):
                val = getattr(node, attr)
                if not _type_eq(val, defaults[attr]):
                    d[attr] = _encode_type(val)
        elif isinstance(node, Node2D):
            defaults = _SPATIAL_DEFAULTS_2D
            for attr in ("position", "scale"):
                val = getattr(node, attr)
                if not _type_eq(val, defaults[attr]):
                    d[attr] = _encode_type(val)
            if not _type_eq(node.rotation, 0.0):
                d["rotation"] = node.rotation
        return d

    d = {"__type__": type(node).__name__, "name": node.name}

    # Spatial properties (only non-default)
    if isinstance(node, Node3D):
        defaults = _SPATIAL_DEFAULTS_3D
        for attr in ("position", "rotation", "scale"):
            val = getattr(node, attr)
            if not _type_eq(val, defaults[attr]):
                d[attr] = _encode_type(val)
    elif isinstance(node, Node2D):
        defaults = _SPATIAL_DEFAULTS_2D
        for attr in ("position", "scale"):
            val = getattr(node, attr)
            if not _type_eq(val, defaults[attr]):
                d[attr] = _encode_type(val)
        if not _type_eq(node.rotation, 0.0):
            d["rotation"] = node.rotation

    # Properties (non-default only)
    settings = {}
    for name, prop in node.get_properties().items():
        val = getattr(node, name)
        try:
            eq = val == prop.default
            is_default = bool(eq) if not isinstance(eq, np.ndarray) else eq.all()
        except Exception:
            is_default = False
        if not is_default:
            settings[name] = val
    if settings:
        d["settings"] = settings

    # Script
    if node.script:
        d["script"] = node.script
    if node._script_inline:
        d["script_inline"] = node._script_inline
    if node._script_embedded:
        d["script_embedded"] = node._script_embedded

    # Groups
    if node._groups:
        d["groups"] = sorted(node._groups)

    # Mesh resource URI on MeshInstance3D
    if isinstance(node, MeshInstance3D) and node.mesh is not None:
        if getattr(node.mesh, "resource_uri", None):
            d["mesh"] = node.mesh.resource_uri

    # Material on MeshInstance3D
    if isinstance(node, MeshInstance3D) and node.material is not None:
        d["material"] = _serialize_material(node.material)

    # Audio stream on AudioStreamPlayer nodes
    if isinstance(node, AudioStreamPlayer | AudioStreamPlayer2D | AudioStreamPlayer3D):
        if node.stream is not None:
            if getattr(node.stream, "resource_uri", None):
                d["stream"] = node.stream.resource_uri
            else:
                d["stream"] = node.stream.path

    # Children
    children = [_serialize_node(c) for c in node.children]
    if children:
        d["children"] = children

    return d


# ============================================================================
# Deserialize — plain dict → node tree
# ============================================================================


def _deserialize_node(d: dict, scene_dir: str = "") -> Node:
    """Recursively reconstruct a node from a plain dict."""
    # Sub-scene reference
    if "__scene__" in d:
        scene_path = Path(scene_dir) / d["__scene__"] if scene_dir else Path(d["__scene__"])
        sub = PackedScene(scene_path).instance()
        # Apply overrides
        if "name" in d:
            sub.name = d["name"]
        for attr in ("position", "rotation", "scale"):
            if attr in d:
                setattr(sub, attr, _decode_type(d[attr]))
        return sub

    cls_name = d["__type__"]
    cls = Node._registry.get(cls_name)
    if cls is None:
        raise ValueError(f"Unknown node type: {cls_name!r}")

    kwargs = {"name": d.get("name", cls_name)}

    # Spatial properties
    for attr in ("position", "rotation", "scale"):
        if attr in d:
            kwargs[attr] = _decode_type(d[attr])

    # Settings — normalize JSON arrays back to tuples (colours, vectors)
    if "settings" in d:
        for key, val in d["settings"].items():
            if isinstance(val, list):
                d["settings"][key] = tuple(val)
        kwargs.update(d["settings"])

    node = cls(**kwargs)

    # Script
    if "script" in d:
        node.script = d["script"]
    if "script_inline" in d:
        node._script_inline = d["script_inline"]
    if "script_embedded" in d:
        node._script_embedded = d["script_embedded"]

    # Groups
    for group in d.get("groups", ()):
        node.add_to_group(group)

    # Material
    if "material" in d and isinstance(node, MeshInstance3D):
        from .graphics.material import Material

        # Convert texture URI keys from storage format to constructor format
        mat_data = d["material"].copy()
        if "albedo_uri" in mat_data:
            mat_data["albedo_map"] = mat_data.pop("albedo_uri")
        if "normal_uri" in mat_data:
            mat_data["normal_map"] = mat_data.pop("normal_uri")
        if "metallic_roughness_uri" in mat_data:
            mat_data["metallic_roughness_map"] = mat_data.pop("metallic_roughness_uri")
        if "emissive_uri" in mat_data:
            mat_data["emissive_map"] = mat_data.pop("emissive_uri")
        if "ao_uri" in mat_data:
            mat_data["ao_map"] = mat_data.pop("ao_uri")
        node.material = Material(**mat_data)

    # Mesh resource
    if "mesh" in d and isinstance(node, MeshInstance3D):
        from .resource import ResourceCache

        cache = ResourceCache.get()
        if scene_dir:
            cache.base_path = scene_dir
        node.mesh = cache.resolve_mesh(d["mesh"])

    # Audio stream resource
    if "stream" in d and isinstance(node, AudioStreamPlayer | AudioStreamPlayer2D | AudioStreamPlayer3D):
        from .resource import ResourceCache

        cache = ResourceCache.get()
        if scene_dir:
            cache.base_path = scene_dir
        # If it's a URI, resolve it; otherwise create AudioStream directly
        if d["stream"].startswith("audio://"):
            node.stream = cache.resolve_audio(d["stream"])
        else:
            from .audio import AudioStream

            node.stream = AudioStream(d["stream"])

    # Children
    for child_data in d.get("children", ()):
        node.add_child(_deserialize_node(child_data, scene_dir=scene_dir))

    # Re-wire collision for CharacterBody nodes
    if isinstance(node, CharacterBody2D | CharacterBody3D) and node.collision is None:
        shape_type = CollisionShape2D if isinstance(node, CharacterBody2D) else CollisionShape3D
        shape = node.find(shape_type, recursive=False)
        if shape is not None:
            node.collision = shape

    return node


# ============================================================================
# Public API
# ============================================================================


class _TypeEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, (Vec2, Vec3, Quat)):
            return _encode_type(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super().default(obj)


[docs] def save_scene(node: Node, path: str | Path, format: str = "json"): """Save a node tree to disk. Args: node: Root node to serialize. path: Output file path. format: "json" or "pickle". """ data = _serialize_node(node) data["format_version"] = CURRENT_FORMAT_VERSION path = Path(path) if format == "json": path.write_text(json.dumps(data, cls=_TypeEncoder, indent=2)) elif format == "pickle": path.write_bytes(pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)) else: raise ValueError(f"Unknown format: {format!r}")
[docs] def load_scene(path: str | Path, format: str | None = None) -> Node: """Load a node tree from disk. Args: path: Input file path. format: "json", "pickle", or None to auto-detect from extension. Returns: Reconstructed root Node. """ path = Path(path) if format is None: ext = path.suffix.lower() if ext == ".json": format = "json" elif ext in (".pkl", ".pickle"): format = "pickle" else: raise ValueError(f"Cannot auto-detect format from extension: {ext!r}") if format == "json": raw = json.loads(path.read_text()) data = _decode_json_recursive(raw) elif format == "pickle": data = pickle.loads(path.read_bytes()) else: raise ValueError(f"Unknown format: {format!r}") data = _migrate_scene(data) return _deserialize_node(data, scene_dir=str(path.parent))
[docs] class PackedScene: """Reusable scene resource that loads once and can be instanced many times. The scene file is lazily parsed on the first call to instance(). Subsequent calls deep-copy the cached data to produce independent node trees. """ def __init__(self, path: str | Path): self.path = Path(path) self._data: dict | None = None
[docs] def instance(self) -> Node: """Load and return a new node tree from the scene file.""" if self._data is None: ext = self.path.suffix.lower() if ext == ".json": raw = json.loads(self.path.read_text()) self._data = _decode_json_recursive(raw) elif ext in (".pkl", ".pickle"): self._data = pickle.loads(self.path.read_bytes()) else: raise ValueError(f"Cannot auto-detect format from extension: {ext!r}") self._data = _migrate_scene(self._data) import copy node = _deserialize_node(copy.deepcopy(self._data), scene_dir=str(self.path.parent)) node._packed_scene_path = str(self.path) return node
[docs] def save(self, node: Node): """Save a node tree to this PackedScene's path.""" save_scene(node, self.path)
def _decode_json_recursive(obj): """Walk a parsed JSON structure and decode tagged type dicts.""" if isinstance(obj, dict): decoded = _decode_type(obj) if decoded is not obj: return decoded return {k: _decode_json_recursive(v) for k, v in obj.items()} if isinstance(obj, list): return [_decode_json_recursive(item) for item in obj] return obj