Spaces:
Sleeping
Sleeping
from threading import Thread | |
from functools import lru_cache | |
import torch | |
import gradio as gr | |
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, TextIteratorStreamer | |
torch_device = "cuda" if torch.cuda.is_available() else "cpu" | |
# only cache the latest model | |
def get_model_and_tokenizer(model_id): | |
config = AutoConfig.from_pretrained(model_id) | |
if config.is_encoder_decoder: | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_id) | |
else: | |
model = AutoModelForCausalLM.from_pretrained(model_id) | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = model.to(torch_device) | |
return model, tokenizer | |
def run_generation(model_id, user_text, top_p, temperature, top_k, max_new_tokens, history): | |
if history is None: | |
history = [] | |
history.append([user_text, ""]) | |
# Get the model and tokenizer, and tokenize the user text. | |
model, tokenizer = get_model_and_tokenizer(model_id) | |
model_inputs = tokenizer([user_text], return_tensors="pt").to(torch_device) | |
# Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer | |
# in the main thread. | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
model_inputs, | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
top_p=top_p, | |
temperature=temperature, | |
top_k=top_k | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
# Pull the generated text from the streamer, and update the chatbot. | |
for new_text in streamer: | |
history[-1][1] += new_text | |
yield history | |
return history | |
def reset_textbox(): | |
return gr.update(value='') | |
with gr.Blocks( | |
css="""#col_container {width: 1000px; margin-left: auto; margin-right: auto;} | |
#chatbot {height: 520px; overflow: auto;}""" | |
) as demo: | |
with gr.Column(elem_id="col_container"): | |
demo_link = "https://huggingface.co/spaces/joaogante/chatbot_transformers_streaming" | |
gr.Markdown( | |
f""" | |
# ๐ค Transformers Gradio ๐ฅStreaming๐ฅ | |
This demo showcases how to use the streaming feature of ๐ค Transformers with Gradio to generate text in real-time. | |
โ ๏ธ [Duplicate this Space]({demo_link}) if โ ๏ธ | |
- You want to use a large model (> 1GB). Otherwise, this public space will become slow for others ๐ | |
- You want to build your own app, using this demo as a template ๐ | |
- You want to bypass the queue and/or add hardware resources ๐พ | |
""" | |
) | |
model_id = gr.Textbox(value='EleutherAI/pythia-410m', label="๐ค Hub Model repo") | |
chatbot = gr.Chatbot(elem_id='chatbot', label="Message history") | |
user_text = gr.Textbox(placeholder="Is pineapple a pizza topping?", label="Type an input and press Enter") | |
button = gr.Button(value="Clear message history") | |
with gr.Accordion("Generation Parameters", open=False): | |
max_new_tokens = gr.Slider( | |
minimum=1, maximum=1000, value=100, step=1, interactive=True, label="Max New Tokens", | |
) | |
top_p = gr.Slider( | |
minimum=0, maximum=1.0, value=1.0, step=0.05, interactive=True, label="Top-p (nucleus sampling)", | |
) | |
temperature = gr.Slider( | |
minimum=0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Temperature (set to 0 for Greedy Decoding)", | |
) | |
top_k = gr.Slider( | |
minimum=1, maximum=50, value=50, step=1, interactive=True, label="Top-k", | |
) | |
user_text.submit( | |
run_generation, | |
[model_id, user_text, top_p, temperature, top_k, max_new_tokens, chatbot], | |
chatbot | |
) | |
button.click(reset_textbox, [], [user_text]) | |
demo.queue(max_size=32).launch() | |