Spaces:
Sleeping
Sleeping
from transformers import TextIteratorStreamer | |
from threading import Thread | |
import gradio as gr | |
MAX_INPUT_TOKEN_LENGTH = 4096 | |
def generate(message, chat_history): | |
# Step 1: pre-process the inputs | |
conversation = [] | |
for user, assistant in chat_history: | |
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]) | |
conversation.append({"role": "user", "content": message}) | |
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt") | |
# in-case our inputs exceed the maximum length, we might need to cut them | |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") | |
input_ids = input_ids.to(model.device) | |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
# Step 2: define generation arguments | |
generate_kwargs = dict( | |
{"input_ids": input_ids}, | |
streamer=streamer, | |
max_new_tokens=1024, | |
do_sample=True, | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
# Step 3: generate and stream outputs | |
outputs = "" | |
for text in streamer: | |
outputs += text | |
yield outputs | |
chat_interface = gr.ChatInterface(generate) | |
chat_interface.queue().launch(share=True) | |