from threading import Thread import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer model_id = "declare-lab/flan-alpaca-xl" torch_device = "cuda" if torch.cuda.is_available() else "cpu" print("Running on device:", torch_device) print("CPU threads:", torch.get_num_threads()) model = AutoModelForSeq2SeqLM.from_pretrained(model_id, load_in_8bit=True, device_map="auto") tokenizer = AutoTokenizer.from_pretrained(model_id) def run_generation(user_text, top_p, temperature, top_k, max_new_tokens, use_history, history): if history is None: history = [] history.append([user_text, ""]) # Get the model and tokenizer, and tokenize the user text. If `use_history` is True, we use the chatbot history if use_history: user_name, assistant_name, sep = "User: ", "Assistant: ", "\n" past = [] for data in history: user_data, model_data = data if not user_data.startswith(user_name): user_data = user_name + user_data if not model_data.startswith(sep + assistant_name): model_data = sep + assistant_name + model_data past.append(user_data + model_data.rstrip() + sep) text_input = "".join(past) else: text_input = user_text model_inputs = tokenizer([text_input], return_tensors="pt").to(torch_device) # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer # in the main thread. streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( model_inputs, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, temperature=temperature, top_k=top_k ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() # Pull the generated text from the streamer, and update the chatbot. for new_text in streamer: history[-1][1] += new_text yield history return history def reset_textbox(): return gr.update(value='') with gr.Blocks( css="""#col_container {width: 1000px; margin-left: auto; margin-right: auto;} #chatbot {height: 520px; overflow: auto;}""" ) as demo: with gr.Column(elem_id="col_container"): duplicate_link = "https://huggingface.co/spaces/joaogante/chatbot_transformers_streaming?duplicate=true" gr.Markdown( "# 🤗 Transformers 🔥Streaming🔥 on Gradio\n" "This demo showcases the use of the " "[streaming feature](https://huggingface.co/docs/transformers/main/en/generation_strategies#streaming) " "of 🤗 Transformers with Gradio to generate text in real-time, as a chatbot. It uses " f"[{model_id}](https://huggingface.co/{model_id}), " "loaded in 8-bit quantized form.\n\n" f"Feel free to [duplicate this Space]({duplicate_link}) to try your own models or use this space as a " "template! 💛" ) chatbot = gr.Chatbot(elem_id='chatbot', label="Chat history") user_text = gr.Textbox( placeholder="Write an email about an alpaca that likes flan", label="Type an input and press Enter" ) with gr.Row(): button_submit = gr.Button(value="Submit") button_clear = gr.Button(value="Clear chat history") with gr.Accordion("Generation Parameters", open=False): use_history = gr.Checkbox(value=False, label="Use chat history as prompt") max_new_tokens = gr.Slider( minimum=1, maximum=1000, value=250, step=1, interactive=True, label="Max New Tokens", ) top_p = gr.Slider( minimum=0, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top-p (nucleus sampling)", ) temperature = gr.Slider( minimum=0, maximum=5.0, value=0.8, step=0.1, interactive=True, label="Temperature (set to 0 for Greedy Decoding)", ) top_k = gr.Slider( minimum=1, maximum=50, value=50, step=1, interactive=True, label="Top-k", ) user_text.submit( run_generation, [user_text, top_p, temperature, top_k, max_new_tokens, use_history, chatbot], chatbot ) button_submit.click( run_generation, [user_text, top_p, temperature, top_k, max_new_tokens, use_history, chatbot], chatbot ) button_clear.click(reset_textbox, [], [chatbot]) demo.queue(max_size=32).launch()