hmnshudhmn24 commited on
Commit
b91943a
·
verified ·
1 Parent(s): 4de82ad

Upload 12 files

Browse files
.gitattributes CHANGED
@@ -1,35 +1,2 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
1
  *.bin filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  *.safetensors filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
LICENSE ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ Apache License 2.0
2
+
3
+ Copyright 2025 hmnshudhmn24
4
+
5
+ Licensed under the Apache License, Version 2.0 (the "License");
6
+ you may not use this file except in compliance with the License.
7
+ You may obtain a copy of the License at
8
+
9
+ http://www.apache.org/licenses/LICENSE-2.0
README.md CHANGED
@@ -1,3 +1,49 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ license: apache-2.0
5
+ tags:
6
+ - code-explanation
7
+ - visualization
8
+ - mermaid
9
+ - codet5
10
+ - developer-tools
11
+ pipeline_tag: text-generation
12
+ library_name: transformers
13
+ base_model: Salesforce/codet5-small
14
+ ---
15
+
16
+ # code-explain-viz
17
+
18
+ **Short:** `code-explain-viz` explains functions, generates step-by-step reasoning, creates a Mermaid flowchart of control flow, and suggests unit tests — combining LLM-generated explanations with deterministic AST-based visualizations.
19
+
20
+ ## Quick start
21
+
22
+ 1. Install requirements:
23
+ ```bash
24
+ pip install -r requirements.txt
25
+ ```
26
+
27
+ 2. Run CLI demo:
28
+ ```bash
29
+ python cli.py --file data_examples/example_code.py
30
+ ```
31
+
32
+ 3. Copy the Mermaid flowchart text printed by CLI into a Mermaid live editor (https://mermaid.live) or render with mermaid-cli to see the visual flowchart.
33
+
34
+ ## What you get
35
+ - `short` one-line explanation
36
+ - `detailed` explanation (multi-line)
37
+ - `mermaid` flowchart text describing control flow
38
+ - `unit_tests` template (pytest)
39
+
40
+ ## How it works
41
+ - A sequence-to-sequence model (CodeT5) generates natural language explanations from code.
42
+ - `viz_generator.py` parses the function AST and produces a reliable mermaid flowchart.
43
+ - Combining both yields both human-friendly narrative and precise structural view.
44
+
45
+ ## Train / Fine-tune
46
+ Use `train_docgen.py` with a JSONL dataset (each line: `{"code": "...", "doc": "..."}`).
47
+
48
+ ## License
49
+ Apache-2.0
cli.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cli.py
2
+ import argparse
3
+ from inference import CodeExplainViz
4
+
5
+ def main():
6
+ parser = argparse.ArgumentParser()
7
+ parser.add_argument("--file", type=str, help="Path to python file with a function")
8
+ parser.add_argument("--code", type=str, help="Code string to explain")
9
+ parser.add_argument("--model", type=str, default="Salesforce/codet5-small", help="Model path or HF name")
10
+ args = parser.parse_args()
11
+
12
+ code = None
13
+ if args.file:
14
+ with open(args.file, "r", encoding="utf-8") as f:
15
+ code = f.read()
16
+ elif args.code:
17
+ code = args.code
18
+ else:
19
+ print("Provide --file or --code")
20
+ return
21
+
22
+ explainer = CodeExplainViz(model_name_or_path=args.model)
23
+ out = explainer.explain(code)
24
+ print("\n--- Short Explanation ---\n")
25
+ print(out["short"])
26
+ print("\n--- Detailed Explanation ---\n")
27
+ print(out["detailed"])
28
+ print("\n--- Mermaid Flowchart (copy into mermaid live editor) ---\n")
29
+ print(out["mermaid"])
30
+ print("\n--- Unit test template ---\n")
31
+ print(out["unit_tests"])
32
+
33
+ if __name__ == "__main__":
34
+ main()
configs/model_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_model": "Salesforce/codet5-small",
3
+ "max_input_length": 512,
4
+ "max_target_length": 256,
5
+ "train_epochs": 3,
6
+ "train_batch_size": 8,
7
+ "learning_rate": 3e-05
8
+ }
data_examples/example_code.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ def factorial(n):
2
+ if not isinstance(n, int):
3
+ raise TypeError("n must be an integer")
4
+ if n < 0:
5
+ raise ValueError("n must be >= 0")
6
+ result = 1
7
+ for i in range(2, n+1):
8
+ result *= i
9
+ return result
inference.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference.py
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import textwrap
4
+ from viz_generator import code_to_mermaid
5
+
6
+ DEFAULT_MODEL = "Salesforce/codet5-small"
7
+
8
+ class CodeExplainViz:
9
+ def __init__(self, model_name_or_path=DEFAULT_MODEL):
10
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
11
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
12
+
13
+ def explain(self, code: str, max_length: int = 256) -> dict:
14
+ prompt = "explain: " + code
15
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
16
+ outputs = self.model.generate(**inputs, max_length=max_length, num_beams=4, early_stopping=True)
17
+ text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
18
+ lines = [l.strip() for l in text.splitlines() if l.strip()]
19
+ short = lines[0] if lines else textwrap.shorten(text, width=120)
20
+ detailed = "\n".join(lines[1:]) if len(lines) > 1 else text
21
+ mermaid = code_to_mermaid(code)
22
+ unit_tests = self._make_unit_test_template(code)
23
+ return {"short": short, "detailed": detailed, "mermaid": mermaid, "unit_tests": unit_tests}
24
+
25
+ def _make_unit_test_template(self, code: str) -> str:
26
+ import re
27
+ m = re.search(r"def\s+([A-Za-z0-9_]+)\s*\((.*?)\):", code)
28
+ fn = m.group(1) if m else "function_under_test"
29
+ params = m.group(2) if m else ""
30
+ param_count = len([p for p in params.split(',') if p.strip()]) if params.strip() else 0
31
+ args = ", ".join(["0"] * param_count)
32
+ template = f"""import pytest
33
+
34
+ from your_module import {fn}
35
+
36
+ def test_{fn}_basic():
37
+ # TODO: replace with real inputs and expected outputs
38
+ assert {fn}({args}) == ...
39
+
40
+ def test_{fn}_edge_cases():
41
+ # Example edge-case tests
42
+ with pytest.raises(Exception):
43
+ {fn}(...)"""
44
+ return template
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ transformers>=4.30.0
2
+ datasets>=2.10.0
3
+ torch>=1.12.0
4
+ astor
5
+ graphviz
6
+ pytest
7
+ black
tests/test_viz_and_explain.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from data_examples.example_code import factorial
3
+ from inference import CodeExplainViz
4
+
5
+ def test_explain_and_viz_runs():
6
+ with open("data_examples/example_code.py", "r", encoding="utf-8") as f:
7
+ code = f.read()
8
+ expl = CodeExplainViz()
9
+ out = expl.explain(code)
10
+ assert "mermaid" in out
11
+ assert out["mermaid"].startswith("flowchart")
12
+ assert isinstance(out["short"], str)
13
+ assert len(out["short"]) > 0
train_docgen.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train_docgen.py
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
3
+ from datasets import load_dataset
4
+ import argparse
5
+
6
+ def parse_args():
7
+ p = argparse.ArgumentParser()
8
+ p.add_argument("--data", type=str, default="data_examples/sample_dataset.jsonl", help="jsonl with {'code','doc'}")
9
+ p.add_argument("--output_dir", type=str, default="./code-explain-viz-model")
10
+ p.add_argument("--epochs", type=int, default=1)
11
+ return p.parse_args()
12
+
13
+ def preprocess_batch(examples, tokenizer, max_src=512, max_tgt=256):
14
+ inputs = ["explain: " + c for c in examples["code"]]
15
+ model_inputs = tokenizer(inputs, truncation=True, padding="max_length", max_length=max_src)
16
+ labels = tokenizer(text_target=examples["doc"], truncation=True, padding="max_length", max_length=max_tgt)
17
+ model_inputs["labels"] = labels["input_ids"]
18
+ return model_inputs
19
+
20
+ def main():
21
+ args = parse_args()
22
+ model_name = "Salesforce/codet5-small"
23
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
24
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
25
+
26
+ ds = load_dataset("json", data_files={"train": args.data})
27
+ tokenized = ds["train"].map(lambda x: preprocess_batch(x, tokenizer), batched=True, remove_columns=ds["train"].column_names)
28
+
29
+ training_args = Seq2SeqTrainingArguments(
30
+ output_dir=args.output_dir,
31
+ num_train_epochs=args.epochs,
32
+ per_device_train_batch_size=2,
33
+ save_strategy="epoch",
34
+ logging_steps=50
35
+ )
36
+ trainer = Seq2SeqTrainer(model=model, args=training_args, train_dataset=tokenized)
37
+ trainer.train()
38
+ trainer.save_model(args.output_dir)
39
+ tokenizer.save_pretrained(args.output_dir)
40
+ print("Saved model to", args.output_dir)
41
+
42
+ if __name__ == "__main__":
43
+ main()
utils.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # utils.py
2
+ import re
3
+
4
+ def extract_first_function_name(code: str):
5
+ m = re.search(r"def\s+([A-Za-z0-9_]+)\s*\(", code)
6
+ return m.group(1) if m else None
viz_generator.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # viz_generator.py
2
+ import ast
3
+
4
+ class VizBuilder(ast.NodeVisitor):
5
+ def __init__(self):
6
+ self.nodes = []
7
+ self.edges = []
8
+ self.counter = 0
9
+
10
+ def new_id(self, prefix="n"):
11
+ self.counter += 1
12
+ return f"{prefix}{self.counter}"
13
+
14
+ def add_node(self, nid, label):
15
+ label = label.replace("\n", "\\n").replace('"', '\\"')
16
+ self.nodes.append((nid, label))
17
+
18
+ def add_edge(self, a, b, label=""):
19
+ self.edges.append((a, b, label))
20
+
21
+ def visit_FunctionDef(self, node: ast.FunctionDef):
22
+ start = self.new_id("start")
23
+ self.add_node(start, f"def {node.name}(...)")
24
+ prev = start
25
+ for stmt in node.body:
26
+ cur = self.visit(stmt)
27
+ if cur:
28
+ self.add_edge(prev, cur)
29
+ prev = cur
30
+ return start
31
+
32
+ def visit_Return(self, node: ast.Return):
33
+ nid = self.new_id("ret")
34
+ val = ast.unparse(node.value) if node.value else ""
35
+ self.add_node(nid, f"return {val}")
36
+ return nid
37
+
38
+ def visit_Raise(self, node: ast.Raise):
39
+ nid = self.new_id("raise")
40
+ exc = ast.unparse(node.exc) if node.exc else ""
41
+ self.add_node(nid, f"raise {exc}")
42
+ return nid
43
+
44
+ def visit_For(self, node: ast.For):
45
+ nid = self.new_id("for")
46
+ target = ast.unparse(node.target)
47
+ iter_ = ast.unparse(node.iter)
48
+ self.add_node(nid, f"for {target} in {iter_}")
49
+ prev = nid
50
+ for stmt in node.body:
51
+ cur = self.visit(stmt)
52
+ if cur:
53
+ self.add_edge(prev, cur)
54
+ prev = cur
55
+ return nid
56
+
57
+ def visit_While(self, node: ast.While):
58
+ nid = self.new_id("while")
59
+ cond = ast.unparse(node.test)
60
+ self.add_node(nid, f"while {cond}")
61
+ prev = nid
62
+ for stmt in node.body:
63
+ cur = self.visit(stmt)
64
+ if cur:
65
+ self.add_edge(prev, cur)
66
+ prev = cur
67
+ return nid
68
+
69
+ def visit_If(self, node: ast.If):
70
+ nid = self.new_id("if")
71
+ cond = ast.unparse(node.test)
72
+ self.add_node(nid, f"if {cond}")
73
+ for stmt in node.body:
74
+ cur = self.visit(stmt)
75
+ if cur:
76
+ self.add_edge(nid, cur, label="true")
77
+ if node.orelse:
78
+ for stmt in node.orelse:
79
+ cur = self.visit(stmt)
80
+ if cur:
81
+ self.add_edge(nid, cur, label="false")
82
+ return nid
83
+
84
+ def visit_Expr(self, node: ast.Expr):
85
+ nid = self.new_id("expr")
86
+ txt = ast.unparse(node.value)
87
+ self.add_node(nid, txt)
88
+ return nid
89
+
90
+ def visit_Assign(self, node: ast.Assign):
91
+ nid = self.new_id("assign")
92
+ targets = ", ".join([ast.unparse(t) for t in node.targets])
93
+ val = ast.unparse(node.value)
94
+ self.add_node(nid, f"{targets} = {val}")
95
+ return nid
96
+
97
+ def generic_visit(self, node):
98
+ super().generic_visit(node)
99
+ return None
100
+
101
+ def code_to_mermaid(code: str) -> str:
102
+ tree = ast.parse(code)
103
+ vb = VizBuilder()
104
+ root_id = None
105
+ for node in tree.body:
106
+ if isinstance(node, ast.FunctionDef):
107
+ root_id = vb.visit(node)
108
+ break
109
+ lines = ["flowchart TD"]
110
+ for nid, label in vb.nodes:
111
+ lines.append(f' {nid}["{label}"]')
112
+ for a, b, lbl in vb.edges:
113
+ if lbl:
114
+ lines.append(f' {a} -->|{lbl}| {b}')
115
+ else:
116
+ lines.append(f' {a} --> {b}')
117
+ return "\n".join(lines)