Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import torch | |
from transformers import TextStreamer, AutoModelForCausalLM, AutoTokenizer | |
import spaces | |
# Define the model configuration | |
model_config = { | |
"model_name": "admincybers2/sentinal", | |
"max_seq_length": 1024, | |
"dtype": torch.float16, | |
"load_in_4bit": True | |
} | |
# Hugging Face token | |
hf_token = os.getenv("HF_TOKEN") | |
# Load the model when the application starts | |
loaded_model = None | |
loaded_tokenizer = None | |
def load_model(): | |
global loaded_model, loaded_tokenizer | |
if loaded_model is None: | |
model = AutoModelForCausalLM.from_pretrained( | |
model_config["model_name"], | |
torch_dtype=model_config["dtype"], | |
device_map="auto", | |
use_auth_token=hf_token | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_config["model_name"], | |
use_auth_token=hf_token | |
) | |
loaded_model = model | |
loaded_tokenizer = tokenizer | |
return loaded_model, loaded_tokenizer | |
# Vulnerability prompt template | |
vulnerability_prompt = """Identify the specific line of code that is vulnerable and describe the type of software vulnerability. | |
### Vulnerable Line: | |
{} | |
### Vulnerability Description: | |
""" | |
def predict(prompt): | |
model, tokenizer = load_model() | |
formatted_prompt = vulnerability_prompt.format(prompt) # Ensure this matches the correct number of placeholders | |
inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda") | |
text_streamer = TextStreamer(tokenizer) | |
output = model.generate( | |
**inputs, | |
streamer=text_streamer, | |
use_cache=True, | |
temperature=0.4, | |
top_k=50, # Default value, considers the top 50 most likely next tokens | |
top_p=0.9, # Nucleus sampling, focuses on the most likely token set | |
min_p=0.01, # Ensures that tokens below this probability are less likely to be selected | |
typical_p=0.95, # Focuses on tokens that are most typical given the context | |
repetition_penalty=1.2, # Penalizes repetitive sequences to improve text diversity | |
no_repeat_ngram_size=3, # Prevents the same 3-gram sequence from repeating | |
renormalize_logits=True, # Ensures logits are normalized after processing | |
max_new_tokens=640 | |
) | |
return tokenizer.decode(output[0], skip_special_tokens=True) | |
theme = gr.themes.Default( | |
primary_hue=gr.themes.colors.rose, | |
secondary_hue=gr.themes.colors.blue, | |
font=gr.themes.GoogleFont("Source Sans Pro") | |
) | |
# Pre-load the model | |
load_model() | |
with gr.Blocks(theme=theme) as demo: | |
prompt = gr.Textbox(lines=5, placeholder="Enter your code snippet or topic here...", label="Prompt") | |
generated_text = gr.Textbox(label="Generated Text") | |
generate_button = gr.Button("Generate") | |
generate_button.click(predict, inputs=[prompt], outputs=generated_text) | |
gr.Examples( | |
examples=[ | |
["$buff = 'A' x 10000;\nopen(myfile, '>>PASS.PK2');\nprint myfile $buff;\nclose(myfile);"] | |
], | |
inputs=[prompt] | |
) | |
demo.queue(default_concurrency_limit=10).launch( | |
server_name="0.0.0.0", | |
allowed_paths=["/"] | |
) |