"""TreeView -- hierarchical tree display with expand/collapse.
TreeItem is a lightweight data node. TreeView renders the hierarchy
with indentation, selection highlight, and expand/collapse arrows.
Uses virtual scrolling: only visible rows are drawn each frame.
"""
from __future__ import annotations
import logging
from typing import Any
from ..descriptors import Signal
from ..math.types import Vec2
from .core import Control, ThemeColour
log = logging.getLogger(__name__)
__all__ = ["TreeItem", "TreeView"]
[docs]
class TreeItem:
"""Data node for a tree hierarchy.
Attributes:
text: Display text for this item.
children: Child items.
expanded: Whether children are visible.
data: Arbitrary user data attached to this item.
parent: Parent TreeItem (set internally by add_child/remove_child).
Example:
root = TreeItem("Root")
root.add_child(TreeItem("Child A"))
root.add_child(TreeItem("Child B", data={"key": 42}))
"""
__slots__ = ("text", "children", "expanded", "data", "parent")
def __init__(self, text: str = "", children: list[TreeItem] | None = None, expanded: bool = True, data: Any = None):
self.text = text
self.expanded = expanded
self.data = data
self.parent: TreeItem | None = None
self.children: list[TreeItem] = []
if children:
for child in children:
self.add_child(child)
[docs]
def add_child(self, item: TreeItem) -> TreeItem:
"""Add a child item and set its parent."""
item.parent = self
self.children.append(item)
return item
[docs]
def remove_child(self, item: TreeItem):
"""Remove a child item and clear its parent."""
if item in self.children:
item.parent = None
self.children.remove(item)
[docs]
def __repr__(self) -> str:
return f"TreeItem({self.text!r}, children={len(self.children)})"
[docs]
class TreeView(Control):
"""Scrollable tree widget with expand/collapse and selection.
Renders a TreeItem hierarchy with indentation, clickable
expand/collapse arrows, and selection highlight. Uses virtual
scrolling -- only rows visible in the viewport are drawn.
Example:
root = TreeItem("Scene")
root.add_child(TreeItem("Player"))
root.add_child(TreeItem("Enemies", children=[
TreeItem("Goblin"),
TreeItem("Dragon"),
]))
tree = TreeView(root=root)
tree.item_selected.connect(lambda item: print(item.text))
"""
_draw_caching = True
_draws_children = True
bg_colour = ThemeColour("tree_bg")
text_colour = ThemeColour("text")
select_colour = ThemeColour("tree_select")
hover_colour = ThemeColour("tree_hover")
arrow_colour = ThemeColour("tree_arrow")
def __init__(self, root: TreeItem | None = None, **kwargs):
super().__init__(**kwargs)
self._root: TreeItem | None = root
self.selected: TreeItem | None = None
self.indent = 20.0
self.row_height = 22.0
self.font_size = 14.0
self._hovered_item: TreeItem | None = None
self._row_map: list[tuple[TreeItem, float, float, int]] = []
self._scroll_y: float = 0.0
self._content_height: float = 0.0
# Virtual scroll: flattened tree cache
self._flat_rows: list[tuple[TreeItem, int]] | None = None # (item, depth)
# Signals
self.item_selected = Signal()
self.item_expanded = Signal()
self.item_collapsed = Signal()
self.size = Vec2(250, 300)
# -------------------------------------------------------------- root property
@property
def root(self) -> TreeItem | None:
return self._root
@root.setter
def root(self, value: TreeItem | None):
self._root = value
self._flat_rows = None
if hasattr(self, "_draw_dirty"):
self.queue_redraw()
# -------------------------------------------------------- flat row management
def _invalidate_flat_rows(self):
"""Mark the flattened row cache as stale and request redraw."""
self._flat_rows = None
self.queue_redraw()
def _ensure_flat_rows(self):
"""Rebuild the flat row list if stale."""
if self._flat_rows is not None:
return
self._flat_rows = []
if self._root:
self._flatten(self._root, 0)
def _flatten(self, item: TreeItem, depth: int):
"""Recursively flatten visible tree items into _flat_rows."""
self._flat_rows.append((item, depth))
if item.expanded:
for child in item.children:
self._flatten(child, depth + 1)
# ------------------------------------------------------------------- input
def _on_gui_input(self, event):
# Scroll handling
if event.key == "scroll_up":
self._scroll_y = max(0, self._scroll_y - self.row_height * 2)
self.queue_redraw()
return
if event.key == "scroll_down":
_, _, _, h = self.get_global_rect()
max_scroll = max(0, self._content_height - h)
self._scroll_y = min(max_scroll, self._scroll_y + self.row_height * 2)
self.queue_redraw()
return
# Mouse-move hover tracking (button=0, no key/char)
if event.button == 0 and not event.key and not event.char:
self._resolve_hover(event.position)
return
if event.button != 1 or not event.pressed:
return
if not self._row_map:
return
px = event.position.x if hasattr(event.position, "x") else event.position[0]
py = event.position.y if hasattr(event.position, "y") else event.position[1]
for item, row_x, row_y, depth in self._row_map:
if row_y <= py < row_y + self.row_height:
# Check if click is on the expand/collapse arrow area
arrow_x = row_x + depth * self.indent
if item.children and arrow_x <= px < arrow_x + self.row_height:
item.expanded = not item.expanded
self._invalidate_flat_rows()
if item.expanded:
self.item_expanded.emit(item)
else:
self.item_collapsed.emit(item)
else:
self.selected = item
self.queue_redraw()
self.item_selected.emit(item)
break
def _resolve_hover(self, position):
"""Set _hovered_item from mouse position using _row_map."""
py = position.y if hasattr(position, "y") else position[1]
old = self._hovered_item
self._hovered_item = None
for item, _rx, row_y, _depth in self._row_map:
if row_y <= py < row_y + self.row_height:
self._hovered_item = item
break
if self._hovered_item is not old:
self.queue_redraw()
def _update_mouse_over(self, mouse_pos):
"""Override to clear hover highlight when the mouse leaves the widget."""
super()._update_mouse_over(mouse_pos)
if not self.mouse_over and self._hovered_item is not None:
self._hovered_item = None
self.queue_redraw()
# -------------------------------------------------------------------- draw
[docs]
def draw(self, renderer):
x, y, w, h = self.get_global_rect()
# Background
renderer.draw_filled_rect(x, y, w, h, self.bg_colour)
# Clip to widget bounds
renderer.push_clip(x, y, w, h)
self._draw_visible_rows(renderer, x, y, w, h)
# Scrollbar thumb (if content overflows)
if self._content_height > h and h > 0:
sb_w = 6.0
sb_x = x + w - sb_w - 2
visible_ratio = h / self._content_height
thumb_h = max(20.0, h * visible_ratio)
max_scroll = self._content_height - h
scroll_ratio = self._scroll_y / max_scroll if max_scroll > 0 else 0
thumb_y = y + scroll_ratio * (h - thumb_h)
renderer.draw_filled_rect(sb_x, thumb_y, sb_w, thumb_h, (0.4, 0.4, 0.4, 0.5))
renderer.pop_clip()
def _draw_visible_rows(self, renderer, x: float, y: float, w: float, h: float):
"""Draw only the rows visible in the viewport defined by (x, y, w, h).
Populates ``_row_map`` with the visible rows for hit testing.
Can be called externally (e.g. by FileBrowser) to render the tree
into a custom region.
"""
self._ensure_flat_rows()
total_rows = len(self._flat_rows)
self._content_height = total_rows * self.row_height
# Compute visible range
first_visible = max(0, int(self._scroll_y / self.row_height))
visible_count = int(h / self.row_height) + 2 # +2 for partial rows at top/bottom
last_visible = min(total_rows, first_visible + visible_count)
# Build row map for hit testing (only visible rows)
self._row_map.clear()
scale = self.font_size / 14.0
for i in range(first_visible, last_visible):
item, depth = self._flat_rows[i]
row_y = y + i * self.row_height - self._scroll_y
indent_px = depth * self.indent
row_x = x + indent_px
# Store for hit testing
self._row_map.append((item, x, row_y, depth))
# Selection / hover highlight
if item is self.selected:
renderer.draw_filled_rect(x, row_y, w, self.row_height, self.select_colour)
elif item is self._hovered_item:
renderer.draw_filled_rect(x, row_y, w, self.row_height, self.hover_colour)
# Expand/collapse arrow
if item.children:
arrow_cx = row_x + self.row_height * 0.5
arrow_cy = row_y + self.row_height * 0.5
half = 4.0
if item.expanded:
# Down-pointing triangle v
renderer.draw_line_coloured(
arrow_cx - half, arrow_cy - half * 0.5, arrow_cx, arrow_cy + half * 0.5, self.arrow_colour
)
renderer.draw_line_coloured(
arrow_cx, arrow_cy + half * 0.5, arrow_cx + half, arrow_cy - half * 0.5, self.arrow_colour
)
else:
# Right-pointing triangle >
renderer.draw_line_coloured(
arrow_cx - half * 0.5, arrow_cy - half, arrow_cx + half * 0.5, arrow_cy, self.arrow_colour
)
renderer.draw_line_coloured(
arrow_cx + half * 0.5, arrow_cy, arrow_cx - half * 0.5, arrow_cy + half, self.arrow_colour
)
# Item text
text_x = row_x + self.row_height # leave space for arrow
text_y = row_y + (self.row_height - self.font_size) / 2
renderer.draw_text_coloured(item.text, text_x, text_y, scale, self.text_colour)