Spaces:
Running
Running
| """Gradio demo: P&ID graph extraction with Claude VLM + evaluation. | |
| Usage (local): | |
| ANTHROPIC_API_KEY=sk-ant-... python app.py | |
| Or put the key in a `.env` next to this file. The app: | |
| 1. Takes a P&ID image (preset or upload) | |
| 2. Runs extraction (optionally tiled 2x2) via Claude Opus 4.6 | |
| 3. If a ground-truth graphml is provided, collapses it to semantic-only | |
| form and computes node/edge P/R/F1 via `pid2graph_eval.metrics` | |
| 4. Draws both the prediction and the ground truth as NetworkX graphs | |
| using bbox-based layouts so the topology matches the source image | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import time | |
| from pathlib import Path | |
| from typing import Optional | |
| # Matplotlib backend must be set before pyplot import for headless use; | |
| # the `matplotlib.use()` call below taints every subsequent import with | |
| # E402 ("module level import not at top of file"), which is expected here. | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.patches as mpatches # noqa: E402 | |
| import matplotlib.pyplot as plt # noqa: E402 | |
| import networkx as nx # noqa: E402 | |
| import anthropic # noqa: E402 | |
| # --------------------------------------------------------------------------- | |
| # Gradio 4.44 / gradio_client 1.3.0 bug workaround | |
| # --------------------------------------------------------------------------- | |
| # At `/info` boot, Gradio walks every component's JSON schema via | |
| # `gradio_client.utils._json_schema_to_python_type`. That function does not | |
| # handle bool schemas (`additionalProperties: false` or `true`, both of which | |
| # are valid JSON Schema) — it recurses with the bool and then `if "const" in | |
| # schema:` on line 863 raises `TypeError: argument of type 'bool' is not | |
| # iterable`. Patch the function here before importing gradio so the crash is | |
| # avoided regardless of which component triggers it. (Fixed upstream in later | |
| # gradio_client releases; we stay on 4.44 because Python 3.9 can't run | |
| # gradio 5.) Harmless on versions where the bug is already fixed. | |
| import gradio_client.utils as _gc_utils # noqa: E402 | |
| _orig_json_schema_to_python_type = _gc_utils._json_schema_to_python_type | |
| def _patched_json_schema_to_python_type(schema, defs=None): # type: ignore[override] | |
| if isinstance(schema, bool): | |
| # `True` means "any value is allowed"; `False` means "no value". | |
| return "Any" if schema else "None" | |
| return _orig_json_schema_to_python_type(schema, defs) | |
| _gc_utils._json_schema_to_python_type = _patched_json_schema_to_python_type | |
| import gradio as gr # noqa: E402 (must come after the monkey-patch) | |
| from dotenv import load_dotenv # noqa: E402 | |
| from pid2graph_eval.extractor import ( # noqa: E402 | |
| DEFAULT_MODEL, | |
| extract_graph, | |
| extract_graph_tiled, | |
| ) | |
| from pid2graph_eval.gt_loader import ( # noqa: E402 | |
| SEMANTIC_EQUIPMENT_TYPES, | |
| collapse_through_primitives, | |
| filter_by_types, | |
| load_graphml, | |
| ) | |
| from pid2graph_eval.metrics import evaluate # noqa: E402 | |
| # --------------------------------------------------------------------------- | |
| # Config | |
| # --------------------------------------------------------------------------- | |
| APP_ROOT = Path(__file__).parent | |
| SAMPLES_DIR = APP_ROOT / "samples" | |
| load_dotenv(APP_ROOT / ".env") | |
| # Presets: (display name) -> (image path, graphml path) | |
| PRESETS: dict[str, tuple[Path, Path]] = { | |
| "OPEN100 #1 — small (27 semantic nodes)": ( | |
| SAMPLES_DIR / "open100_01_small.png", | |
| SAMPLES_DIR / "open100_01_small.graphml", | |
| ), | |
| "OPEN100 #3 — medium (53 semantic nodes)": ( | |
| SAMPLES_DIR / "open100_03_medium.png", | |
| SAMPLES_DIR / "open100_03_medium.graphml", | |
| ), | |
| "OPEN100 #0 — large (82 semantic nodes)": ( | |
| SAMPLES_DIR / "open100_00_large.png", | |
| SAMPLES_DIR / "open100_00_large.graphml", | |
| ), | |
| } | |
| NONE_LABEL = "(none — upload your own)" | |
| # Fixed palette so pred/GT visualizations use matching colors. | |
| TYPE_COLORS: dict[str, str] = { | |
| "valve": "#ff6b6b", | |
| "pump": "#4ecdc4", | |
| "tank": "#ffd93d", | |
| "instrumentation": "#6bcfff", | |
| "inlet/outlet": "#c47bff", | |
| } | |
| LEGEND_HANDLES = [ | |
| mpatches.Patch(color=c, label=t) for t, c in TYPE_COLORS.items() | |
| ] | |
| # --------------------------------------------------------------------------- | |
| # Visualization | |
| # --------------------------------------------------------------------------- | |
| def _bbox_to_xyxy(bbox) -> Optional[tuple[float, float, float, float]]: | |
| """Normalize a bbox to `(xmin, ymin, xmax, ymax)` floats. | |
| Accepts both shapes that flow through the pipeline: | |
| * **list / tuple** `[x1, y1, x2, y2]` — produced by `gt_loader._bbox` | |
| and by `tile.merge_tile_graphs` for the tiled pred path. | |
| * **dict** `{"xmin": ..., "ymin": ..., "xmax": ..., "ymax": ...}` — | |
| produced by `GraphOut.to_dict()` in single-shot mode, because the | |
| Pydantic `BBox` model round-trips through `model_dump()`. | |
| Returns `None` if the bbox is missing or malformed. | |
| """ | |
| if bbox is None: | |
| return None | |
| if isinstance(bbox, dict): | |
| try: | |
| return ( | |
| float(bbox["xmin"]), | |
| float(bbox["ymin"]), | |
| float(bbox["xmax"]), | |
| float(bbox["ymax"]), | |
| ) | |
| except (KeyError, TypeError, ValueError): | |
| return None | |
| if isinstance(bbox, (list, tuple)) and len(bbox) == 4: | |
| try: | |
| return ( | |
| float(bbox[0]), | |
| float(bbox[1]), | |
| float(bbox[2]), | |
| float(bbox[3]), | |
| ) | |
| except (TypeError, ValueError): | |
| return None | |
| return None | |
| def draw_graph(graph_dict: dict, title: str, figsize=(8, 6)) -> plt.Figure: | |
| """Render a graph as a matplotlib figure. | |
| Node positions come from bbox centers when available (so the drawing | |
| preserves the spatial layout of the original P&ID); nodes without a | |
| bbox fall back to networkx spring layout. | |
| """ | |
| fig, ax = plt.subplots(figsize=figsize, dpi=110) | |
| G = nx.Graph() | |
| pos: dict[str, tuple[float, float]] = {} | |
| colors: list[str] = [] | |
| node_list: list[str] = [] | |
| for n in graph_dict.get("nodes", []): | |
| nid = n["id"] | |
| G.add_node(nid) | |
| node_list.append(nid) | |
| colors.append(TYPE_COLORS.get(n.get("type", ""), "#cccccc")) | |
| coords = _bbox_to_xyxy(n.get("bbox")) | |
| if coords is not None: | |
| x1, y1, x2, y2 = coords | |
| cx = (x1 + x2) / 2.0 | |
| cy = (y1 + y2) / 2.0 | |
| pos[nid] = (cx, -cy) # flip y so the image is right-side up | |
| for e in graph_dict.get("edges", []): | |
| s, t = e.get("source"), e.get("target") | |
| if s in G.nodes and t in G.nodes: | |
| G.add_edge(s, t) | |
| # Fall back to spring layout for any nodes that lack a bbox. | |
| missing = [nid for nid in G.nodes if nid not in pos] | |
| if missing: | |
| if not pos: | |
| pos = nx.spring_layout(G, seed=42) | |
| else: | |
| # Place missing nodes near the existing bbox cloud center. | |
| xs = [p[0] for p in pos.values()] | |
| ys = [p[1] for p in pos.values()] | |
| cx0 = sum(xs) / len(xs) | |
| cy0 = sum(ys) / len(ys) | |
| for nid in missing: | |
| pos[nid] = (cx0, cy0) | |
| nx.draw_networkx_edges(G, pos, alpha=0.35, width=0.6, ax=ax) | |
| nx.draw_networkx_nodes( | |
| G, pos, | |
| nodelist=node_list, | |
| node_color=colors, | |
| node_size=55, | |
| linewidths=0.5, | |
| edgecolors="#222", | |
| ax=ax, | |
| ) | |
| ax.set_title(title, fontsize=11) | |
| ax.set_aspect("equal") | |
| ax.axis("off") | |
| ax.legend(handles=LEGEND_HANDLES, loc="lower right", fontsize=7, framealpha=0.9) | |
| fig.tight_layout() | |
| return fig | |
| # --------------------------------------------------------------------------- | |
| # Pipeline | |
| # --------------------------------------------------------------------------- | |
| def _preset_paths(preset_name: str) -> tuple[Optional[str], Optional[str]]: | |
| """Resolve a preset dropdown selection to (image_path, graphml_path).""" | |
| if preset_name == NONE_LABEL or preset_name not in PRESETS: | |
| return None, None | |
| img, gt = PRESETS[preset_name] | |
| return (str(img) if img.exists() else None, | |
| str(gt) if gt.exists() else None) | |
| def _format_metrics(metrics: dict, latency_s: float, mode: str) -> str: | |
| nm = metrics["nodes"] | |
| em = metrics["edges"] | |
| return f""" | |
| ### Metrics | |
| | | Precision | Recall | F1 | TP | FP | FN | | |
| |---|---:|---:|---:|---:|---:|---:| | |
| | **Nodes** | {nm['precision']:.3f} | {nm['recall']:.3f} | **{nm['f1']:.3f}** | {nm['tp']} | {nm['fp']} | {nm['fn']} | | |
| | **Edges** | {em['precision']:.3f} | {em['recall']:.3f} | **{em['f1']:.3f}** | {em['tp']} | {em['fp']} | {em['fn']} | | |
| - Pred: **{metrics['n_pred_nodes']}** ノード / **{metrics['n_pred_edges']}** エッジ | |
| - GT (semantic-collapsed): **{metrics['n_gt_nodes']}** ノード / **{metrics['n_gt_edges']}** エッジ | |
| - Mode: `{mode}` · Latency: **{latency_s:.1f}s** | |
| """ | |
| def _format_pred_only(pred_dict: dict, latency_s: float, mode: str) -> str: | |
| return f""" | |
| ### Prediction | |
| - **{len(pred_dict['nodes'])}** ノード / **{len(pred_dict['edges'])}** エッジ | |
| - Mode: `{mode}` · Latency: **{latency_s:.1f}s** | |
| - (正解 graphml 未指定のため評価スキップ) | |
| """ | |
| def run_extraction( | |
| preset_name: str, | |
| image_path: Optional[str], | |
| gt_path: Optional[str], | |
| use_tiling: bool, | |
| progress: gr.Progress = gr.Progress(), | |
| ) -> tuple[str, Optional[plt.Figure], Optional[plt.Figure], str]: | |
| """Entry point wired to the Run button.""" | |
| # Preset overrides manual upload so the demo is reproducible. | |
| preset_img, preset_gt = _preset_paths(preset_name) | |
| if preset_img: | |
| image_path = preset_img | |
| if preset_gt: | |
| gt_path = preset_gt | |
| if not image_path: | |
| return ( | |
| "⚠️ 画像をアップロードするか、プリセットを選択してください。", | |
| None, None, "", | |
| ) | |
| if not os.environ.get("ANTHROPIC_API_KEY"): | |
| return ( | |
| "⚠️ `ANTHROPIC_API_KEY` が設定されていません。`.env` に追記して再起動してください。", | |
| None, None, "", | |
| ) | |
| client = anthropic.Anthropic() | |
| mode = "tiled 2x2 + seam filter" if use_tiling else "single-shot" | |
| try: | |
| progress(0.05, desc=f"VLM 抽出開始 ({mode})…") | |
| t0 = time.time() | |
| if use_tiling: | |
| pred_dict = extract_graph_tiled( | |
| Path(image_path), | |
| client=client, | |
| rows=2, | |
| cols=2, | |
| overlap=0.1, | |
| dedup_px=40.0, | |
| ) | |
| else: | |
| pred = extract_graph(Path(image_path), client=client) | |
| pred_dict = pred.to_dict() | |
| latency = time.time() - t0 | |
| progress(0.55, desc="予測を semantic types に絞り込み…") | |
| # Defensive: drop anything non-semantic the VLM may have emitted. | |
| pred_dict = filter_by_types(pred_dict, SEMANTIC_EQUIPMENT_TYPES) | |
| except Exception as e: | |
| return (f"❌ VLM 抽出中にエラー: `{e}`", None, None, "") | |
| progress(0.65, desc="予測グラフを描画…") | |
| pred_fig = draw_graph( | |
| pred_dict, | |
| title=f"Prediction — {len(pred_dict['nodes'])} nodes, {len(pred_dict['edges'])} edges", | |
| ) | |
| gt_fig = None | |
| metrics_md = _format_pred_only(pred_dict, latency, mode) | |
| if gt_path and Path(gt_path).exists(): | |
| try: | |
| progress(0.75, desc="GT graphml をロード & 縮約…") | |
| gt_raw = load_graphml(Path(gt_path)) | |
| gt_dict = collapse_through_primitives(gt_raw, SEMANTIC_EQUIPMENT_TYPES) | |
| progress(0.85, desc="P/R/F1 を評価…") | |
| metrics = evaluate( | |
| pred_dict, | |
| gt_dict, | |
| directed=False, | |
| match_threshold=0.5, | |
| ) | |
| metrics_md = _format_metrics(metrics, latency, mode) | |
| progress(0.95, desc="GT グラフを描画…") | |
| gt_fig = draw_graph( | |
| gt_dict, | |
| title=f"Ground Truth — {len(gt_dict['nodes'])} nodes, {len(gt_dict['edges'])} edges", | |
| ) | |
| except Exception as e: | |
| metrics_md += f"\n\n⚠️ GT 処理でエラー: `{e}`" | |
| # Strip heavy-ish keys before JSON display. | |
| display_dict = { | |
| "nodes": pred_dict["nodes"], | |
| "edges": pred_dict["edges"], | |
| } | |
| pred_json = json.dumps(display_dict, indent=2, ensure_ascii=False) | |
| progress(1.0, desc="完了") | |
| return metrics_md, pred_fig, gt_fig, pred_json | |
| def on_preset_change(preset_name: str): | |
| """When a preset is picked, auto-fill the image and graphml fields.""" | |
| img, gt = _preset_paths(preset_name) | |
| return img, gt | |
| # --------------------------------------------------------------------------- | |
| # UI | |
| # --------------------------------------------------------------------------- | |
| DESCRIPTION = """ | |
| # PID2Graph × Claude VLM Demo | |
| P&ID (配管計装図) を Claude Opus 4.6 のビジョンで読み取り、シンボル(valve / pump / | |
| tank / instrumentation / inlet・outlet)とその接続関係を JSON グラフに変換します。 | |
| 正解 graphml を指定すると、ノード/エッジ単位の Precision / Recall / F1 を算出します。 | |
| - **タイル分割 (2x2)**: 大きな図面では 1 枚を 4 タイルに分割してから抽出し、マージ時に | |
| bbox 距離で重複排除 + タイル境界の inlet/outlet FP を後処理で除去します。 | |
| - **評価ルール**: GT 側は semantic equipment のみを残し、配管プリミティブ (connector / | |
| crossing / arrow / background) を経由する接続を 1 エッジに縮約します。 | |
| - **VLM 設定**: `temperature=0` で決定論的サンプリング、構造化出力で JSON スキーマを強制。 | |
| """ | |
| def build_ui() -> gr.Blocks: | |
| with gr.Blocks(title="PID2Graph × Claude VLM Demo") as demo: | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| preset = gr.Dropdown( | |
| choices=[NONE_LABEL] + list(PRESETS.keys()), | |
| value=NONE_LABEL, | |
| label="プリセット (OPEN100 より)", | |
| ) | |
| image_in = gr.Image( | |
| type="filepath", | |
| label="P&ID 画像", | |
| height=260, | |
| ) | |
| gt_in = gr.File( | |
| label="正解 graphml (任意)", | |
| file_types=[".graphml", ".xml"], | |
| type="filepath", | |
| ) | |
| tiling = gr.Checkbox( | |
| value=True, | |
| label="タイル分割 (2x2) で抽出 — 高精度だがコスト・時間 4 倍", | |
| ) | |
| run_btn = gr.Button("抽出実行", variant="primary") | |
| gr.Markdown( | |
| "モデル: `" + DEFAULT_MODEL + "` · 所要時間目安: single ~20s / tiled ~60-80s" | |
| ) | |
| with gr.Column(scale=2): | |
| metrics_md = gr.Markdown() | |
| with gr.Row(): | |
| pred_plot = gr.Plot(label="Prediction") | |
| gt_plot = gr.Plot(label="Ground Truth") | |
| with gr.Accordion("予測 JSON (nodes / edges)", open=False): | |
| # NOTE: using Textbox rather than `gr.Code(language="json")` | |
| # because the latter's schema has tripped the gradio_client | |
| # `additionalProperties: false` bug on 4.44.1 in the past. | |
| # Textbox is a plain string component — zero schema surface. | |
| pred_json = gr.Textbox( | |
| label="", | |
| lines=20, | |
| max_lines=30, | |
| show_copy_button=True, | |
| interactive=False, | |
| ) | |
| preset.change(on_preset_change, inputs=[preset], outputs=[image_in, gt_in]) | |
| run_btn.click( | |
| run_extraction, | |
| inputs=[preset, image_in, gt_in, tiling], | |
| outputs=[metrics_md, pred_plot, gt_plot, pred_json], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| # HF Spaces runs this inside a container where localhost is not | |
| # reachable from outside — binding to 0.0.0.0 is required, otherwise | |
| # Gradio raises "When localhost is not accessible, a shareable link | |
| # must be created". Locally this is harmless. | |
| # | |
| # `show_api=False` hides the docs panel in the UI; the monkey-patch | |
| # at the top of this file is what actually prevents the 4.44 schema | |
| # crash (now kept as defensive dead-code since we pin to 4.31.5). | |
| build_ui().launch(show_api=False, server_name="0.0.0.0") | |