|
from typing import List |
|
from queue import Queue |
|
|
|
|
|
def build_chat_input(tokenizer, messages: List[dict]): |
|
|
|
|
|
|
|
|
|
prompt = "<s>" |
|
for msg in messages: |
|
role = msg["role"] |
|
message = msg["content"] |
|
if message is None : |
|
continue |
|
if role == "user": |
|
prompt += "Human: " + message + "\n\nAssistant: </s>" |
|
if role == "assistant": |
|
prompt += message + "</s>" |
|
|
|
input_tokens = tokenizer.encode(prompt) |
|
return input_tokens |
|
|
|
|
|
class TextIterStreamer: |
|
def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False): |
|
self.tokenizer = tokenizer |
|
self.skip_prompt = skip_prompt |
|
self.skip_special_tokens = skip_special_tokens |
|
self.tokens = [] |
|
self.text_queue = Queue() |
|
self.next_tokens_are_prompt = True |
|
|
|
def put(self, value): |
|
if self.skip_prompt and self.next_tokens_are_prompt: |
|
self.next_tokens_are_prompt = False |
|
else: |
|
if len(value.shape) > 1: |
|
value = value[0] |
|
self.tokens.extend(value.tolist()) |
|
self.text_queue.put( |
|
self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens)) |
|
|
|
def end(self): |
|
self.text_queue.put(None) |
|
|
|
def __iter__(self): |
|
return self |
|
|
|
def __next__(self): |
|
value = self.text_queue.get() |
|
if value is None: |
|
raise StopIteration() |
|
else: |
|
return value |
|
|