Source code for simvx.core.hot_reload

"""
Hot-reload system -- watch script files for changes and reload them live.

On file change: serialize node state -> importlib.reload(module) -> instantiate
new class -> restore state.  Non-serializable state is warned and skipped.

Public API:
    from simvx.core.hot_reload import HotReloadManager

    mgr = HotReloadManager(tree)
    mgr.watch("game.py")          # start watching a module file
    mgr.poll()                    # call each frame (or on a timer)
"""


from __future__ import annotations

import importlib
import logging
import os
import sys
from pathlib import Path
from typing import Any

from .node import Node
from .scene_tree import SceneTree
from .descriptors import Signal

log = logging.getLogger(__name__)

__all__ = ["HotReloadManager"]


def _serialize_settings(node: Node) -> dict[str, Any]:
    """Extract current Property values from a node into a plain dict."""
    state: dict[str, Any] = {}
    for name, prop in node.get_properties().items():
        try:
            val = getattr(node, prop.attr, prop.default)
            # Quick JSON-safe check: skip callables, generators, etc.
            if callable(val) and not isinstance(val, type):
                continue
            state[name] = val
        except (AttributeError, TypeError):
            log.debug("hot_reload: skipping non-serializable setting %s on %s", name, node.name)
    return state


def _serialize_node_state(node: Node) -> dict[str, Any]:
    """Capture serializable state from a node (settings + spatial props)."""
    data: dict[str, Any] = {
        "name": node.name,
        "class": type(node).__name__,
        "module": type(node).__module__,
        "settings": _serialize_settings(node),
        "groups": list(node._groups),
        "visible": node.visible,
    }

    # Spatial properties
    if hasattr(node, "position"):
        data["position"] = node.position
    if hasattr(node, "rotation"):
        data["rotation"] = node.rotation
    if hasattr(node, "scale") and not callable(node.scale):
        data["scale"] = node.scale
    if hasattr(node, "velocity"):
        data["velocity"] = node.velocity

    # Children (recursive)
    data["children"] = [_serialize_node_state(c) for c in node.children]
    return data


def _restore_settings(node: Node, state: dict[str, Any]):
    """Apply saved settings back to a node, warning on failures."""
    for name, val in state.items():
        try:
            if name in node.get_properties():
                setattr(node, name, val)
        except Exception as e:
            log.warning("hot_reload: failed to restore setting %s=%r on %s: %s", name, val, node.name, e)


def _restore_node_state(node: Node, state: dict[str, Any]):
    """Apply serialized state back to a node."""
    node.name = state.get("name", node.name)
    node.visible = state.get("visible", True)

    for group in state.get("groups", []):
        node.add_to_group(group)

    if "position" in state and hasattr(node, "position"):
        try:
            node.position = state["position"]
        except (KeyError, TypeError, ValueError):
            pass
    if "rotation" in state and hasattr(node, "rotation"):
        try:
            node.rotation = state["rotation"]
        except (KeyError, TypeError, ValueError):
            pass
    if "scale" in state and hasattr(node, "scale") and not callable(node.scale):
        try:
            node.scale = state["scale"]
        except (KeyError, TypeError, ValueError):
            pass
    if "velocity" in state and hasattr(node, "velocity"):
        try:
            node.velocity = state["velocity"]
        except (KeyError, TypeError, ValueError):
            pass

    _restore_settings(node, state.get("settings", {}))


class _WatchedFile:
    """Tracks a single file's modification time and associated module."""

    __slots__ = ("path", "module_name", "mtime")

    def __init__(self, path: str, module_name: str):
        self.path = path
        self.module_name = module_name
        self.mtime = self._stat()

    def _stat(self) -> float:
        try:
            return os.stat(self.path).st_mtime
        except OSError:
            return 0.0

    def changed(self) -> bool:
        """Return True if file was modified since last check."""
        new_mtime = self._stat()
        if new_mtime != self.mtime:
            self.mtime = new_mtime
            return True
        return False


