import gradio as gr from transformers import AutoModel, AutoTokenizer import sys import torch import argparse from peft import PeftModel import transformers from collections import namedtuple from transformers import ( LlamaForCausalLM, LlamaTokenizer, AutoModel, AutoTokenizer, BloomForCausalLM, BloomTokenizerFast, GenerationConfig) tokenizer=None model=None LOAD_8BIT = False ModelClass = namedtuple("ModelClass", ('tokenizer', 'model')) _MODEL_CLASSES = { "llama": ModelClass(**{ "tokenizer": LlamaTokenizer, "model": LlamaForCausalLM, }), "bloom": ModelClass(**{ "tokenizer": AutoTokenizer, "model": BloomForCausalLM, }) } if torch.cuda.is_available(): device = "cuda" else: device = "cpu" def get_model_class(model_type, model_name_or_path, lora_model_path): global model, tokenizer model_class = _MODEL_CLASSES[model_type] # tokenizer, model model_base = model_class.model.from_pretrained(model_name_or_path, load_in_8bit=LOAD_8BIT, torch_dtype=torch.float16, device_map="auto", ) tokenizer = model_class.tokenizer.from_pretrained(model_name_or_path) # default add_eos_token=False model = PeftModel.from_pretrained( model_base, lora_model_path, torch_dtype=torch.float16, ) if not LOAD_8BIT: model.half() def predict( instruction, top_p=0.75, temperature=0.1, history=None, top_k=40, num_beams=4, max_new_tokens=512, **kwargs, ): history = history or [] prompt = ( "Below is an instruction that describes a task. " "Write a response that appropriately completes the request.\n\n" "### Instruction:\n{0}\n\n### Response:" ).format(instruction) inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to(device) generation_config = GenerationConfig( temperature=temperature, top_p=top_p, top_k=top_k, num_beams=num_beams, **kwargs, ) with torch.no_grad(): generation_output = model.generate( input_ids=input_ids, generation_config=generation_config, return_dict_in_generate=True, output_scores=True, max_new_tokens=max_new_tokens, ) s = generation_output.sequences[0] output = tokenizer.decode(s) print('ζ¨‘εž‹ε›žε€', output) bot_response = output.split("### Response:")[1].strip() history.append((instruction, bot_response)) return "", history, history def predict_test(message, top_p, temperature, history): history = history or [] user_message = f"{message} {top_p}, {temperature}" print(user_message) history.append((message, user_message)) return history, history def clear_session(): return '', '', None parser = argparse.ArgumentParser(description='Process some integers.') parser.add_argument('--size', default=7, type=int, help='the size of llama model') parser.add_argument('--data', default="", type=str, help='the data used for instructing tuning') parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training') parser.add_argument('--model_type', default="llama", choices=['llama', 'chatglm', 'bloom']) parser.add_argument('--model_name_or_path', default="decapoda-research/llama-7b-hf", type=str) parser.add_argument('--lora_name_or_path', default="./saved-alpaca-belle-cot7b", type=str) args = parser.parse_args() get_model_class(args.model_type, args.model_name_or_path, args.lora_name_or_path) block = gr.Blocks(css = """#col_container { margin-left: auto; margin-right: auto;} #chatbot {height: 520px; overflow: auto;}""") with block as demo: #top_p, temperature with gr.Accordion("Parameters", open=False): top_p = gr.Slider( minimum=-0, maximum=1.0, value=0.75, step=0.05, interactive=True, label="Top-p (nucleus sampling)",) temperature = gr.Slider( minimum=-0, maximum=5.0, value=0.1, step=0.1, interactive=True, label="Temperature",) chatbot = gr.Chatbot(label="Alpaca-CoT") message = gr.Textbox() state = gr.State() message.submit(predict, inputs=[message, top_p, temperature, state], outputs=[message, chatbot, state], queue=False) with gr.Row(): clear_history = gr.Button("πŸ—‘ ζΈ…ι™€εŽ†ε²ε―Ήθ― | Clear History") clear = gr.Button('🧹 清陀输ε…₯ | Clear Input') send = gr.Button("πŸš€ 发送 | Send") regenerate = gr.Button("πŸš— ι‡ζ–°η”Ÿζˆ | regenerate") # regenerate.click(regenerate, inputs=[message], outputs=[chatbot]) regenerate.click(fn=clear_session , inputs=[], outputs=[message, chatbot, state], queue=False) send.click(predict, inputs=[message, top_p, temperature, state], outputs=[message, chatbot, state]) clear.click(lambda: None, None, message, queue=False) clear_history.click(fn=clear_session , inputs=[], outputs=[message, chatbot, state], queue=False) demo.queue(max_size=20, concurrency_count=20).launch()