Spaces:
Running
on
Zero
Running
on
Zero
| # Modified from CosyVoice https://github.com/FunAudioLLM/CosyVoice | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import logging | |
| import numpy as np | |
| import torch | |
| import torchaudio as ta | |
| from functools import lru_cache | |
| from typing import Optional | |
| from ..s3tokenizer import S3_SR, SPEECH_VOCAB_SIZE, S3Tokenizer | |
| from .const import S3GEN_SR | |
| from .flow import CausalMaskedDiffWithXvec | |
| from .xvector import CAMPPlus | |
| from .utils.mel import mel_spectrogram | |
| from .f0_predictor import ConvRNNF0Predictor | |
| from .hifigan import HiFTGenerator | |
| from .transformer.upsample_encoder import UpsampleConformerEncoder | |
| from .flow_matching import CausalConditionalCFM | |
| from .decoder import ConditionalDecoder | |
| from .configs import CFM_PARAMS | |
| def drop_invalid_tokens(x): | |
| assert len(x.shape) <= 2 and x.shape[0] == 1, "only batch size of one allowed for now" | |
| return x[x < SPEECH_VOCAB_SIZE] | |
| # TODO: global resampler cache | |
| def get_resampler(src_sr, dst_sr, device): | |
| return ta.transforms.Resample(src_sr, dst_sr).to(device) | |
| class S3Token2Mel(torch.nn.Module): | |
| """ | |
| CosyVoice2's CFM decoder maps S3 speech tokens to mel-spectrograms. | |
| TODO: make these modules configurable? | |
| """ | |
| def __init__(self): | |
| super().__init__() | |
| self.tokenizer = S3Tokenizer("speech_tokenizer_v2_25hz") | |
| self.mel_extractor = mel_spectrogram # TODO: make it a torch module? | |
| self.speaker_encoder = CAMPPlus() # use default args | |
| encoder = UpsampleConformerEncoder( | |
| output_size=512, | |
| attention_heads=8, | |
| linear_units=2048, | |
| num_blocks=6, | |
| dropout_rate=0.1, | |
| positional_dropout_rate=0.1, | |
| attention_dropout_rate=0.1, | |
| normalize_before=True, | |
| input_layer='linear', | |
| pos_enc_layer_type='rel_pos_espnet', | |
| selfattention_layer_type='rel_selfattn', | |
| input_size=512, | |
| use_cnn_module=False, | |
| macaron_style=False, | |
| ) | |
| estimator = ConditionalDecoder( | |
| in_channels=320, | |
| out_channels=80, | |
| causal=True, | |
| channels=[256], | |
| dropout=0.0, | |
| attention_head_dim=64, | |
| n_blocks=4, | |
| num_mid_blocks=12, | |
| num_heads=8, | |
| act_fn='gelu', | |
| ) | |
| cfm_params = CFM_PARAMS | |
| decoder = CausalConditionalCFM( | |
| spk_emb_dim=80, | |
| cfm_params=cfm_params, | |
| estimator=estimator, | |
| ) | |
| self.flow = CausalMaskedDiffWithXvec( | |
| encoder=encoder, | |
| decoder=decoder | |
| ) | |
| self.resamplers = {} | |
| def device(self): | |
| params = self.tokenizer.parameters() | |
| return next(params).device | |
| def embed_ref( | |
| self, | |
| ref_wav: torch.Tensor, | |
| ref_sr: int, | |
| device="auto", | |
| ref_fade_out=True, | |
| ): | |
| device = self.device if device == "auto" else device | |
| if isinstance(ref_wav, np.ndarray): | |
| ref_wav = torch.from_numpy(ref_wav).float() | |
| if ref_wav.device != device: | |
| ref_wav = ref_wav.to(device) | |
| if len(ref_wav.shape) == 1: | |
| ref_wav = ref_wav.unsqueeze(0) # (B, L) | |
| if ref_wav.size(1) > 10 * ref_sr: | |
| print("WARNING: cosydec received ref longer than 10s") | |
| ref_wav_24 = ref_wav | |
| if ref_sr != S3GEN_SR: | |
| ref_wav_24 = get_resampler(ref_sr, S3GEN_SR, device)(ref_wav) | |
| ref_mels_24 = self.mel_extractor(ref_wav_24).transpose(1, 2).to(device) | |
| ref_mels_24_len = None | |
| # Resample to 16kHz | |
| ref_wav_16 = get_resampler(ref_sr, S3_SR, device)(ref_wav).to(device) | |
| # Speaker embedding | |
| ref_x_vector = self.speaker_encoder.inference(ref_wav_16) | |
| # Tokenize 16khz reference | |
| ref_speech_tokens, ref_speech_token_lens = self.tokenizer(ref_wav_16) | |
| # Make sure mel_len = 2 * stoken_len (happens when the input is not padded to multiple of 40ms) | |
| if ref_mels_24.shape[1] != 2 * ref_speech_tokens.shape[1]: | |
| logging.warning( | |
| "Reference mel length is not equal to 2 * reference token length.\n" | |
| ) | |
| ref_speech_tokens = ref_speech_tokens[:, :ref_mels_24.shape[1] // 2] | |
| ref_speech_token_lens[0] = ref_speech_tokens.shape[1] | |
| return dict( | |
| prompt_token=ref_speech_tokens.to(device), | |
| prompt_token_len=ref_speech_token_lens, | |
| prompt_feat=ref_mels_24, | |
| prompt_feat_len=ref_mels_24_len, | |
| embedding=ref_x_vector, | |
| ) | |
| def forward( | |
| self, | |
| speech_tokens: torch.LongTensor, | |
| # locally-computed ref embedding (mutex with ref_dict) | |
| ref_wav: Optional[torch.Tensor], | |
| ref_sr: Optional[int], | |
| # pre-computed ref embedding (prod API) | |
| ref_dict: Optional[dict] = None, | |
| finalize: bool = False, | |
| ): | |
| """ | |
| Generate waveforms from S3 speech tokens and a reference waveform, which the speaker timbre is inferred from. | |
| NOTE: | |
| - The speaker encoder accepts 16 kHz waveform. | |
| - S3TokenizerV2 accepts 16 kHz waveform. | |
| - The mel-spectrogram for the reference assumes 24 kHz input signal. | |
| - This function is designed for batch_size=1 only. | |
| Args | |
| ---- | |
| - `speech_tokens`: S3 speech tokens [B=1, T] | |
| - `ref_wav`: reference waveform (`torch.Tensor` with shape=[B=1, T]) | |
| - `ref_sr`: reference sample rate | |
| - `finalize`: whether streaming is finished or not. Note that if False, the last 3 tokens will be ignored. | |
| """ | |
| assert (ref_wav is None) ^ (ref_dict is None), f"Must provide exactly one of ref_wav or ref_dict (got {ref_wav} and {ref_dict})" | |
| if ref_dict is None: | |
| ref_dict = self.embed_ref(ref_wav, ref_sr) | |
| else: | |
| # type/device casting (all values will be numpy if it's from a prod API call) | |
| for rk in list(ref_dict): | |
| if isinstance(ref_dict[rk], np.ndarray): | |
| ref_dict[rk] = torch.from_numpy(ref_dict[rk]) | |
| if torch.is_tensor(ref_dict[rk]): | |
| ref_dict[rk] = ref_dict[rk].to(self.device) | |
| if len(speech_tokens.shape) == 1: | |
| speech_tokens = speech_tokens.unsqueeze(0) | |
| # assert speech_tokens.shape[0] == 1, "only batch size of one allowed for now" | |
| speech_token_lens = torch.LongTensor([speech_tokens.size(1)]).to(self.device) | |
| output_mels, _ = self.flow.inference( | |
| token=speech_tokens, | |
| token_len=speech_token_lens, | |
| finalize=finalize, | |
| **ref_dict, | |
| ) | |
| return output_mels | |
| class S3Token2Wav(S3Token2Mel): | |
| """ | |
| The decoder of CosyVoice2 is a concat of token-to-mel (CFM) and a mel-to-waveform (HiFiGAN) modules. | |
| TODO: make these modules configurable? | |
| """ | |
| def __init__(self): | |
| super().__init__() | |
| f0_predictor = ConvRNNF0Predictor() | |
| self.mel2wav = HiFTGenerator( | |
| sampling_rate=S3GEN_SR, | |
| upsample_rates=[8, 5, 3], | |
| upsample_kernel_sizes=[16, 11, 7], | |
| source_resblock_kernel_sizes=[7, 7, 11], | |
| source_resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], | |
| f0_predictor=f0_predictor, | |
| ) | |
| # silence out a few ms and fade audio in to reduce artifacts | |
| n_trim = S3GEN_SR // 50 # 20ms = half of a frame | |
| trim_fade = torch.zeros(2 * n_trim) | |
| trim_fade[n_trim:] = (torch.cos(torch.linspace(torch.pi, 0, n_trim)) + 1) / 2 | |
| self.register_buffer("trim_fade", trim_fade, persistent=False) # (buffers get automatic device casting) | |
| def forward( | |
| self, | |
| speech_tokens, | |
| # locally-computed ref embedding (mutex with ref_dict) | |
| ref_wav: Optional[torch.Tensor], | |
| ref_sr: Optional[int], | |
| # pre-computed ref embedding (prod API) | |
| ref_dict: Optional[dict] = None, | |
| finalize: bool = False | |
| ): | |
| output_mels = super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize) | |
| # TODO jrm: ignoring the speed control (mel interpolation) and the HiFTGAN caching mechanisms for now. | |
| hift_cache_source = torch.zeros(1, 1, 0).to(self.device) | |
| output_wavs, *_ = self.mel2wav.inference(speech_feat=output_mels, cache_source=hift_cache_source) | |
| if not self.training: | |
| # NOTE: ad-hoc method to reduce "spillover" from the reference clip. | |
| output_wavs[:, :len(self.trim_fade)] *= self.trim_fade | |
| return output_wavs | |
| def flow_inference( | |
| self, | |
| speech_tokens, | |
| # locally-computed ref embedding (mutex with ref_dict) | |
| ref_wav: Optional[torch.Tensor] = None, | |
| ref_sr: Optional[int] = None, | |
| # pre-computed ref embedding (prod API) | |
| ref_dict: Optional[dict] = None, | |
| finalize: bool = False, | |
| ): | |
| return super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize) | |
| def hift_inference(self, speech_feat, cache_source: torch.Tensor = None): | |
| if cache_source is None: | |
| cache_source = torch.zeros(1, 1, 0).to(self.device) | |
| return self.mel2wav.inference(speech_feat=speech_feat, cache_source=cache_source) | |
| def inference( | |
| self, | |
| speech_tokens, | |
| # locally-computed ref embedding (mutex with ref_dict) | |
| ref_wav: Optional[torch.Tensor] = None, | |
| ref_sr: Optional[int] = None, | |
| # pre-computed ref embedding (prod API) | |
| ref_dict: Optional[dict] = None, | |
| cache_source: torch.Tensor = None, # NOTE: this arg is for streaming, it can probably be removed here | |
| finalize: bool = True, | |
| ): | |
| output_mels = self.flow_inference(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize) | |
| output_wavs, output_sources = self.hift_inference(output_mels, cache_source) | |
| # NOTE: ad-hoc method to reduce "spillover" from the reference clip. | |
| output_wavs[:, :len(self.trim_fade)] *= self.trim_fade | |
| return output_wavs, output_sources | |