mjbuehler commited on
Commit
469c382
·
verified ·
1 Parent(s): 6d38791

Create graph_reasoning.py

Browse files
Files changed (1) hide show
  1. graph_reasoning.py +588 -0
graph_reasoning.py ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ graph_reasoning.py
4
+
5
+ CLI runner for Graph-PRefLexOR-style models:
6
+ - Load a user-specified HF model
7
+ - Accept a user prompt (arg or stdin)
8
+ - Generate with Hugging Face Transformers
9
+ - Save prompt, rendered prompt, thinking/content/full output, and graph artifacts
10
+ - Extract <graph_json>...</graph_json>, parse JSON, build NetworkX DiGraph
11
+ - Render graph to PNG + SVG (Graphviz dot if available, else spring layout)
12
+ - Robust fail-safe crash handling + atomic writes
13
+
14
+ Example:
15
+ python graph_reasoning.py \
16
+ --model lamm-mit/Graph-Preflexor-8b_12292025 \
17
+ --prompt "Explain dragline silk toughness."
18
+
19
+ Stdin prompt:
20
+ echo "Your prompt here" | python graph_reasoning.py --model ... --prompt -
21
+
22
+ Notes:
23
+ - If the model uses a different thinking end token, pass --think-end-token-id
24
+ - If the model doesn't support enable_thinking in apply_chat_template, we fall back safely.
25
+ """
26
+
27
+ import os
28
+ import re
29
+ import sys
30
+ import json
31
+ import math
32
+ import time
33
+ import argparse
34
+ import logging
35
+ from datetime import datetime
36
+ from typing import Optional, Tuple, Any, Dict
37
+
38
+ import torch
39
+ import networkx as nx
40
+ import matplotlib.pyplot as plt
41
+ from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
42
+
43
+
44
+ # ==============================================================================
45
+ # Constants / defaults
46
+ # ==============================================================================
47
+
48
+ GRAPH_JSON_OPEN = "<graph_json>"
49
+ GRAPH_JSON_CLOSE = "</graph_json>"
50
+
51
+
52
+ # ==============================================================================
53
+ # Helpers: filesystem + parsing
54
+ # ==============================================================================
55
+
56
+ def atomic_write_text(path: str, text: str) -> None:
57
+ """Write text atomically to avoid partial files on crash."""
58
+ tmp = path + ".tmp"
59
+ with open(tmp, "w", encoding="utf-8") as f:
60
+ f.write(text)
61
+ os.replace(tmp, path)
62
+
63
+
64
+ def atomic_write_bytes(path: str, data: bytes) -> None:
65
+ """Atomic binary write."""
66
+ tmp = path + ".tmp"
67
+ with open(tmp, "wb") as f:
68
+ f.write(data)
69
+ os.replace(tmp, path)
70
+
71
+
72
+ def safe_json_loads(s: str) -> Optional[Any]:
73
+ """Best-effort JSON parsing."""
74
+ try:
75
+ return json.loads(s)
76
+ except Exception:
77
+ return None
78
+
79
+
80
+ def now_run_id() -> str:
81
+ return datetime.now().strftime("%Y%m%d_%H%M%S")
82
+
83
+
84
+ def resolve_prompt(prompt_arg: str) -> str:
85
+ """
86
+ Resolve prompt from:
87
+ - literal string
88
+ - '-' meaning read stdin fully
89
+ - '@path' meaning read prompt from file
90
+ """
91
+ if prompt_arg == "-":
92
+ return sys.stdin.read().strip()
93
+ if prompt_arg.startswith("@"):
94
+ path = prompt_arg[1:]
95
+ with open(path, "r", encoding="utf-8") as f:
96
+ return f.read().strip()
97
+ return prompt_arg
98
+
99
+
100
+ def split_thinking_by_token_id(
101
+ output_ids: list,
102
+ tokenizer,
103
+ think_end_id: Optional[int],
104
+ ) -> Tuple[str, str]:
105
+ """
106
+ Split generated token ids into (thinking, final_content) based on think_end_id.
107
+ If think_end_id is None or not found, returns ("", decoded_all) as a safe fallback.
108
+ """
109
+ if think_end_id is None:
110
+ return "", tokenizer.decode(output_ids, skip_special_tokens=True).strip()
111
+
112
+ try:
113
+ # Find first occurrence of think_end_id
114
+ idx = output_ids.index(think_end_id) + 1
115
+ except ValueError:
116
+ idx = 0
117
+
118
+ thinking = tokenizer.decode(output_ids[:idx], skip_special_tokens=True).strip()
119
+ content = tokenizer.decode(output_ids[idx:], skip_special_tokens=True).strip()
120
+ return thinking, content
121
+
122
+
123
+ def extract_graph_json_block(text: str) -> Tuple[Optional[str], Optional[dict]]:
124
+ """
125
+ Extract first <graph_json>...</graph_json> block.
126
+ Returns (raw_json_text, parsed_obj) or (None, None).
127
+
128
+ Fail-safe recovery:
129
+ - try parsing inner content
130
+ - else take largest {...} region inside tag block
131
+ """
132
+ m = re.search(
133
+ rf"{re.escape(GRAPH_JSON_OPEN)}(.*?){re.escape(GRAPH_JSON_CLOSE)}",
134
+ text,
135
+ flags=re.DOTALL,
136
+ )
137
+ if not m:
138
+ return None, None
139
+
140
+ inner = m.group(1).strip()
141
+
142
+ obj = safe_json_loads(inner)
143
+ if obj is not None and isinstance(obj, dict):
144
+ return inner, obj
145
+
146
+ i1 = inner.find("{")
147
+ i2 = inner.rfind("}")
148
+ if i1 != -1 and i2 != -1 and i2 > i1:
149
+ candidate = inner[i1 : i2 + 1].strip()
150
+ obj2 = safe_json_loads(candidate)
151
+ if obj2 is not None and isinstance(obj2, dict):
152
+ return candidate, obj2
153
+
154
+ return inner, None
155
+
156
+
157
+ # ==============================================================================
158
+ # Graph utilities
159
+ # ==============================================================================
160
+
161
+ def build_nx_graph(graph_obj: Dict[str, Any]) -> nx.DiGraph:
162
+ """
163
+ Build a NetworkX DiGraph from JSON:
164
+ graph_obj["nodes"] = [{"id": "...", ...}, ...]
165
+ graph_obj["edges"] = [{"source":"...", "target":"...", "relation":"...", ...}, ...]
166
+ """
167
+ G = nx.DiGraph()
168
+
169
+ nodes = graph_obj.get("nodes", []) or []
170
+ edges = graph_obj.get("edges", []) or []
171
+
172
+ for n in nodes:
173
+ if not isinstance(n, dict):
174
+ continue
175
+ nid = n.get("id")
176
+ if nid:
177
+ attrs = {k: v for k, v in n.items() if k != "id"}
178
+ G.add_node(nid, **attrs)
179
+
180
+ for e in edges:
181
+ if not isinstance(e, dict):
182
+ continue
183
+ src = e.get("source")
184
+ tgt = e.get("target")
185
+ if not (src and tgt):
186
+ continue
187
+ rel = e.get("relation", "")
188
+ attrs = {k: v for k, v in e.items() if k not in ("source", "target")}
189
+ attrs["relation"] = rel
190
+
191
+ if src not in G:
192
+ G.add_node(src)
193
+ if tgt not in G:
194
+ G.add_node(tgt)
195
+ G.add_edge(src, tgt, **attrs)
196
+
197
+ return G
198
+
199
+
200
+ def layout_graph(G: nx.DiGraph):
201
+ """
202
+ Prefer Graphviz 'dot' layout if available; else spring layout.
203
+ """
204
+ try:
205
+ from networkx.drawing.nx_pydot import graphviz_layout
206
+ pos = graphviz_layout(G, prog="dot")
207
+ return pos, "graphviz(dot)"
208
+ except Exception:
209
+ pos = nx.spring_layout(G, seed=7, k=0.9)
210
+ return pos, "spring_layout"
211
+
212
+
213
+ def visualize_and_save_graph(G: nx.DiGraph, out_dir: str, title: str, log: logging.Logger):
214
+ """
215
+ Render and save PNG + SVG with edge relation labels.
216
+ Fail-safe: saves a minimal plot if something fails.
217
+ """
218
+ png_path = os.path.join(out_dir, "graph.png")
219
+ svg_path = os.path.join(out_dir, "graph.svg")
220
+
221
+ if G.number_of_nodes() == 0:
222
+ log.warning("Graph has 0 nodes; skipping visualization.")
223
+ return None, None
224
+
225
+ pos, layout_used = layout_graph(G)
226
+ log.info(f"Graph layout: {layout_used} | nodes={G.number_of_nodes()} edges={G.number_of_edges()}")
227
+
228
+ n = G.number_of_nodes()
229
+ fig_w = min(22, max(12, 0.9 * math.sqrt(n) * 8))
230
+ fig_h = min(12, max(7, 0.6 * math.sqrt(n) * 6))
231
+
232
+ plt.figure(figsize=(fig_w, fig_h))
233
+ try:
234
+ nx.draw_networkx_nodes(G, pos, node_size=2200, linewidths=1.2)
235
+ nx.draw_networkx_edges(G, pos, arrows=True, arrowstyle="-|>", arrowsize=18, width=1.6)
236
+ nx.draw_networkx_labels(G, pos, font_size=10)
237
+
238
+ edge_labels = {(u, v): (d.get("relation") or "") for u, v, d in G.edges(data=True)}
239
+ nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=9, rotate=False)
240
+
241
+ plt.title(f"{title} ({layout_used})")
242
+ plt.axis("off")
243
+ plt.tight_layout()
244
+ plt.savefig(png_path, dpi=300, bbox_inches="tight")
245
+ plt.savefig(svg_path, bbox_inches="tight")
246
+ plt.close()
247
+ return png_path, svg_path
248
+
249
+ except Exception as e:
250
+ log.exception(f"Visualization failed (attempting minimal save): {e}")
251
+ plt.clf()
252
+ plt.figure(figsize=(12, 7))
253
+ nx.draw(G, with_labels=True)
254
+ plt.title(f"{title} (minimal)")
255
+ plt.axis("off")
256
+ plt.tight_layout()
257
+ plt.savefig(png_path, dpi=200, bbox_inches="tight")
258
+ plt.savefig(svg_path, bbox_inches="tight")
259
+ plt.close()
260
+ return png_path, svg_path
261
+
262
+
263
+ # ==============================================================================
264
+ # Tokenizer / prompt template compatibility
265
+ # ==============================================================================
266
+
267
+ def render_chat_prompt(tokenizer, user_prompt: str, enable_thinking: bool, log: logging.Logger) -> str:
268
+ """
269
+ Render prompt using chat template when available.
270
+ - Tries enable_thinking=True if requested.
271
+ - Falls back to enable_thinking=False.
272
+ - Falls back to a minimal plain prompt if apply_chat_template fails.
273
+ """
274
+ messages = [{"role": "user", "content": user_prompt}]
275
+
276
+ if hasattr(tokenizer, "apply_chat_template"):
277
+ # Try with enable_thinking if requested
278
+ if enable_thinking:
279
+ try:
280
+ return tokenizer.apply_chat_template(
281
+ messages,
282
+ tokenize=False,
283
+ add_generation_prompt=True,
284
+ enable_thinking=True,
285
+ )
286
+ except TypeError as e:
287
+ # Some tokenizers don't accept enable_thinking kwarg
288
+ log.warning(f"Tokenizer chat template does not support enable_thinking kwarg: {e}")
289
+ except Exception as e:
290
+ log.warning(f"apply_chat_template(enable_thinking=True) failed; falling back: {e}")
291
+
292
+ # Try without enable_thinking
293
+ try:
294
+ return tokenizer.apply_chat_template(
295
+ messages,
296
+ tokenize=False,
297
+ add_generation_prompt=True,
298
+ )
299
+ except Exception as e:
300
+ log.warning(f"apply_chat_template failed; falling back to plain prompt: {e}")
301
+
302
+ # Plain prompt fallback
303
+ return user_prompt.strip()
304
+
305
+
306
+ # ==============================================================================
307
+ # Main
308
+ # ==============================================================================
309
+
310
+ def parse_args() -> argparse.Namespace:
311
+ p = argparse.ArgumentParser(
312
+ description="CLI Graph Reasoning Runner (Graph-PRefLexOR style): generate, extract <graph_json>, visualize.",
313
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
314
+ )
315
+
316
+ # Model/token/auth
317
+ p.add_argument("--model", required=True, help="Hugging Face model name or local path")
318
+ p.add_argument("--hf-token", default=None, help="HF token (or set HF_TOKEN env var)")
319
+ p.add_argument("--revision", default=None, help="Model revision (branch/tag/commit)")
320
+
321
+ # Prompt
322
+ p.add_argument(
323
+ "--prompt",
324
+ required=True,
325
+ help="Prompt text, or '-' for stdin, or '@path' to read from file",
326
+ )
327
+ p.add_argument(
328
+ "--enable-thinking",
329
+ action="store_true",
330
+ help="Attempt to enable thinking via tokenizer.apply_chat_template(enable_thinking=True)",
331
+ )
332
+
333
+ # Generation
334
+ p.add_argument("--max-new-tokens", type=int, default=32768)
335
+ p.add_argument("--temperature", type=float, default=0.2)
336
+ p.add_argument("--do-sample", action="store_true", help="Enable sampling")
337
+ p.add_argument("--top-p", type=float, default=None, help="Optional top_p")
338
+ p.add_argument("--top-k", type=int, default=None, help="Optional top_k")
339
+ p.add_argument("--repetition-penalty", type=float, default=None, help="Optional repetition penalty")
340
+
341
+ # Thinking split
342
+ p.add_argument(
343
+ "--think-end-token-id",
344
+ type=int,
345
+ default=None,
346
+ help="Token id marking end of thinking (e.g., 151668). If unset, no splitting occurs.",
347
+ )
348
+
349
+ # Output
350
+ p.add_argument("--out-dir", default=None, help="Output directory (default: ./run_<timestamp>)")
351
+ p.add_argument("--run-id", default=None, help="Optional custom run id (default: timestamp)")
352
+ p.add_argument("--print-thinking", action="store_true", help="Also print the thinking section to stdout")
353
+ p.add_argument("--no-print", action="store_true", help="Do not print model output to stdout")
354
+
355
+ # Performance/device
356
+ p.add_argument("--dtype", default="auto", choices=["auto", "float16", "bfloat16", "float32"], help="torch_dtype")
357
+ p.add_argument("--device-map", default="auto", help="Transformers device_map (e.g., auto, cuda:0, cpu)")
358
+ p.add_argument("--attn-impl", default=None, help="Optional attn_implementation (e.g., flash_attention_2)")
359
+
360
+ return p.parse_args()
361
+
362
+
363
+ def setup_outdir(run_id: str, out_dir_arg: Optional[str]) -> str:
364
+ if out_dir_arg:
365
+ out_dir = os.path.abspath(out_dir_arg)
366
+ else:
367
+ out_dir = os.path.abspath(f"./run_{run_id}")
368
+ os.makedirs(out_dir, exist_ok=True)
369
+ return out_dir
370
+
371
+
372
+ def setup_logger(out_dir: str) -> logging.Logger:
373
+ log_path = os.path.join(out_dir, "run.log")
374
+ logger = logging.getLogger("graph_reasoning")
375
+ logger.setLevel(logging.INFO)
376
+ logger.handlers = [] # avoid duplicate handlers in repeated runs
377
+
378
+ fmt = logging.Formatter("%(asctime)s | %(levelname)s | %(message)s")
379
+ fh = logging.FileHandler(log_path)
380
+ fh.setFormatter(fmt)
381
+ sh = logging.StreamHandler(sys.stdout)
382
+ sh.setFormatter(fmt)
383
+
384
+ logger.addHandler(fh)
385
+ logger.addHandler(sh)
386
+ return logger
387
+
388
+
389
+ def torch_dtype_from_arg(dtype: str):
390
+ if dtype == "auto":
391
+ return "auto"
392
+ if dtype == "float16":
393
+ return torch.float16
394
+ if dtype == "bfloat16":
395
+ return torch.bfloat16
396
+ if dtype == "float32":
397
+ return torch.float32
398
+ return "auto"
399
+
400
+
401
+ def main() -> int:
402
+ args = parse_args()
403
+
404
+ run_id = args.run_id or now_run_id()
405
+ out_dir = setup_outdir(run_id, args.out_dir)
406
+ log = setup_logger(out_dir)
407
+
408
+ hf_token = args.hf_token or os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
409
+
410
+ # Persist run metadata early
411
+ meta = {
412
+ "run_id": run_id,
413
+ "timestamp": datetime.now().isoformat(),
414
+ "model": args.model,
415
+ "revision": args.revision,
416
+ "max_new_tokens": args.max_new_tokens,
417
+ "temperature": args.temperature,
418
+ "do_sample": bool(args.do_sample),
419
+ "top_p": args.top_p,
420
+ "top_k": args.top_k,
421
+ "repetition_penalty": args.repetition_penalty,
422
+ "think_end_token_id": args.think_end_token_id,
423
+ "enable_thinking": bool(args.enable_thinking),
424
+ "dtype": args.dtype,
425
+ "device_map": args.device_map,
426
+ "attn_impl": args.attn_impl,
427
+ "python": sys.version,
428
+ "torch": getattr(torch, "__version__", None),
429
+ }
430
+ atomic_write_text(os.path.join(out_dir, "run_meta.json"), json.dumps(meta, indent=2))
431
+
432
+ # Resolve prompt
433
+ prompt = resolve_prompt(args.prompt)
434
+ if not prompt:
435
+ log.error("Prompt is empty.")
436
+ return 2
437
+
438
+ atomic_write_text(os.path.join(out_dir, "prompt.txt"), prompt)
439
+
440
+ log.info(f"Output dir: {out_dir}")
441
+ log.info(f"Model: {args.model}")
442
+ if args.revision:
443
+ log.info(f"Revision: {args.revision}")
444
+ log.info("Loading tokenizer/model...")
445
+
446
+ # Load tokenizer/model
447
+ tok_kwargs = {"token": hf_token} if hf_token else {}
448
+ if args.revision:
449
+ tok_kwargs["revision"] = args.revision
450
+
451
+ tokenizer = AutoTokenizer.from_pretrained(args.model, **tok_kwargs)
452
+
453
+ model_kwargs = {
454
+ "device_map": args.device_map,
455
+ "token": hf_token if hf_token else None,
456
+ }
457
+ if args.revision:
458
+ model_kwargs["revision"] = args.revision
459
+
460
+ td = torch_dtype_from_arg(args.dtype)
461
+ if td != "auto":
462
+ model_kwargs["torch_dtype"] = td
463
+ else:
464
+ model_kwargs["torch_dtype"] = "auto"
465
+
466
+ if args.attn_impl:
467
+ model_kwargs["attn_implementation"] = args.attn_impl
468
+
469
+ model = AutoModelForCausalLM.from_pretrained(args.model, **model_kwargs)
470
+ model.eval()
471
+
472
+ # Render chat prompt
473
+ rendered = render_chat_prompt(tokenizer, prompt, enable_thinking=args.enable_thinking, log=log)
474
+ atomic_write_text(os.path.join(out_dir, "prompt_rendered.txt"), rendered)
475
+
476
+ # Tokenize
477
+ model_inputs = tokenizer(rendered, return_tensors="pt")
478
+
479
+ # Move inputs to model device where possible
480
+ try:
481
+ model_inputs = {k: v.to(model.device) for k, v in model_inputs.items()}
482
+ except Exception:
483
+ # In some device_map setups, model.device may not be meaningful; leave as-is.
484
+ pass
485
+
486
+ # Generation config
487
+ gen_cfg_kwargs = dict(
488
+ max_new_tokens=args.max_new_tokens,
489
+ do_sample=bool(args.do_sample),
490
+ temperature=float(args.temperature),
491
+ )
492
+ if args.top_p is not None:
493
+ gen_cfg_kwargs["top_p"] = float(args.top_p)
494
+ if args.top_k is not None:
495
+ gen_cfg_kwargs["top_k"] = int(args.top_k)
496
+ if args.repetition_penalty is not None:
497
+ gen_cfg_kwargs["repetition_penalty"] = float(args.repetition_penalty)
498
+
499
+ gen_config = GenerationConfig(**gen_cfg_kwargs)
500
+
501
+ log.info("Generating...")
502
+ t0 = time.time()
503
+ with torch.no_grad():
504
+ generated = model.generate(**model_inputs, generation_config=gen_config)
505
+ t1 = time.time()
506
+ log.info(f"Generation done in {t1 - t0:.2f}s")
507
+
508
+ # Slice off prompt tokens to get only generated continuation
509
+ input_len = model_inputs["input_ids"].shape[1]
510
+ output_ids = generated[0, input_len:].tolist()
511
+
512
+ thinking, content = split_thinking_by_token_id(output_ids, tokenizer, args.think_end_token_id)
513
+
514
+ # Persist outputs (always)
515
+ atomic_write_text(os.path.join(out_dir, "thinking.txt"), thinking or "")
516
+ atomic_write_text(os.path.join(out_dir, "content.txt"), content or "")
517
+ atomic_write_text(os.path.join(out_dir, "full_output.txt"), (thinking + "\n\n" + content).strip())
518
+
519
+ # Print
520
+ if not args.no_print:
521
+ if args.print_thinking and thinking:
522
+ sys.stdout.write("\n" + "=" * 80 + "\nTHINKING\n" + "=" * 80 + "\n")
523
+ sys.stdout.write(thinking + "\n")
524
+ sys.stdout.write("\n" + "=" * 80 + "\nFINAL OUTPUT\n" + "=" * 80 + "\n")
525
+ sys.stdout.write(content + "\n")
526
+ sys.stdout.flush()
527
+
528
+ # Extract graph json
529
+ raw_block, graph_obj = extract_graph_json_block((thinking or "") + "\n" + (content or ""))
530
+
531
+ if raw_block is None:
532
+ log.warning("No <graph_json>...</graph_json> block found in output.")
533
+ atomic_write_text(os.path.join(out_dir, "graph_status.txt"), "not_found")
534
+ return 0
535
+
536
+ atomic_write_text(os.path.join(out_dir, "graph_json_raw.txt"), raw_block)
537
+
538
+ if graph_obj is None:
539
+ log.warning("Found <graph_json> block, but JSON parsing failed. Saved raw block for inspection.")
540
+ atomic_write_text(os.path.join(out_dir, "graph_status.txt"), "found_but_parse_failed")
541
+ return 0
542
+
543
+ atomic_write_text(os.path.join(out_dir, "graph.json"), json.dumps(graph_obj, indent=2, ensure_ascii=False))
544
+ atomic_write_text(os.path.join(out_dir, "graph_status.txt"), "parsed_ok")
545
+
546
+ # Build & visualize graph
547
+ G = build_nx_graph(graph_obj)
548
+ atomic_write_text(
549
+ os.path.join(out_dir, "graph_stats.json"),
550
+ json.dumps(
551
+ {"nodes": G.number_of_nodes(), "edges": G.number_of_edges()},
552
+ indent=2,
553
+ ),
554
+ )
555
+
556
+ png_path, svg_path = visualize_and_save_graph(G, out_dir, title="Graph Reasoning Output Graph", log=log)
557
+ if png_path and svg_path:
558
+ log.info(f"Saved graph: {png_path}")
559
+ log.info(f"Saved graph: {svg_path}")
560
+
561
+ return 0
562
+
563
+
564
+ if __name__ == "__main__":
565
+ # Hard fail-safe: always write CRASH marker if something bubbles up
566
+ _run_id = None
567
+ _out_dir = None
568
+ _log = None
569
+ try:
570
+ rc = main()
571
+ raise SystemExit(rc)
572
+ except SystemExit:
573
+ raise
574
+ except Exception as e:
575
+ # Best-effort to write crash marker if we can infer out_dir from args
576
+ try:
577
+ # Minimal heuristic: if user passed --out-dir use that; else default to latest run_* in cwd
578
+ # (We do not attempt to re-parse args fully here to avoid cascading failures.)
579
+ candidates = []
580
+ for name in os.listdir("."):
581
+ if name.startswith("run_") and os.path.isdir(name):
582
+ candidates.append(name)
583
+ candidates.sort(reverse=True)
584
+ fallback_dir = os.path.abspath(candidates[0]) if candidates else os.path.abspath("./")
585
+ atomic_write_text(os.path.join(fallback_dir, "CRASH.txt"), repr(e))
586
+ except Exception:
587
+ pass
588
+ raise