Source code for simvx.core.ui.graph_edit

"""Node graph editor widget for visual programming, animation blend trees, shader graphs.

Provides GraphNode (a draggable node with typed input/output ports) and
GraphEdit (a zoomable canvas that manages nodes and connections).
"""


from __future__ import annotations

import logging
from dataclasses import dataclass

from ..descriptors import Signal
from ..math.types import Vec2
from .containers import Container
from .core import Control

log = logging.getLogger(__name__)

__all__ = ["GraphPort", "GraphConnection", "GraphNode", "GraphEdit"]


[docs] @dataclass class GraphPort: """Input or output port on a GraphNode. Attributes: name: Display name. type: Type identifier for connection validation. colour: RGBA colour for the port indicator. """ name: str type: str = "default" colour: tuple[float, float, float, float] = (0.8, 0.8, 0.8, 1.0)
[docs] @dataclass class GraphConnection: """Connection between two ports on two GraphNodes. Attributes: from_node: Source node name. from_port: Source output port index. to_node: Destination node name. to_port: Destination input port index. """ from_node: str from_port: int to_node: str to_port: int
[docs] class GraphNode(Container): """A node in the graph editor with titled header and typed input/output ports. Example: node = GraphNode(title="Add") node.add_input("A", type="float") node.add_input("B", type="float") node.add_output("Result", type="float") graph.add_graph_node(node) """ def __init__(self, title: str = "Node", **kwargs): super().__init__(**kwargs) self.title = title self.inputs: list[GraphPort] = [] self.outputs: list[GraphPort] = [] self.graph_position: tuple[float, float] = (0.0, 0.0) self.selected = False self.dragging = False self._drag_offset = (0.0, 0.0) self.size = Vec2(160, 80)
[docs] def add_input(self, name: str, type: str = "default") -> int: """Add an input port. Returns port index.""" self.inputs.append(GraphPort(name=name, type=type)) return len(self.inputs) - 1
[docs] def add_output(self, name: str, type: str = "default") -> int: """Add an output port. Returns port index.""" self.outputs.append(GraphPort(name=name, type=type)) return len(self.outputs) - 1
[docs] def get_input_port_position(self, index: int) -> tuple[float, float]: """Position of input port in local node space.""" y_offset = 40.0 + index * 24.0 return (0.0, y_offset)
[docs] def get_output_port_position(self, index: int) -> tuple[float, float]: """Position of output port in local node space.""" y_offset = 40.0 + index * 24.0 return (self.size.x, y_offset)
[docs] class GraphEdit(Control): """Zoomable node graph editor with pan, zoom, and connection management. Nodes are placed in graph space. The viewport applies zoom and scroll offset to determine which portion of the graph is visible. Example: graph = GraphEdit() node_a = GraphNode(name="A", title="Source") node_a.add_output("Out") graph.add_graph_node(node_a) node_b = GraphNode(name="B", title="Sink") node_b.add_input("In") graph.add_graph_node(node_b) graph.connect_node("A", 0, "B", 0) """ def __init__(self, **kwargs): super().__init__(**kwargs) self._nodes: dict[str, GraphNode] = {} self._connections: list[GraphConnection] = [] self._zoom = 1.0 self._scroll_offset = (0.0, 0.0) # Interaction state self._dragging_node: GraphNode | None = None self._panning = False self._pan_start = (0.0, 0.0) self._pan_start_offset = (0.0, 0.0) self._connecting_from: tuple[str, int, bool] | None = None # (node_name, port_idx, is_output) self._connecting_mouse: tuple[float, float] = (0.0, 0.0) # Signals self.connection_request = Signal() # (from_node, from_port, to_node, to_port) self.disconnection_request = Signal() self.node_selected = Signal() # (node_name,) self.size = Vec2(600, 400) # ------------------------------------------------------------------ nodes
[docs] def add_graph_node(self, node: GraphNode): """Register a GraphNode in this editor.""" self._nodes[node.name] = node
[docs] def remove_graph_node(self, name: str): """Remove a GraphNode and all its connections.""" self._nodes.pop(name, None) self._connections = [c for c in self._connections if c.from_node != name and c.to_node != name]
[docs] def get_node(self, name: str) -> GraphNode | None: """Retrieve a graph node by name.""" return self._nodes.get(name)
# ------------------------------------------------------------------ connections
[docs] def connect_node(self, from_node: str, from_port: int, to_node: str, to_port: int): """Create a connection between two ports.""" conn = GraphConnection(from_node, from_port, to_node, to_port) self._connections.append(conn)
[docs] def disconnect_node(self, from_node: str, from_port: int, to_node: str, to_port: int): """Remove a specific connection.""" self._connections = [ c for c in self._connections if not ( c.from_node == from_node and c.from_port == from_port and c.to_node == to_node and c.to_port == to_port ) ]
[docs] def get_connections(self) -> list[GraphConnection]: """Return a copy of the current connections list.""" return list(self._connections)
[docs] def clear_connections(self): """Remove all connections.""" self._connections.clear()
# ------------------------------------------------------------------ viewport
[docs] def set_zoom(self, zoom: float): """Set zoom level, clamped to [0.25, 2.0].""" self._zoom = max(0.25, min(2.0, zoom))
@property def zoom(self) -> float: return self._zoom @property def scroll_offset(self) -> tuple[float, float]: return self._scroll_offset
[docs] def center_on_node(self, name: str): """Pan the viewport to center on the named node.""" node = self._nodes.get(name) if node: self._scroll_offset = ( node.graph_position[0] - self.size.x / 2, node.graph_position[1] - self.size.y / 2, )
# ------------------------------------------------------------------ input _PORT_HIT_RADIUS = 10.0 def _screen_to_graph(self, sx: float, sy: float) -> tuple[float, float]: """Convert screen coords (relative to widget origin) to graph-space.""" x, y, _, _ = self.get_global_rect() return ((sx - x) / self._zoom + self._scroll_offset[0], (sy - y) / self._zoom + self._scroll_offset[1]) def _hit_port(self, gx: float, gy: float) -> tuple[str, int, bool] | None: """Return (node_name, port_index, is_output) if (gx, gy) is near a port, else None.""" r = self._PORT_HIT_RADIUS / self._zoom for node in self._nodes.values(): nx, ny = node.graph_position for pi, _ in enumerate(node.outputs): px, py = node.get_output_port_position(pi) if abs(gx - (nx + px)) < r and abs(gy - (ny + py)) < r: return (node.name, pi, True) for pi, _ in enumerate(node.inputs): px, py = node.get_input_port_position(pi) if abs(gx - (nx + px)) < r and abs(gy - (ny + py)) < r: return (node.name, pi, False) return None def _hit_node(self, gx: float, gy: float) -> GraphNode | None: """Return the topmost node under graph-space point, or None.""" for node in reversed(list(self._nodes.values())): nx, ny = node.graph_position if nx <= gx <= nx + node.size.x and ny <= gy <= ny + node.size.y: return node return None def _on_gui_input(self, event): # Zoom if event.key == "scroll_up": self.set_zoom(self._zoom * 1.1) return if event.key == "scroll_down": self.set_zoom(self._zoom / 1.1) return px = event.position.x if hasattr(event.position, "x") else event.position[0] py = event.position.y if hasattr(event.position, "y") else event.position[1] # --- Left-click: select/drag nodes, connect/disconnect ports --- if event.button == 1: gx, gy = self._screen_to_graph(px, py) if event.pressed: # Check port hit first port_hit = self._hit_port(gx, gy) if port_hit: self._connecting_from = port_hit self._connecting_mouse = (px, py) self.grab_mouse() return # Check node hit — start dragging node = self._hit_node(gx, gy) if node: # Deselect others for n in self._nodes.values(): n.selected = False node.selected = True node.dragging = True node._drag_offset = (gx - node.graph_position[0], gy - node.graph_position[1]) self._dragging_node = node self.node_selected.emit(node.name) self.grab_mouse() else: # Click on empty space — deselect all for n in self._nodes.values(): n.selected = False elif not event.pressed: # Release: finish drag or connection if self._connecting_from: target = self._hit_port(gx, gy) if target and target != self._connecting_from: src, dst = self._connecting_from, target # Ensure output -> input direction if src[2] and not dst[2]: self.connect_node(src[0], src[1], dst[0], dst[1]) self.connection_request.emit(src[0], src[1], dst[0], dst[1]) elif not src[2] and dst[2]: self.connect_node(dst[0], dst[1], src[0], src[1]) self.connection_request.emit(dst[0], dst[1], src[0], src[1]) self._connecting_from = None self.release_mouse() if self._dragging_node: self._dragging_node.dragging = False self._dragging_node = None self.release_mouse() # --- Right-click (button 2): disconnect --- if event.button == 2 and event.pressed: gx, gy = self._screen_to_graph(px, py) self._try_remove_connection(gx, gy) # --- Middle-click (button 3): pan --- if event.button == 3: if event.pressed: self._panning = True self._pan_start = (px, py) self._pan_start_offset = self._scroll_offset elif not event.pressed: self._panning = False # --- Mouse motion --- if event.position and not event.key: if self._dragging_node: gx, gy = self._screen_to_graph(px, py) node = self._dragging_node node.graph_position = (gx - node._drag_offset[0], gy - node._drag_offset[1]) elif self._connecting_from: self._connecting_mouse = (px, py) elif self._panning: dx = px - self._pan_start[0] dy = py - self._pan_start[1] self._scroll_offset = ( self._pan_start_offset[0] - dx / self._zoom, self._pan_start_offset[1] - dy / self._zoom, ) def _try_remove_connection(self, gx: float, gy: float) -> bool: """Remove the connection nearest to graph-space point, if close enough. Returns True if removed.""" threshold = 15.0 best_dist = threshold best_conn = None for conn in self._connections: fn = self._nodes.get(conn.from_node) tn = self._nodes.get(conn.to_node) if not fn or not tn: continue fp = fn.get_output_port_position(conn.from_port) tp = tn.get_input_port_position(conn.to_port) x1, y1 = fn.graph_position[0] + fp[0], fn.graph_position[1] + fp[1] x2, y2 = tn.graph_position[0] + tp[0], tn.graph_position[1] + tp[1] # Sample along the cubic bezier and find minimum distance cp = min(80, abs(x2 - x1) * 0.5) cp1x, cp1y = x1 + cp, y1 cp2x, cp2y = x2 - cp, y2 for i in range(11): t = i / 10 it = 1 - t bx = it**3 * x1 + 3 * it**2 * t * cp1x + 3 * it * t**2 * cp2x + t**3 * x2 by = it**3 * y1 + 3 * it**2 * t * cp1y + 3 * it * t**2 * cp2y + t**3 * y2 d = ((gx - bx) ** 2 + (gy - by) ** 2) ** 0.5 if d < best_dist: best_dist = d best_conn = conn if best_conn: self._connections.remove(best_conn) self.disconnection_request.emit( best_conn.from_node, best_conn.from_port, best_conn.to_node, best_conn.to_port, ) return True return False # -------------------------------------------------------------------- draw _NODE_BG = (0.22, 0.22, 0.28, 0.95) _NODE_HEADER = (0.30, 0.35, 0.55, 1.0) _NODE_SELECTED = (0.45, 0.55, 0.80, 1.0) _NODE_BORDER = (0.40, 0.40, 0.50, 1.0) _PORT_IN = (0.4, 0.8, 0.4, 1.0) _PORT_OUT = (0.8, 0.5, 0.3, 1.0) _CONN_COLOUR = (0.7, 0.7, 0.8, 0.8) _GRID_COLOUR = (0.18, 0.18, 0.22, 1.0) _BG = (0.12, 0.12, 0.15, 1.0) def _graph_to_screen(self, gx: float, gy: float) -> tuple[float, float]: """Convert graph-space coords to screen-space relative to this widget.""" ox, oy = self._scroll_offset return ((gx - ox) * self._zoom, (gy - oy) * self._zoom)
[docs] def draw(self, renderer): x, y, w, h = self.get_global_rect() z = self._zoom # Background renderer.draw_filled_rect(x, y, w, h, self._BG) # Clip to widget bounds renderer.push_clip(int(x), int(y), int(w), int(h)) # Grid grid_step = 40.0 * z if grid_step > 8: ox_off = -(self._scroll_offset[0] * z) % grid_step oy_off = -(self._scroll_offset[1] * z) % grid_step gx = x + ox_off while gx < x + w: renderer.draw_line_coloured(gx, y, gx, y + h, self._GRID_COLOUR) gx += grid_step gy = y + oy_off while gy < y + h: renderer.draw_line_coloured(x, gy, x + w, gy, self._GRID_COLOUR) gy += grid_step # Connections (bezier-approximated as line segments) for conn in self._connections: fn = self._nodes.get(conn.from_node) tn = self._nodes.get(conn.to_node) if not fn or not tn: continue fp = fn.get_output_port_position(conn.from_port) tp = tn.get_input_port_position(conn.to_port) sx1, sy1 = self._graph_to_screen(fn.graph_position[0] + fp[0], fn.graph_position[1] + fp[1]) sx2, sy2 = self._graph_to_screen(tn.graph_position[0] + tp[0], tn.graph_position[1] + tp[1]) sx1 += x sy1 += y sx2 += x sy2 += y # Simple bezier with 2 control points cp_offset = min(80 * z, abs(sx2 - sx1) * 0.5) segments = 12 prev = (sx1, sy1) for i in range(1, segments + 1): t = i / segments it = 1 - t # Cubic bezier: P0, P1=(sx1+cp, sy1), P2=(sx2-cp, sy2), P3 bx = it**3 * sx1 + 3 * it**2 * t * (sx1 + cp_offset) + 3 * it * t**2 * (sx2 - cp_offset) + t**3 * sx2 by = it**3 * sy1 + 3 * it**2 * t * sy1 + 3 * it * t**2 * sy2 + t**3 * sy2 renderer.draw_line_coloured(prev[0], prev[1], bx, by, self._CONN_COLOUR) prev = (bx, by) # Nodes font_scale = 12.0 / 14.0 port_r = 5.0 * z for node in self._nodes.values(): nx, ny = self._graph_to_screen(node.graph_position[0], node.graph_position[1]) nx += x ny += y nw = node.size.x * z nh = node.size.y * z header_h = 28.0 * z # Node body renderer.draw_filled_rect(nx, ny, nw, nh, self._NODE_BG) # Header hdr = self._NODE_SELECTED if node.selected else self._NODE_HEADER renderer.draw_filled_rect(nx, ny, nw, header_h, hdr) # Border renderer.draw_rect_coloured(nx, ny, nw, nh, self._NODE_BORDER) # Title renderer.draw_text_coloured( node.title, nx + 8 * z, ny + 6 * z, font_scale * z, (1.0, 1.0, 1.0, 1.0), ) # Input ports for pi, port in enumerate(node.inputs): px, py = node.get_input_port_position(pi) spx = nx + px * z spy = ny + py * z renderer.draw_filled_circle(spx, spy, port_r, self._PORT_IN) renderer.draw_text_coloured( port.name, spx + port_r + 4 * z, spy - 6 * z, font_scale * z * 0.85, (0.8, 0.8, 0.8, 1.0), ) # Output ports for pi, port in enumerate(node.outputs): px, py = node.get_output_port_position(pi) spx = nx + px * z spy = ny + py * z renderer.draw_filled_circle(spx, spy, port_r, self._PORT_OUT) tw = renderer.text_width(port.name, font_scale * z * 0.85) renderer.draw_text_coloured( port.name, spx - port_r - 4 * z - tw, spy - 6 * z, font_scale * z * 0.85, (0.8, 0.8, 0.8, 1.0), ) # Connection-in-progress wire preview if self._connecting_from: name, pidx, is_out = self._connecting_from node = self._nodes.get(name) if node: pp = node.get_output_port_position(pidx) if is_out else node.get_input_port_position(pidx) sx1, sy1 = self._graph_to_screen(node.graph_position[0] + pp[0], node.graph_position[1] + pp[1]) sx1 += x sy1 += y mx, my = self._connecting_mouse wire_colour = (0.9, 0.9, 0.3, 0.9) cp = min(80 * z, abs(mx - sx1) * 0.5) segments = 12 prev = (sx1, sy1) if is_out: cp1x, cp2x = sx1 + cp, mx - cp else: cp1x, cp2x = sx1 - cp, mx + cp for i in range(1, segments + 1): t = i / segments it = 1 - t bx = it**3 * sx1 + 3 * it**2 * t * cp1x + 3 * it * t**2 * cp2x + t**3 * mx by = it**3 * sy1 + 3 * it**2 * t * sy1 + 3 * it * t**2 * my + t**3 * my renderer.draw_line_coloured(prev[0], prev[1], bx, by, wire_colour) prev = (bx, by) renderer.pop_clip()