streaming generate
#2
by
weege007
- opened
from transformers.generation.streamers import BaseStreamer
class TokenStreamer(BaseStreamer):
def __init__(self, skip_prompt: bool = False, timeout=None):
self.skip_prompt = skip_prompt
# variables used in the streaming process
self.token_queue = Queue()
self.stop_signal = None
self.next_tokens_are_prompt = True
self.timeout = timeout
def put(self, value):
if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError("TextStreamer only supports batch size 1")
elif len(value.shape) > 1:
value = value[0]
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
return
for token in value.tolist():
self.token_queue.put(token)
def end(self):
self.token_queue.put(self.stop_signal)
def __iter__(self):
return self
def __next__(self):
value = self.token_queue.get(timeout=self.timeout)
if value == self.stop_signal:
raise StopIteration()
else:
return value
#TTS start!
with torch.no_grad():
formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
# Tokenize the text
chat = [
{"role": "user", "content": "Convert the text to speech:" + formatted_text},
{"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"}
]
input_ids = tokenizer.apply_chat_template(
chat,
tokenize=True,
return_tensors='pt',
continue_final_message=True
)
input_ids = input_ids.to('cuda')
speech_end_id = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>')
streamer = TokenStreamer(skip_prompt=True)
generation_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
max_length=2048, # We trained our model with a max length of 2048
eos_token_id= speech_end_id ,
do_sample=True,
top_p=1, # Adjusts the diversity of generated content
temperature=0.8, # Controls randomness in output
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
i = 0
batch_size = 60
generated_ids=[]
j=0
for token_id in streamer:
print(token_id, end=',', flush=True)
generated_ids.append(token_id)
if i>0 and i % batch_size == 0:
#print(generated_ids)
speech_tokens = tokenizer.batch_decode(torch.tensor(generated_ids).cuda(), skip_special_tokens=True)
# Convert token <|s_23456|> to int 23456
speech_tokens = extract_speech_ids(speech_tokens)
speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)
# Decode the speech tokens to speech waveform
gen_wav = Codec_model.decode_code(speech_tokens)
sf.write(f"gen_{j}.wav", gen_wav[0, 0, :].cpu().numpy(), 16000)
generated_ids=[]
j+=1
i += 1
#yield token_id
if len(generated_ids)>0:
speech_tokens = tokenizer.batch_decode(torch.tensor(generated_ids).cuda(), skip_special_tokens=True)
# Convert token <|s_23456|> to int 23456
speech_tokens = extract_speech_ids(speech_tokens)
speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)
# Decode the speech tokens to speech waveform
gen_wav = Codec_model.decode_code(speech_tokens)
sf.write(f"gen_{j}.wav", gen_wav[0, 0, :].cpu().numpy(), 16000)
colab 笔记:https://github.com/weedge/doraemon-nb/blob/main/LLaSA.ipynb
Thank you for sharing on Colab! It’s very well-written and helpful!