elselse's picture
Update app.py
6519258 verified
import gradio as gr
import json
from transformers import pipeline
# Load Hugging Face model (text classification)
classifier = pipeline(
task="text-classification",
model="CIRCL/cwe-parent-vulnerability-classification-roberta-base",
return_all_scores=True
)
# Load child-to-parent mapping
with open("child_to_parent_mapping.json", "r") as f:
child_to_parent = json.load(f)
def predict_cwe(commit_message: str):
"""
Predict CWE(s) from a commit message and map to parent CWEs.
"""
results = classifier(commit_message)[0]
sorted_results = sorted(results, key=lambda x: x["score"], reverse=True)
# Map predictions to parent CWE (if available)
mapped_results = {}
for item in sorted_results[:5]:
child_cwe = item["label"].replace("CWE-", "")
parent_cwe = child_to_parent.get(child_cwe, child_cwe) # default to child if no parent
mapped_results[f"CWE-{parent_cwe}"] = round(float(item["score"]), 4)
return mapped_results
# Gradio UI
demo = gr.Interface(
fn=predict_cwe,
inputs=gr.Textbox(lines=3, placeholder="Enter your commit message here..."),
outputs=gr.Label(num_top_classes=5),
title="CWE Prediction from Commit Message",
description="This tool uses a fine-tuned model to predict CWE categories from Git commit messages. "
"Predicted child CWEs are mapped to their parent CWEs if applicable.",
examples=[
["Fixed buffer overflow in input parsing"],
["SQL injection possible in login flow"],
["Improved input validation to prevent XSS"],
["Added try/catch to avoid null pointer crash"],
["Patched race condition in thread lock logic"]
]
)
if __name__ == "__main__":
demo.launch()