doc-mining / mine_diffs.py
SharathReddy's picture
Update mine_diffs.py
eb46234 verified
import os
import json
import ast
import tempfile
from git import Repo, exc
# --- Configuration ---
REPO_CONFIG = {
"fastapi": "https://github.com/tiangolo/fastapi.git",
"requests": "https://github.com/psf/requests.git",
"scikit-learn": "https://github.com/scikit-learn/scikit-learn.git"
}
OUTPUT_FILE = "diff_dataset.jsonl"
MAX_COMMITS_PER_REPO = 5000
class FuncParser(ast.NodeVisitor):
def __init__(self):
self.functions = {}
def visit_FunctionDef(self, node):
docstring = ast.get_docstring(node) or ""
self.functions[node.name] = docstring
self.generic_visit(node)
def get_functions_from_source(source_code):
try:
tree = ast.parse(source_code)
parser = FuncParser()
parser.visit(tree)
return parser.functions
except SyntaxError:
return {}
def format_for_model(diff_text, old_doc, new_doc):
return {
"text": f"""### INSTRUCTION:
A Python function's code was changed. Based on the `git diff` provided, update the function's documentation.
### GIT DIFF:
```diff
{diff_text}
OLD DOCUMENTATION:
{old_doc.strip()}
UPDATED DOCUMENTATION:
{new_doc.strip()}
"""
}
def main():
dataset = []
base_repo_dir = tempfile.mkdtemp()
print(f"Using temporary directory for clones: {base_repo_dir}")
for name, url in REPO_CONFIG.items():
repo_dir = os.path.join(base_repo_dir, name)
try:
print(f"Cloning {name} from {url}...")
repo = Repo.clone_from(url, repo_dir)
except exc.GitCommandError as e:
print(f"Error cloning {name}: {e}")
continue
print(f"Mining commit history for {name}...")
commits = list(repo.iter_commits(max_count=MAX_COMMITS_PER_REPO))
for commit in commits:
if not commit.parents:
continue
parent = commit.parents[0]
diffs = commit.diff(parent, create_patch=True, unified=0)
for diff in diffs:
if not (diff.a_path and diff.b_path and diff.a_path.endswith('.py') and diff.b_path.endswith('.py')):
continue
if diff.a_blob is None or diff.b_blob is None:
continue
try:
old_source = diff.a_blob.data_stream.read().decode('utf-8')
new_source = diff.b_blob.data_stream.read().decode('utf-8')
except UnicodeDecodeError:
continue
old_funcs = get_functions_from_source(old_source)
new_funcs = get_functions_from_source(new_source)
for func_name, old_doc in old_funcs.items():
if func_name in new_funcs:
new_doc = new_funcs[func_name]
if old_doc != new_doc and len(old_doc) > 20 and len(new_doc) > 20:
diff_text = diff.diff.decode('utf-8', errors='ignore')
formatted_example = format_for_model(diff_text, old_doc, new_doc)
dataset.append(formatted_example)
print(f"\nFound {len(dataset)} high-quality examples.")
try:
with open(OUTPUT_FILE, 'w') as f:
for item in dataset:
f.write(json.dumps(item) + "\n")
print(f"Dataset successfully saved to '{OUTPUT_FILE}'.")
except Exception as e:
print(f"FATAL: Could not write final dataset file to {OUTPUT_FILE}. Error: {e}")
if __name__ == "main":
main()