joaogante's picture
joaogante HF staff
limit to pythia 6.9b
1be532a
raw
history blame
No virus
3.75 kB
from threading import Thread
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", torch_device)
print("CPU threads:", torch.get_num_threads())
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-6.9b-deduped", load_in_8bit=True, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-6.9b-deduped")
def run_generation(user_text, top_p, temperature, top_k, max_new_tokens, history):
if history is None:
history = []
history.append([user_text, ""])
# Get the model and tokenizer, and tokenize the user text.
model_inputs = tokenizer([user_text], return_tensors="pt").to(torch_device)
# Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
# in the main thread.
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
temperature=temperature,
top_k=top_k
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# Pull the generated text from the streamer, and update the chatbot.
for new_text in streamer:
history[-1][1] += new_text
yield history
return history
def reset_textbox():
return gr.update(value='')
with gr.Blocks(
css="""#col_container {width: 1000px; margin-left: auto; margin-right: auto;}
#chatbot {height: 520px; overflow: auto;}"""
) as demo:
with gr.Column(elem_id="col_container"):
duplicate_link = "https://huggingface.co/spaces/joaogante/chatbot_transformers_streaming?duplicate=true"
gr.Markdown(
"# 🤗 Transformers Gradio 🔥Streaming🔥\n"
"This demo showcases the use of the "
"[streaming feature](https://huggingface.co/docs/transformers/main/en/generation_strategies#streaming) "
"of 🤗 Transformers with Gradio to generate text in real-time. It uses "
"[EleutherAI/pythia-6.9b-deduped](https://huggingface.co/EleutherAI/pythia-6.9b-deduped), "
"a 6.9B parameter GPT-NeoX model by EleutherAI, loaded in 8-bit quantized form.\n\n"
f"Feel free to [duplicate this Space]({duplicate_link}) to try your own models or to use this space as a "
"template! 💛"
)
chatbot = gr.Chatbot(elem_id='chatbot', label="Message history")
user_text = gr.Textbox(placeholder="Is pineapple a pizza topping?", label="Type an input and press Enter")
button = gr.Button(value="Clear message history")
with gr.Accordion("Generation Parameters", open=False):
max_new_tokens = gr.Slider(
minimum=1, maximum=1000, value=100, step=1, interactive=True, label="Max New Tokens",
)
top_p = gr.Slider(
minimum=0, maximum=1.0, value=1.0, step=0.05, interactive=True, label="Top-p (nucleus sampling)",
)
temperature = gr.Slider(
minimum=0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Temperature (set to 0 for Greedy Decoding)",
)
top_k = gr.Slider(
minimum=1, maximum=50, value=50, step=1, interactive=True, label="Top-k",
)
user_text.submit(
run_generation,
[user_text, top_p, temperature, top_k, max_new_tokens, chatbot],
chatbot
)
button.click(reset_textbox, [], [user_text])
demo.queue(max_size=32).launch()