import gradio as gr from transformers import ( AutoModelForCausalLM, AutoTokenizer ) from peft import PeftModel model_name = "tiiuae/falcon-7b" model_id = "personachat-finetuned-3000-steps" template = open("template.txt", "r").read() tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code = True ) base_model = AutoModelForCausalLM.from_pretrained( model_name, device_map = "auto", load_in_8bit = True, trust_remote_code = True, low_cpu_mem_usage = True ) tuned_model = PeftModel.from_pretrained( base_model, model_id ) def parse_response(encoded_output, user_input): decoded_output = tokenizer.batch_decode(encoded_output)[0] decoded_output = decoded_output.replace(user_input, "") decoded_output = decoded_output.split("<|endoftext|>",1)[0].strip() return decoded_output def generate(personality, user_input, state = {"base_state":[], "tune_state":[]}): try: personality = "\n".join(personality.split(".")) except: pass state["base_state"].append(user_input) state["tune_state"].append(user_input) base_prompt = template.format( personality = personality, history = "\n".join(state["base_state"]) ) tune_prompt = template.format( personality = personality, history = "\n".join(state["tune_state"]) ) print("****************************") print(base_prompt) print("****************************") print(tune_prompt) print("****************************") base_input_ids = tokenizer(base_prompt, return_tensors="pt").to("cuda") tune_input_ids = tokenizer(tune_prompt, return_tensors="pt").to("cuda") kwargs = dict({ "top_k": 0, "top_p": 0.9, "do_sample": True, "temperature": 0.5, "max_new_tokens": 50, "repetition_penalty": 1.1, "num_return_sequences": 1 }) base_model_response = parse_response( base_model.generate( input_ids = base_input_ids["input_ids"], **kwargs ), base_prompt ) tune_model_response = parse_response( tuned_model.generate( input_ids = tune_input_ids["input_ids"], **kwargs ), tune_prompt ) state["base_state"].append(base_model_response) state["tune_state"].append(tune_model_response) return base_model_response, tune_model_response, state gr.Interface( fn = generate, inputs = [ gr.Textbox(label = "user personality", place_holder = "Enter your personality"), gr.Textbox(label = "user chat", place_holder = "Enter your message"), "state" ], outputs = [ gr.Textbox(label = "base model response"), gr.Textbox(label = "fine tuned model response"), "state" ], theme = "gradio/seafoam" ).launch(share = True)