Source code for simvx.core.ui.tree

"""TreeView -- hierarchical tree display with expand/collapse.

TreeItem is a lightweight data node. TreeView renders the hierarchy
with indentation, selection highlight, and expand/collapse arrows.
Uses virtual scrolling: only visible rows are drawn each frame.
"""


from __future__ import annotations

import logging
from typing import Any

from ..descriptors import Signal
from ..math.types import Vec2
from .core import Control, ThemeColour

log = logging.getLogger(__name__)

__all__ = ["TreeItem", "TreeView"]


[docs] class TreeItem: """Data node for a tree hierarchy. Attributes: text: Display text for this item. children: Child items. expanded: Whether children are visible. data: Arbitrary user data attached to this item. parent: Parent TreeItem (set internally by add_child/remove_child). Example: root = TreeItem("Root") root.add_child(TreeItem("Child A")) root.add_child(TreeItem("Child B", data={"key": 42})) """ __slots__ = ("text", "children", "expanded", "data", "parent") def __init__(self, text: str = "", children: list[TreeItem] | None = None, expanded: bool = True, data: Any = None): self.text = text self.expanded = expanded self.data = data self.parent: TreeItem | None = None self.children: list[TreeItem] = [] if children: for child in children: self.add_child(child)
[docs] def add_child(self, item: TreeItem) -> TreeItem: """Add a child item and set its parent.""" item.parent = self self.children.append(item) return item
[docs] def remove_child(self, item: TreeItem): """Remove a child item and clear its parent.""" if item in self.children: item.parent = None self.children.remove(item)
[docs] def __repr__(self) -> str: return f"TreeItem({self.text!r}, children={len(self.children)})"
[docs] class TreeView(Control): """Scrollable tree widget with expand/collapse and selection. Renders a TreeItem hierarchy with indentation, clickable expand/collapse arrows, and selection highlight. Uses virtual scrolling -- only rows visible in the viewport are drawn. Example: root = TreeItem("Scene") root.add_child(TreeItem("Player")) root.add_child(TreeItem("Enemies", children=[ TreeItem("Goblin"), TreeItem("Dragon"), ])) tree = TreeView(root=root) tree.item_selected.connect(lambda item: print(item.text)) """ _draw_caching = True _draws_children = True bg_colour = ThemeColour("tree_bg") text_colour = ThemeColour("text") select_colour = ThemeColour("tree_select") hover_colour = ThemeColour("tree_hover") arrow_colour = ThemeColour("tree_arrow") def __init__(self, root: TreeItem | None = None, **kwargs): super().__init__(**kwargs) self._root: TreeItem | None = root self.selected: TreeItem | None = None self.indent = 20.0 self.row_height = 22.0 self.font_size = 14.0 self._hovered_item: TreeItem | None = None self._row_map: list[tuple[TreeItem, float, float, int]] = [] self._scroll_y: float = 0.0 self._content_height: float = 0.0 # Virtual scroll: flattened tree cache self._flat_rows: list[tuple[TreeItem, int]] | None = None # (item, depth) # Signals self.item_selected = Signal() self.item_expanded = Signal() self.item_collapsed = Signal() self.size = Vec2(250, 300) # -------------------------------------------------------------- root property @property def root(self) -> TreeItem | None: return self._root @root.setter def root(self, value: TreeItem | None): self._root = value self._flat_rows = None if hasattr(self, "_draw_dirty"): self.queue_redraw() # -------------------------------------------------------- flat row management def _invalidate_flat_rows(self): """Mark the flattened row cache as stale and request redraw.""" self._flat_rows = None self.queue_redraw() def _ensure_flat_rows(self): """Rebuild the flat row list if stale.""" if self._flat_rows is not None: return self._flat_rows = [] if self._root: self._flatten(self._root, 0) def _flatten(self, item: TreeItem, depth: int): """Recursively flatten visible tree items into _flat_rows.""" self._flat_rows.append((item, depth)) if item.expanded: for child in item.children: self._flatten(child, depth + 1) # ------------------------------------------------------------------- input def _on_gui_input(self, event): # Scroll handling if event.key == "scroll_up": self._scroll_y = max(0, self._scroll_y - self.row_height * 2) self.queue_redraw() return if event.key == "scroll_down": _, _, _, h = self.get_global_rect() max_scroll = max(0, self._content_height - h) self._scroll_y = min(max_scroll, self._scroll_y + self.row_height * 2) self.queue_redraw() return # Mouse-move hover tracking (button=0, no key/char) if event.button == 0 and not event.key and not event.char: self._resolve_hover(event.position) return if event.button != 1 or not event.pressed: return if not self._row_map: return px = event.position.x if hasattr(event.position, "x") else event.position[0] py = event.position.y if hasattr(event.position, "y") else event.position[1] for item, row_x, row_y, depth in self._row_map: if row_y <= py < row_y + self.row_height: # Check if click is on the expand/collapse arrow area arrow_x = row_x + depth * self.indent if item.children and arrow_x <= px < arrow_x + self.row_height: item.expanded = not item.expanded self._invalidate_flat_rows() if item.expanded: self.item_expanded.emit(item) else: self.item_collapsed.emit(item) else: self.selected = item self.queue_redraw() self.item_selected.emit(item) break def _resolve_hover(self, position): """Set _hovered_item from mouse position using _row_map.""" py = position.y if hasattr(position, "y") else position[1] old = self._hovered_item self._hovered_item = None for item, _rx, row_y, _depth in self._row_map: if row_y <= py < row_y + self.row_height: self._hovered_item = item break if self._hovered_item is not old: self.queue_redraw() def _update_mouse_over(self, mouse_pos): """Override to clear hover highlight when the mouse leaves the widget.""" super()._update_mouse_over(mouse_pos) if not self.mouse_over and self._hovered_item is not None: self._hovered_item = None self.queue_redraw() # -------------------------------------------------------------------- draw
[docs] def draw(self, renderer): x, y, w, h = self.get_global_rect() # Background renderer.draw_filled_rect(x, y, w, h, self.bg_colour) # Clip to widget bounds renderer.push_clip(x, y, w, h) self._draw_visible_rows(renderer, x, y, w, h) # Scrollbar thumb (if content overflows) if self._content_height > h and h > 0: sb_w = 6.0 sb_x = x + w - sb_w - 2 visible_ratio = h / self._content_height thumb_h = max(20.0, h * visible_ratio) max_scroll = self._content_height - h scroll_ratio = self._scroll_y / max_scroll if max_scroll > 0 else 0 thumb_y = y + scroll_ratio * (h - thumb_h) renderer.draw_filled_rect(sb_x, thumb_y, sb_w, thumb_h, (0.4, 0.4, 0.4, 0.5)) renderer.pop_clip()
def _draw_visible_rows(self, renderer, x: float, y: float, w: float, h: float): """Draw only the rows visible in the viewport defined by (x, y, w, h). Populates ``_row_map`` with the visible rows for hit testing. Can be called externally (e.g. by FileBrowser) to render the tree into a custom region. """ self._ensure_flat_rows() total_rows = len(self._flat_rows) self._content_height = total_rows * self.row_height # Compute visible range first_visible = max(0, int(self._scroll_y / self.row_height)) visible_count = int(h / self.row_height) + 2 # +2 for partial rows at top/bottom last_visible = min(total_rows, first_visible + visible_count) # Build row map for hit testing (only visible rows) self._row_map.clear() scale = self.font_size / 14.0 for i in range(first_visible, last_visible): item, depth = self._flat_rows[i] row_y = y + i * self.row_height - self._scroll_y indent_px = depth * self.indent row_x = x + indent_px # Store for hit testing self._row_map.append((item, x, row_y, depth)) # Selection / hover highlight if item is self.selected: renderer.draw_filled_rect(x, row_y, w, self.row_height, self.select_colour) elif item is self._hovered_item: renderer.draw_filled_rect(x, row_y, w, self.row_height, self.hover_colour) # Expand/collapse arrow if item.children: arrow_cx = row_x + self.row_height * 0.5 arrow_cy = row_y + self.row_height * 0.5 half = 4.0 if item.expanded: # Down-pointing triangle v renderer.draw_line_coloured( arrow_cx - half, arrow_cy - half * 0.5, arrow_cx, arrow_cy + half * 0.5, self.arrow_colour ) renderer.draw_line_coloured( arrow_cx, arrow_cy + half * 0.5, arrow_cx + half, arrow_cy - half * 0.5, self.arrow_colour ) else: # Right-pointing triangle > renderer.draw_line_coloured( arrow_cx - half * 0.5, arrow_cy - half, arrow_cx + half * 0.5, arrow_cy, self.arrow_colour ) renderer.draw_line_coloured( arrow_cx + half * 0.5, arrow_cy, arrow_cx - half * 0.5, arrow_cy + half, self.arrow_colour ) # Item text text_x = row_x + self.row_height # leave space for arrow text_y = row_y + (self.row_height - self.font_size) / 2 renderer.draw_text_coloured(item.text, text_x, text_y, scale, self.text_colour)