#!/usr/bin/python3 # -*- coding: utf-8 -*- from typing import List, Tuple from threading import Thread import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.generation.streamers import TextIteratorStreamer import torch from project_settings import project_path def greet(question: str, history: List[Tuple[str, str]]): answer = "Hello " + question + "!" result = history + [(question, answer)] return result model_map: dict = dict() def init_model(pretrained_model_name_or_path: str): device: str = "cuda" if torch.cuda.is_available() else "cpu" global model_map if pretrained_model_name_or_path not in model_map.keys(): # clear for k1, v1 in model_map.items(): for k2, v2 in v1.items(): del v2 model_map = dict() # build model model = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, device_map="auto", offload_folder="./offload", offload_state_dict=True, # load_in_4bit=True, ) if model.config.model_type == "chatglm": model = model.eval() else: model = model.to(device) model = model.bfloat16().eval() tokenizer = AutoTokenizer.from_pretrained( pretrained_model_name_or_path, trust_remote_code=True, # llama不支持fast use_fast=False if model.config.model_type == "llama" else True, padding_side="left" ) # QWenTokenizer比较特殊, pad_token_id, bos_token_id, eos_token_id 均 为None. eod_id对应的token为<|endoftext|> if tokenizer.__class__.__name__ == "QWenTokenizer": tokenizer.pad_token_id = tokenizer.eod_id tokenizer.bos_token_id = tokenizer.eod_id tokenizer.eos_token_id = tokenizer.eod_id model_map[pretrained_model_name_or_path] = { "model": model, "tokenizer": tokenizer, } else: model = model_map[pretrained_model_name_or_path]["model"] tokenizer = model_map[pretrained_model_name_or_path]["tokenizer"] return model, tokenizer def chat_with_llm_non_stream(question: str, history: List[Tuple[str, str]], pretrained_model_name_or_path: str, max_new_tokens: int, top_p: float, temperature: float, repetition_penalty: float, history_max_len: int, ): device: str = "cuda" if torch.cuda.is_available() else "cpu" model, tokenizer = init_model(pretrained_model_name_or_path) # input_ids if model.config.model_type == "chatglm": input_ids = [] else: input_ids = [tokenizer.bos_token_id] # history utterances = list() for idx, (h_question, h_answer) in enumerate(history): if model.config.model_type == "chatglm": h_question = "[Round {}]\n\n问:{}\n\n答:".format(idx, h_question) utterances.append(h_question) utterances.append(h_answer) utterances.append(question) encoded_utterances = tokenizer.__call__(utterances, add_special_tokens=False) encoded_utterances = encoded_utterances["input_ids"] for encoded_utterance in encoded_utterances: input_ids.extend(encoded_utterance) if model.config.model_type == "chatglm": input_ids.append(tokenizer.eos_token_id) input_ids = torch.tensor([input_ids], dtype=torch.long) input_ids = input_ids[:, -history_max_len:].to(device) with torch.no_grad(): outputs = model.generate( input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty, eos_token_id=tokenizer.eos_token_id ) outputs = outputs.tolist()[0][len(input_ids[0]):] answer = tokenizer.decode(outputs) answer = answer.strip().replace(tokenizer.eos_token, "").strip() result = history + [(question, answer)] return result def chat_with_llm_streaming(question: str, history: List[Tuple[str, str]], pretrained_model_name_or_path: str, max_new_tokens: int, top_p: float, temperature: float, repetition_penalty: float, history_max_len: int, ): device: str = "cuda" if torch.cuda.is_available() else "cpu" model, tokenizer = init_model(pretrained_model_name_or_path) # input_ids if model.config.model_type == "chatglm": input_ids = [] else: input_ids = [tokenizer.bos_token_id] # history utterances = list() for idx, (h_question, h_answer) in enumerate(history): if model.config.model_type == "chatglm": h_question = "[Round {}]\n\n问:{}\n\n答:".format(idx, h_question) utterances.append(h_question) utterances.append(h_answer) utterances.append(question) encoded_utterances = tokenizer.__call__(utterances, add_special_tokens=False) encoded_utterances = encoded_utterances["input_ids"] for encoded_utterance in encoded_utterances: input_ids.extend(encoded_utterance) if model.config.model_type == "chatglm": input_ids.append(tokenizer.eos_token_id) input_ids = torch.tensor([input_ids], dtype=torch.long) input_ids = input_ids[:, -history_max_len:].to(device) streamer = TextIteratorStreamer(tokenizer=tokenizer) generation_kwargs = dict( inputs=input_ids, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, streamer=streamer, ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() answer = "" for output_ in streamer: output_ = output_.replace(question, "") output_ = output_.replace(tokenizer.eos_token, "") answer += output_ result = [(question, answer)] yield history + result def main(): description = """ chat llm """ with gr.Blocks() as blocks: gr.Markdown(value=description) chatbot = gr.Chatbot([], elem_id="chatbot", height=400) with gr.Row(): with gr.Column(scale=4): text_box = gr.Textbox(show_label=False, placeholder="Enter text and press enter", container=False) with gr.Column(scale=1): submit_button = gr.Button("💬Submit") with gr.Column(scale=1): clear_button = gr.Button( '🗑️Clear', variant='secondary', ) with gr.Row(): with gr.Column(scale=1): max_new_tokens = gr.Slider(minimum=0, maximum=512, value=512, step=1, label="max_new_tokens") with gr.Column(scale=1): top_p = gr.Slider(minimum=0, maximum=1, value=0.85, step=0.01, label="top_p") with gr.Column(scale=1): temperature = gr.Slider(minimum=0, maximum=1, value=0.35, step=0.01, label="temperature") with gr.Column(scale=1): repetition_penalty = gr.Slider(minimum=0, maximum=2, value=1.2, step=0.01, label="repetition_penalty") with gr.Column(scale=1): history_max_len = gr.Slider(minimum=0, maximum=4096, value=1024, step=1, label="history_max_len") with gr.Row(): with gr.Column(scale=1): model_name = gr.Dropdown( choices=[ "Qwen/Qwen-7B-Chat", "THUDM/chatglm2-6b", "baichuan-inc/Baichuan2-7B-Chat", ], value="Qwen/Qwen-7B-Chat", label="model_name", ) gr.Examples(examples=["你好"], inputs=text_box) inputs = [ text_box, chatbot, model_name, max_new_tokens, top_p, temperature, repetition_penalty, history_max_len ] outputs = [ chatbot ] text_box.submit(chat_with_llm_streaming, inputs, outputs) submit_button.click(chat_with_llm_streaming, inputs, outputs) clear_button.click( fn=lambda: ('', ''), outputs=[text_box, chatbot], queue=False, api_name=False, ) blocks.queue().launch() return if __name__ == '__main__': main()