| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from typing import Any |
|
|
| from dataflow.connection import Connection |
| from dataflow.enums import DataPortState |
| from dataflow.graph import DataGraph |
| from . import utils |
| from .image_data import ImageDataNode |
| from .text_data import TextDataNode |
| from .text_to_image import TextToImageNode |
|
|
|
|
| @dataclass(slots=True) |
| class GraphController: |
| """Handle Vue Flow events and keep the DataGraph plus node state in sync. |
| |
| This lives in the app layer so the core dataflow package does not |
| depend on your concrete node types. |
| """ |
|
|
| graph: DataGraph |
|
|
| def handle_event(self, event: dict[str, Any]) -> None: |
| event_type = event.get("type") |
| raw_payload = event.get("payload") |
|
|
| if event_type == "connect": |
| payload = raw_payload or {} |
| self._on_connect(payload) |
| elif event_type == "node_moved": |
| payload = raw_payload or {} |
| self._on_node_moved(payload) |
| elif event_type == "node_field_changed": |
| payload = raw_payload or {} |
| self._on_node_field_changed(payload) |
| elif event_type == "edges_delete": |
| edges = raw_payload or [] |
| if isinstance(edges, list): |
| self._on_edges_delete(edges) |
| elif event_type == "edges_change": |
| changes = raw_payload or [] |
| if isinstance(changes, list): |
| self._on_edges_change(changes) |
| elif event_type == "nodes_delete": |
| nodes = raw_payload or [] |
| if isinstance(nodes, list): |
| self._on_nodes_delete(nodes) |
| elif event_type == "nodes_change": |
| changes = raw_payload or [] |
| if isinstance(changes, list): |
| self._on_nodes_change(changes) |
| |
| |
|
|
| def _on_connect(self, payload: dict[str, Any]) -> None: |
| source_handle = payload.get("sourceHandle") or "" |
| target_handle = payload.get("targetHandle") or "" |
|
|
| def split(handle: str) -> tuple[str, str]: |
| if not handle: |
| return "", "" |
| if ":" in handle: |
| node_id, port = handle.split(":", 1) |
| return node_id, port |
| return "", handle |
|
|
| src_node_id, src_port_name = split(source_handle) |
| tgt_node_id, tgt_port_name = split(target_handle) |
|
|
| if not src_node_id: |
| src_node_id = payload.get("source") or "" |
| if not tgt_node_id: |
| tgt_node_id = payload.get("target") or "" |
|
|
| if not src_node_id or not tgt_node_id: |
| return |
|
|
| start_node = self.graph.nodes.get(src_node_id) |
| end_node = self.graph.nodes.get(tgt_node_id) |
| if start_node is None or end_node is None: |
| return |
|
|
| start_port = start_node.outputs.get(src_port_name) if start_node.outputs is not None else None |
| if start_port is None and start_node.inputs is not None: |
| start_port = start_node.inputs.get(src_port_name) |
|
|
| end_port = end_node.inputs.get(tgt_port_name) if end_node.inputs is not None else None |
| if end_port is None and end_node.outputs is not None: |
| end_port = end_node.outputs.get(tgt_port_name) |
|
|
| if start_port is None or end_port is None: |
| return |
|
|
| conn = Connection( |
| start_node=start_node, |
| start_port=start_port, |
| end_node=end_node, |
| end_port=end_port, |
| ) |
|
|
| try: |
| self.graph.add_connection(conn) |
| except ValueError: |
| |
| return |
|
|
| |
| if hasattr(start_port, "state"): |
| start_port.state = DataPortState.DIRTY |
|
|
| if hasattr(end_port, "state"): |
| end_port.state = DataPortState.DIRTY |
| |
| if hasattr(end_port, "value"): |
| end_port.value = None |
|
|
| def _update_node_position(self, node_id: str, position: dict[str, Any]) -> None: |
| node = self.graph.nodes.get(node_id) |
| if node is None: |
| return |
|
|
| pos = position or {} |
| x = pos.get("x") |
| y = pos.get("y") |
|
|
| if x is not None: |
| try: |
| node.x = float(x) |
| except (TypeError, ValueError): |
| pass |
| if y is not None: |
| try: |
| node.y = float(y) |
| except (TypeError, ValueError): |
| pass |
|
|
| def _on_node_moved(self, payload: dict[str, Any]) -> None: |
| node_id = payload.get("id") |
| if not node_id: |
| return |
| position = payload.get("position") or {} |
| self._update_node_position(node_id, position) |
|
|
| def _on_node_field_changed(self, payload: dict[str, Any]) -> None: |
| node_id = payload.get("id") |
| field = payload.get("field") |
| value = payload.get("value") |
|
|
| if not node_id or not field: |
| return |
|
|
| node = self.graph.nodes.get(node_id) |
| if node is None: |
| return |
|
|
| if isinstance(node, TextDataNode) and field == "text": |
| port = node.outputs.get("text") if node.outputs is not None else None |
| if port is not None: |
| port.value = "" if value is None else str(value) |
| port.state = DataPortState.DIRTY |
|
|
| elif isinstance(node, ImageDataNode) and field == "image": |
| port = node.outputs.get("image") if node.outputs is not None else None |
| if port is not None: |
| if value: |
| |
| try: |
| |
| img = utils.decode_image(str(value)) |
| port.value = img |
| port.state = DataPortState.DIRTY |
| except Exception: |
| |
| port.value = None |
| else: |
| port.value = None |
| port.state = DataPortState.DIRTY |
|
|
| elif isinstance(node, TextToImageNode) and field == "image": |
| node.image_src = "" if value is None else str(value) |
|
|
| elif isinstance(node, TextToImageNode) and field == "aspect_ratio": |
| |
| aspect_ratio_value = "1:1" |
|
|
| print(f"[DEBUG controller] Aspect ratio set to {value}, type={type(value)}") |
|
|
| if value is not None and isinstance(value, str): |
| aspect_ratio_value = value.strip() |
| else: |
| print(f"[DEBUG controller] aspect_ratio value is None, using default 1:1") |
|
|
| |
| |
| if aspect_ratio_value and ":" in aspect_ratio_value: |
| parts = aspect_ratio_value.split(":") |
| if len(parts) == 2: |
| try: |
| |
| float(parts[0]) |
| float(parts[1]) |
| old_ratio = node.aspect_ratio |
| node.aspect_ratio = aspect_ratio_value |
|
|
| except (ValueError, TypeError): |
| |
| node.aspect_ratio = "1:1" |
| print(f"[DEBUG Controller] Invalid numeric format, using default: {node.aspect_ratio}") |
|
|
| def _on_edges_delete(self, edges: list[dict[str, Any]]) -> None: |
| """Remove matching connections from the DataGraph when edges are deleted in Vue.""" |
| if not edges: |
| return |
|
|
| for edge in edges: |
| if not isinstance(edge, dict): |
| continue |
|
|
| src_id = edge.get("source") |
| tgt_id = edge.get("target") |
| src_handle = edge.get("sourceHandle") or "" |
| tgt_handle = edge.get("targetHandle") or "" |
|
|
| src_port = src_handle.split(":", 1)[1] if ":" in src_handle else None |
| tgt_port = tgt_handle.split(":", 1)[1] if ":" in tgt_handle else None |
|
|
| def should_remove(conn: Connection) -> bool: |
| if src_id and conn.start_node.node_id != src_id: |
| return False |
| if tgt_id and conn.end_node.node_id != tgt_id: |
| return False |
| if src_port is not None and conn.start_port.name != src_port: |
| return False |
| if tgt_port is not None and conn.end_port.name != tgt_port: |
| return False |
| return True |
|
|
| self.graph.connections = [c for c in self.graph.connections if not should_remove(c)] |
|
|
| def _on_edges_change(self, changes: list[dict[str, Any]]) -> None: |
| """Handle generic edge changes. |
| |
| Vue Flow sends EdgeChange objects. |
| """ |
| if not changes: |
| return |
|
|
| edges_to_delete: list[dict[str, Any]] = [] |
|
|
| for change in changes: |
| if not isinstance(change, dict): |
| continue |
|
|
| if change.get("type") == "remove": |
| |
| edge = change |
| if isinstance(edge, dict): |
| edges_to_delete.append(edge) |
|
|
| if edges_to_delete: |
| self._on_edges_delete(edges_to_delete) |
|
|
| def _on_nodes_delete(self, nodes: list[dict[str, Any]]) -> None: |
| """Remove nodes and all their connections when Vue deletes them.""" |
| if not nodes: |
| return |
|
|
| node_ids = {n.get("id") for n in nodes if isinstance(n, dict) and n.get("id")} |
| self._delete_nodes(node_ids) |
|
|
| def _on_nodes_change(self, changes: list[dict[str, Any]]) -> None: |
| """Handle generic node changes. |
| |
| Currently supports: |
| - type == "remove": delete node and related connections |
| - type == "position": update node position like 'node_moved' |
| Other change types (select, dimensions, etc.) do not affect the DataGraph. |
| """ |
| if not changes: |
| return |
|
|
| node_ids_to_delete: set[str] = set() |
|
|
| for change in changes: |
| if not isinstance(change, dict): |
| continue |
|
|
| ctype = change.get("type") |
| node_id = change.get("id") |
|
|
| if ctype == "remove" and node_id: |
| node_ids_to_delete.add(node_id) |
| elif ctype == "position" and node_id: |
| |
| position = change.get("position") or change.get("positionAbsolute") or {} |
| self._update_node_position(node_id, position) |
|
|
| if node_ids_to_delete: |
| self._delete_nodes(node_ids_to_delete) |
|
|
| def _delete_nodes(self, node_ids: set[str]) -> None: |
| if not node_ids: |
| return |
|
|
| self.graph.connections = [ |
| c |
| for c in self.graph.connections |
| if c.start_node.node_id not in node_ids and c.end_node.node_id not in node_ids |
| ] |
|
|
| for node_id in node_ids: |
| self.graph.nodes.pop(node_id, None) |
|
|