TransGPT / app.py
iKING-ROC's picture
Update app.py
a7d13e5
raw
history blame
8.25 kB
import argparse
import os
import gradio as gr
import mdtex2html
from gradio.themes.utils import colors, fonts, sizes
import torch
from peft import PeftModel
from transformers import (
AutoModel,
AutoTokenizer,
AutoModelForCausalLM,
BloomForCausalLM,
BloomTokenizerFast,
LlamaTokenizer,
LlamaForCausalLM,
GenerationConfig,
)
MODEL_CLASSES = {
"bloom": (BloomForCausalLM, BloomTokenizerFast),
"chatglm": (AutoModel, AutoTokenizer),
"llama": (LlamaForCausalLM, LlamaTokenizer),
"auto": (AutoModelForCausalLM, AutoTokenizer),
}
class OpenGVLab(gr.themes.base.Base):
def __init__(
self,
*,
primary_hue=colors.blue,
secondary_hue=colors.sky,
neutral_hue=colors.gray,
spacing_size=sizes.spacing_md,
radius_size=sizes.radius_sm,
text_size=sizes.text_md,
font=(
fonts.GoogleFont("Noto Sans"),
"ui-sans-serif",
"sans-serif",
),
font_mono=(
fonts.GoogleFont("IBM Plex Mono"),
"ui-monospace",
"monospace",
),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
spacing_size=spacing_size,
radius_size=radius_size,
text_size=text_size,
font=font,
font_mono=font_mono,
)
super().set(
body_background_fill="*neutral_50",
)
gvlabtheme = OpenGVLab(primary_hue=colors.blue,
secondary_hue=colors.sky,
neutral_hue=colors.gray,
spacing_size=sizes.spacing_md,
radius_size=sizes.radius_sm,
text_size=sizes.text_md,
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model_type', default="llama", type=str)
parser.add_argument('--base_model', default="DUOMO-Lab/TransGPT-v0", type=str)
parser.add_argument('--lora_model', default="", type=str, help="If None, perform inference on the base model")
parser.add_argument('--tokenizer_path', default=None, type=str)
parser.add_argument('--gpus', default="0", type=str)
parser.add_argument('--only_cpu', action='store_true', help='only use CPU for inference')
parser.add_argument('--resize_emb', action='store_true', help='Whether to resize model token embeddings')
args = parser.parse_args()
if args.only_cpu is True:
args.gpus = ""
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
def postprocess(self, y):
if y is None:
return []
for i, (message, response) in enumerate(y):
y[i] = (
None if message is None else mdtex2html.convert((message)),
None if response is None else mdtex2html.convert(response),
)
return y
gr.Chatbot.postprocess = postprocess
generation_config = dict(
temperature=0.2,
top_k=40,
top_p=0.9,
do_sample=True,
num_beams=1,
repetition_penalty=1.1,
max_new_tokens=400
)
load_type = torch.float16
if torch.cuda.is_available():
device = torch.device(0)
else:
device = torch.device('cpu')
if args.tokenizer_path is None:
args.tokenizer_path = args.base_model
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_path, trust_remote_code=True)
base_model = model_class.from_pretrained(
args.base_model,
load_in_8bit=False,
torch_dtype=load_type,
low_cpu_mem_usage=True,
device_map='auto',
trust_remote_code=True,
)
if args.resize_emb:
model_vocab_size = base_model.get_input_embeddings().weight.size(0)
tokenzier_vocab_size = len(tokenizer)
print(f"Vocab of the base model: {model_vocab_size}")
print(f"Vocab of the tokenizer: {tokenzier_vocab_size}")
if model_vocab_size != tokenzier_vocab_size:
print("Resize model embeddings to fit tokenizer")
base_model.resize_token_embeddings(tokenzier_vocab_size)
if args.lora_model:
model = PeftModel.from_pretrained(base_model, args.lora_model, torch_dtype=load_type, device_map='auto')
print("loaded lora model")
else:
model = base_model
if device == torch.device('cpu'):
model.float()
model.eval()
def reset_user_input():
return gr.update(value='')
def reset_state():
return [], []
def generate_prompt(instruction):
return f"""You are TransGPT, a specialist in the field of transportation.Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Response: """
def predict(
input,
chatbot,
history,
max_new_tokens=128,
top_p=0.75,
temperature=0.1,
top_k=40,
num_beams=4,
repetition_penalty=1.0,
max_memory=256,
**kwargs,
):
now_input = input
chatbot.append((input, ""))
history = history or []
if len(history) != 0:
input = "".join(
["### Instruction:\n" + i[0] + "\n\n" + "### Response: " + i[1] + "\n\n" for i in history]) + \
"### Instruction:\n" + input
input = input[len("### Instruction:\n"):]
if len(input) > max_memory:
input = input[-max_memory:]
prompt = generate_prompt(input)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
**kwargs,
)
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=False,
max_new_tokens=max_new_tokens,
repetition_penalty=float(repetition_penalty),
)
s = generation_output.sequences[0]
output = tokenizer.decode(s, skip_special_tokens=True)
output = output.split("### Response:")[-1].strip()
history.append((now_input, output))
chatbot[-1] = (now_input, output)
return chatbot, history
title = """<h1 align="center">Welcome to TransGPT!"""
with gr.Blocks(title="DUOMO TransGPT!", theme=gvlabtheme,
css="#chatbot {overflow:auto; height:500px;} #InputVideo {overflow:visible; height:320px;} footer {visibility: none}") as demo:
gr.Markdown(title)
chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=4):
with gr.Column(scale=12):
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
container=False)
with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History")
max_length = gr.Slider(
0, 4096, value=128, step=1.0, label="Maximum length", interactive=True)
top_p = gr.Slider(0, 1, value=0.8, step=0.01,
label="Top P", interactive=True)
temperature = gr.Slider(
0, 1, value=0.7, step=0.01, label="Temperature", interactive=True)
history = gr.State([]) # (message, bot_message)
submitBtn.click(predict, [user_input, chatbot, history, max_length, top_p, temperature], [chatbot, history],
show_progress=True)
submitBtn.click(reset_user_input, [], [user_input])
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
demo.queue().launch()
if __name__ == '__main__':
main()