alpaca-cot / app.py
love2poppy's picture
update chatbot state
3fa0f4e
raw
history blame contribute delete
No virus
5.37 kB
import gradio as gr
from transformers import AutoModel, AutoTokenizer
import sys
import torch
import argparse
from peft import PeftModel
import transformers
from collections import namedtuple
from transformers import (
LlamaForCausalLM, LlamaTokenizer,
AutoModel, AutoTokenizer,
BloomForCausalLM, BloomTokenizerFast, GenerationConfig)
tokenizer=None
model=None
LOAD_8BIT = False
ModelClass = namedtuple("ModelClass", ('tokenizer', 'model'))
_MODEL_CLASSES = {
"llama": ModelClass(**{
"tokenizer": LlamaTokenizer,
"model": LlamaForCausalLM,
}),
"bloom": ModelClass(**{
"tokenizer": AutoTokenizer,
"model": BloomForCausalLM,
})
}
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
def get_model_class(model_type,
model_name_or_path,
lora_model_path):
global model, tokenizer
model_class = _MODEL_CLASSES[model_type] # tokenizer, model
model_base = model_class.model.from_pretrained(model_name_or_path,
load_in_8bit=LOAD_8BIT,
torch_dtype=torch.float16,
device_map="auto",
)
tokenizer = model_class.tokenizer.from_pretrained(model_name_or_path) # default add_eos_token=False
model = PeftModel.from_pretrained(
model_base,
lora_model_path,
torch_dtype=torch.float16,
)
if not LOAD_8BIT:
model.half()
def predict(
instruction,
top_p=0.75,
temperature=0.1,
history=None,
top_k=40,
num_beams=4,
max_new_tokens=512,
**kwargs,
):
history = history or []
prompt = (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{0}\n\n### Response:"
).format(instruction)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
**kwargs,
)
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens,
)
s = generation_output.sequences[0]
output = tokenizer.decode(s)
print('ζ¨‘εž‹ε›žε€', output)
bot_response = output.split("### Response:")[1].strip()
history.append((instruction, bot_response))
return "", history, history
def predict_test(message, top_p, temperature, history):
history = history or []
user_message = f"{message} {top_p}, {temperature}"
print(user_message)
history.append((message, user_message))
return history, history
def clear_session():
return '', '', None
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--size', default=7, type=int, help='the size of llama model')
parser.add_argument('--data', default="", type=str, help='the data used for instructing tuning')
parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training')
parser.add_argument('--model_type', default="llama", choices=['llama', 'chatglm', 'bloom'])
parser.add_argument('--model_name_or_path', default="decapoda-research/llama-7b-hf", type=str)
parser.add_argument('--lora_name_or_path', default="./saved-alpaca-belle-cot7b", type=str)
args = parser.parse_args()
get_model_class(args.model_type, args.model_name_or_path, args.lora_name_or_path)
block = gr.Blocks(css = """#col_container { margin-left: auto; margin-right: auto;} #chatbot {height: 520px; overflow: auto;}""")
with block as demo:
#top_p, temperature
with gr.Accordion("Parameters", open=False):
top_p = gr.Slider( minimum=-0, maximum=1.0, value=0.75, step=0.05, interactive=True, label="Top-p (nucleus sampling)",)
temperature = gr.Slider( minimum=-0, maximum=5.0, value=0.1, step=0.1, interactive=True, label="Temperature",)
chatbot = gr.Chatbot(label="Alpaca-CoT")
message = gr.Textbox()
state = gr.State()
message.submit(predict, inputs=[message, top_p, temperature, state], outputs=[message, chatbot, state], queue=False)
with gr.Row():
clear_history = gr.Button("πŸ—‘ ζΈ…ι™€εŽ†ε²ε―Ήθ― | Clear History")
clear = gr.Button('🧹 清陀输ε…₯ | Clear Input')
send = gr.Button("πŸš€ 发送 | Send")
regenerate = gr.Button("πŸš— ι‡ζ–°η”Ÿζˆ | regenerate")
# regenerate.click(regenerate, inputs=[message], outputs=[chatbot])
regenerate.click(fn=clear_session , inputs=[], outputs=[message, chatbot, state], queue=False)
send.click(predict, inputs=[message, top_p, temperature, state], outputs=[message, chatbot, state])
clear.click(lambda: None, None, message, queue=False)
clear_history.click(fn=clear_session , inputs=[], outputs=[message, chatbot, state], queue=False)
demo.queue(max_size=20, concurrency_count=20).launch()