| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig |
|
|
| from utils import compute_token_num, load_audio, log_mel_spectrogram, padding_mels |
|
|
|
|
| class StepAudio2Base: |
|
|
| def __init__(self, model_path: str): |
| self.llm_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side="right") |
| self.llm = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16).cuda() |
| self.eos_token_id = self.llm_tokenizer.eos_token_id |
|
|
| def __call__(self, messages: list, **kwargs): |
| messages, mels = self.apply_chat_template(messages) |
|
|
| |
| prompt_ids = [] |
| for msg in messages: |
| if isinstance(msg, str): |
| prompt_ids.append(self.llm_tokenizer(text=msg, return_tensors="pt", padding=True)["input_ids"]) |
| elif isinstance(msg, list): |
| prompt_ids.append(torch.tensor([msg], dtype=torch.int32)) |
| else: |
| raise ValueError(f"Unsupported content type: {type(msg)}") |
| prompt_ids = torch.cat(prompt_ids, dim=-1).cuda() |
| attention_mask = torch.ones_like(prompt_ids) |
|
|
| |
| |
| if len(mels)==0: |
| mels = None |
| mel_lengths = None |
| else: |
| mels, mel_lengths = padding_mels(mels) |
| mels = mels.cuda() |
| mel_lengths = mel_lengths.cuda() |
|
|
| generate_inputs = { |
| "input_ids": prompt_ids, |
| "wavs": mels, |
| "wav_lens": mel_lengths, |
| "attention_mask":attention_mask |
| } |
|
|
| generation_config = dict(max_new_tokens=2048, |
| pad_token_id=self.llm_tokenizer.pad_token_id, |
| eos_token_id=self.eos_token_id, |
| ) |
| generation_config.update(kwargs) |
| generation_config = GenerationConfig(**generation_config) |
|
|
| outputs = self.llm.generate(**generate_inputs, generation_config=generation_config) |
| output_token_ids = outputs[0, prompt_ids.shape[-1] : -1].tolist() |
| output_text_tokens = [i for i in output_token_ids if i < 151688] |
| output_audio_tokens = [i - 151696 for i in output_token_ids if i > 151695] |
| output_text = self.llm_tokenizer.decode(output_text_tokens) |
| return output_token_ids, output_text, output_audio_tokens |
|
|
| def apply_chat_template(self, messages: list): |
| results = [] |
| mels = [] |
| for msg in messages: |
| content = msg |
| if isinstance(content, str): |
| text_with_audio = content |
| results.append(text_with_audio) |
| elif isinstance(content, dict): |
| if content["type"] == "text": |
| results.append(f"{content['text']}") |
| elif content["type"] == "audio": |
| audio = load_audio(content['audio']) |
| for i in range(0, audio.shape[0], 16000 * 25): |
| mel = log_mel_spectrogram(audio[i:i+16000*25], n_mels=128, padding=479) |
| mels.append(mel) |
| audio_tokens = "<audio_patch>" * compute_token_num(mel.shape[1]) |
| results.append(f"<audio_start>{audio_tokens}<audio_end>") |
| elif content["type"] == "token": |
| results.append(content["token"]) |
| else: |
| raise ValueError(f"Unsupported content type: {type(content)}") |
| |
| return results, mels |
|
|
|
|
| class StepAudio2(StepAudio2Base): |
|
|
| def __init__(self, model_path: str): |
| super().__init__(model_path) |
| self.llm_tokenizer.eos_token = "<|EOT|>" |
| self.llm.config.eos_token_id = self.llm_tokenizer.convert_tokens_to_ids("<|EOT|>") |
| self.eos_token_id = self.llm_tokenizer.convert_tokens_to_ids("<|EOT|>") |
|
|
| def apply_chat_template(self, messages: list): |
| results = [] |
| mels = [] |
| for msg in messages: |
| role = msg["role"] |
| content = msg["content"] |
| if role == "user": |
| role = "human" |
| if isinstance(content, str): |
| text_with_audio = f"<|BOT|>{role}\n{content}" |
| text_with_audio += '<|EOT|>' if msg.get('eot', True) else '' |
| results.append(text_with_audio) |
| elif isinstance(content, list): |
| results.append(f"<|BOT|>{role}\n") |
| for item in content: |
| if item["type"] == "text": |
| results.append(f"{item['text']}") |
| elif item["type"] == "audio": |
| audio = load_audio(item['audio']) |
| for i in range(0, audio.shape[0], 16000 * 25): |
| mel = log_mel_spectrogram(audio[i:i+16000*25], n_mels=128, padding=479) |
| mels.append(mel) |
| audio_tokens = "<audio_patch>" * compute_token_num(mel.shape[1]) |
| results.append(f"<audio_start>{audio_tokens}<audio_end>") |
| elif item["type"] == "token": |
| results.append(item["token"]) |
| if msg.get('eot', True): |
| results.append('<|EOT|>') |
| elif content is None: |
| results.append(f"<|BOT|>{role}\n") |
| else: |
| raise ValueError(f"Unsupported content type: {type(content)}") |
| |
| return results, mels |
|
|
| if __name__ == '__main__': |
| from token2wav import Token2wav |
|
|
| model = StepAudio2('Step-Audio-2-mini') |
| token2wav = Token2wav('Step-Audio-2-mini/token2wav') |
|
|
| |
| print() |
| messages = [ |
| {"role": "system", "content": "You are a helpful assistant."}, |
| {"role": "human", "content": "Give me a brief introduction to the Great Wall."}, |
| {"role": "assistant", "content": None} |
| ] |
| tokens, text, _ = model(messages, max_new_tokens=256, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True) |
| print(text) |
|
|
| |
| print() |
| messages = [ |
| {"role": "system", "content": "You are a helpful assistant."}, |
| {"role": "human", "content": "Give me a brief introduction to the Great Wall."}, |
| {"role": "assistant", "content": "<tts_start>", "eot": False}, |
| ] |
| tokens, text, audio = model(messages, max_new_tokens=4096, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True) |
| print(text) |
| print(tokens) |
| audio = token2wav(audio, prompt_wav='assets/default_male.wav') |
| with open('output-male.wav', 'wb') as f: |
| f.write(audio) |
|
|
| |
| print() |
| messages = [ |
| {"role": "system", "content": "You are a helpful assistant."}, |
| {"role": "human", "content": [{"type": "audio", "audio": "assets/give_me_a_brief_introduction_to_the_great_wall.wav"}]}, |
| {"role": "assistant", "content": None} |
| ] |
| tokens, text, _ = model(messages, max_new_tokens=256, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True) |
| print(text) |
|
|
| |
| print() |
| messages = [ |
| {"role": "system", "content": "You are a helpful assistant."}, |
| {"role": "human", "content": [{"type": "audio", "audio": "assets/give_me_a_brief_introduction_to_the_great_wall.wav"}]}, |
| {"role": "assistant", "content": "<tts_start>", "eot": False}, |
| ] |
| tokens, text, audio = model(messages, max_new_tokens=4096, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True) |
| print(text) |
| print(tokens) |
| audio = token2wav(audio, prompt_wav='assets/default_female.wav') |
| with open('output-female.wav', 'wb') as f: |
| f.write(audio) |
|
|
| |
| print() |
| messages.pop(-1) |
| messages += [ |
| {"role": "assistant", "content": [{"type": "text", "text": "<tts_start>"}, |
| {"type": "token", "token": tokens}]}, |
| {"role": "human", "content": "Now write a 4-line poem about it."}, |
| {"role": "assistant", "content": None} |
| ] |
| tokens, text, audio = model(messages, max_new_tokens=256, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True) |
| print(text) |
|
|
| |
| print() |
| messages = [ |
| {"role": "system", "content": "You are a helpful assistant."}, |
| {"role": "human", "content": [{"type": "text", "text": "Translate the speech into Chinese."}, |
| {"type": "audio", "audio": "assets/give_me_a_brief_introduction_to_the_great_wall.wav"}]}, |
| {"role": "assistant", "content": None} |
| ] |
| tokens, text, audio = model(messages, max_new_tokens=256, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True) |
| print(text) |
|
|