File size: 3,257 Bytes
6d0c8f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig, AutoModelForSequenceClassification
import torch.nn.functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
llama_models = {
"Chat-IPT 3.2": "meta-llama/Llama-3.2-1B-Instruct",
}
def load_model(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device=device)
return generator
model_cache = {}
def predict(model, prompt, response=None):
device = model.device
if response is None:
inputs = tokenizer(prompt, return_tensors="pt")
else:
inputs = tokenizer(prompt, response, return_tensors="pt", padding=True, truncation=True)
inputs = inputs.to(device)
with torch.no_grad():
outputs = model(**inputs)
unsafe_prob = F.softmax(outputs.logits, dim=-1)[:, 1]
return unsafe_prob.item()
tokenizer = AutoTokenizer.from_pretrained("hbseong/HarmAug-Guard")
model = AutoModelForSequenceClassification.from_pretrained("hbseong/HarmAug-Guard")
model = model.to(device)
model.eval()
def generate_chat(user_input, history, model_choice):
if model_choice not in model_cache:
model_cache[model_choice] = load_model(llama_models[model_choice])
generator = model_cache[model_choice]
system_prompt = {"role": "system", "content": "Você é um ótimo assistente"}
if history is None:
history = [system_prompt]
history.append({"role": "user", "content": user_input})
if len(history) > 5:
history = history[-5:]
response = generator(
user_input,
max_length=1024,
pad_token_id=generator.tokenizer.eos_token_id,
do_sample=True,
temperature=0.55,
top_p=0.7
)[0]["generated_text"]
unsafe_score = predict(model, user_input, response)
safety_threshold = 0.85
if unsafe_score > safety_threshold:
response = "Desculpe, mas não posso fornecer ajuda com essa solicitação devido a questões de segurança ou ética."
history.append({"role": "assistant", "content": response})
return history
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("<h1><center>Teste</center></h1>")
model_choice = gr.Dropdown(list(llama_models.keys()), label="Selecione o Modelo.")
chatbot = gr.Chatbot(label=" ", type = "messages")
txt_input = gr.Textbox(show_label=False, placeholder="Escreva a sua mensagem aqui...")
def respond(user_input, chat_history, model_choice):
if model_choice is None:
model_choice = list(llama_models.keys())[0]
updated_history = generate_chat(user_input, chat_history, model_choice)
return "", updated_history
txt_input.submit(respond, [txt_input, chatbot, model_choice], [txt_input, chatbot])
submit_btn = gr.Button("Enviar")
submit_btn.click(respond, [txt_input, chatbot, model_choice], [txt_input, chatbot])
demo.launch(debug=False, show_error=True, share=True) |