from transformers import TextIteratorStreamer from threading import Thread import gradio as gr MAX_INPUT_TOKEN_LENGTH = 4096 def generate(message, chat_history): # Step 1: pre-process the inputs conversation = [] for user, assistant in chat_history: conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]) conversation.append({"role": "user", "content": message}) input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt") # in-case our inputs exceed the maximum length, we might need to cut them if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") input_ids = input_ids.to(model.device) streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) # Step 2: define generation arguments generate_kwargs = dict( {"input_ids": input_ids}, streamer=streamer, max_new_tokens=1024, do_sample=True, ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() # Step 3: generate and stream outputs outputs = "" for text in streamer: outputs += text yield outputs chat_interface = gr.ChatInterface(generate) chat_interface.queue().launch(share=True)