import ctranslate2 from transformers import AutoTokenizer import threading import gradio as gr from typing import Optional from queue import Queue class TokenIteratorStreamer: def __init__(self, end_token_id: int, timeout: Optional[float] = None): self.end_token_id = end_token_id self.queue = Queue() self.timeout = timeout def put(self, token: int): self.queue.put(token, timeout=self.timeout) def __iter__(self): return self def __next__(self): token = self.queue.get(timeout=self.timeout) if token == self.end_token_id: raise StopIteration() else: return token def generate_prompt(history): prompt = "" for chain in history[:-1]: prompt += f": {chain[0]}\n: {chain[1]}\n" prompt += f": {history[-1][0]}\n:" tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(prompt)) return tokens def generate(streamer, history): def stepResultCallback(result): streamer.put(result.token_id) if result.is_last and (result.token_id != end_token_id): streamer.put(end_token_id) print(f"step={result.step}, batch_id={result.batch_id}, token={result.token}") tokens = generate_prompt(history) results = translator.translate_batch( [tokens], beam_size=1, max_decoding_length = 256, repetition_penalty = 1.8, callback = stepResultCallback ) return results translator = ctranslate2.Translator("model", intra_threads=2) tokenizer = AutoTokenizer.from_pretrained("DKYoon/mt5-xl-lm-adapt") end_token = "" end_token_id = tokenizer.encode(end_token)[0] with gr.Blocks() as demo: chatbot = gr.Chatbot() msg = gr.Textbox() clear = gr.Button("Clear") def user(user_message, history): return "", history + [[user_message, ""]] def bot(history): bot_message_tokens = [] streamer = TokenIteratorStreamer(end_token_id = end_token_id) generation_thread = threading.Thread(target=generate, args=(streamer, history)) generation_thread.start() for token in streamer: bot_message_tokens.append(token) history[-1][1] = tokenizer.decode(bot_message_tokens) yield history generation_thread.join() msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( bot, chatbot, chatbot ) clear.click(lambda: None, None, chatbot, queue=False) demo.queue() if __name__ == "__main__": demo.launch()