"""TreeView -- hierarchical tree display with expand/collapse.
TreeItem is a lightweight data node. TreeView renders the hierarchy
with indentation, selection highlight, and expand/collapse arrows.
Uses virtual scrolling: only visible rows are drawn each frame.
"""
import logging
from typing import Any
from ..signals import Signal
from ..math.types import Vec2
from .core import Control, ThemeColour
from ..input.enums import MouseButton
log = logging.getLogger(__name__)
__all__ = ["TreeItem", "TreeView"]
[docs]
class TreeItem:
"""Data node for a tree hierarchy.
Attributes:
text: Display text for this item.
children: Child items.
expanded: Whether children are visible.
data: Arbitrary user data attached to this item.
parent: Parent TreeItem (set internally by add_child/remove_child).
badge_text: Optional trailing icon/text drawn at the right edge of the
row (e.g. an error indicator). Hit-tested separately so panels can
respond to badge clicks via ``TreeView.item_badge_clicked``.
badge_tooltip: Tooltip text shown when hovering the badge region.
Example:
root = TreeItem("Root")
root.add_child(TreeItem("Child A"))
root.add_child(TreeItem("Child B", data={"key": 42}))
"""
__slots__ = ("text", "children", "expanded", "data", "parent", "badge_text", "badge_tooltip")
def __init__(
self,
text: str = "",
children: list[TreeItem] | None = None,
expanded: bool = True,
data: Any = None,
badge_text: str = "",
badge_tooltip: str = "",
):
self.text = text
self.expanded = expanded
self.data = data
self.parent: TreeItem | None = None
self.children: list[TreeItem] = []
self.badge_text = badge_text
self.badge_tooltip = badge_tooltip
if children:
for child in children:
self.add_child(child)
[docs]
def add_child(self, item: TreeItem) -> TreeItem:
"""Add a child item and set its parent."""
item.parent = self
self.children.append(item)
return item
[docs]
def remove_child(self, item: TreeItem):
"""Remove a child item and clear its parent."""
if item in self.children:
item.parent = None
self.children.remove(item)
[docs]
def __repr__(self) -> str:
return f"TreeItem({self.text!r}, children={len(self.children)})"
[docs]
class TreeView(Control):
"""Scrollable tree widget with expand/collapse and selection.
Renders a TreeItem hierarchy with indentation, clickable
expand/collapse arrows, and selection highlight. Uses virtual
scrolling -- only rows visible in the viewport are drawn.
Example:
root = TreeItem("Scene")
root.add_child(TreeItem("Player"))
root.add_child(TreeItem("Enemies", children=[
TreeItem("Goblin"),
TreeItem("Dragon"),
]))
tree = TreeView(root=root)
tree.item_selected.connect(lambda item: print(item.text))
"""
_draw_caching = True
_draws_children = True
bg_colour = ThemeColour("tree_bg")
text_colour = ThemeColour("text")
select_colour = ThemeColour("tree_select")
hover_colour = ThemeColour("tree_hover")
arrow_colour = ThemeColour("tree_arrow")
def __init__(self, root: TreeItem | None = None, **kwargs):
super().__init__(**kwargs)
self._root: TreeItem | None = root
self.selected: TreeItem | None = None
self.indent = 20.0
self.row_height = 22.0
self.font_size = 14.0
self._hovered_item: TreeItem | None = None
self._row_map: list[tuple[TreeItem, float, float, int]] = []
self._scroll_y: float = 0.0
self._content_height: float = 0.0
# Virtual scroll: flattened tree cache
self._flat_rows: list[tuple[TreeItem, int]] | None = None # (item, depth)
# Signals
self.item_selected = Signal()
self.item_expanded = Signal()
self.item_collapsed = Signal()
# Emitted when a user clicks the trailing badge region of a row. Used
# by the scene tree to navigate from an error icon to its source.
self.item_badge_clicked = Signal()
# badge-region hit-test overlays, parallel to _row_map
self._badge_rects: list[tuple[TreeItem, float, float, float, float]] = []
self.size = Vec2(250, 300)
# -------------------------------------------------------------- root property
@property
def root(self) -> TreeItem | None:
return self._root
[docs]
@root.setter
def root(self, value: TreeItem | None):
self._root = value
self._flat_rows = None
self._row_map.clear()
if hasattr(self, "_draw_dirty"):
self.queue_redraw()
# -------------------------------------------------------- flat row management
def _invalidate_flat_rows(self):
"""Mark the flattened row cache as stale and request redraw."""
self._flat_rows = None
self._row_map.clear()
self.queue_redraw()
def _ensure_flat_rows(self):
"""Rebuild the flat row list if stale."""
if self._flat_rows is not None:
return
self._flat_rows = []
if self._root:
self._flatten(self._root, 0)
def _flatten(self, item: TreeItem, depth: int):
"""Recursively flatten visible tree items into _flat_rows."""
self._flat_rows.append((item, depth))
if item.expanded:
for child in item.children:
self._flatten(child, depth + 1)
# ------------------------------------------------------------------- input
def _on_gui_input(self, event):
# Scroll handling
if event.key == "scroll_up":
self._scroll_y = max(0, self._scroll_y - self.row_height * 2)
self.queue_redraw()
return
if event.key == "scroll_down":
_, _, _, h = self.get_global_rect()
max_scroll = max(0, self._content_height - h)
self._scroll_y = min(max_scroll, self._scroll_y + self.row_height * 2)
self.queue_redraw()
return
# Mouse-move hover tracking (button=0, no key/char)
if event.button is None and not event.key and not event.char:
self._resolve_hover(event.position)
return
if event.button != MouseButton.LEFT or not event.pressed:
return
if not self._row_map:
return
px = event.position.x if hasattr(event.position, "x") else event.position[0]
py = event.position.y if hasattr(event.position, "y") else event.position[1]
# Badge click: takes priority over row selection so clicking an error
# icon dispatches navigation instead of just selecting the node.
for item, bx, by, bw, bh in self._badge_rects:
if bx <= px <= bx + bw and by <= py <= by + bh:
self.item_badge_clicked.emit(item)
return
for item, row_x, row_y, depth in self._row_map:
if row_y <= py < row_y + self.row_height:
# Check if click is on the expand/collapse arrow area
arrow_x = row_x + depth * self.indent
if item.children and arrow_x <= px < arrow_x + self.row_height:
item.expanded = not item.expanded
self._invalidate_flat_rows()
if item.expanded:
self.item_expanded.emit(item)
else:
self.item_collapsed.emit(item)
else:
self.selected = item
self.queue_redraw()
self.item_selected.emit(item)
break
def _resolve_hover(self, position):
"""Set _hovered_item from mouse position using _row_map."""
py = position.y if hasattr(position, "y") else position[1]
old = self._hovered_item
self._hovered_item = None
for item, _rx, row_y, _depth in self._row_map:
if row_y <= py < row_y + self.row_height:
self._hovered_item = item
break
if self._hovered_item is not old:
self.queue_redraw()
def _update_mouse_over(self, mouse_pos):
"""Override to clear hover highlight when the mouse leaves the widget."""
super()._update_mouse_over(mouse_pos)
if not self.mouse_over and self._hovered_item is not None:
self._hovered_item = None
self.queue_redraw()
# -------------------------------------------------------------------- draw
[docs]
def on_draw(self, renderer):
x, y, w, h = self.get_global_rect()
# Background
renderer.draw_rect((x, y), (w, h), colour=self.bg_colour, filled=True)
# Clip to widget bounds
renderer.push_clip(x, y, w, h)
self._draw_visible_rows(renderer, x, y, w, h)
# Scrollbar thumb (if content overflows)
if self._content_height > h and h > 0:
sb_w = 6.0
sb_x = x + w - sb_w - 2
visible_ratio = h / self._content_height
thumb_h = max(20.0, h * visible_ratio)
max_scroll = self._content_height - h
scroll_ratio = self._scroll_y / max_scroll if max_scroll > 0 else 0
thumb_y = y + scroll_ratio * (h - thumb_h)
renderer.draw_rect((sb_x, thumb_y), (sb_w, thumb_h), colour=(0.4, 0.4, 0.4, 0.5), filled=True)
renderer.pop_clip()
def _draw_visible_rows(self, renderer, x: float, y: float, w: float, h: float):
"""Draw only the rows visible in the viewport defined by (x, y, w, h).
Populates ``_row_map`` with the visible rows for hit testing.
Can be called externally (e.g. by FileBrowser) to render the tree
into a custom region.
"""
self._ensure_flat_rows()
total_rows = len(self._flat_rows)
self._content_height = total_rows * self.row_height
# Compute visible range
first_visible = max(0, int(self._scroll_y / self.row_height))
visible_count = int(h / self.row_height) + 2 # +2 for partial rows at top/bottom
last_visible = min(total_rows, first_visible + visible_count)
# Build row map for hit testing (only visible rows)
self._row_map.clear()
self._badge_rects.clear()
scale = self.font_size / 14.0
for i in range(first_visible, last_visible):
item, depth = self._flat_rows[i]
row_y = y + i * self.row_height - self._scroll_y
indent_px = depth * self.indent
row_x = x + indent_px
# Store for hit testing
self._row_map.append((item, x, row_y, depth))
# Selection / hover highlight
if item is self.selected:
renderer.draw_rect((x, row_y), (w, self.row_height), colour=self.select_colour, filled=True)
elif item is self._hovered_item:
renderer.draw_rect((x, row_y), (w, self.row_height), colour=self.hover_colour, filled=True)
# Expand/collapse arrow
if item.children:
arrow_cx = row_x + self.row_height * 0.5
arrow_cy = row_y + self.row_height * 0.5
half = 4.0
if item.expanded:
# Down-pointing triangle v
renderer.draw_line(
(arrow_cx - half, arrow_cy - half * 0.5),
(arrow_cx, arrow_cy + half * 0.5),
colour=self.arrow_colour,
)
renderer.draw_line(
(arrow_cx, arrow_cy + half * 0.5),
(arrow_cx + half, arrow_cy - half * 0.5),
colour=self.arrow_colour,
)
else:
# Right-pointing triangle >
renderer.draw_line(
(arrow_cx - half * 0.5, arrow_cy - half),
(arrow_cx + half * 0.5, arrow_cy),
colour=self.arrow_colour,
)
renderer.draw_line(
(arrow_cx + half * 0.5, arrow_cy),
(arrow_cx - half * 0.5, arrow_cy + half),
colour=self.arrow_colour,
)
# Item text
text_x = row_x + self.row_height # leave space for arrow
text_y = row_y + (self.row_height - self.font_size) / 2
renderer.draw_text(item.text, (text_x, text_y), colour=self.text_colour, scale=scale)
# Trailing badge (clickable): drawn at the right edge of the row
# and recorded in _badge_rects for hit testing in _on_gui_input.
if item.badge_text:
badge_w = renderer.text_width(item.badge_text, scale) + 10.0
badge_x = x + w - badge_w - 6.0
badge_bg = (0.6, 0.2, 0.2, 0.35)
renderer.draw_rect((badge_x, row_y + 2), (badge_w, self.row_height - 4), colour=badge_bg, filled=True)
renderer.draw_text(item.badge_text, (badge_x + 5, text_y), colour=(1.0, 0.85, 0.3, 1.0), scale=scale)
self._badge_rects.append((item, badge_x, row_y + 2, badge_w, self.row_height - 4))