Upload 12 files
Browse files- .gitattributes +0 -33
- LICENSE +9 -0
- README.md +49 -3
- cli.py +34 -0
- configs/model_config.json +8 -0
- data_examples/example_code.py +9 -0
- inference.py +44 -0
- requirements.txt +7 -0
- tests/test_viz_and_explain.py +13 -0
- train_docgen.py +43 -0
- utils.py +6 -0
- viz_generator.py +117 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,2 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
| 1 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LICENSE
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License 2.0
|
| 2 |
+
|
| 3 |
+
Copyright 2025 hmnshudhmn24
|
| 4 |
+
|
| 5 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
you may not use this file except in compliance with the License.
|
| 7 |
+
You may obtain a copy of the License at
|
| 8 |
+
|
| 9 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
README.md
CHANGED
|
@@ -1,3 +1,49 @@
|
|
| 1 |
-
---
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language:
|
| 3 |
+
- en
|
| 4 |
+
license: apache-2.0
|
| 5 |
+
tags:
|
| 6 |
+
- code-explanation
|
| 7 |
+
- visualization
|
| 8 |
+
- mermaid
|
| 9 |
+
- codet5
|
| 10 |
+
- developer-tools
|
| 11 |
+
pipeline_tag: text-generation
|
| 12 |
+
library_name: transformers
|
| 13 |
+
base_model: Salesforce/codet5-small
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
# code-explain-viz
|
| 17 |
+
|
| 18 |
+
**Short:** `code-explain-viz` explains functions, generates step-by-step reasoning, creates a Mermaid flowchart of control flow, and suggests unit tests — combining LLM-generated explanations with deterministic AST-based visualizations.
|
| 19 |
+
|
| 20 |
+
## Quick start
|
| 21 |
+
|
| 22 |
+
1. Install requirements:
|
| 23 |
+
```bash
|
| 24 |
+
pip install -r requirements.txt
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
2. Run CLI demo:
|
| 28 |
+
```bash
|
| 29 |
+
python cli.py --file data_examples/example_code.py
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
3. Copy the Mermaid flowchart text printed by CLI into a Mermaid live editor (https://mermaid.live) or render with mermaid-cli to see the visual flowchart.
|
| 33 |
+
|
| 34 |
+
## What you get
|
| 35 |
+
- `short` one-line explanation
|
| 36 |
+
- `detailed` explanation (multi-line)
|
| 37 |
+
- `mermaid` flowchart text describing control flow
|
| 38 |
+
- `unit_tests` template (pytest)
|
| 39 |
+
|
| 40 |
+
## How it works
|
| 41 |
+
- A sequence-to-sequence model (CodeT5) generates natural language explanations from code.
|
| 42 |
+
- `viz_generator.py` parses the function AST and produces a reliable mermaid flowchart.
|
| 43 |
+
- Combining both yields both human-friendly narrative and precise structural view.
|
| 44 |
+
|
| 45 |
+
## Train / Fine-tune
|
| 46 |
+
Use `train_docgen.py` with a JSONL dataset (each line: `{"code": "...", "doc": "..."}`).
|
| 47 |
+
|
| 48 |
+
## License
|
| 49 |
+
Apache-2.0
|
cli.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# cli.py
|
| 2 |
+
import argparse
|
| 3 |
+
from inference import CodeExplainViz
|
| 4 |
+
|
| 5 |
+
def main():
|
| 6 |
+
parser = argparse.ArgumentParser()
|
| 7 |
+
parser.add_argument("--file", type=str, help="Path to python file with a function")
|
| 8 |
+
parser.add_argument("--code", type=str, help="Code string to explain")
|
| 9 |
+
parser.add_argument("--model", type=str, default="Salesforce/codet5-small", help="Model path or HF name")
|
| 10 |
+
args = parser.parse_args()
|
| 11 |
+
|
| 12 |
+
code = None
|
| 13 |
+
if args.file:
|
| 14 |
+
with open(args.file, "r", encoding="utf-8") as f:
|
| 15 |
+
code = f.read()
|
| 16 |
+
elif args.code:
|
| 17 |
+
code = args.code
|
| 18 |
+
else:
|
| 19 |
+
print("Provide --file or --code")
|
| 20 |
+
return
|
| 21 |
+
|
| 22 |
+
explainer = CodeExplainViz(model_name_or_path=args.model)
|
| 23 |
+
out = explainer.explain(code)
|
| 24 |
+
print("\n--- Short Explanation ---\n")
|
| 25 |
+
print(out["short"])
|
| 26 |
+
print("\n--- Detailed Explanation ---\n")
|
| 27 |
+
print(out["detailed"])
|
| 28 |
+
print("\n--- Mermaid Flowchart (copy into mermaid live editor) ---\n")
|
| 29 |
+
print(out["mermaid"])
|
| 30 |
+
print("\n--- Unit test template ---\n")
|
| 31 |
+
print(out["unit_tests"])
|
| 32 |
+
|
| 33 |
+
if __name__ == "__main__":
|
| 34 |
+
main()
|
configs/model_config.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"base_model": "Salesforce/codet5-small",
|
| 3 |
+
"max_input_length": 512,
|
| 4 |
+
"max_target_length": 256,
|
| 5 |
+
"train_epochs": 3,
|
| 6 |
+
"train_batch_size": 8,
|
| 7 |
+
"learning_rate": 3e-05
|
| 8 |
+
}
|
data_examples/example_code.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def factorial(n):
|
| 2 |
+
if not isinstance(n, int):
|
| 3 |
+
raise TypeError("n must be an integer")
|
| 4 |
+
if n < 0:
|
| 5 |
+
raise ValueError("n must be >= 0")
|
| 6 |
+
result = 1
|
| 7 |
+
for i in range(2, n+1):
|
| 8 |
+
result *= i
|
| 9 |
+
return result
|
inference.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# inference.py
|
| 2 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 3 |
+
import textwrap
|
| 4 |
+
from viz_generator import code_to_mermaid
|
| 5 |
+
|
| 6 |
+
DEFAULT_MODEL = "Salesforce/codet5-small"
|
| 7 |
+
|
| 8 |
+
class CodeExplainViz:
|
| 9 |
+
def __init__(self, model_name_or_path=DEFAULT_MODEL):
|
| 10 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
| 11 |
+
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
|
| 12 |
+
|
| 13 |
+
def explain(self, code: str, max_length: int = 256) -> dict:
|
| 14 |
+
prompt = "explain: " + code
|
| 15 |
+
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
|
| 16 |
+
outputs = self.model.generate(**inputs, max_length=max_length, num_beams=4, early_stopping=True)
|
| 17 |
+
text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 18 |
+
lines = [l.strip() for l in text.splitlines() if l.strip()]
|
| 19 |
+
short = lines[0] if lines else textwrap.shorten(text, width=120)
|
| 20 |
+
detailed = "\n".join(lines[1:]) if len(lines) > 1 else text
|
| 21 |
+
mermaid = code_to_mermaid(code)
|
| 22 |
+
unit_tests = self._make_unit_test_template(code)
|
| 23 |
+
return {"short": short, "detailed": detailed, "mermaid": mermaid, "unit_tests": unit_tests}
|
| 24 |
+
|
| 25 |
+
def _make_unit_test_template(self, code: str) -> str:
|
| 26 |
+
import re
|
| 27 |
+
m = re.search(r"def\s+([A-Za-z0-9_]+)\s*\((.*?)\):", code)
|
| 28 |
+
fn = m.group(1) if m else "function_under_test"
|
| 29 |
+
params = m.group(2) if m else ""
|
| 30 |
+
param_count = len([p for p in params.split(',') if p.strip()]) if params.strip() else 0
|
| 31 |
+
args = ", ".join(["0"] * param_count)
|
| 32 |
+
template = f"""import pytest
|
| 33 |
+
|
| 34 |
+
from your_module import {fn}
|
| 35 |
+
|
| 36 |
+
def test_{fn}_basic():
|
| 37 |
+
# TODO: replace with real inputs and expected outputs
|
| 38 |
+
assert {fn}({args}) == ...
|
| 39 |
+
|
| 40 |
+
def test_{fn}_edge_cases():
|
| 41 |
+
# Example edge-case tests
|
| 42 |
+
with pytest.raises(Exception):
|
| 43 |
+
{fn}(...)"""
|
| 44 |
+
return template
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers>=4.30.0
|
| 2 |
+
datasets>=2.10.0
|
| 3 |
+
torch>=1.12.0
|
| 4 |
+
astor
|
| 5 |
+
graphviz
|
| 6 |
+
pytest
|
| 7 |
+
black
|
tests/test_viz_and_explain.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from data_examples.example_code import factorial
|
| 3 |
+
from inference import CodeExplainViz
|
| 4 |
+
|
| 5 |
+
def test_explain_and_viz_runs():
|
| 6 |
+
with open("data_examples/example_code.py", "r", encoding="utf-8") as f:
|
| 7 |
+
code = f.read()
|
| 8 |
+
expl = CodeExplainViz()
|
| 9 |
+
out = expl.explain(code)
|
| 10 |
+
assert "mermaid" in out
|
| 11 |
+
assert out["mermaid"].startswith("flowchart")
|
| 12 |
+
assert isinstance(out["short"], str)
|
| 13 |
+
assert len(out["short"]) > 0
|
train_docgen.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# train_docgen.py
|
| 2 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
|
| 3 |
+
from datasets import load_dataset
|
| 4 |
+
import argparse
|
| 5 |
+
|
| 6 |
+
def parse_args():
|
| 7 |
+
p = argparse.ArgumentParser()
|
| 8 |
+
p.add_argument("--data", type=str, default="data_examples/sample_dataset.jsonl", help="jsonl with {'code','doc'}")
|
| 9 |
+
p.add_argument("--output_dir", type=str, default="./code-explain-viz-model")
|
| 10 |
+
p.add_argument("--epochs", type=int, default=1)
|
| 11 |
+
return p.parse_args()
|
| 12 |
+
|
| 13 |
+
def preprocess_batch(examples, tokenizer, max_src=512, max_tgt=256):
|
| 14 |
+
inputs = ["explain: " + c for c in examples["code"]]
|
| 15 |
+
model_inputs = tokenizer(inputs, truncation=True, padding="max_length", max_length=max_src)
|
| 16 |
+
labels = tokenizer(text_target=examples["doc"], truncation=True, padding="max_length", max_length=max_tgt)
|
| 17 |
+
model_inputs["labels"] = labels["input_ids"]
|
| 18 |
+
return model_inputs
|
| 19 |
+
|
| 20 |
+
def main():
|
| 21 |
+
args = parse_args()
|
| 22 |
+
model_name = "Salesforce/codet5-small"
|
| 23 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 24 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
| 25 |
+
|
| 26 |
+
ds = load_dataset("json", data_files={"train": args.data})
|
| 27 |
+
tokenized = ds["train"].map(lambda x: preprocess_batch(x, tokenizer), batched=True, remove_columns=ds["train"].column_names)
|
| 28 |
+
|
| 29 |
+
training_args = Seq2SeqTrainingArguments(
|
| 30 |
+
output_dir=args.output_dir,
|
| 31 |
+
num_train_epochs=args.epochs,
|
| 32 |
+
per_device_train_batch_size=2,
|
| 33 |
+
save_strategy="epoch",
|
| 34 |
+
logging_steps=50
|
| 35 |
+
)
|
| 36 |
+
trainer = Seq2SeqTrainer(model=model, args=training_args, train_dataset=tokenized)
|
| 37 |
+
trainer.train()
|
| 38 |
+
trainer.save_model(args.output_dir)
|
| 39 |
+
tokenizer.save_pretrained(args.output_dir)
|
| 40 |
+
print("Saved model to", args.output_dir)
|
| 41 |
+
|
| 42 |
+
if __name__ == "__main__":
|
| 43 |
+
main()
|
utils.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# utils.py
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
def extract_first_function_name(code: str):
|
| 5 |
+
m = re.search(r"def\s+([A-Za-z0-9_]+)\s*\(", code)
|
| 6 |
+
return m.group(1) if m else None
|
viz_generator.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# viz_generator.py
|
| 2 |
+
import ast
|
| 3 |
+
|
| 4 |
+
class VizBuilder(ast.NodeVisitor):
|
| 5 |
+
def __init__(self):
|
| 6 |
+
self.nodes = []
|
| 7 |
+
self.edges = []
|
| 8 |
+
self.counter = 0
|
| 9 |
+
|
| 10 |
+
def new_id(self, prefix="n"):
|
| 11 |
+
self.counter += 1
|
| 12 |
+
return f"{prefix}{self.counter}"
|
| 13 |
+
|
| 14 |
+
def add_node(self, nid, label):
|
| 15 |
+
label = label.replace("\n", "\\n").replace('"', '\\"')
|
| 16 |
+
self.nodes.append((nid, label))
|
| 17 |
+
|
| 18 |
+
def add_edge(self, a, b, label=""):
|
| 19 |
+
self.edges.append((a, b, label))
|
| 20 |
+
|
| 21 |
+
def visit_FunctionDef(self, node: ast.FunctionDef):
|
| 22 |
+
start = self.new_id("start")
|
| 23 |
+
self.add_node(start, f"def {node.name}(...)")
|
| 24 |
+
prev = start
|
| 25 |
+
for stmt in node.body:
|
| 26 |
+
cur = self.visit(stmt)
|
| 27 |
+
if cur:
|
| 28 |
+
self.add_edge(prev, cur)
|
| 29 |
+
prev = cur
|
| 30 |
+
return start
|
| 31 |
+
|
| 32 |
+
def visit_Return(self, node: ast.Return):
|
| 33 |
+
nid = self.new_id("ret")
|
| 34 |
+
val = ast.unparse(node.value) if node.value else ""
|
| 35 |
+
self.add_node(nid, f"return {val}")
|
| 36 |
+
return nid
|
| 37 |
+
|
| 38 |
+
def visit_Raise(self, node: ast.Raise):
|
| 39 |
+
nid = self.new_id("raise")
|
| 40 |
+
exc = ast.unparse(node.exc) if node.exc else ""
|
| 41 |
+
self.add_node(nid, f"raise {exc}")
|
| 42 |
+
return nid
|
| 43 |
+
|
| 44 |
+
def visit_For(self, node: ast.For):
|
| 45 |
+
nid = self.new_id("for")
|
| 46 |
+
target = ast.unparse(node.target)
|
| 47 |
+
iter_ = ast.unparse(node.iter)
|
| 48 |
+
self.add_node(nid, f"for {target} in {iter_}")
|
| 49 |
+
prev = nid
|
| 50 |
+
for stmt in node.body:
|
| 51 |
+
cur = self.visit(stmt)
|
| 52 |
+
if cur:
|
| 53 |
+
self.add_edge(prev, cur)
|
| 54 |
+
prev = cur
|
| 55 |
+
return nid
|
| 56 |
+
|
| 57 |
+
def visit_While(self, node: ast.While):
|
| 58 |
+
nid = self.new_id("while")
|
| 59 |
+
cond = ast.unparse(node.test)
|
| 60 |
+
self.add_node(nid, f"while {cond}")
|
| 61 |
+
prev = nid
|
| 62 |
+
for stmt in node.body:
|
| 63 |
+
cur = self.visit(stmt)
|
| 64 |
+
if cur:
|
| 65 |
+
self.add_edge(prev, cur)
|
| 66 |
+
prev = cur
|
| 67 |
+
return nid
|
| 68 |
+
|
| 69 |
+
def visit_If(self, node: ast.If):
|
| 70 |
+
nid = self.new_id("if")
|
| 71 |
+
cond = ast.unparse(node.test)
|
| 72 |
+
self.add_node(nid, f"if {cond}")
|
| 73 |
+
for stmt in node.body:
|
| 74 |
+
cur = self.visit(stmt)
|
| 75 |
+
if cur:
|
| 76 |
+
self.add_edge(nid, cur, label="true")
|
| 77 |
+
if node.orelse:
|
| 78 |
+
for stmt in node.orelse:
|
| 79 |
+
cur = self.visit(stmt)
|
| 80 |
+
if cur:
|
| 81 |
+
self.add_edge(nid, cur, label="false")
|
| 82 |
+
return nid
|
| 83 |
+
|
| 84 |
+
def visit_Expr(self, node: ast.Expr):
|
| 85 |
+
nid = self.new_id("expr")
|
| 86 |
+
txt = ast.unparse(node.value)
|
| 87 |
+
self.add_node(nid, txt)
|
| 88 |
+
return nid
|
| 89 |
+
|
| 90 |
+
def visit_Assign(self, node: ast.Assign):
|
| 91 |
+
nid = self.new_id("assign")
|
| 92 |
+
targets = ", ".join([ast.unparse(t) for t in node.targets])
|
| 93 |
+
val = ast.unparse(node.value)
|
| 94 |
+
self.add_node(nid, f"{targets} = {val}")
|
| 95 |
+
return nid
|
| 96 |
+
|
| 97 |
+
def generic_visit(self, node):
|
| 98 |
+
super().generic_visit(node)
|
| 99 |
+
return None
|
| 100 |
+
|
| 101 |
+
def code_to_mermaid(code: str) -> str:
|
| 102 |
+
tree = ast.parse(code)
|
| 103 |
+
vb = VizBuilder()
|
| 104 |
+
root_id = None
|
| 105 |
+
for node in tree.body:
|
| 106 |
+
if isinstance(node, ast.FunctionDef):
|
| 107 |
+
root_id = vb.visit(node)
|
| 108 |
+
break
|
| 109 |
+
lines = ["flowchart TD"]
|
| 110 |
+
for nid, label in vb.nodes:
|
| 111 |
+
lines.append(f' {nid}["{label}"]')
|
| 112 |
+
for a, b, lbl in vb.edges:
|
| 113 |
+
if lbl:
|
| 114 |
+
lines.append(f' {a} -->|{lbl}| {b}')
|
| 115 |
+
else:
|
| 116 |
+
lines.append(f' {a} --> {b}')
|
| 117 |
+
return "\n".join(lines)
|