"""
Hot-reload system -- watch script files for changes and reload them live.
On file change: serialize node state -> importlib.reload(module) -> instantiate
new class -> restore state. Non-serializable state is warned and skipped.
Public API:
from simvx.core.hot_reload import HotReloadManager
mgr = HotReloadManager(tree)
mgr.watch("game.py") # start watching a module file
mgr.poll() # call each frame (or on a timer)
"""
from __future__ import annotations
import importlib
import logging
import os
import sys
from pathlib import Path
from typing import Any
from .node import Node
from .scene_tree import SceneTree
from .descriptors import Signal
log = logging.getLogger(__name__)
__all__ = ["HotReloadManager"]
def _serialize_settings(node: Node) -> dict[str, Any]:
"""Extract current Property values from a node into a plain dict."""
state: dict[str, Any] = {}
for name, prop in node.get_properties().items():
try:
val = getattr(node, prop.attr, prop.default)
# Quick JSON-safe check: skip callables, generators, etc.
if callable(val) and not isinstance(val, type):
continue
state[name] = val
except (AttributeError, TypeError):
log.debug("hot_reload: skipping non-serializable setting %s on %s", name, node.name)
return state
def _serialize_node_state(node: Node) -> dict[str, Any]:
"""Capture serializable state from a node (settings + spatial props)."""
data: dict[str, Any] = {
"name": node.name,
"class": type(node).__name__,
"module": type(node).__module__,
"settings": _serialize_settings(node),
"groups": list(node._groups),
"visible": node.visible,
}
# Spatial properties
if hasattr(node, "position"):
data["position"] = node.position
if hasattr(node, "rotation"):
data["rotation"] = node.rotation
if hasattr(node, "scale") and not callable(node.scale):
data["scale"] = node.scale
if hasattr(node, "velocity"):
data["velocity"] = node.velocity
# Children (recursive)
data["children"] = [_serialize_node_state(c) for c in node.children]
return data
def _restore_settings(node: Node, state: dict[str, Any]):
"""Apply saved settings back to a node, warning on failures."""
for name, val in state.items():
try:
if name in node.get_properties():
setattr(node, name, val)
except Exception as e:
log.warning("hot_reload: failed to restore setting %s=%r on %s: %s", name, val, node.name, e)
def _restore_node_state(node: Node, state: dict[str, Any]):
"""Apply serialized state back to a node."""
node.name = state.get("name", node.name)
node.visible = state.get("visible", True)
for group in state.get("groups", []):
node.add_to_group(group)
if "position" in state and hasattr(node, "position"):
try:
node.position = state["position"]
except (KeyError, TypeError, ValueError):
pass
if "rotation" in state and hasattr(node, "rotation"):
try:
node.rotation = state["rotation"]
except (KeyError, TypeError, ValueError):
pass
if "scale" in state and hasattr(node, "scale") and not callable(node.scale):
try:
node.scale = state["scale"]
except (KeyError, TypeError, ValueError):
pass
if "velocity" in state and hasattr(node, "velocity"):
try:
node.velocity = state["velocity"]
except (KeyError, TypeError, ValueError):
pass
_restore_settings(node, state.get("settings", {}))
class _WatchedFile:
"""Tracks a single file's modification time and associated module."""
__slots__ = ("path", "module_name", "mtime")
def __init__(self, path: str, module_name: str):
self.path = path
self.module_name = module_name
self.mtime = self._stat()
def _stat(self) -> float:
try:
return os.stat(self.path).st_mtime
except OSError:
return 0.0
def changed(self) -> bool:
"""Return True if file was modified since last check."""
new_mtime = self._stat()
if new_mtime != self.mtime:
self.mtime = new_mtime
return True
return False
[docs]
class HotReloadManager:
"""Watches script files for changes and hot-reloads node classes.
Usage:
mgr = HotReloadManager(tree)
mgr.watch("my_game.py")
# In your game loop:
mgr.poll() # checks every poll_interval seconds
"""
def __init__(self, tree: SceneTree, poll_interval: float = 0.5):
self.tree = tree
self.poll_interval = poll_interval
self.reloaded = Signal() # emitted with (module_name, class_names) on reload
self._watched: list[_WatchedFile] = []
self._time_accumulator: float = 0.0
[docs]
def watch(self, file_path: str) -> None:
"""Start watching a Python file for changes.
Args:
file_path: Path to a .py file. The corresponding module must be importable.
"""
path = str(Path(file_path).resolve())
if not Path(path).is_file():
log.warning("hot_reload: file not found: %s", path)
return
# Find module name from sys.modules by matching file path
module_name = self._find_module_name(path)
if module_name is None:
# Try importing from filename
module_name = Path(path).stem
# Add directory to sys.path if needed
parent = str(Path(path).parent)
if parent not in sys.path:
sys.path.insert(0, parent)
try:
importlib.import_module(module_name)
except ImportError:
log.warning("hot_reload: could not import module for %s", path)
return
# Don't double-watch
if any(w.path == path for w in self._watched):
return
self._watched.append(_WatchedFile(path, module_name))
log.debug("hot_reload: watching %s (%s)", path, module_name)
[docs]
def unwatch(self, file_path: str) -> None:
"""Stop watching a file."""
path = str(Path(file_path).resolve())
self._watched = [w for w in self._watched if w.path != path]
[docs]
def poll(self, dt: float = 0.0) -> list[str]:
"""Check watched files for changes. Call each frame with delta time.
Returns:
List of module names that were reloaded.
"""
self._time_accumulator += dt
if self._time_accumulator < self.poll_interval:
return []
self._time_accumulator = 0.0
reloaded: list[str] = []
for wf in self._watched:
if wf.changed():
try:
self._reload_module(wf.module_name)
reloaded.append(wf.module_name)
except Exception as e:
log.error("hot_reload: failed to reload %s: %s", wf.module_name, e)
return reloaded
[docs]
def force_reload(self, module_name: str) -> bool:
"""Force-reload a specific module by name.
Returns:
True if reload succeeded, False otherwise.
"""
try:
self._reload_module(module_name)
return True
except Exception as e:
log.error("hot_reload: failed to reload %s: %s", module_name, e)
return False
def _reload_module(self, module_name: str):
"""Reload a module and replace live node instances with new class versions."""
module = sys.modules.get(module_name)
if module is None:
module = importlib.import_module(module_name)
# Collect old classes from the module
old_classes = {
name: obj
for name, obj in vars(module).items()
if isinstance(obj, type) and issubclass(obj, Node) and obj.__module__ == module_name
}
# Find live nodes using old classes
live_nodes = self._find_live_nodes(old_classes)
# Serialize state of live nodes before reload
saved_states: dict[int, dict[str, Any]] = {}
for node in live_nodes:
saved_states[id(node)] = _serialize_node_state(node)
# Reload the module
module = importlib.reload(module)
# Get new classes
new_classes = {
name: obj
for name, obj in vars(module).items()
if isinstance(obj, type) and issubclass(obj, Node) and obj.__module__ == module_name
}
# Replace nodes with new class instances
replaced_classes: list[str] = []
for node in live_nodes:
state = saved_states.get(id(node))
if state is None:
continue
class_name = state["class"]
new_cls = new_classes.get(class_name)
if new_cls is None:
continue
# Swap the class on the existing instance (avoids re-parenting)
node.__class__ = new_cls
_restore_node_state(node, state)
if class_name not in replaced_classes:
replaced_classes.append(class_name)
if replaced_classes:
log.info("hot_reload: reloaded %s -> %s", module_name, replaced_classes)
self.reloaded.emit(module_name, replaced_classes)
def _find_live_nodes(self, classes: dict[str, type]) -> list[Node]:
"""Find all nodes in the tree whose class is one of the given classes."""
if not self.tree.root:
return []
class_set = set(classes.values())
result: list[Node] = []
self._collect_nodes(self.tree.root, class_set, result)
return result
def _collect_nodes(self, node: Node, class_set: set[type], result: list[Node]):
"""Recursively collect nodes whose type is in class_set."""
if type(node) in class_set:
result.append(node)
for child in node.children:
self._collect_nodes(child, class_set, result)
@staticmethod
def _find_module_name(path: str) -> str | None:
"""Find the module name in sys.modules that corresponds to a file path."""
for name, mod in sys.modules.items():
if mod is None:
continue
mod_file = getattr(mod, "__file__", None)
if mod_file and Path(mod_file).resolve() == Path(path).resolve():
return name
return None