import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForSequenceClassification import gradio as gr # Example code snippets VULNERABLE_EXAMPLE = """static int cirrus_bitblt_videotovideo_patterncopy(CirrusVGAState * s)\n{\n return cirrus_bitblt_common_patterncopy(s,\n\t\t\t\t\t s->vram_ptr +\n (s->cirrus_blt_srcaddr & ~7));\n}""" NON_VULNERABLE_EXAMPLE = """static void loongarch_cpu_synchronize_from_tb(CPUState *cs, \n const TranslationBlock *tb)\n{\n LoongArchCPU *cpu = LOONGARCH_CPU(cs);\n CPULoongArchState *env = &cpu->env;\n\n env->pc = tb->pc;\n}""" # Load the model and tokenizer def load_model(): """Load the model and tokenizer""" model_name = "moazx/Code-Vulnerability-Classifier_app" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) model.eval() return model, tokenizer, device # Load the model and tokenizer once when the app starts model, tokenizer, device = load_model() def classify_code_sample(code_sample): """Classify a single code sample and get probabilities""" inputs = tokenizer( code_sample, truncation=True, padding='max_length', max_length=512, return_tensors="pt" ).to(device) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probabilities = F.softmax(logits, dim=-1).squeeze().cpu().numpy() return probabilities def analyze_code(code_input): """Analyze the code and return results""" if not code_input.strip(): return "Please enter some code to analyze." try: # Get predictions probabilities = classify_code_sample(code_input) # Class names and confidence class_names = ["Non-vulnerable", "Vulnerable"] predicted_class_index = probabilities.argmax() predicted_class = class_names[predicted_class_index] confidence = probabilities[predicted_class_index] * 100 # Prepare results result = f"**Prediction:** {predicted_class}\n" result += f"**Confidence:** {confidence:.1f}%\n\n" # Detailed probabilities result += "**Detailed Probabilities:**\n" for class_name, prob in zip(class_names, probabilities): result += f"- {class_name}: {prob * 100:.1f}%\n" # Additional warnings for vulnerable code if predicted_class == "Vulnerable": result += "\nāš ļø **Warning:** This code has been flagged as potentially vulnerable. Please review it carefully for:\n" result += "- Security issues (e.g., input validation, authentication)\n" result += "- Implementation issues (e.g., memory management, resource handling)\n" result += "- Design issues (e.g., concurrency, logic errors)\n" return result except Exception as e: return f"Error during analysis: {str(e)}" # Gradio Interface with gr.Blocks() as demo: gr.Markdown("# DiverseVul Code Vulnerability Classifier") gr.Markdown(""" This tool analyzes code snippets for various types of vulnerabilities, including: - Security vulnerabilities (e.g., buffer overflows, injection flaws) - Memory management issues - Concurrency problems - Resource leaks - Logic errors - Performance issues - Reliability problems """) with gr.Row(): with gr.Column(): code_input = gr.Textbox( label="Enter your code snippet here:", placeholder="Paste your code here...", lines=10, max_lines=20, value="" ) analyze_button = gr.Button("Analyze Code") with gr.Column(): output = gr.Markdown(label="Analysis Results") # Example buttons gr.Markdown("### Try an Example") with gr.Row(): vulnerable_example_button = gr.Button("šŸ“‹ Load Vulnerable Example") non_vulnerable_example_button = gr.Button("šŸ“‹ Load Non-Vulnerable Example") # Event handlers analyze_button.click( analyze_code, inputs=code_input, outputs=output ) vulnerable_example_button.click( lambda: VULNERABLE_EXAMPLE, outputs=code_input ) non_vulnerable_example_button.click( lambda: NON_VULNERABLE_EXAMPLE, outputs=code_input ) # Launch the Gradio app if __name__ == "__main__": demo.launch()