import os import gradio as gr import torch from transformers import AutoModel, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer from threading import Thread MODEL_PATH = os.environ.get('MODEL_PATH', "ClueAI/ChatYuan-large-v2") TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH) tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True) # model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True).float() model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval() class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: stop_ids = [0, 2] for stop_id in stop_ids: if input_ids[0][-1] == stop_id: return True return False def parse_text(text): lines = text.split("\n") lines = [line for line in lines if line != ""] count = 0 for i, line in enumerate(lines): if "```" in line: count += 1 items = line.split('`') if count % 2 == 1: lines[i] = f'
'
            else:
                lines[i] = f'
' else: if i > 0: if count % 2 == 1: line = line.replace("`", "\`") line = line.replace("<", "<") line = line.replace(">", ">") line = line.replace(" ", " ") line = line.replace("*", "*") line = line.replace("_", "_") line = line.replace("-", "-") line = line.replace(".", ".") line = line.replace("!", "!") line = line.replace("(", "(") line = line.replace(")", ")") line = line.replace("$", "$") lines[i] = "
" + line text = "".join(lines) return text def predict(history, max_length, top_p, temperature): stop = StopOnTokens() messages = [] for idx, (user_msg, model_msg) in enumerate(history): if idx == len(history) - 1 and not model_msg: messages.append({"role": "user", "content": user_msg}) break if user_msg: messages.append({"role": "user", "content": user_msg}) if model_msg: messages.append({"role": "assistant", "content": model_msg}) print("\n\n====conversation====\n", messages) model_inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_tensors="pt").to(next(model.parameters()).device) streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True) generate_kwargs = { "input_ids": model_inputs, "streamer": streamer, "max_new_tokens": max_length, "do_sample": True, "top_p": top_p, "temperature": temperature, "stopping_criteria": StoppingCriteriaList([stop]), "repetition_penalty": 1.2, } t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() for new_token in streamer: if new_token != '': history[-1][1] += new_token yield history with gr.Blocks() as demo: gr.HTML("""

ChatGLGradio Simple Demo

""") chatbot = gr.Chatbot() with gr.Row(): with gr.Column(scale=4): with gr.Column(scale=12): user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10, container=False) with gr.Column(min_width=32, scale=1): submitBtn = gr.Button("Submit") with gr.Column(scale=1): emptyBtn = gr.Button("Clear History") max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True) top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True) temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True) def user(query, history): return "", history + [[parse_text(query), ""]] submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then( predict, [chatbot, max_length, top_p, temperature], chatbot ) emptyBtn.click(lambda: None, None, chatbot, queue=False) demo.queue() demo.launch(share=True)