import random import time import os import gradio as gr from text_generation import Client from conversation import get_default_conv_template endpoint_url = os.environ.get("ENDPOINT_URL") client = Client(endpoint_url, timeout=120) eos_token = "" def generate_response(history, max_new_token=512, top_p=0.9, temperature=0.8, do_sample=True): conv = get_default_conv_template("vicuna").copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # map human to USER and gpt to ASSISTANT for user, bot in history: conv.append_message(roles['human'], user) conv.append_message(roles["gpt"], bot) msg = conv.get_prompt() for response in client.generate_stream( msg, max_new_tokens=max_new_token, top_p=top_p, temperature=temperature, do_sample=do_sample, ): if not response.token.special: yield response.token.text # res = client.generate( # msg, # stop_sequences=["<|assistant|>", eos_token, "<|system|>", "<|user|>"], # max_new_tokens=max_new_token, # top_p=top_p, # top_k=top_k, # do_sample=do_sample, # temperature=temperature, # repetition_penalty=repetition_penalty, # ) # return [("assistant", res.generated_text)] # with gr.Blocks() as demo: chatbot = gr.Chatbot() msg = gr.Textbox() clear = gr.Button("Clear") def user(user_message, history): return "", history + [[user_message, None]] def bot(history): # history = list of [[user_message, bot_message], ...] import ipdb ipdb.set_trace() bot_message = random.choice(["How are you?", "I love you", "I'm very hungry"]) history[-1][1] = "" for character in bot_message: history[-1][1] += character time.sleep(0.05) yield history msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( generate_response, chatbot, chatbot ) clear.click(lambda: None, None, chatbot, queue=False) demo.queue() demo.launch()