Spaces:
Paused
Paused
import gradio as gr | |
from transformers import AutoModel, AutoTokenizer | |
import sys | |
import torch | |
import argparse | |
from peft import PeftModel | |
import transformers | |
from collections import namedtuple | |
from transformers import ( | |
LlamaForCausalLM, LlamaTokenizer, | |
AutoModel, AutoTokenizer, | |
BloomForCausalLM, BloomTokenizerFast, GenerationConfig) | |
tokenizer=None | |
model=None | |
LOAD_8BIT = False | |
ModelClass = namedtuple("ModelClass", ('tokenizer', 'model')) | |
_MODEL_CLASSES = { | |
"llama": ModelClass(**{ | |
"tokenizer": LlamaTokenizer, | |
"model": LlamaForCausalLM, | |
}), | |
"bloom": ModelClass(**{ | |
"tokenizer": AutoTokenizer, | |
"model": BloomForCausalLM, | |
}) | |
} | |
if torch.cuda.is_available(): | |
device = "cuda" | |
else: | |
device = "cpu" | |
def get_model_class(model_type, | |
model_name_or_path, | |
lora_model_path): | |
global model, tokenizer | |
model_class = _MODEL_CLASSES[model_type] # tokenizer, model | |
model_base = model_class.model.from_pretrained(model_name_or_path, | |
load_in_8bit=LOAD_8BIT, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
) | |
tokenizer = model_class.tokenizer.from_pretrained(model_name_or_path) # default add_eos_token=False | |
model = PeftModel.from_pretrained( | |
model_base, | |
lora_model_path, | |
torch_dtype=torch.float16, | |
) | |
if not LOAD_8BIT: | |
model.half() | |
def predict( | |
instruction, | |
top_p=0.75, | |
temperature=0.1, | |
history=None, | |
top_k=40, | |
num_beams=4, | |
max_new_tokens=512, | |
**kwargs, | |
): | |
history = history or [] | |
prompt = ( | |
"Below is an instruction that describes a task. " | |
"Write a response that appropriately completes the request.\n\n" | |
"### Instruction:\n{0}\n\n### Response:" | |
).format(instruction) | |
inputs = tokenizer(prompt, return_tensors="pt") | |
input_ids = inputs["input_ids"].to(device) | |
generation_config = GenerationConfig( | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
num_beams=num_beams, | |
**kwargs, | |
) | |
with torch.no_grad(): | |
generation_output = model.generate( | |
input_ids=input_ids, | |
generation_config=generation_config, | |
return_dict_in_generate=True, | |
output_scores=True, | |
max_new_tokens=max_new_tokens, | |
) | |
s = generation_output.sequences[0] | |
output = tokenizer.decode(s) | |
print('樑εεε€', output) | |
bot_response = output.split("### Response:")[1].strip() | |
history.append((instruction, bot_response)) | |
return "", history, history | |
def predict_test(message, top_p, temperature, history): | |
history = history or [] | |
user_message = f"{message} {top_p}, {temperature}" | |
print(user_message) | |
history.append((message, user_message)) | |
return history, history | |
def clear_session(): | |
return '', '', None | |
parser = argparse.ArgumentParser(description='Process some integers.') | |
parser.add_argument('--size', default=7, type=int, help='the size of llama model') | |
parser.add_argument('--data', default="", type=str, help='the data used for instructing tuning') | |
parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training') | |
parser.add_argument('--model_type', default="llama", choices=['llama', 'chatglm', 'bloom']) | |
parser.add_argument('--model_name_or_path', default="decapoda-research/llama-7b-hf", type=str) | |
parser.add_argument('--lora_name_or_path', default="./saved-alpaca-belle-cot7b", type=str) | |
args = parser.parse_args() | |
get_model_class(args.model_type, args.model_name_or_path, args.lora_name_or_path) | |
block = gr.Blocks(css = """#col_container { margin-left: auto; margin-right: auto;} #chatbot {height: 520px; overflow: auto;}""") | |
with block as demo: | |
#top_p, temperature | |
with gr.Accordion("Parameters", open=False): | |
top_p = gr.Slider( minimum=-0, maximum=1.0, value=0.75, step=0.05, interactive=True, label="Top-p (nucleus sampling)",) | |
temperature = gr.Slider( minimum=-0, maximum=5.0, value=0.1, step=0.1, interactive=True, label="Temperature",) | |
chatbot = gr.Chatbot(label="Alpaca-CoT") | |
message = gr.Textbox() | |
state = gr.State() | |
message.submit(predict, inputs=[message, top_p, temperature, state], outputs=[message, chatbot, state], queue=False) | |
with gr.Row(): | |
clear_history = gr.Button("π ζΈ ι€εε²ε―Ήθ― | Clear History") | |
clear = gr.Button('π§Ή ζΈ ι€θΎε ₯ | Clear Input') | |
send = gr.Button("π ει | Send") | |
regenerate = gr.Button("π ιζ°ηζ | regenerate") | |
# regenerate.click(regenerate, inputs=[message], outputs=[chatbot]) | |
regenerate.click(fn=clear_session , inputs=[], outputs=[message, chatbot, state], queue=False) | |
send.click(predict, inputs=[message, top_p, temperature, state], outputs=[message, chatbot, state]) | |
clear.click(lambda: None, None, message, queue=False) | |
clear_history.click(fn=clear_session , inputs=[], outputs=[message, chatbot, state], queue=False) | |
demo.queue(max_size=20, concurrency_count=20).launch() |