BlueDice's picture
Upload 3 files
80aa3d0
import gradio as gr
from transformers import (
AutoModelForCausalLM,
AutoTokenizer
)
from peft import PeftModel
model_name = "tiiuae/falcon-7b"
model_id = "personachat-finetuned-3000-steps"
template = open("template.txt", "r").read()
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code = True
)
base_model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map = "auto",
load_in_8bit = True,
trust_remote_code = True,
low_cpu_mem_usage = True
)
tuned_model = PeftModel.from_pretrained(
base_model,
model_id
)
def parse_response(encoded_output, user_input):
decoded_output = tokenizer.batch_decode(encoded_output)[0]
decoded_output = decoded_output.replace(user_input, "")
decoded_output = decoded_output.split("<|endoftext|>",1)[0].strip()
return decoded_output
def generate(personality, user_input, state = {"base_state":[], "tune_state":[]}):
try:
personality = "\n".join(personality.split("."))
except: pass
state["base_state"].append(user_input)
state["tune_state"].append(user_input)
base_prompt = template.format(
personality = personality,
history = "\n".join(state["base_state"])
)
tune_prompt = template.format(
personality = personality,
history = "\n".join(state["tune_state"])
)
print("****************************")
print(base_prompt)
print("****************************")
print(tune_prompt)
print("****************************")
base_input_ids = tokenizer(base_prompt, return_tensors="pt").to("cuda")
tune_input_ids = tokenizer(tune_prompt, return_tensors="pt").to("cuda")
kwargs = dict({
"top_k": 0,
"top_p": 0.9,
"do_sample": True,
"temperature": 0.5,
"max_new_tokens": 50,
"repetition_penalty": 1.1,
"num_return_sequences": 1
})
base_model_response = parse_response(
base_model.generate(
input_ids = base_input_ids["input_ids"],
**kwargs
),
base_prompt
)
tune_model_response = parse_response(
tuned_model.generate(
input_ids = tune_input_ids["input_ids"],
**kwargs
),
tune_prompt
)
state["base_state"].append(base_model_response)
state["tune_state"].append(tune_model_response)
return base_model_response, tune_model_response, state
gr.Interface(
fn = generate,
inputs = [
gr.Textbox(label = "user personality", place_holder = "Enter your personality"),
gr.Textbox(label = "user chat", place_holder = "Enter your message"),
"state"
],
outputs = [
gr.Textbox(label = "base model response"),
gr.Textbox(label = "fine tuned model response"),
"state"
],
theme = "gradio/seafoam"
).launch(share = True)