expertllama / app.py
SpiketheCowboy's picture
Upload folder using huggingface_hub
77f24ac
raw
history blame
3.39 kB
'''
simple demo adapted from [gradio](https://gradio.app/creating-a-chatbot/).
'''
import gradio as gr
import random
import time
import transformers
import os
import json
import torch
import argparse
from transformers import LlamaTokenizer, LlamaForCausalLM
def apply_delta(base_model_path, target_model_path, delta_path):
print(f"Loading the delta weights from {delta_path}")
delta_tokenizer = LlamaTokenizer.from_pretrained(delta_path, use_fast=False)
delta = LlamaForCausalLM.from_pretrained(
delta_path, low_cpu_mem_usage=True, torch_dtype=torch.float16
)
print(f"Loading the base model from {base_model_path}")
base_tokenizer = LlamaTokenizer.from_pretrained(base_model_path, use_fast=False)
base = LlamaForCausalLM.from_pretrained(
base_model_path, low_cpu_mem_usage=True
)
# following alpaca training recipe, we have added new initialized tokens
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"
special_tokens_dict = {
"pad_token": DEFAULT_PAD_TOKEN,
"eos_token": DEFAULT_EOS_TOKEN,
"bos_token": DEFAULT_BOS_TOKEN,
"unk_token": DEFAULT_UNK_TOKEN,
}
num_new_tokens = base_tokenizer.add_special_tokens(special_tokens_dict)
base.resize_token_embeddings(len(base_tokenizer))
input_embeddings = base.get_input_embeddings().weight.data
output_embeddings = base.get_output_embeddings().weight.data
input_embeddings[-num_new_tokens:] = 0
output_embeddings[-num_new_tokens:] = 0
print("Applying the delta")
target_weights = {}
for name, param in tqdm(base.state_dict().items(), desc="Applying delta"):
assert name in delta.state_dict()
param.data += delta.state_dict()[name]
target_weights[name] = param.data
print(f"Saving the target model to {target_model_path}")
base.load_state_dict(target_weights)
# base.save_pretrained(target_model_path)
# delta_tokenizer.save_pretrained(target_model_path)
return base, delta_tokenizer
base_weights = 'decapoda-research/llama-7b-hf'
target_weights = 'expertllama' # local path
delta_weights = 'OFA-Sys/expertllama-7b-delta'
model, tokenizer = apply_delta(base_weights, target_weights, delta_weights)
# tokenizer = transformers.LlamaTokenizer.from_pretrained(expertllama_path)
# model = transformers.LlamaForCausalLM.from_pretrained(expertllama_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
# model.cuda()
with gr.Blocks() as demo:
chatbot = gr.Chatbot()
msg = gr.Textbox()
clear = gr.Button("Clear")
def respond(message, chat_history):
# prompt wrapper, only single-turn is allowed for now
prompt = f"### Human:\n{prompt}\n\n### Assistant:\n"
batch = tokenizer(
prompt,
return_tensors="pt",
add_special_tokens=False
)
batch = {k: v.cuda() for k, v in batch.items()}
generated = model.generate(batch["input_ids"], max_length=1024, temperature=0.8)
bot_message = tokenizer.decode(generated[0][:-1]).split("### Assistant:\n", 1)[1]
chat_history.append((message, bot_message))
time.sleep(1)
return "", chat_history
msg.submit(respond, [msg, chatbot], [msg, chatbot])
clear.click(lambda: None, None, chatbot, queue=False)
demo.launch()