import ctranslate2 from tokenizers import Tokenizer from huggingface_hub import snapshot_download import threading import gradio as gr from typing import Optional from queue import Queue class TokenIteratorStreamer: def __init__(self, end_token_id: int, timeout: Optional[float] = None): self.end_token_id = end_token_id self.queue = Queue() self.timeout = timeout def put(self, token: int): self.queue.put(token, timeout=self.timeout) def __iter__(self): return self def __next__(self): token = self.queue.get(timeout=self.timeout) if token == self.end_token_id: raise StopIteration() else: return token def generate_prompt(history): prompt = f'[INST] <>\n{system_prompt}\n<>\n\n' for chain in history[-2:-1]: prompt += f"{chain[0]} [/INST] {chain[1]} [INST] " prompt += f"{history[-1][0]} [/INST]" tokens = tokenizer.encode(prompt).tokens return tokens def generate(streamer, history): def stepResultCallback(result): streamer.put(result.token_id) if result.is_last and (result.token_id != end_token_id): streamer.put(end_token_id) print(f"step={result.step}, batch_id={result.batch_id}, token={result.token}") tokens = generate_prompt(history) results = generator.generate_batch( [tokens], include_prompt_in_result=False, max_length = 256, callback = stepResultCallback ) return results system_prompt = """\ You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\ """ model_name = "neongeckocom/Llama-2-7b-chat-hf" model_path = snapshot_download(repo_id=model_name) generator = ctranslate2.Generator(model_path, intra_threads=2) tokenizer = Tokenizer.from_pretrained(model_name) end_token = "" end_token_id = tokenizer.encode(end_token).tokens[0] with gr.Blocks() as demo: chatbot = gr.Chatbot() msg = gr.Textbox() clear = gr.Button("Clear") def user(user_message, history): return "", history + [[user_message, None]] def bot(history): bot_message_tokens = [] streamer = TokenIteratorStreamer(end_token_id = end_token_id) generation_thread = threading.Thread(target=generate, args=(streamer, history)) generation_thread.start() history[-1][1] = "" for token in streamer: bot_message_tokens.append(token) history[-1][1] = tokenizer.decode(bot_message_tokens) yield history generation_thread.join() msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( bot, chatbot, chatbot ) clear.click(lambda: None, None, chatbot, queue=False) demo.queue() if __name__ == "__main__": demo.launch()