File size: 1,751 Bytes
9ccb1c5
4c74307
bcd6411
4c74307
bcd6411
 
 
 
9b0878e
bcd6411
4c74307
bcd6411
 
 
 
 
 
 
 
 
 
 
 
 
 
2c3bbb9
 
 
 
bcd6411
 
 
 
 
 
 
 
 
 
 
 
 
 
9b0878e
 
bcd6411
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import gradio as gr
import json
from transformers import pipeline

# Load Hugging Face model (text classification)
classifier = pipeline(
    task="text-classification",
    model="CIRCL/cwe-parent-vulnerability-classification-roberta-base",
    return_all_scores=True
)

# Load child-to-parent mapping
with open("child_to_parent_mapping.json", "r") as f:
    child_to_parent = json.load(f)

def predict_cwe(commit_message: str):
    """
    Predict CWE(s) from a commit message and map to parent CWEs.
    """
    results = classifier(commit_message)[0]
    sorted_results = sorted(results, key=lambda x: x["score"], reverse=True)

    # Map predictions to parent CWE (if available)
    mapped_results = {}
    for item in sorted_results[:5]:
        child_cwe = item["label"].replace("CWE-", "")
        parent_cwe = child_to_parent.get(child_cwe, child_cwe)  # default to child if no parent
        mapped_results[f"CWE-{parent_cwe}"] = round(float(item["score"]), 4)

    return mapped_results

# Gradio UI
demo = gr.Interface(
    fn=predict_cwe,
    inputs=gr.Textbox(lines=3, placeholder="Enter your commit message here..."),
    outputs=gr.Label(num_top_classes=5),
    title="CWE Prediction from Commit Message",
    description="This tool uses a fine-tuned model to predict CWE categories from Git commit messages. "
                "Predicted child CWEs are mapped to their parent CWEs if applicable.",
    examples=[
        ["Fixed buffer overflow in input parsing"],
        ["SQL injection possible in login flow"],
        ["Improved input validation to prevent XSS"],
        ["Added try/catch to avoid null pointer crash"],
        ["Patched race condition in thread lock logic"]
    ]
)

if __name__ == "__main__":
    demo.launch()