"""Node graph editor widget for visual programming, animation blend trees, shader graphs.
Provides GraphNode (a draggable node with typed input/output ports) and
GraphEdit (a zoomable canvas that manages nodes and connections).
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from ..descriptors import Signal
from ..math.types import Vec2
from .containers import Container
from .core import Control
log = logging.getLogger(__name__)
__all__ = ["GraphPort", "GraphConnection", "GraphNode", "GraphEdit"]
[docs]
@dataclass
class GraphPort:
"""Input or output port on a GraphNode.
Attributes:
name: Display name.
type: Type identifier for connection validation.
colour: RGBA colour for the port indicator.
"""
name: str
type: str = "default"
colour: tuple[float, float, float, float] = (0.8, 0.8, 0.8, 1.0)
[docs]
@dataclass
class GraphConnection:
"""Connection between two ports on two GraphNodes.
Attributes:
from_node: Source node name.
from_port: Source output port index.
to_node: Destination node name.
to_port: Destination input port index.
"""
from_node: str
from_port: int
to_node: str
to_port: int
[docs]
class GraphNode(Container):
"""A node in the graph editor with titled header and typed input/output ports.
Example:
node = GraphNode(title="Add")
node.add_input("A", type="float")
node.add_input("B", type="float")
node.add_output("Result", type="float")
graph.add_graph_node(node)
"""
def __init__(self, title: str = "Node", **kwargs):
super().__init__(**kwargs)
self.title = title
self.inputs: list[GraphPort] = []
self.outputs: list[GraphPort] = []
self.graph_position: tuple[float, float] = (0.0, 0.0)
self.selected = False
self.dragging = False
self._drag_offset = (0.0, 0.0)
self.size = Vec2(160, 80)
[docs]
def add_output(self, name: str, type: str = "default") -> int:
"""Add an output port. Returns port index."""
self.outputs.append(GraphPort(name=name, type=type))
return len(self.outputs) - 1
[docs]
def get_output_port_position(self, index: int) -> tuple[float, float]:
"""Position of output port in local node space."""
y_offset = 40.0 + index * 24.0
return (self.size.x, y_offset)
[docs]
class GraphEdit(Control):
"""Zoomable node graph editor with pan, zoom, and connection management.
Nodes are placed in graph space. The viewport applies zoom and scroll offset
to determine which portion of the graph is visible.
Example:
graph = GraphEdit()
node_a = GraphNode(name="A", title="Source")
node_a.add_output("Out")
graph.add_graph_node(node_a)
node_b = GraphNode(name="B", title="Sink")
node_b.add_input("In")
graph.add_graph_node(node_b)
graph.connect_node("A", 0, "B", 0)
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._nodes: dict[str, GraphNode] = {}
self._connections: list[GraphConnection] = []
self._zoom = 1.0
self._scroll_offset = (0.0, 0.0)
# Interaction state
self._dragging_node: GraphNode | None = None
self._panning = False
self._pan_start = (0.0, 0.0)
self._pan_start_offset = (0.0, 0.0)
self._connecting_from: tuple[str, int, bool] | None = None # (node_name, port_idx, is_output)
self._connecting_mouse: tuple[float, float] = (0.0, 0.0)
# Signals
self.connection_request = Signal() # (from_node, from_port, to_node, to_port)
self.disconnection_request = Signal()
self.node_selected = Signal() # (node_name,)
self.size = Vec2(600, 400)
# ------------------------------------------------------------------ nodes
[docs]
def add_graph_node(self, node: GraphNode):
"""Register a GraphNode in this editor."""
self._nodes[node.name] = node
[docs]
def remove_graph_node(self, name: str):
"""Remove a GraphNode and all its connections."""
self._nodes.pop(name, None)
self._connections = [c for c in self._connections if c.from_node != name and c.to_node != name]
[docs]
def get_node(self, name: str) -> GraphNode | None:
"""Retrieve a graph node by name."""
return self._nodes.get(name)
# ------------------------------------------------------------------ connections
[docs]
def connect_node(self, from_node: str, from_port: int, to_node: str, to_port: int):
"""Create a connection between two ports."""
conn = GraphConnection(from_node, from_port, to_node, to_port)
self._connections.append(conn)
[docs]
def disconnect_node(self, from_node: str, from_port: int, to_node: str, to_port: int):
"""Remove a specific connection."""
self._connections = [
c
for c in self._connections
if not (
c.from_node == from_node and c.from_port == from_port and c.to_node == to_node and c.to_port == to_port
)
]
[docs]
def get_connections(self) -> list[GraphConnection]:
"""Return a copy of the current connections list."""
return list(self._connections)
[docs]
def clear_connections(self):
"""Remove all connections."""
self._connections.clear()
# ------------------------------------------------------------------ viewport
[docs]
def set_zoom(self, zoom: float):
"""Set zoom level, clamped to [0.25, 2.0]."""
self._zoom = max(0.25, min(2.0, zoom))
@property
def zoom(self) -> float:
return self._zoom
@property
def scroll_offset(self) -> tuple[float, float]:
return self._scroll_offset
[docs]
def center_on_node(self, name: str):
"""Pan the viewport to center on the named node."""
node = self._nodes.get(name)
if node:
self._scroll_offset = (
node.graph_position[0] - self.size.x / 2,
node.graph_position[1] - self.size.y / 2,
)
# ------------------------------------------------------------------ input
_PORT_HIT_RADIUS = 10.0
def _screen_to_graph(self, sx: float, sy: float) -> tuple[float, float]:
"""Convert screen coords (relative to widget origin) to graph-space."""
x, y, _, _ = self.get_global_rect()
return ((sx - x) / self._zoom + self._scroll_offset[0],
(sy - y) / self._zoom + self._scroll_offset[1])
def _hit_port(self, gx: float, gy: float) -> tuple[str, int, bool] | None:
"""Return (node_name, port_index, is_output) if (gx, gy) is near a port, else None."""
r = self._PORT_HIT_RADIUS / self._zoom
for node in self._nodes.values():
nx, ny = node.graph_position
for pi, _ in enumerate(node.outputs):
px, py = node.get_output_port_position(pi)
if abs(gx - (nx + px)) < r and abs(gy - (ny + py)) < r:
return (node.name, pi, True)
for pi, _ in enumerate(node.inputs):
px, py = node.get_input_port_position(pi)
if abs(gx - (nx + px)) < r and abs(gy - (ny + py)) < r:
return (node.name, pi, False)
return None
def _hit_node(self, gx: float, gy: float) -> GraphNode | None:
"""Return the topmost node under graph-space point, or None."""
for node in reversed(list(self._nodes.values())):
nx, ny = node.graph_position
if nx <= gx <= nx + node.size.x and ny <= gy <= ny + node.size.y:
return node
return None
def _on_gui_input(self, event):
# Zoom
if event.key == "scroll_up":
self.set_zoom(self._zoom * 1.1)
return
if event.key == "scroll_down":
self.set_zoom(self._zoom / 1.1)
return
px = event.position.x if hasattr(event.position, "x") else event.position[0]
py = event.position.y if hasattr(event.position, "y") else event.position[1]
# --- Left-click: select/drag nodes, connect/disconnect ports ---
if event.button == 1:
gx, gy = self._screen_to_graph(px, py)
if event.pressed:
# Check port hit first
port_hit = self._hit_port(gx, gy)
if port_hit:
self._connecting_from = port_hit
self._connecting_mouse = (px, py)
self.grab_mouse()
return
# Check node hit — start dragging
node = self._hit_node(gx, gy)
if node:
# Deselect others
for n in self._nodes.values():
n.selected = False
node.selected = True
node.dragging = True
node._drag_offset = (gx - node.graph_position[0], gy - node.graph_position[1])
self._dragging_node = node
self.node_selected.emit(node.name)
self.grab_mouse()
else:
# Click on empty space — deselect all
for n in self._nodes.values():
n.selected = False
elif not event.pressed:
# Release: finish drag or connection
if self._connecting_from:
target = self._hit_port(gx, gy)
if target and target != self._connecting_from:
src, dst = self._connecting_from, target
# Ensure output -> input direction
if src[2] and not dst[2]:
self.connect_node(src[0], src[1], dst[0], dst[1])
self.connection_request.emit(src[0], src[1], dst[0], dst[1])
elif not src[2] and dst[2]:
self.connect_node(dst[0], dst[1], src[0], src[1])
self.connection_request.emit(dst[0], dst[1], src[0], src[1])
self._connecting_from = None
self.release_mouse()
if self._dragging_node:
self._dragging_node.dragging = False
self._dragging_node = None
self.release_mouse()
# --- Right-click (button 2): disconnect ---
if event.button == 2 and event.pressed:
gx, gy = self._screen_to_graph(px, py)
self._try_remove_connection(gx, gy)
# --- Middle-click (button 3): pan ---
if event.button == 3:
if event.pressed:
self._panning = True
self._pan_start = (px, py)
self._pan_start_offset = self._scroll_offset
elif not event.pressed:
self._panning = False
# --- Mouse motion ---
if event.position and not event.key:
if self._dragging_node:
gx, gy = self._screen_to_graph(px, py)
node = self._dragging_node
node.graph_position = (gx - node._drag_offset[0], gy - node._drag_offset[1])
elif self._connecting_from:
self._connecting_mouse = (px, py)
elif self._panning:
dx = px - self._pan_start[0]
dy = py - self._pan_start[1]
self._scroll_offset = (
self._pan_start_offset[0] - dx / self._zoom,
self._pan_start_offset[1] - dy / self._zoom,
)
def _try_remove_connection(self, gx: float, gy: float) -> bool:
"""Remove the connection nearest to graph-space point, if close enough. Returns True if removed."""
threshold = 15.0
best_dist = threshold
best_conn = None
for conn in self._connections:
fn = self._nodes.get(conn.from_node)
tn = self._nodes.get(conn.to_node)
if not fn or not tn:
continue
fp = fn.get_output_port_position(conn.from_port)
tp = tn.get_input_port_position(conn.to_port)
x1, y1 = fn.graph_position[0] + fp[0], fn.graph_position[1] + fp[1]
x2, y2 = tn.graph_position[0] + tp[0], tn.graph_position[1] + tp[1]
# Sample along the cubic bezier and find minimum distance
cp = min(80, abs(x2 - x1) * 0.5)
cp1x, cp1y = x1 + cp, y1
cp2x, cp2y = x2 - cp, y2
for i in range(11):
t = i / 10
it = 1 - t
bx = it**3 * x1 + 3 * it**2 * t * cp1x + 3 * it * t**2 * cp2x + t**3 * x2
by = it**3 * y1 + 3 * it**2 * t * cp1y + 3 * it * t**2 * cp2y + t**3 * y2
d = ((gx - bx) ** 2 + (gy - by) ** 2) ** 0.5
if d < best_dist:
best_dist = d
best_conn = conn
if best_conn:
self._connections.remove(best_conn)
self.disconnection_request.emit(
best_conn.from_node, best_conn.from_port, best_conn.to_node, best_conn.to_port,
)
return True
return False
# -------------------------------------------------------------------- draw
_NODE_BG = (0.22, 0.22, 0.28, 0.95)
_NODE_HEADER = (0.30, 0.35, 0.55, 1.0)
_NODE_SELECTED = (0.45, 0.55, 0.80, 1.0)
_NODE_BORDER = (0.40, 0.40, 0.50, 1.0)
_PORT_IN = (0.4, 0.8, 0.4, 1.0)
_PORT_OUT = (0.8, 0.5, 0.3, 1.0)
_CONN_COLOUR = (0.7, 0.7, 0.8, 0.8)
_GRID_COLOUR = (0.18, 0.18, 0.22, 1.0)
_BG = (0.12, 0.12, 0.15, 1.0)
def _graph_to_screen(self, gx: float, gy: float) -> tuple[float, float]:
"""Convert graph-space coords to screen-space relative to this widget."""
ox, oy = self._scroll_offset
return ((gx - ox) * self._zoom, (gy - oy) * self._zoom)
[docs]
def draw(self, renderer):
x, y, w, h = self.get_global_rect()
z = self._zoom
# Background
renderer.draw_filled_rect(x, y, w, h, self._BG)
# Clip to widget bounds
renderer.push_clip(int(x), int(y), int(w), int(h))
# Grid
grid_step = 40.0 * z
if grid_step > 8:
ox_off = -(self._scroll_offset[0] * z) % grid_step
oy_off = -(self._scroll_offset[1] * z) % grid_step
gx = x + ox_off
while gx < x + w:
renderer.draw_line_coloured(gx, y, gx, y + h, self._GRID_COLOUR)
gx += grid_step
gy = y + oy_off
while gy < y + h:
renderer.draw_line_coloured(x, gy, x + w, gy, self._GRID_COLOUR)
gy += grid_step
# Connections (bezier-approximated as line segments)
for conn in self._connections:
fn = self._nodes.get(conn.from_node)
tn = self._nodes.get(conn.to_node)
if not fn or not tn:
continue
fp = fn.get_output_port_position(conn.from_port)
tp = tn.get_input_port_position(conn.to_port)
sx1, sy1 = self._graph_to_screen(fn.graph_position[0] + fp[0], fn.graph_position[1] + fp[1])
sx2, sy2 = self._graph_to_screen(tn.graph_position[0] + tp[0], tn.graph_position[1] + tp[1])
sx1 += x
sy1 += y
sx2 += x
sy2 += y
# Simple bezier with 2 control points
cp_offset = min(80 * z, abs(sx2 - sx1) * 0.5)
segments = 12
prev = (sx1, sy1)
for i in range(1, segments + 1):
t = i / segments
it = 1 - t
# Cubic bezier: P0, P1=(sx1+cp, sy1), P2=(sx2-cp, sy2), P3
bx = it**3 * sx1 + 3 * it**2 * t * (sx1 + cp_offset) + 3 * it * t**2 * (sx2 - cp_offset) + t**3 * sx2
by = it**3 * sy1 + 3 * it**2 * t * sy1 + 3 * it * t**2 * sy2 + t**3 * sy2
renderer.draw_line_coloured(prev[0], prev[1], bx, by, self._CONN_COLOUR)
prev = (bx, by)
# Nodes
font_scale = 12.0 / 14.0
port_r = 5.0 * z
for node in self._nodes.values():
nx, ny = self._graph_to_screen(node.graph_position[0], node.graph_position[1])
nx += x
ny += y
nw = node.size.x * z
nh = node.size.y * z
header_h = 28.0 * z
# Node body
renderer.draw_filled_rect(nx, ny, nw, nh, self._NODE_BG)
# Header
hdr = self._NODE_SELECTED if node.selected else self._NODE_HEADER
renderer.draw_filled_rect(nx, ny, nw, header_h, hdr)
# Border
renderer.draw_rect_coloured(nx, ny, nw, nh, self._NODE_BORDER)
# Title
renderer.draw_text_coloured(
node.title, nx + 8 * z, ny + 6 * z, font_scale * z, (1.0, 1.0, 1.0, 1.0),
)
# Input ports
for pi, port in enumerate(node.inputs):
px, py = node.get_input_port_position(pi)
spx = nx + px * z
spy = ny + py * z
renderer.draw_filled_circle(spx, spy, port_r, self._PORT_IN)
renderer.draw_text_coloured(
port.name, spx + port_r + 4 * z, spy - 6 * z, font_scale * z * 0.85, (0.8, 0.8, 0.8, 1.0),
)
# Output ports
for pi, port in enumerate(node.outputs):
px, py = node.get_output_port_position(pi)
spx = nx + px * z
spy = ny + py * z
renderer.draw_filled_circle(spx, spy, port_r, self._PORT_OUT)
tw = renderer.text_width(port.name, font_scale * z * 0.85)
renderer.draw_text_coloured(
port.name, spx - port_r - 4 * z - tw, spy - 6 * z, font_scale * z * 0.85, (0.8, 0.8, 0.8, 1.0),
)
# Connection-in-progress wire preview
if self._connecting_from:
name, pidx, is_out = self._connecting_from
node = self._nodes.get(name)
if node:
pp = node.get_output_port_position(pidx) if is_out else node.get_input_port_position(pidx)
sx1, sy1 = self._graph_to_screen(node.graph_position[0] + pp[0], node.graph_position[1] + pp[1])
sx1 += x
sy1 += y
mx, my = self._connecting_mouse
wire_colour = (0.9, 0.9, 0.3, 0.9)
cp = min(80 * z, abs(mx - sx1) * 0.5)
segments = 12
prev = (sx1, sy1)
if is_out:
cp1x, cp2x = sx1 + cp, mx - cp
else:
cp1x, cp2x = sx1 - cp, mx + cp
for i in range(1, segments + 1):
t = i / segments
it = 1 - t
bx = it**3 * sx1 + 3 * it**2 * t * cp1x + 3 * it * t**2 * cp2x + t**3 * mx
by = it**3 * sy1 + 3 * it**2 * t * sy1 + 3 * it * t**2 * my + t**3 * my
renderer.draw_line_coloured(prev[0], prev[1], bx, by, wire_colour)
prev = (bx, by)
renderer.pop_clip()