streaming generate

#2
by weege007 - opened
from transformers.generation.streamers import BaseStreamer

class TokenStreamer(BaseStreamer):
    def __init__(self, skip_prompt: bool = False, timeout=None):
        self.skip_prompt = skip_prompt

        # variables used in the streaming process
        self.token_queue = Queue()
        self.stop_signal = None
        self.next_tokens_are_prompt = True
        self.timeout = timeout

    def put(self, value):
        if len(value.shape) > 1 and value.shape[0] > 1:
            raise ValueError("TextStreamer only supports batch size 1")
        elif len(value.shape) > 1:
            value = value[0]

        if self.skip_prompt and self.next_tokens_are_prompt:
            self.next_tokens_are_prompt = False
            return

        for token in value.tolist():
            self.token_queue.put(token)

    def end(self):
        self.token_queue.put(self.stop_signal)

    def __iter__(self):
        return self

    def __next__(self):
        value = self.token_queue.get(timeout=self.timeout)
        if value == self.stop_signal:
            raise StopIteration()
        else:
            return value

#TTS start!
with torch.no_grad():
 
    formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"

    # Tokenize the text
    chat = [
        {"role": "user", "content": "Convert the text to speech:" + formatted_text},
        {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"}
    ]

    input_ids = tokenizer.apply_chat_template(
        chat, 
        tokenize=True, 
        return_tensors='pt', 
        continue_final_message=True
    )
    input_ids = input_ids.to('cuda')
    speech_end_id = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>')
    streamer = TokenStreamer(skip_prompt=True)
    generation_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        max_length=2048,  # We trained our model with a max length of 2048
        eos_token_id= speech_end_id ,
        do_sample=True,    
        top_p=1,           #  Adjusts the diversity of generated content
        temperature=0.8,   #  Controls randomness in output
    )
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    i = 0
    batch_size = 60
    generated_ids=[]
    j=0
    for token_id in streamer:
        print(token_id, end=',', flush=True)
        generated_ids.append(token_id)
        if i>0 and i % batch_size == 0:
            #print(generated_ids)
            speech_tokens = tokenizer.batch_decode(torch.tensor(generated_ids).cuda(), skip_special_tokens=True)
            # Convert  token <|s_23456|> to int 23456 
            speech_tokens = extract_speech_ids(speech_tokens)
            speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)
            # Decode the speech tokens to speech waveform
            gen_wav = Codec_model.decode_code(speech_tokens)
            sf.write(f"gen_{j}.wav", gen_wav[0, 0, :].cpu().numpy(), 16000)
            generated_ids=[]
            j+=1
        i += 1
        #yield token_id
    if len(generated_ids)>0:
        speech_tokens = tokenizer.batch_decode(torch.tensor(generated_ids).cuda(), skip_special_tokens=True)
        # Convert  token <|s_23456|> to int 23456 
        speech_tokens = extract_speech_ids(speech_tokens)
        speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)
        # Decode the speech tokens to speech waveform
        gen_wav = Codec_model.decode_code(speech_tokens)
        sf.write(f"gen_{j}.wav", gen_wav[0, 0, :].cpu().numpy(), 16000)

colab 笔记:https://github.com/weedge/doraemon-nb/blob/main/LLaSA.ipynb

HKUST Audio org

Thank you for sharing on Colab! It’s very well-written and helpful!

Sign up or log in to comment