[docs] class HotReloadManager: """Watches script files for changes and hot-reloads node classes. Usage: mgr = HotReloadManager(tree) mgr.watch("my_game.py") # In your game loop: mgr.poll() # checks every poll_interval seconds """ def __init__(self, tree: SceneTree, poll_interval: float = 0.5): self.tree = tree self.poll_interval = poll_interval self.reloaded = Signal() # emitted with (module_name, class_names) on reload self._watched: list[_WatchedFile] = [] self._time_accumulator: float = 0.0
[docs] def watch(self, file_path: str) -> None: """Start watching a Python file for changes. Args: file_path: Path to a .py file. The corresponding module must be importable. """ path = str(Path(file_path).resolve()) if not Path(path).is_file(): log.warning("hot_reload: file not found: %s", path) return # Find module name from sys.modules by matching file path module_name = self._find_module_name(path) if module_name is None: # Try importing from filename module_name = Path(path).stem # Add directory to sys.path if needed parent = str(Path(path).parent) if parent not in sys.path: sys.path.insert(0, parent) try: importlib.import_module(module_name) except ImportError: log.warning("hot_reload: could not import module for %s", path) return # Don't double-watch if any(w.path == path for w in self._watched): return self._watched.append(_WatchedFile(path, module_name)) log.debug("hot_reload: watching %s (%s)", path, module_name)
[docs] def unwatch(self, file_path: str) -> None: """Stop watching a file.""" path = str(Path(file_path).resolve()) self._watched = [w for w in self._watched if w.path != path]
[docs] def poll(self, dt: float = 0.0) -> list[str]: """Check watched files for changes. Call each frame with delta time. Returns: List of module names that were reloaded. """ self._time_accumulator += dt if self._time_accumulator < self.poll_interval: return [] self._time_accumulator = 0.0 reloaded: list[str] = [] for wf in self._watched: if wf.changed(): try: self._reload_module(wf.module_name) reloaded.append(wf.module_name) except Exception as e: log.error("hot_reload: failed to reload %s: %s", wf.module_name, e) return reloaded
[docs] def force_reload(self, module_name: str) -> bool: """Force-reload a specific module by name. Returns: True if reload succeeded, False otherwise. """ try: self._reload_module(module_name) return True except Exception as e: log.error("hot_reload: failed to reload %s: %s", module_name, e) return False
def _reload_module(self, module_name: str): """Reload a module and replace live node instances with new class versions.""" module = sys.modules.get(module_name) if module is None: module = importlib.import_module(module_name) # Collect old classes from the module old_classes = { name: obj for name, obj in vars(module).items() if isinstance(obj, type) and issubclass(obj, Node) and obj.__module__ == module_name } # Find live nodes using old classes live_nodes = self._find_live_nodes(old_classes) # Serialize state of live nodes before reload saved_states: dict[int, dict[str, Any]] = {} for node in live_nodes: saved_states[id(node)] = _serialize_node_state(node) # Reload the module module = importlib.reload(module) # Get new classes new_classes = { name: obj for name, obj in vars(module).items() if isinstance(obj, type) and issubclass(obj, Node) and obj.__module__ == module_name } # Replace nodes with new class instances replaced_classes: list[str] = [] for node in live_nodes: state = saved_states.get(id(node)) if state is None: continue class_name = state["class"] new_cls = new_classes.get(class_name) if new_cls is None: continue # Swap the class on the existing instance (avoids re-parenting) node.__class__ = new_cls _restore_node_state(node, state) if class_name not in replaced_classes: replaced_classes.append(class_name) if replaced_classes: log.info("hot_reload: reloaded %s -> %s", module_name, replaced_classes) self.reloaded.emit(module_name, replaced_classes) def _find_live_nodes(self, classes: dict[str, type]) -> list[Node]: """Find all nodes in the tree whose class is one of the given classes.""" if not self.tree.root: return [] class_set = set(classes.values()) result: list[Node] = [] self._collect_nodes(self.tree.root, class_set, result) return result def _collect_nodes(self, node: Node, class_set: set[type], result: list[Node]): """Recursively collect nodes whose type is in class_set.""" if type(node) in class_set: result.append(node) for child in node.children: self._collect_nodes(child, class_set, result) @staticmethod def _find_module_name(path: str) -> str | None: """Find the module name in sys.modules that corresponds to a file path.""" for name, mod in sys.modules.items(): if mod is None: continue mod_file = getattr(mod, "__file__", None) if mod_file and Path(mod_file).resolve() == Path(path).resolve(): return name return None