Llama2_jailbreak / test.py
xuyichang's picture
Upload folder using huggingface_hub
5f76aa6
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
import gradio as gr
tokenizer = AutoTokenizer.from_pretrained("Llama-2-13b-chat-hf", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained('./results/final_checkpoint', device_map={"":0}, torch_dtype=torch.bfloat16)
def answer(my_input):
text = "<s>[INST] "+my_input+" [/INST]"
inputs = tokenizer(text, return_tensors="pt").to("cuda")
outputs = model.generate(input_ids=inputs["input_ids"].to("cuda"), attention_mask=inputs["attention_mask"], max_new_tokens=256, pad_token_id=tokenizer.eos_token_id)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = result.split('[/INST]')[1].strip()
return response
def expand(my_input, response):
text = "<s>[INST] "+my_input+" [/INST]"
text += " " + response
inputs = tokenizer(text, return_tensors="pt").to("cuda")
outputs = model.generate(input_ids=inputs["input_ids"].to("cuda"), attention_mask=inputs["attention_mask"], max_new_tokens=256, pad_token_id=tokenizer.eos_token_id)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = result.split('[/INST]')[1].strip()
return response
with gr.Blocks() as demo:
prompt = gr.Textbox(label = 'Prompt')
response = gr.Textbox(label = 'Response')
generate_btn = gr.Button('Generate')
generate_btn.click(fn = answer, inputs = prompt, outputs = response, api_name = 'answer')
expand_btn = gr.Button('Continue')
expand_btn.click(fn = expand, inputs = [prompt, response], outputs = response, api_name = 'continue')
demo.launch(share = True)