import os import gradio as gr import torch from transformers import TextStreamer, AutoModelForCausalLM, AutoTokenizer import spaces # Define the model configurations model_configs = { "CyberSentinel": { "model_name": "dad1909/CyberSentinel", "max_seq_length": 1028, "dtype": torch.float16, "load_in_4bit": True } } # Hugging Face token hf_token = os.getenv("HF_TOKEN") @spaces.GPU def load_model(selected_model): config = model_configs[selected_model] model = AutoModelForCausalLM.from_pretrained( config["model_name"], torch_dtype=config["dtype"], device_map="auto", token=hf_token ) tokenizer = AutoTokenizer.from_pretrained( config["model_name"], token=hf_token ) return model, tokenizer # Define different alpaca_prompts alpaca_prompts = { "Information": "Give me information about the following topic: {}", "Vulnerability": """Identify the line of code that is vulnerable and describe the type of software vulnerability. ### Code Snippet: {} ### Vulnerability Description:""", "Math": "Give me a math prompt to show a math problem involving: {}", "Chat": "{}" } @spaces.GPU(duration=100) def predict(selected_model, prompt, prompt_type, max_length=128): model, tokenizer = load_model(selected_model) selected_prompt = alpaca_prompts[prompt_type] formatted_prompt = selected_prompt.format(prompt) inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda") text_streamer = TextStreamer(tokenizer) output = model.generate(**inputs, streamer=text_streamer, max_new_tokens=max_length) return tokenizer.decode(output[0], skip_special_tokens=True) theme = gr.themes.Default( primary_hue=gr.themes.colors.rose, secondary_hue=gr.themes.colors.blue, font=gr.themes.GoogleFont("Source Sans Pro") ) with gr.Blocks(theme=theme) as demo: selected_model = gr.Dropdown(choices=list(model_configs.keys()), value="CyberSentinel", label="Model") prompt = gr.Textbox(lines=5, placeholder="Enter your code snippet or topic here...", label="Prompt") prompt_type = gr.Dropdown(choices=list(alpaca_prompts.keys()), value="Chat", label="Prompt Type") max_length = gr.Slider(minimum=128, maximum=2048, step=128, value=128, label="Max Length") generated_text = gr.Textbox(label="Generated Text") generate_button = gr.Button("Generate") generate_button.click(predict, inputs=[selected_model, prompt, prompt_type, max_length], outputs=generated_text) gr.Examples( examples=[ ["CyberSentinel", "What is SQL injection?", "Information", 128], ["CyberSentinel", "$buff = 'A' x 10000;\nopen(myfile, '>>PASS.PK2');\nprint myfile $buff;\nclose(myfile);", "Vulnerability", 128], ["CyberSentinel", "Solve the equation x^2 - 4x + 4 = 0", "Math", 128], ["CyberSentinel", "Can you tell me a joke?", "Chat", 128] ], inputs=[selected_model, prompt, prompt_type, max_length] ) demo.queue(default_concurrency_limit=5).launch( server_name="0.0.0.0", allowed_paths=["/"], share=True )