Spaces:
Running
on
Zero
Running
on
Zero
| import io | |
| import torch | |
| import torchaudio | |
| import s3tokenizer | |
| import onnxruntime | |
| from huggingface_hub import hf_hub_download | |
| import torchaudio.compliance.kaldi as kaldi | |
| from flashcosyvoice.modules.hifigan import HiFTGenerator | |
| from flashcosyvoice.utils.audio import mel_spectrogram | |
| from hyperpyyaml import load_hyperpyyaml | |
| class Token2wav(): | |
| def __init__(self, model_path, float16=False): | |
| self.float16 = float16 | |
| self.audio_tokenizer = s3tokenizer.load_model("speech_tokenizer_v2_25hz").cuda().eval() | |
| option = onnxruntime.SessionOptions() | |
| option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| option.intra_op_num_threads = 1 | |
| self.spk_model = onnxruntime.InferenceSession("token2wav/campplus.onnx", sess_options=option, providers=["CPUExecutionProvider"]) | |
| with open(f"{model_path}/flow.yaml", "r") as f: | |
| configs = load_hyperpyyaml(f) | |
| self.flow = configs['flow'] | |
| if float16: | |
| self.flow.half() | |
| self.flow.load_state_dict(torch.load(f"{model_path}/flow.pt", map_location="cpu", weights_only=True), strict=True) | |
| self.flow.cuda().eval() | |
| self.hift = HiFTGenerator() | |
| hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_path}/hift.pt", map_location="cpu", weights_only=True).items()} | |
| self.hift.load_state_dict(hift_state_dict, strict=True) | |
| self.hift.cuda().eval() | |
| def __call__(self, generated_speech_tokens, prompt_wav): | |
| audio = s3tokenizer.load_audio(prompt_wav, sr=16000) # [T] | |
| mels = s3tokenizer.log_mel_spectrogram(audio) | |
| mels, mels_lens = s3tokenizer.padding([mels]) | |
| prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize(mels.cuda(), mels_lens.cuda()) | |
| spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000) | |
| spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True) | |
| spk_emb = torch.tensor(self.spk_model.run( | |
| None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()} | |
| )[0], device='cuda') | |
| audio, sample_rate = torchaudio.load(prompt_wav, backend='soundfile') | |
| audio = audio.mean(dim=0, keepdim=True) # [1, T] | |
| if sample_rate != 24000: | |
| audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio) | |
| prompt_mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels] | |
| prompt_mels = prompt_mel.unsqueeze(0).cuda() | |
| prompt_mels_lens = torch.tensor([prompt_mels.shape[1]], dtype=torch.int32, device='cuda') | |
| generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda') | |
| generated_speech_tokens_lens = torch.tensor([generated_speech_tokens.shape[1]], dtype=torch.int32, device='cuda') | |
| with torch.amp.autocast("cuda", dtype=torch.float16 if self.float16 else torch.float32): | |
| mel = self.flow.inference(generated_speech_tokens, generated_speech_tokens_lens, | |
| prompt_speech_tokens, prompt_speech_tokens_lens, | |
| prompt_mels, prompt_mels_lens, spk_emb, 10) | |
| wav, _ = self.hift(speech_feat=mel) | |
| output = io.BytesIO() | |
| torchaudio.save(output, wav.cpu(), sample_rate=24000, format='wav') | |
| return output.getvalue() | |
| if __name__ == '__main__': | |
| token2wav = Token2wav('/mnt/gpfs/lijingbei/Step-Audio-2-mini/token2wav') | |
| tokens = [1493, 4299, 4218, 2049, 528, 2752, 4850, 4569, 4575, 6372, 2127, 4068, 2312, 4993, 4769, 2300, 226, 2175, 2160, 2152, 6311, 6065, 4859, 5102, 4615, 6534, 6426, 1763, 2249, 2209, 5938, 1725, 6048, 3816, 6058, 958, 63, 4460, 5914, 2379, 735, 5319, 4593, 2328, 890, 35, 751, 1483, 1484, 1483, 2112, 303, 4753, 2301, 5507, 5588, 5261, 5744, 5501, 2341, 2001, 2252, 2344, 1860, 2031, 414, 4366, 4366, 6059, 5300, 4814, 5092, 5100, 1923, 3054, 4320, 4296, 2148, 4371, 5831, 5084, 5027, 4946, 4946, 2678, 575, 575, 521, 518, 638, 1367, 2804, 3402, 4299] | |
| audio = token2wav(tokens, 'assets/default_male.wav') | |
| with open('assets/give_me_a_brief_introduction_to_the_great_wall.wav', 'wb') as f: | |
| f.write(audio) | |