|
import gradio as gr |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer |
|
) |
|
from peft import PeftModel |
|
|
|
model_name = "tiiuae/falcon-7b" |
|
model_id = "personachat-finetuned-3000-steps" |
|
template = open("template.txt", "r").read() |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, |
|
trust_remote_code = True |
|
) |
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
device_map = "auto", |
|
load_in_8bit = True, |
|
trust_remote_code = True, |
|
low_cpu_mem_usage = True |
|
) |
|
tuned_model = PeftModel.from_pretrained( |
|
base_model, |
|
model_id |
|
) |
|
|
|
def parse_response(encoded_output, user_input): |
|
decoded_output = tokenizer.batch_decode(encoded_output)[0] |
|
decoded_output = decoded_output.replace(user_input, "") |
|
decoded_output = decoded_output.split("<|endoftext|>",1)[0].strip() |
|
return decoded_output |
|
|
|
def generate(personality, user_input, state = {"base_state":[], "tune_state":[]}): |
|
try: |
|
personality = "\n".join(personality.split(".")) |
|
except: pass |
|
state["base_state"].append(user_input) |
|
state["tune_state"].append(user_input) |
|
base_prompt = template.format( |
|
personality = personality, |
|
history = "\n".join(state["base_state"]) |
|
) |
|
tune_prompt = template.format( |
|
personality = personality, |
|
history = "\n".join(state["tune_state"]) |
|
) |
|
print("****************************") |
|
print(base_prompt) |
|
print("****************************") |
|
print(tune_prompt) |
|
print("****************************") |
|
base_input_ids = tokenizer(base_prompt, return_tensors="pt").to("cuda") |
|
tune_input_ids = tokenizer(tune_prompt, return_tensors="pt").to("cuda") |
|
kwargs = dict({ |
|
"top_k": 0, |
|
"top_p": 0.9, |
|
"do_sample": True, |
|
"temperature": 0.5, |
|
"max_new_tokens": 50, |
|
"repetition_penalty": 1.1, |
|
"num_return_sequences": 1 |
|
}) |
|
base_model_response = parse_response( |
|
base_model.generate( |
|
input_ids = base_input_ids["input_ids"], |
|
**kwargs |
|
), |
|
base_prompt |
|
) |
|
tune_model_response = parse_response( |
|
tuned_model.generate( |
|
input_ids = tune_input_ids["input_ids"], |
|
**kwargs |
|
), |
|
tune_prompt |
|
) |
|
state["base_state"].append(base_model_response) |
|
state["tune_state"].append(tune_model_response) |
|
return base_model_response, tune_model_response, state |
|
|
|
gr.Interface( |
|
fn = generate, |
|
inputs = [ |
|
gr.Textbox(label = "user personality", place_holder = "Enter your personality"), |
|
gr.Textbox(label = "user chat", place_holder = "Enter your message"), |
|
"state" |
|
], |
|
outputs = [ |
|
gr.Textbox(label = "base model response"), |
|
gr.Textbox(label = "fine tuned model response"), |
|
"state" |
|
], |
|
theme = "gradio/seafoam" |
|
).launch(share = True) |