Spaces:
Sleeping
Sleeping
from threading import Thread | |
import torch | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
torch_device = "cuda" if torch.cuda.is_available() else "cpu" | |
print("Running on device:", torch_device) | |
print("CPU threads:", torch.get_num_threads()) | |
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-6.9b-deduped", load_in_8bit=True, device_map="auto") | |
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-6.9b-deduped") | |
def run_generation(user_text, top_p, temperature, top_k, max_new_tokens, history): | |
if history is None: | |
history = [] | |
history.append([user_text, ""]) | |
# Get the model and tokenizer, and tokenize the user text. | |
model_inputs = tokenizer([user_text], return_tensors="pt").to(torch_device) | |
# Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer | |
# in the main thread. | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
model_inputs, | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
top_p=top_p, | |
temperature=temperature, | |
top_k=top_k | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
# Pull the generated text from the streamer, and update the chatbot. | |
for new_text in streamer: | |
history[-1][1] += new_text | |
yield history | |
return history | |
def reset_textbox(): | |
return gr.update(value='') | |
with gr.Blocks( | |
css="""#col_container {width: 1000px; margin-left: auto; margin-right: auto;} | |
#chatbot {height: 520px; overflow: auto;}""" | |
) as demo: | |
with gr.Column(elem_id="col_container"): | |
duplicate_link = "https://huggingface.co/spaces/joaogante/chatbot_transformers_streaming?duplicate=true" | |
gr.Markdown( | |
"# 🤗 Transformers Gradio 🔥Streaming🔥\n" | |
"This demo showcases the use of the " | |
"[streaming feature](https://huggingface.co/docs/transformers/main/en/generation_strategies#streaming) " | |
"of 🤗 Transformers with Gradio to generate text in real-time. It uses " | |
"[EleutherAI/pythia-6.9b-deduped](https://huggingface.co/EleutherAI/pythia-6.9b-deduped), " | |
"a 6.9B parameter GPT-NeoX model by EleutherAI, loaded in 8-bit quantized form.\n\n" | |
f"Feel free to [duplicate this Space]({duplicate_link}) to try your own models or to use this space as a " | |
"template! 💛" | |
) | |
chatbot = gr.Chatbot(elem_id='chatbot', label="Message history") | |
user_text = gr.Textbox(placeholder="Is pineapple a pizza topping?", label="Type an input and press Enter") | |
button = gr.Button(value="Clear message history") | |
with gr.Accordion("Generation Parameters", open=False): | |
max_new_tokens = gr.Slider( | |
minimum=1, maximum=1000, value=100, step=1, interactive=True, label="Max New Tokens", | |
) | |
top_p = gr.Slider( | |
minimum=0, maximum=1.0, value=1.0, step=0.05, interactive=True, label="Top-p (nucleus sampling)", | |
) | |
temperature = gr.Slider( | |
minimum=0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Temperature (set to 0 for Greedy Decoding)", | |
) | |
top_k = gr.Slider( | |
minimum=1, maximum=50, value=50, step=1, interactive=True, label="Top-k", | |
) | |
user_text.submit( | |
run_generation, | |
[user_text, top_p, temperature, top_k, max_new_tokens, chatbot], | |
chatbot | |
) | |
button.click(reset_textbox, [], [user_text]) | |
demo.queue(max_size=32).launch() | |