Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Commit 
							
							·
						
						04be12f
	
1
								Parent(s):
							
							ff769a6
								
disable deepspeed and cuda kernel
Browse files- indextts/infer_v2.py +2 -3
- indextts/s2mel/modules/.ipynb_checkpoints/audio-checkpoint.py +82 -0
- indextts/s2mel/modules/.ipynb_checkpoints/commons-checkpoint.py +610 -0
- indextts/s2mel/modules/.ipynb_checkpoints/diffusion_transformer-checkpoint.py +258 -0
- indextts/s2mel/modules/.ipynb_checkpoints/flow_matching-checkpoint.py +171 -0
- indextts/s2mel/modules/.ipynb_checkpoints/length_regulator-checkpoint.py +141 -0
- webui.py +3 -1
    	
        indextts/infer_v2.py
    CHANGED
    
    | @@ -35,7 +35,7 @@ import torch.nn.functional as F | |
| 35 | 
             
            class IndexTTS2:
         | 
| 36 | 
             
                def __init__(
         | 
| 37 | 
             
                        self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=False, device=None,
         | 
| 38 | 
            -
                        use_cuda_kernel=None,
         | 
| 39 | 
             
                ):
         | 
| 40 | 
             
                    """
         | 
| 41 | 
             
                    Args:
         | 
| @@ -83,14 +83,13 @@ class IndexTTS2: | |
| 83 | 
             
                        try:
         | 
| 84 | 
             
                            import deepspeed
         | 
| 85 |  | 
| 86 | 
            -
                            use_deepspeed = True
         | 
| 87 | 
             
                        except (ImportError, OSError, CalledProcessError) as e:
         | 
| 88 | 
             
                            use_deepspeed = False
         | 
| 89 | 
             
                            print(f">> DeepSpeed加载失败,回退到标准推理: {e}")
         | 
| 90 |  | 
| 91 | 
             
                        self.gpt.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=True, half=True)
         | 
| 92 | 
             
                    else:
         | 
| 93 | 
            -
                        self.gpt.post_init_gpt2_config(use_deepspeed= | 
| 94 |  | 
| 95 | 
             
                    if self.use_cuda_kernel:
         | 
| 96 | 
             
                        # preload the CUDA kernel for BigVGAN
         | 
|  | |
| 35 | 
             
            class IndexTTS2:
         | 
| 36 | 
             
                def __init__(
         | 
| 37 | 
             
                        self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=False, device=None,
         | 
| 38 | 
            +
                        use_cuda_kernel=None,use_deepspeed=False
         | 
| 39 | 
             
                ):
         | 
| 40 | 
             
                    """
         | 
| 41 | 
             
                    Args:
         | 
|  | |
| 83 | 
             
                        try:
         | 
| 84 | 
             
                            import deepspeed
         | 
| 85 |  | 
|  | |
| 86 | 
             
                        except (ImportError, OSError, CalledProcessError) as e:
         | 
| 87 | 
             
                            use_deepspeed = False
         | 
| 88 | 
             
                            print(f">> DeepSpeed加载失败,回退到标准推理: {e}")
         | 
| 89 |  | 
| 90 | 
             
                        self.gpt.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=True, half=True)
         | 
| 91 | 
             
                    else:
         | 
| 92 | 
            +
                        self.gpt.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=True, half=False)
         | 
| 93 |  | 
| 94 | 
             
                    if self.use_cuda_kernel:
         | 
| 95 | 
             
                        # preload the CUDA kernel for BigVGAN
         | 
    	
        indextts/s2mel/modules/.ipynb_checkpoints/audio-checkpoint.py
    ADDED
    
    | @@ -0,0 +1,82 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import torch.utils.data
         | 
| 4 | 
            +
            from librosa.filters import mel as librosa_mel_fn
         | 
| 5 | 
            +
            from scipy.io.wavfile import read
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            MAX_WAV_VALUE = 32768.0
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            def load_wav(full_path):
         | 
| 11 | 
            +
                sampling_rate, data = read(full_path)
         | 
| 12 | 
            +
                return data, sampling_rate
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            def dynamic_range_compression(x, C=1, clip_val=1e-5):
         | 
| 16 | 
            +
                return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            def dynamic_range_decompression(x, C=1):
         | 
| 20 | 
            +
                return np.exp(x) / C
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
         | 
| 24 | 
            +
                return torch.log(torch.clamp(x, min=clip_val) * C)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            def dynamic_range_decompression_torch(x, C=1):
         | 
| 28 | 
            +
                return torch.exp(x) / C
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            def spectral_normalize_torch(magnitudes):
         | 
| 32 | 
            +
                output = dynamic_range_compression_torch(magnitudes)
         | 
| 33 | 
            +
                return output
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            def spectral_de_normalize_torch(magnitudes):
         | 
| 37 | 
            +
                output = dynamic_range_decompression_torch(magnitudes)
         | 
| 38 | 
            +
                return output
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            mel_basis = {}
         | 
| 42 | 
            +
            hann_window = {}
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
         | 
| 46 | 
            +
            #     if torch.min(y) < -1.0:
         | 
| 47 | 
            +
            #         print("min value is ", torch.min(y))
         | 
| 48 | 
            +
            #     if torch.max(y) > 1.0:
         | 
| 49 | 
            +
            #         print("max value is ", torch.max(y))
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                global mel_basis, hann_window  # pylint: disable=global-statement
         | 
| 52 | 
            +
                if f"{str(sampling_rate)}_{str(fmax)}_{str(y.device)}" not in mel_basis:
         | 
| 53 | 
            +
                    mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
         | 
| 54 | 
            +
                    mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
         | 
| 55 | 
            +
                    hann_window[str(sampling_rate) + "_" + str(y.device)] = torch.hann_window(win_size).to(y.device)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                y = torch.nn.functional.pad(
         | 
| 58 | 
            +
                    y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
         | 
| 59 | 
            +
                )
         | 
| 60 | 
            +
                y = y.squeeze(1)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                spec = torch.view_as_real(
         | 
| 63 | 
            +
                    torch.stft(
         | 
| 64 | 
            +
                        y,
         | 
| 65 | 
            +
                        n_fft,
         | 
| 66 | 
            +
                        hop_length=hop_size,
         | 
| 67 | 
            +
                        win_length=win_size,
         | 
| 68 | 
            +
                        window=hann_window[str(sampling_rate) + "_" + str(y.device)],
         | 
| 69 | 
            +
                        center=center,
         | 
| 70 | 
            +
                        pad_mode="reflect",
         | 
| 71 | 
            +
                        normalized=False,
         | 
| 72 | 
            +
                        onesided=True,
         | 
| 73 | 
            +
                        return_complex=True,
         | 
| 74 | 
            +
                    )
         | 
| 75 | 
            +
                )
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                spec = torch.matmul(mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)], spec)
         | 
| 80 | 
            +
                spec = spectral_normalize_torch(spec)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                return spec
         | 
    	
        indextts/s2mel/modules/.ipynb_checkpoints/commons-checkpoint.py
    ADDED
    
    | @@ -0,0 +1,610 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from torch import nn
         | 
| 5 | 
            +
            from torch.nn import functional as F
         | 
| 6 | 
            +
            from munch import Munch
         | 
| 7 | 
            +
            import json
         | 
| 8 | 
            +
            import argparse
         | 
| 9 | 
            +
            from torch.nn.parallel import DistributedDataParallel as DDP
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            def str2bool(v):
         | 
| 12 | 
            +
                if isinstance(v, bool):
         | 
| 13 | 
            +
                    return v
         | 
| 14 | 
            +
                if v.lower() in ("yes", "true", "t", "y", "1"):
         | 
| 15 | 
            +
                    return True
         | 
| 16 | 
            +
                elif v.lower() in ("no", "false", "f", "n", "0"):
         | 
| 17 | 
            +
                    return False
         | 
| 18 | 
            +
                else:
         | 
| 19 | 
            +
                    raise argparse.ArgumentTypeError("Boolean value expected.")
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            class AttrDict(dict):
         | 
| 22 | 
            +
                def __init__(self, *args, **kwargs):
         | 
| 23 | 
            +
                    super(AttrDict, self).__init__(*args, **kwargs)
         | 
| 24 | 
            +
                    self.__dict__ = self
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            def init_weights(m, mean=0.0, std=0.01):
         | 
| 28 | 
            +
                classname = m.__class__.__name__
         | 
| 29 | 
            +
                if classname.find("Conv") != -1:
         | 
| 30 | 
            +
                    m.weight.data.normal_(mean, std)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            def get_padding(kernel_size, dilation=1):
         | 
| 34 | 
            +
                return int((kernel_size * dilation - dilation) / 2)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            def convert_pad_shape(pad_shape):
         | 
| 38 | 
            +
                l = pad_shape[::-1]
         | 
| 39 | 
            +
                pad_shape = [item for sublist in l for item in sublist]
         | 
| 40 | 
            +
                return pad_shape
         | 
| 41 | 
            +
             | 
| 42 | 
            +
             | 
| 43 | 
            +
            def intersperse(lst, item):
         | 
| 44 | 
            +
                result = [item] * (len(lst) * 2 + 1)
         | 
| 45 | 
            +
                result[1::2] = lst
         | 
| 46 | 
            +
                return result
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            def kl_divergence(m_p, logs_p, m_q, logs_q):
         | 
| 50 | 
            +
                """KL(P||Q)"""
         | 
| 51 | 
            +
                kl = (logs_q - logs_p) - 0.5
         | 
| 52 | 
            +
                kl += (
         | 
| 53 | 
            +
                    0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
         | 
| 54 | 
            +
                )
         | 
| 55 | 
            +
                return kl
         | 
| 56 | 
            +
             | 
| 57 | 
            +
             | 
| 58 | 
            +
            def rand_gumbel(shape):
         | 
| 59 | 
            +
                """Sample from the Gumbel distribution, protect from overflows."""
         | 
| 60 | 
            +
                uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
         | 
| 61 | 
            +
                return -torch.log(-torch.log(uniform_samples))
         | 
| 62 | 
            +
             | 
| 63 | 
            +
             | 
| 64 | 
            +
            def rand_gumbel_like(x):
         | 
| 65 | 
            +
                g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
         | 
| 66 | 
            +
                return g
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            def slice_segments(x, ids_str, segment_size=4):
         | 
| 70 | 
            +
                ret = torch.zeros_like(x[:, :, :segment_size])
         | 
| 71 | 
            +
                for i in range(x.size(0)):
         | 
| 72 | 
            +
                    idx_str = ids_str[i]
         | 
| 73 | 
            +
                    idx_end = idx_str + segment_size
         | 
| 74 | 
            +
                    ret[i] = x[i, :, idx_str:idx_end]
         | 
| 75 | 
            +
                return ret
         | 
| 76 | 
            +
             | 
| 77 | 
            +
             | 
| 78 | 
            +
            def slice_segments_audio(x, ids_str, segment_size=4):
         | 
| 79 | 
            +
                ret = torch.zeros_like(x[:, :segment_size])
         | 
| 80 | 
            +
                for i in range(x.size(0)):
         | 
| 81 | 
            +
                    idx_str = ids_str[i]
         | 
| 82 | 
            +
                    idx_end = idx_str + segment_size
         | 
| 83 | 
            +
                    ret[i] = x[i, idx_str:idx_end]
         | 
| 84 | 
            +
                return ret
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
            def rand_slice_segments(x, x_lengths=None, segment_size=4):
         | 
| 88 | 
            +
                b, d, t = x.size()
         | 
| 89 | 
            +
                if x_lengths is None:
         | 
| 90 | 
            +
                    x_lengths = t
         | 
| 91 | 
            +
                ids_str_max = x_lengths - segment_size + 1
         | 
| 92 | 
            +
                ids_str = ((torch.rand([b]).to(device=x.device) * ids_str_max).clip(0)).to(
         | 
| 93 | 
            +
                    dtype=torch.long
         | 
| 94 | 
            +
                )
         | 
| 95 | 
            +
                ret = slice_segments(x, ids_str, segment_size)
         | 
| 96 | 
            +
                return ret, ids_str
         | 
| 97 | 
            +
             | 
| 98 | 
            +
             | 
| 99 | 
            +
            def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
         | 
| 100 | 
            +
                position = torch.arange(length, dtype=torch.float)
         | 
| 101 | 
            +
                num_timescales = channels // 2
         | 
| 102 | 
            +
                log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
         | 
| 103 | 
            +
                    num_timescales - 1
         | 
| 104 | 
            +
                )
         | 
| 105 | 
            +
                inv_timescales = min_timescale * torch.exp(
         | 
| 106 | 
            +
                    torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
         | 
| 107 | 
            +
                )
         | 
| 108 | 
            +
                scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
         | 
| 109 | 
            +
                signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
         | 
| 110 | 
            +
                signal = F.pad(signal, [0, 0, 0, channels % 2])
         | 
| 111 | 
            +
                signal = signal.view(1, channels, length)
         | 
| 112 | 
            +
                return signal
         | 
| 113 | 
            +
             | 
| 114 | 
            +
             | 
| 115 | 
            +
            def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
         | 
| 116 | 
            +
                b, channels, length = x.size()
         | 
| 117 | 
            +
                signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
         | 
| 118 | 
            +
                return x + signal.to(dtype=x.dtype, device=x.device)
         | 
| 119 | 
            +
             | 
| 120 | 
            +
             | 
| 121 | 
            +
            def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
         | 
| 122 | 
            +
                b, channels, length = x.size()
         | 
| 123 | 
            +
                signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
         | 
| 124 | 
            +
                return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
             | 
| 127 | 
            +
            def subsequent_mask(length):
         | 
| 128 | 
            +
                mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
         | 
| 129 | 
            +
                return mask
         | 
| 130 | 
            +
             | 
| 131 | 
            +
             | 
| 132 | 
            +
            @torch.jit.script
         | 
| 133 | 
            +
            def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
         | 
| 134 | 
            +
                n_channels_int = n_channels[0]
         | 
| 135 | 
            +
                in_act = input_a + input_b
         | 
| 136 | 
            +
                t_act = torch.tanh(in_act[:, :n_channels_int, :])
         | 
| 137 | 
            +
                s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
         | 
| 138 | 
            +
                acts = t_act * s_act
         | 
| 139 | 
            +
                return acts
         | 
| 140 | 
            +
             | 
| 141 | 
            +
             | 
| 142 | 
            +
            def convert_pad_shape(pad_shape):
         | 
| 143 | 
            +
                l = pad_shape[::-1]
         | 
| 144 | 
            +
                pad_shape = [item for sublist in l for item in sublist]
         | 
| 145 | 
            +
                return pad_shape
         | 
| 146 | 
            +
             | 
| 147 | 
            +
             | 
| 148 | 
            +
            def shift_1d(x):
         | 
| 149 | 
            +
                x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
         | 
| 150 | 
            +
                return x
         | 
| 151 | 
            +
             | 
| 152 | 
            +
             | 
| 153 | 
            +
            def sequence_mask(length, max_length=None):
         | 
| 154 | 
            +
                if max_length is None:
         | 
| 155 | 
            +
                    max_length = length.max()
         | 
| 156 | 
            +
                x = torch.arange(max_length, dtype=length.dtype, device=length.device)
         | 
| 157 | 
            +
                return x.unsqueeze(0) < length.unsqueeze(1)
         | 
| 158 | 
            +
             | 
| 159 | 
            +
             | 
| 160 | 
            +
            def avg_with_mask(x, mask):
         | 
| 161 | 
            +
                assert mask.dtype == torch.float, "Mask should be float"
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                if mask.ndim == 2:
         | 
| 164 | 
            +
                    mask = mask.unsqueeze(1)
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                if mask.shape[1] == 1:
         | 
| 167 | 
            +
                    mask = mask.expand_as(x)
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                return (x * mask).sum() / mask.sum()
         | 
| 170 | 
            +
             | 
| 171 | 
            +
             | 
| 172 | 
            +
            def generate_path(duration, mask):
         | 
| 173 | 
            +
                """
         | 
| 174 | 
            +
                duration: [b, 1, t_x]
         | 
| 175 | 
            +
                mask: [b, 1, t_y, t_x]
         | 
| 176 | 
            +
                """
         | 
| 177 | 
            +
                device = duration.device
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                b, _, t_y, t_x = mask.shape
         | 
| 180 | 
            +
                cum_duration = torch.cumsum(duration, -1)
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                cum_duration_flat = cum_duration.view(b * t_x)
         | 
| 183 | 
            +
                path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
         | 
| 184 | 
            +
                path = path.view(b, t_x, t_y)
         | 
| 185 | 
            +
                path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
         | 
| 186 | 
            +
                path = path.unsqueeze(1).transpose(2, 3) * mask
         | 
| 187 | 
            +
                return path
         | 
| 188 | 
            +
             | 
| 189 | 
            +
             | 
| 190 | 
            +
            def clip_grad_value_(parameters, clip_value, norm_type=2):
         | 
| 191 | 
            +
                if isinstance(parameters, torch.Tensor):
         | 
| 192 | 
            +
                    parameters = [parameters]
         | 
| 193 | 
            +
                parameters = list(filter(lambda p: p.grad is not None, parameters))
         | 
| 194 | 
            +
                norm_type = float(norm_type)
         | 
| 195 | 
            +
                if clip_value is not None:
         | 
| 196 | 
            +
                    clip_value = float(clip_value)
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                total_norm = 0
         | 
| 199 | 
            +
                for p in parameters:
         | 
| 200 | 
            +
                    param_norm = p.grad.data.norm(norm_type)
         | 
| 201 | 
            +
                    total_norm += param_norm.item() ** norm_type
         | 
| 202 | 
            +
                    if clip_value is not None:
         | 
| 203 | 
            +
                        p.grad.data.clamp_(min=-clip_value, max=clip_value)
         | 
| 204 | 
            +
                total_norm = total_norm ** (1.0 / norm_type)
         | 
| 205 | 
            +
                return total_norm
         | 
| 206 | 
            +
             | 
| 207 | 
            +
             | 
| 208 | 
            +
            def log_norm(x, mean=-4, std=4, dim=2):
         | 
| 209 | 
            +
                """
         | 
| 210 | 
            +
                normalized log mel -> mel -> norm -> log(norm)
         | 
| 211 | 
            +
                """
         | 
| 212 | 
            +
                x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
         | 
| 213 | 
            +
                return x
         | 
| 214 | 
            +
             | 
| 215 | 
            +
             | 
| 216 | 
            +
            def load_F0_models(path):
         | 
| 217 | 
            +
                # load F0 model
         | 
| 218 | 
            +
                from .JDC.model import JDCNet
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                F0_model = JDCNet(num_class=1, seq_len=192)
         | 
| 221 | 
            +
                params = torch.load(path, map_location="cpu")["net"]
         | 
| 222 | 
            +
                F0_model.load_state_dict(params)
         | 
| 223 | 
            +
                _ = F0_model.train()
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                return F0_model
         | 
| 226 | 
            +
             | 
| 227 | 
            +
             | 
| 228 | 
            +
            def modify_w2v_forward(self, output_layer=15):
         | 
| 229 | 
            +
                """
         | 
| 230 | 
            +
                change forward method of w2v encoder to get its intermediate layer output
         | 
| 231 | 
            +
                :param self:
         | 
| 232 | 
            +
                :param layer:
         | 
| 233 | 
            +
                :return:
         | 
| 234 | 
            +
                """
         | 
| 235 | 
            +
                from transformers.modeling_outputs import BaseModelOutput
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                def forward(
         | 
| 238 | 
            +
                    hidden_states,
         | 
| 239 | 
            +
                    attention_mask=None,
         | 
| 240 | 
            +
                    output_attentions=False,
         | 
| 241 | 
            +
                    output_hidden_states=False,
         | 
| 242 | 
            +
                    return_dict=True,
         | 
| 243 | 
            +
                ):
         | 
| 244 | 
            +
                    all_hidden_states = () if output_hidden_states else None
         | 
| 245 | 
            +
                    all_self_attentions = () if output_attentions else None
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                    conv_attention_mask = attention_mask
         | 
| 248 | 
            +
                    if attention_mask is not None:
         | 
| 249 | 
            +
                        # make sure padded tokens output 0
         | 
| 250 | 
            +
                        hidden_states = hidden_states.masked_fill(
         | 
| 251 | 
            +
                            ~attention_mask.bool().unsqueeze(-1), 0.0
         | 
| 252 | 
            +
                        )
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                        # extend attention_mask
         | 
| 255 | 
            +
                        attention_mask = 1.0 - attention_mask[:, None, None, :].to(
         | 
| 256 | 
            +
                            dtype=hidden_states.dtype
         | 
| 257 | 
            +
                        )
         | 
| 258 | 
            +
                        attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
         | 
| 259 | 
            +
                        attention_mask = attention_mask.expand(
         | 
| 260 | 
            +
                            attention_mask.shape[0],
         | 
| 261 | 
            +
                            1,
         | 
| 262 | 
            +
                            attention_mask.shape[-1],
         | 
| 263 | 
            +
                            attention_mask.shape[-1],
         | 
| 264 | 
            +
                        )
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                    hidden_states = self.dropout(hidden_states)
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                    if self.embed_positions is not None:
         | 
| 269 | 
            +
                        relative_position_embeddings = self.embed_positions(hidden_states)
         | 
| 270 | 
            +
                    else:
         | 
| 271 | 
            +
                        relative_position_embeddings = None
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                    deepspeed_zero3_is_enabled = False
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                    for i, layer in enumerate(self.layers):
         | 
| 276 | 
            +
                        if output_hidden_states:
         | 
| 277 | 
            +
                            all_hidden_states = all_hidden_states + (hidden_states,)
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                        # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
         | 
| 280 | 
            +
                        dropout_probability = torch.rand([])
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                        skip_the_layer = (
         | 
| 283 | 
            +
                            True
         | 
| 284 | 
            +
                            if self.training and (dropout_probability < self.config.layerdrop)
         | 
| 285 | 
            +
                            else False
         | 
| 286 | 
            +
                        )
         | 
| 287 | 
            +
                        if not skip_the_layer or deepspeed_zero3_is_enabled:
         | 
| 288 | 
            +
                            # under deepspeed zero3 all gpus must run in sync
         | 
| 289 | 
            +
                            if self.gradient_checkpointing and self.training:
         | 
| 290 | 
            +
                                layer_outputs = self._gradient_checkpointing_func(
         | 
| 291 | 
            +
                                    layer.__call__,
         | 
| 292 | 
            +
                                    hidden_states,
         | 
| 293 | 
            +
                                    attention_mask,
         | 
| 294 | 
            +
                                    relative_position_embeddings,
         | 
| 295 | 
            +
                                    output_attentions,
         | 
| 296 | 
            +
                                    conv_attention_mask,
         | 
| 297 | 
            +
                                )
         | 
| 298 | 
            +
                            else:
         | 
| 299 | 
            +
                                layer_outputs = layer(
         | 
| 300 | 
            +
                                    hidden_states,
         | 
| 301 | 
            +
                                    attention_mask=attention_mask,
         | 
| 302 | 
            +
                                    relative_position_embeddings=relative_position_embeddings,
         | 
| 303 | 
            +
                                    output_attentions=output_attentions,
         | 
| 304 | 
            +
                                    conv_attention_mask=conv_attention_mask,
         | 
| 305 | 
            +
                                )
         | 
| 306 | 
            +
                            hidden_states = layer_outputs[0]
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                        if skip_the_layer:
         | 
| 309 | 
            +
                            layer_outputs = (None, None)
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                        if output_attentions:
         | 
| 312 | 
            +
                            all_self_attentions = all_self_attentions + (layer_outputs[1],)
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                        if i == output_layer - 1:
         | 
| 315 | 
            +
                            break
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                    if output_hidden_states:
         | 
| 318 | 
            +
                        all_hidden_states = all_hidden_states + (hidden_states,)
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                    if not return_dict:
         | 
| 321 | 
            +
                        return tuple(
         | 
| 322 | 
            +
                            v
         | 
| 323 | 
            +
                            for v in [hidden_states, all_hidden_states, all_self_attentions]
         | 
| 324 | 
            +
                            if v is not None
         | 
| 325 | 
            +
                        )
         | 
| 326 | 
            +
                    return BaseModelOutput(
         | 
| 327 | 
            +
                        last_hidden_state=hidden_states,
         | 
| 328 | 
            +
                        hidden_states=all_hidden_states,
         | 
| 329 | 
            +
                        attentions=all_self_attentions,
         | 
| 330 | 
            +
                    )
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                return forward
         | 
| 333 | 
            +
             | 
| 334 | 
            +
             | 
| 335 | 
            +
            MATPLOTLIB_FLAG = False
         | 
| 336 | 
            +
             | 
| 337 | 
            +
             | 
| 338 | 
            +
            def plot_spectrogram_to_numpy(spectrogram):
         | 
| 339 | 
            +
                global MATPLOTLIB_FLAG
         | 
| 340 | 
            +
                if not MATPLOTLIB_FLAG:
         | 
| 341 | 
            +
                    import matplotlib
         | 
| 342 | 
            +
                    import logging
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                    matplotlib.use("Agg")
         | 
| 345 | 
            +
                    MATPLOTLIB_FLAG = True
         | 
| 346 | 
            +
                    mpl_logger = logging.getLogger("matplotlib")
         | 
| 347 | 
            +
                    mpl_logger.setLevel(logging.WARNING)
         | 
| 348 | 
            +
                import matplotlib.pylab as plt
         | 
| 349 | 
            +
                import numpy as np
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                fig, ax = plt.subplots(figsize=(10, 2))
         | 
| 352 | 
            +
                im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
         | 
| 353 | 
            +
                plt.colorbar(im, ax=ax)
         | 
| 354 | 
            +
                plt.xlabel("Frames")
         | 
| 355 | 
            +
                plt.ylabel("Channels")
         | 
| 356 | 
            +
                plt.tight_layout()
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                fig.canvas.draw()
         | 
| 359 | 
            +
                data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
         | 
| 360 | 
            +
                data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
         | 
| 361 | 
            +
                plt.close()
         | 
| 362 | 
            +
                return data
         | 
| 363 | 
            +
             | 
| 364 | 
            +
             | 
| 365 | 
            +
            def normalize_f0(f0_sequence):
         | 
| 366 | 
            +
                # Remove unvoiced frames (replace with -1)
         | 
| 367 | 
            +
                voiced_indices = np.where(f0_sequence > 0)[0]
         | 
| 368 | 
            +
                f0_voiced = f0_sequence[voiced_indices]
         | 
| 369 | 
            +
             | 
| 370 | 
            +
                # Convert to log scale
         | 
| 371 | 
            +
                log_f0 = np.log2(f0_voiced)
         | 
| 372 | 
            +
             | 
| 373 | 
            +
                # Calculate mean and standard deviation
         | 
| 374 | 
            +
                mean_f0 = np.mean(log_f0)
         | 
| 375 | 
            +
                std_f0 = np.std(log_f0)
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                # Normalize the F0 sequence
         | 
| 378 | 
            +
                normalized_f0 = (log_f0 - mean_f0) / std_f0
         | 
| 379 | 
            +
             | 
| 380 | 
            +
                # Create the normalized F0 sequence with unvoiced frames
         | 
| 381 | 
            +
                normalized_sequence = np.zeros_like(f0_sequence)
         | 
| 382 | 
            +
                normalized_sequence[voiced_indices] = normalized_f0
         | 
| 383 | 
            +
                normalized_sequence[f0_sequence <= 0] = -1  # Assign -1 to unvoiced frames
         | 
| 384 | 
            +
             | 
| 385 | 
            +
                return normalized_sequence
         | 
| 386 | 
            +
             | 
| 387 | 
            +
             | 
| 388 | 
            +
            class MyModel(nn.Module):
         | 
| 389 | 
            +
                def __init__(self,args):
         | 
| 390 | 
            +
                    super(MyModel, self).__init__()
         | 
| 391 | 
            +
                    from modules.flow_matching import CFM
         | 
| 392 | 
            +
                    from modules.length_regulator import InterpolateRegulator
         | 
| 393 | 
            +
                    
         | 
| 394 | 
            +
                    length_regulator = InterpolateRegulator(
         | 
| 395 | 
            +
                        channels=args.length_regulator.channels,
         | 
| 396 | 
            +
                        sampling_ratios=args.length_regulator.sampling_ratios,
         | 
| 397 | 
            +
                        is_discrete=args.length_regulator.is_discrete,
         | 
| 398 | 
            +
                        in_channels=args.length_regulator.in_channels if hasattr(args.length_regulator, "in_channels") else None,
         | 
| 399 | 
            +
                        vector_quantize=args.length_regulator.vector_quantize if hasattr(args.length_regulator, "vector_quantize") else False,
         | 
| 400 | 
            +
                        codebook_size=args.length_regulator.content_codebook_size,
         | 
| 401 | 
            +
                        n_codebooks=args.length_regulator.n_codebooks if hasattr(args.length_regulator, "n_codebooks") else 1,
         | 
| 402 | 
            +
                        quantizer_dropout=args.length_regulator.quantizer_dropout if hasattr(args.length_regulator, "quantizer_dropout") else 0.0,
         | 
| 403 | 
            +
                        f0_condition=args.length_regulator.f0_condition if hasattr(args.length_regulator, "f0_condition") else False,
         | 
| 404 | 
            +
                        n_f0_bins=args.length_regulator.n_f0_bins if hasattr(args.length_regulator, "n_f0_bins") else 512,
         | 
| 405 | 
            +
                    )
         | 
| 406 | 
            +
             | 
| 407 | 
            +
                    self.models = nn.ModuleDict({
         | 
| 408 | 
            +
                        'cfm': CFM(args),
         | 
| 409 | 
            +
                        'length_regulator': length_regulator 
         | 
| 410 | 
            +
                    })
         | 
| 411 | 
            +
                
         | 
| 412 | 
            +
                def forward(self, x, target_lengths, prompt_len, cond, y):
         | 
| 413 | 
            +
                    x = self.models['cfm'](x, target_lengths, prompt_len, cond, y)
         | 
| 414 | 
            +
                    return x
         | 
| 415 | 
            +
                
         | 
| 416 | 
            +
                def forward2(self, S_ori,target_lengths,F0_ori):
         | 
| 417 | 
            +
                    x = self.models['length_regulator'](S_ori, ylens=target_lengths, f0=F0_ori)
         | 
| 418 | 
            +
                    return x
         | 
| 419 | 
            +
             | 
| 420 | 
            +
            def build_model(args, stage="DiT"):
         | 
| 421 | 
            +
                if stage == "DiT":
         | 
| 422 | 
            +
                    from modules.flow_matching import CFM
         | 
| 423 | 
            +
                    from modules.length_regulator import InterpolateRegulator
         | 
| 424 | 
            +
                    
         | 
| 425 | 
            +
                    length_regulator = InterpolateRegulator(
         | 
| 426 | 
            +
                        channels=args.length_regulator.channels,
         | 
| 427 | 
            +
                        sampling_ratios=args.length_regulator.sampling_ratios,
         | 
| 428 | 
            +
                        is_discrete=args.length_regulator.is_discrete,
         | 
| 429 | 
            +
                        in_channels=args.length_regulator.in_channels if hasattr(args.length_regulator, "in_channels") else None,
         | 
| 430 | 
            +
                        vector_quantize=args.length_regulator.vector_quantize if hasattr(args.length_regulator, "vector_quantize") else False,
         | 
| 431 | 
            +
                        codebook_size=args.length_regulator.content_codebook_size,
         | 
| 432 | 
            +
                        n_codebooks=args.length_regulator.n_codebooks if hasattr(args.length_regulator, "n_codebooks") else 1,
         | 
| 433 | 
            +
                        quantizer_dropout=args.length_regulator.quantizer_dropout if hasattr(args.length_regulator, "quantizer_dropout") else 0.0,
         | 
| 434 | 
            +
                        f0_condition=args.length_regulator.f0_condition if hasattr(args.length_regulator, "f0_condition") else False,
         | 
| 435 | 
            +
                        n_f0_bins=args.length_regulator.n_f0_bins if hasattr(args.length_regulator, "n_f0_bins") else 512,
         | 
| 436 | 
            +
                    )
         | 
| 437 | 
            +
                    cfm = CFM(args)
         | 
| 438 | 
            +
                    nets = Munch(
         | 
| 439 | 
            +
                        cfm=cfm,
         | 
| 440 | 
            +
                        length_regulator=length_regulator,
         | 
| 441 | 
            +
                    )
         | 
| 442 | 
            +
                    
         | 
| 443 | 
            +
                elif stage == 'codec':
         | 
| 444 | 
            +
                    from dac.model.dac import Encoder
         | 
| 445 | 
            +
                    from modules.quantize import (
         | 
| 446 | 
            +
                        FAquantizer,
         | 
| 447 | 
            +
                    )
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                    encoder = Encoder(
         | 
| 450 | 
            +
                        d_model=args.DAC.encoder_dim,
         | 
| 451 | 
            +
                        strides=args.DAC.encoder_rates,
         | 
| 452 | 
            +
                        d_latent=1024,
         | 
| 453 | 
            +
                        causal=args.causal,
         | 
| 454 | 
            +
                        lstm=args.lstm,
         | 
| 455 | 
            +
                    )
         | 
| 456 | 
            +
             | 
| 457 | 
            +
                    quantizer = FAquantizer(
         | 
| 458 | 
            +
                        in_dim=1024,
         | 
| 459 | 
            +
                        n_p_codebooks=1,
         | 
| 460 | 
            +
                        n_c_codebooks=args.n_c_codebooks,
         | 
| 461 | 
            +
                        n_t_codebooks=2,
         | 
| 462 | 
            +
                        n_r_codebooks=3,
         | 
| 463 | 
            +
                        codebook_size=1024,
         | 
| 464 | 
            +
                        codebook_dim=8,
         | 
| 465 | 
            +
                        quantizer_dropout=0.5,
         | 
| 466 | 
            +
                        causal=args.causal,
         | 
| 467 | 
            +
                        separate_prosody_encoder=args.separate_prosody_encoder,
         | 
| 468 | 
            +
                        timbre_norm=args.timbre_norm,
         | 
| 469 | 
            +
                    )
         | 
| 470 | 
            +
             | 
| 471 | 
            +
                    nets = Munch(
         | 
| 472 | 
            +
                        encoder=encoder,
         | 
| 473 | 
            +
                        quantizer=quantizer,
         | 
| 474 | 
            +
                    )
         | 
| 475 | 
            +
             | 
| 476 | 
            +
                elif stage == "mel_vocos":
         | 
| 477 | 
            +
                    from modules.vocos import Vocos
         | 
| 478 | 
            +
                    decoder = Vocos(args)
         | 
| 479 | 
            +
                    nets = Munch(
         | 
| 480 | 
            +
                        decoder=decoder,
         | 
| 481 | 
            +
                    )
         | 
| 482 | 
            +
             | 
| 483 | 
            +
                else:
         | 
| 484 | 
            +
                    raise ValueError(f"Unknown stage: {stage}")
         | 
| 485 | 
            +
             | 
| 486 | 
            +
                return nets
         | 
| 487 | 
            +
             | 
| 488 | 
            +
             | 
| 489 | 
            +
            def load_checkpoint(
         | 
| 490 | 
            +
                model,
         | 
| 491 | 
            +
                optimizer,
         | 
| 492 | 
            +
                path,
         | 
| 493 | 
            +
                load_only_params=True,
         | 
| 494 | 
            +
                ignore_modules=[],
         | 
| 495 | 
            +
                is_distributed=False,
         | 
| 496 | 
            +
                load_ema=False,
         | 
| 497 | 
            +
            ):
         | 
| 498 | 
            +
                state = torch.load(path, map_location="cpu")
         | 
| 499 | 
            +
                params = state["net"]
         | 
| 500 | 
            +
                if load_ema and "ema" in state:
         | 
| 501 | 
            +
                    print("Loading EMA")
         | 
| 502 | 
            +
                    for key in model:
         | 
| 503 | 
            +
                        i = 0
         | 
| 504 | 
            +
                        for param_name in params[key]:
         | 
| 505 | 
            +
                            if "input_pos" in param_name:
         | 
| 506 | 
            +
                                continue
         | 
| 507 | 
            +
                            assert params[key][param_name].shape == state["ema"][key][0][i].shape
         | 
| 508 | 
            +
                            params[key][param_name] = state["ema"][key][0][i].clone()
         | 
| 509 | 
            +
                            i += 1
         | 
| 510 | 
            +
                for key in model:
         | 
| 511 | 
            +
                    if key in params and key not in ignore_modules:
         | 
| 512 | 
            +
                        if not is_distributed:
         | 
| 513 | 
            +
                            # strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
         | 
| 514 | 
            +
                            for k in list(params[key].keys()):
         | 
| 515 | 
            +
                                if k.startswith("module."):
         | 
| 516 | 
            +
                                    params[key][k[len("module.") :]] = params[key][k]
         | 
| 517 | 
            +
                                    del params[key][k]
         | 
| 518 | 
            +
                        model_state_dict = model[key].state_dict()
         | 
| 519 | 
            +
                        # 过滤出形状匹配的键值对
         | 
| 520 | 
            +
                        filtered_state_dict = {
         | 
| 521 | 
            +
                            k: v
         | 
| 522 | 
            +
                            for k, v in params[key].items()
         | 
| 523 | 
            +
                            if k in model_state_dict and v.shape == model_state_dict[k].shape
         | 
| 524 | 
            +
                        }
         | 
| 525 | 
            +
                        skipped_keys = set(params[key].keys()) - set(filtered_state_dict.keys())
         | 
| 526 | 
            +
                        if skipped_keys:
         | 
| 527 | 
            +
                            print(
         | 
| 528 | 
            +
                                f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}"
         | 
| 529 | 
            +
                            )
         | 
| 530 | 
            +
                        print("%s loaded" % key)
         | 
| 531 | 
            +
                        model[key].load_state_dict(filtered_state_dict, strict=False)
         | 
| 532 | 
            +
                _ = [model[key].eval() for key in model]
         | 
| 533 | 
            +
             | 
| 534 | 
            +
                if not load_only_params:
         | 
| 535 | 
            +
                    epoch = state["epoch"] + 1
         | 
| 536 | 
            +
                    iters = state["iters"]
         | 
| 537 | 
            +
                    optimizer.load_state_dict(state["optimizer"])
         | 
| 538 | 
            +
                    optimizer.load_scheduler_state_dict(state["scheduler"])
         | 
| 539 | 
            +
             | 
| 540 | 
            +
                else:
         | 
| 541 | 
            +
                    epoch = 0
         | 
| 542 | 
            +
                    iters = 0
         | 
| 543 | 
            +
             | 
| 544 | 
            +
                return model, optimizer, epoch, iters
         | 
| 545 | 
            +
             | 
| 546 | 
            +
            def load_checkpoint2(
         | 
| 547 | 
            +
                model,
         | 
| 548 | 
            +
                optimizer,
         | 
| 549 | 
            +
                path,
         | 
| 550 | 
            +
                load_only_params=True,
         | 
| 551 | 
            +
                ignore_modules=[],
         | 
| 552 | 
            +
                is_distributed=False,
         | 
| 553 | 
            +
                load_ema=False,
         | 
| 554 | 
            +
            ):
         | 
| 555 | 
            +
                state = torch.load(path, map_location="cpu")
         | 
| 556 | 
            +
                params = state["net"]
         | 
| 557 | 
            +
                if load_ema and "ema" in state:
         | 
| 558 | 
            +
                    print("Loading EMA")
         | 
| 559 | 
            +
                    for key in model.models:
         | 
| 560 | 
            +
                        i = 0
         | 
| 561 | 
            +
                        for param_name in params[key]:
         | 
| 562 | 
            +
                            if "input_pos" in param_name:
         | 
| 563 | 
            +
                                continue
         | 
| 564 | 
            +
                            assert params[key][param_name].shape == state["ema"][key][0][i].shape
         | 
| 565 | 
            +
                            params[key][param_name] = state["ema"][key][0][i].clone()
         | 
| 566 | 
            +
                            i += 1
         | 
| 567 | 
            +
                for key in model.models:
         | 
| 568 | 
            +
                    if key in params and key not in ignore_modules:
         | 
| 569 | 
            +
                        if not is_distributed:
         | 
| 570 | 
            +
                            # strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
         | 
| 571 | 
            +
                            for k in list(params[key].keys()):
         | 
| 572 | 
            +
                                if k.startswith("module."):
         | 
| 573 | 
            +
                                    params[key][k[len("module.") :]] = params[key][k]
         | 
| 574 | 
            +
                                    del params[key][k]
         | 
| 575 | 
            +
                        model_state_dict = model.models[key].state_dict()
         | 
| 576 | 
            +
                        # 过滤出形状匹配的键值对
         | 
| 577 | 
            +
                        filtered_state_dict = {
         | 
| 578 | 
            +
                            k: v
         | 
| 579 | 
            +
                            for k, v in params[key].items()
         | 
| 580 | 
            +
                            if k in model_state_dict and v.shape == model_state_dict[k].shape
         | 
| 581 | 
            +
                        }
         | 
| 582 | 
            +
                        skipped_keys = set(params[key].keys()) - set(filtered_state_dict.keys())
         | 
| 583 | 
            +
                        if skipped_keys:
         | 
| 584 | 
            +
                            print(
         | 
| 585 | 
            +
                                f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}"
         | 
| 586 | 
            +
                            )
         | 
| 587 | 
            +
                        print("%s loaded" % key)
         | 
| 588 | 
            +
                        model.models[key].load_state_dict(filtered_state_dict, strict=False)
         | 
| 589 | 
            +
                model.eval()
         | 
| 590 | 
            +
            #     _ = [model[key].eval() for key in model]
         | 
| 591 | 
            +
             | 
| 592 | 
            +
                if not load_only_params:
         | 
| 593 | 
            +
                    epoch = state["epoch"] + 1
         | 
| 594 | 
            +
                    iters = state["iters"]
         | 
| 595 | 
            +
                    optimizer.load_state_dict(state["optimizer"])
         | 
| 596 | 
            +
                    optimizer.load_scheduler_state_dict(state["scheduler"])
         | 
| 597 | 
            +
             | 
| 598 | 
            +
                else:
         | 
| 599 | 
            +
                    epoch = 0
         | 
| 600 | 
            +
                    iters = 0
         | 
| 601 | 
            +
             | 
| 602 | 
            +
                return model, optimizer, epoch, iters
         | 
| 603 | 
            +
             | 
| 604 | 
            +
            def recursive_munch(d):
         | 
| 605 | 
            +
                if isinstance(d, dict):
         | 
| 606 | 
            +
                    return Munch((k, recursive_munch(v)) for k, v in d.items())
         | 
| 607 | 
            +
                elif isinstance(d, list):
         | 
| 608 | 
            +
                    return [recursive_munch(v) for v in d]
         | 
| 609 | 
            +
                else:
         | 
| 610 | 
            +
                    return d
         | 
    	
        indextts/s2mel/modules/.ipynb_checkpoints/diffusion_transformer-checkpoint.py
    ADDED
    
    | @@ -0,0 +1,258 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from torch import nn
         | 
| 3 | 
            +
            import math
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from modules.gpt_fast.model import ModelArgs, Transformer
         | 
| 6 | 
            +
            # from modules.torchscript_modules.gpt_fast_model import ModelArgs, Transformer
         | 
| 7 | 
            +
            from modules.wavenet import WN
         | 
| 8 | 
            +
            from modules.commons import sequence_mask
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from torch.nn.utils import weight_norm
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            def modulate(x, shift, scale):
         | 
| 13 | 
            +
                return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            #################################################################################
         | 
| 17 | 
            +
            #               Embedding Layers for Timesteps and Class Labels                 #
         | 
| 18 | 
            +
            #################################################################################
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            class TimestepEmbedder(nn.Module):
         | 
| 21 | 
            +
                """
         | 
| 22 | 
            +
                Embeds scalar timesteps into vector representations.
         | 
| 23 | 
            +
                """
         | 
| 24 | 
            +
                def __init__(self, hidden_size, frequency_embedding_size=256):
         | 
| 25 | 
            +
                    super().__init__()
         | 
| 26 | 
            +
                    self.mlp = nn.Sequential(
         | 
| 27 | 
            +
                        nn.Linear(frequency_embedding_size, hidden_size, bias=True),
         | 
| 28 | 
            +
                        nn.SiLU(),
         | 
| 29 | 
            +
                        nn.Linear(hidden_size, hidden_size, bias=True),
         | 
| 30 | 
            +
                    )
         | 
| 31 | 
            +
                    self.frequency_embedding_size = frequency_embedding_size
         | 
| 32 | 
            +
                    self.max_period = 10000
         | 
| 33 | 
            +
                    self.scale = 1000
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    half = frequency_embedding_size // 2
         | 
| 36 | 
            +
                    freqs = torch.exp(
         | 
| 37 | 
            +
                        -math.log(self.max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
         | 
| 38 | 
            +
                    )
         | 
| 39 | 
            +
                    self.register_buffer("freqs", freqs)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                def timestep_embedding(self, t):
         | 
| 42 | 
            +
                    """
         | 
| 43 | 
            +
                    Create sinusoidal timestep embeddings.
         | 
| 44 | 
            +
                    :param t: a 1-D Tensor of N indices, one per batch element.
         | 
| 45 | 
            +
                                      These may be fractional.
         | 
| 46 | 
            +
                    :param dim: the dimension of the output.
         | 
| 47 | 
            +
                    :param max_period: controls the minimum frequency of the embeddings.
         | 
| 48 | 
            +
                    :return: an (N, D) Tensor of positional embeddings.
         | 
| 49 | 
            +
                    """
         | 
| 50 | 
            +
                    # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    args = self.scale * t[:, None].float() * self.freqs[None]
         | 
| 53 | 
            +
                    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
         | 
| 54 | 
            +
                    if self.frequency_embedding_size % 2:
         | 
| 55 | 
            +
                        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
         | 
| 56 | 
            +
                    return embedding
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                def forward(self, t):
         | 
| 59 | 
            +
                    t_freq = self.timestep_embedding(t)
         | 
| 60 | 
            +
                    t_emb = self.mlp(t_freq)
         | 
| 61 | 
            +
                    return t_emb
         | 
| 62 | 
            +
             | 
| 63 | 
            +
             | 
| 64 | 
            +
            class StyleEmbedder(nn.Module):
         | 
| 65 | 
            +
                """
         | 
| 66 | 
            +
                Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
         | 
| 67 | 
            +
                """
         | 
| 68 | 
            +
                def __init__(self, input_size, hidden_size, dropout_prob):
         | 
| 69 | 
            +
                    super().__init__()
         | 
| 70 | 
            +
                    use_cfg_embedding = dropout_prob > 0
         | 
| 71 | 
            +
                    self.embedding_table = nn.Embedding(int(use_cfg_embedding), hidden_size)
         | 
| 72 | 
            +
                    self.style_in = weight_norm(nn.Linear(input_size, hidden_size, bias=True))
         | 
| 73 | 
            +
                    self.input_size = input_size
         | 
| 74 | 
            +
                    self.dropout_prob = dropout_prob
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                def forward(self, labels, train, force_drop_ids=None):
         | 
| 77 | 
            +
                    use_dropout = self.dropout_prob > 0
         | 
| 78 | 
            +
                    if (train and use_dropout) or (force_drop_ids is not None):
         | 
| 79 | 
            +
                        labels = self.token_drop(labels, force_drop_ids)
         | 
| 80 | 
            +
                    else:
         | 
| 81 | 
            +
                        labels = self.style_in(labels)
         | 
| 82 | 
            +
                    embeddings = labels
         | 
| 83 | 
            +
                    return embeddings
         | 
| 84 | 
            +
             | 
| 85 | 
            +
            class FinalLayer(nn.Module):
         | 
| 86 | 
            +
                """
         | 
| 87 | 
            +
                The final layer of DiT.
         | 
| 88 | 
            +
                """
         | 
| 89 | 
            +
                def __init__(self, hidden_size, patch_size, out_channels):
         | 
| 90 | 
            +
                    super().__init__()
         | 
| 91 | 
            +
                    self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
         | 
| 92 | 
            +
                    self.linear = weight_norm(nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True))
         | 
| 93 | 
            +
                    self.adaLN_modulation = nn.Sequential(
         | 
| 94 | 
            +
                        nn.SiLU(),
         | 
| 95 | 
            +
                        nn.Linear(hidden_size, 2 * hidden_size, bias=True)
         | 
| 96 | 
            +
                    )
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                def forward(self, x, c):
         | 
| 99 | 
            +
                    shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
         | 
| 100 | 
            +
                    x = modulate(self.norm_final(x), shift, scale)
         | 
| 101 | 
            +
                    x = self.linear(x)
         | 
| 102 | 
            +
                    return x
         | 
| 103 | 
            +
             | 
| 104 | 
            +
            class DiT(torch.nn.Module):
         | 
| 105 | 
            +
                def __init__(
         | 
| 106 | 
            +
                    self,
         | 
| 107 | 
            +
                    args
         | 
| 108 | 
            +
                ):
         | 
| 109 | 
            +
                    super(DiT, self).__init__()
         | 
| 110 | 
            +
                    self.time_as_token = args.DiT.time_as_token if hasattr(args.DiT, 'time_as_token') else False
         | 
| 111 | 
            +
                    self.style_as_token = args.DiT.style_as_token if hasattr(args.DiT, 'style_as_token') else False
         | 
| 112 | 
            +
                    self.uvit_skip_connection = args.DiT.uvit_skip_connection if hasattr(args.DiT, 'uvit_skip_connection') else False
         | 
| 113 | 
            +
                    model_args = ModelArgs(
         | 
| 114 | 
            +
                        block_size=16384,#args.DiT.block_size,
         | 
| 115 | 
            +
                        n_layer=args.DiT.depth,
         | 
| 116 | 
            +
                        n_head=args.DiT.num_heads,
         | 
| 117 | 
            +
                        dim=args.DiT.hidden_dim,
         | 
| 118 | 
            +
                        head_dim=args.DiT.hidden_dim // args.DiT.num_heads,
         | 
| 119 | 
            +
                        vocab_size=1024,
         | 
| 120 | 
            +
                        uvit_skip_connection=self.uvit_skip_connection,
         | 
| 121 | 
            +
                        time_as_token=self.time_as_token,
         | 
| 122 | 
            +
                    )
         | 
| 123 | 
            +
                    self.transformer = Transformer(model_args)
         | 
| 124 | 
            +
                    self.in_channels = args.DiT.in_channels
         | 
| 125 | 
            +
                    self.out_channels = args.DiT.in_channels
         | 
| 126 | 
            +
                    self.num_heads = args.DiT.num_heads
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    self.x_embedder = weight_norm(nn.Linear(args.DiT.in_channels, args.DiT.hidden_dim, bias=True))
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    self.content_type = args.DiT.content_type  # 'discrete' or 'continuous'
         | 
| 131 | 
            +
                    self.content_codebook_size = args.DiT.content_codebook_size # for discrete content
         | 
| 132 | 
            +
                    self.content_dim = args.DiT.content_dim # for continuous content
         | 
| 133 | 
            +
                    self.cond_embedder = nn.Embedding(args.DiT.content_codebook_size, args.DiT.hidden_dim)  # discrete content
         | 
| 134 | 
            +
                    self.cond_projection = nn.Linear(args.DiT.content_dim, args.DiT.hidden_dim, bias=True) # continuous content
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    self.is_causal = args.DiT.is_causal
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    self.t_embedder = TimestepEmbedder(args.DiT.hidden_dim)
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    # self.style_embedder1 = weight_norm(nn.Linear(1024, args.DiT.hidden_dim, bias=True))
         | 
| 141 | 
            +
                    # self.style_embedder2 = weight_norm(nn.Linear(1024, args.style_encoder.dim, bias=True))
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    input_pos = torch.arange(16384)
         | 
| 144 | 
            +
                    self.register_buffer("input_pos", input_pos)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    self.final_layer_type = args.DiT.final_layer_type  # mlp or wavenet
         | 
| 147 | 
            +
                    if self.final_layer_type == 'wavenet':
         | 
| 148 | 
            +
                        self.t_embedder2 = TimestepEmbedder(args.wavenet.hidden_dim)
         | 
| 149 | 
            +
                        self.conv1 = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim)
         | 
| 150 | 
            +
                        self.conv2 = nn.Conv1d(args.wavenet.hidden_dim, args.DiT.in_channels, 1)
         | 
| 151 | 
            +
                        self.wavenet = WN(hidden_channels=args.wavenet.hidden_dim,
         | 
| 152 | 
            +
                                          kernel_size=args.wavenet.kernel_size,
         | 
| 153 | 
            +
                                          dilation_rate=args.wavenet.dilation_rate,
         | 
| 154 | 
            +
                                          n_layers=args.wavenet.num_layers,
         | 
| 155 | 
            +
                                          gin_channels=args.wavenet.hidden_dim,
         | 
| 156 | 
            +
                                          p_dropout=args.wavenet.p_dropout,
         | 
| 157 | 
            +
                                          causal=False)
         | 
| 158 | 
            +
                        self.final_layer = FinalLayer(args.wavenet.hidden_dim, 1, args.wavenet.hidden_dim)
         | 
| 159 | 
            +
                        self.res_projection = nn.Linear(args.DiT.hidden_dim,
         | 
| 160 | 
            +
                                                        args.wavenet.hidden_dim)  # residual connection from tranformer output to final output
         | 
| 161 | 
            +
                        self.wavenet_style_condition = args.wavenet.style_condition
         | 
| 162 | 
            +
                        assert args.DiT.style_condition == args.wavenet.style_condition
         | 
| 163 | 
            +
                    else:
         | 
| 164 | 
            +
                        self.final_mlp = nn.Sequential(
         | 
| 165 | 
            +
                                nn.Linear(args.DiT.hidden_dim, args.DiT.hidden_dim),
         | 
| 166 | 
            +
                                nn.SiLU(),
         | 
| 167 | 
            +
                                nn.Linear(args.DiT.hidden_dim, args.DiT.in_channels),
         | 
| 168 | 
            +
                        )
         | 
| 169 | 
            +
                    self.transformer_style_condition = args.DiT.style_condition
         | 
| 170 | 
            +
             | 
| 171 | 
            +
             | 
| 172 | 
            +
                    self.class_dropout_prob = args.DiT.class_dropout_prob
         | 
| 173 | 
            +
                    self.content_mask_embedder = nn.Embedding(1, args.DiT.hidden_dim)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    self.long_skip_connection = args.DiT.long_skip_connection
         | 
| 176 | 
            +
                    self.skip_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels, args.DiT.hidden_dim)
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                    self.cond_x_merge_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels * 2 +
         | 
| 179 | 
            +
                                                         args.style_encoder.dim * self.transformer_style_condition * (not self.style_as_token),
         | 
| 180 | 
            +
                                                         args.DiT.hidden_dim)
         | 
| 181 | 
            +
                    if self.style_as_token:
         | 
| 182 | 
            +
                        self.style_in = nn.Linear(args.style_encoder.dim, args.DiT.hidden_dim)
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                def setup_caches(self, max_batch_size, max_seq_length):
         | 
| 185 | 
            +
                    self.transformer.setup_caches(max_batch_size, max_seq_length, use_kv_cache=False)
         | 
| 186 | 
            +
                    
         | 
| 187 | 
            +
                def forward(self, x, prompt_x, x_lens, t, style, cond, mask_content=False):
         | 
| 188 | 
            +
                    """
         | 
| 189 | 
            +
                        x (torch.Tensor): random noise
         | 
| 190 | 
            +
                        prompt_x (torch.Tensor): reference mel + zero mel
         | 
| 191 | 
            +
                            shape: (batch_size, 80, 795+1068)
         | 
| 192 | 
            +
                        x_lens (torch.Tensor): mel frames output
         | 
| 193 | 
            +
                            shape: (batch_size, mel_timesteps)
         | 
| 194 | 
            +
                        t (torch.Tensor): radshape: 
         | 
| 195 | 
            +
                            shape: (batch_size)    
         | 
| 196 | 
            +
                        style (torch.Tensor): reference global style
         | 
| 197 | 
            +
                            shape: (batch_size, 192)
         | 
| 198 | 
            +
                        cond (torch.Tensor): semantic info of reference audio and altered audio
         | 
| 199 | 
            +
                            shape: (batch_size, mel_timesteps(795+1069), 512)
         | 
| 200 | 
            +
                    
         | 
| 201 | 
            +
                    """
         | 
| 202 | 
            +
                    class_dropout = False
         | 
| 203 | 
            +
                    if self.training and torch.rand(1) < self.class_dropout_prob:
         | 
| 204 | 
            +
                        class_dropout = True
         | 
| 205 | 
            +
                    if not self.training and mask_content:
         | 
| 206 | 
            +
                        class_dropout = True
         | 
| 207 | 
            +
                    # cond_in_module = self.cond_embedder if self.content_type == 'discrete' else self.cond_projection
         | 
| 208 | 
            +
                    cond_in_module = self.cond_projection
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    B, _, T = x.size()
         | 
| 211 | 
            +
             | 
| 212 | 
            +
             | 
| 213 | 
            +
                    t1 = self.t_embedder(t)  # (N, D) # t1 [2, 512]
         | 
| 214 | 
            +
                    cond = cond_in_module(cond) # cond [2,1863,512]->[2,1863,512]
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    x = x.transpose(1, 2) # [2,1863,80]
         | 
| 217 | 
            +
                    prompt_x = prompt_x.transpose(1, 2) # [2,1863,80]
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                    x_in = torch.cat([x, prompt_x, cond], dim=-1) # 80+80+512=672 [2, 1863, 672]
         | 
| 220 | 
            +
                    
         | 
| 221 | 
            +
                    if self.transformer_style_condition and not self.style_as_token: # True and True
         | 
| 222 | 
            +
                        x_in = torch.cat([x_in, style[:, None, :].repeat(1, T, 1)], dim=-1) #[2, 1863, 864]
         | 
| 223 | 
            +
                        
         | 
| 224 | 
            +
                    if class_dropout: #False
         | 
| 225 | 
            +
                        x_in[..., self.in_channels:] = x_in[..., self.in_channels:] * 0 # 80维后全置为0
         | 
| 226 | 
            +
                        
         | 
| 227 | 
            +
                    x_in = self.cond_x_merge_linear(x_in)  # (N, T, D) [2, 1863, 512]
         | 
| 228 | 
            +
                    
         | 
| 229 | 
            +
                    if self.style_as_token: # False
         | 
| 230 | 
            +
                        style = self.style_in(style)
         | 
| 231 | 
            +
                        style = torch.zeros_like(style) if class_dropout else style
         | 
| 232 | 
            +
                        x_in = torch.cat([style.unsqueeze(1), x_in], dim=1)
         | 
| 233 | 
            +
                        
         | 
| 234 | 
            +
                    if self.time_as_token: # False
         | 
| 235 | 
            +
                        x_in = torch.cat([t1.unsqueeze(1), x_in], dim=1)
         | 
| 236 | 
            +
                        
         | 
| 237 | 
            +
                    x_mask = sequence_mask(x_lens + self.style_as_token + self.time_as_token).to(x.device).unsqueeze(1) #torch.Size([1, 1, 1863])True
         | 
| 238 | 
            +
                    input_pos = self.input_pos[:x_in.size(1)]  # (T,) range(0,1863)
         | 
| 239 | 
            +
                    x_mask_expanded = x_mask[:, None, :].repeat(1, 1, x_in.size(1), 1) if not self.is_causal else None # torch.Size([1, 1, 1863, 1863]
         | 
| 240 | 
            +
                    x_res = self.transformer(x_in, t1.unsqueeze(1), input_pos, x_mask_expanded) # [2, 1863, 512]
         | 
| 241 | 
            +
                    x_res = x_res[:, 1:] if self.time_as_token else x_res
         | 
| 242 | 
            +
                    x_res = x_res[:, 1:] if self.style_as_token else x_res
         | 
| 243 | 
            +
                    
         | 
| 244 | 
            +
                    if self.long_skip_connection: #True
         | 
| 245 | 
            +
                        x_res = self.skip_linear(torch.cat([x_res, x], dim=-1))
         | 
| 246 | 
            +
                    if self.final_layer_type == 'wavenet':
         | 
| 247 | 
            +
                        x = self.conv1(x_res)
         | 
| 248 | 
            +
                        x = x.transpose(1, 2)
         | 
| 249 | 
            +
                        t2 = self.t_embedder2(t)
         | 
| 250 | 
            +
                        x = self.wavenet(x, x_mask, g=t2.unsqueeze(2)).transpose(1, 2) + self.res_projection(
         | 
| 251 | 
            +
                            x_res)  # long residual connection
         | 
| 252 | 
            +
                        x = self.final_layer(x, t1).transpose(1, 2)
         | 
| 253 | 
            +
                        x = self.conv2(x)
         | 
| 254 | 
            +
                    else:
         | 
| 255 | 
            +
                        x = self.final_mlp(x_res)
         | 
| 256 | 
            +
                        x = x.transpose(1, 2)
         | 
| 257 | 
            +
                    # x [2,80,1863]
         | 
| 258 | 
            +
                    return x
         | 
    	
        indextts/s2mel/modules/.ipynb_checkpoints/flow_matching-checkpoint.py
    ADDED
    
    | @@ -0,0 +1,171 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from abc import ABC
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn.functional as F
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from modules.diffusion_transformer import DiT
         | 
| 7 | 
            +
            from modules.commons import sequence_mask
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from tqdm import tqdm
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            class BASECFM(torch.nn.Module, ABC):
         | 
| 12 | 
            +
                def __init__(
         | 
| 13 | 
            +
                    self,
         | 
| 14 | 
            +
                    args,
         | 
| 15 | 
            +
                ):
         | 
| 16 | 
            +
                    super().__init__()
         | 
| 17 | 
            +
                    self.sigma_min = 1e-6
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                    self.estimator = None
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                    self.in_channels = args.DiT.in_channels
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                    self.criterion = torch.nn.MSELoss() if args.reg_loss_type == "l2" else torch.nn.L1Loss()
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                    if hasattr(args.DiT, 'zero_prompt_speech_token'):
         | 
| 26 | 
            +
                        self.zero_prompt_speech_token = args.DiT.zero_prompt_speech_token
         | 
| 27 | 
            +
                    else:
         | 
| 28 | 
            +
                        self.zero_prompt_speech_token = False
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                @torch.inference_mode()
         | 
| 31 | 
            +
                def inference(self, mu, x_lens, prompt, style, f0, n_timesteps, temperature=1.0, inference_cfg_rate=0.5):
         | 
| 32 | 
            +
                    """Forward diffusion
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    Args:
         | 
| 35 | 
            +
                        mu (torch.Tensor): semantic info of reference audio and altered audio
         | 
| 36 | 
            +
                            shape: (batch_size, mel_timesteps(795+1069), 512)
         | 
| 37 | 
            +
                        x_lens (torch.Tensor): mel frames output
         | 
| 38 | 
            +
                            shape: (batch_size, mel_timesteps)
         | 
| 39 | 
            +
                        prompt (torch.Tensor): reference mel
         | 
| 40 | 
            +
                            shape: (batch_size, 80, 795)
         | 
| 41 | 
            +
                        style (torch.Tensor): reference global style
         | 
| 42 | 
            +
                            shape: (batch_size, 192)
         | 
| 43 | 
            +
                        f0: None
         | 
| 44 | 
            +
                        n_timesteps (int): number of diffusion steps
         | 
| 45 | 
            +
                        temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    Returns:
         | 
| 48 | 
            +
                        sample: generated mel-spectrogram
         | 
| 49 | 
            +
                            shape: (batch_size, 80, mel_timesteps)
         | 
| 50 | 
            +
                    """
         | 
| 51 | 
            +
                    B, T = mu.size(0), mu.size(1)
         | 
| 52 | 
            +
                    z = torch.randn([B, self.in_channels, T], device=mu.device) * temperature
         | 
| 53 | 
            +
                    t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
         | 
| 54 | 
            +
                    # t_span = t_span + (-1) * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span)
         | 
| 55 | 
            +
                    return self.solve_euler(z, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                def solve_euler(self, x, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate=0.5):
         | 
| 58 | 
            +
                    """
         | 
| 59 | 
            +
                    Fixed euler solver for ODEs.
         | 
| 60 | 
            +
                    Args:
         | 
| 61 | 
            +
                        x (torch.Tensor): random noise
         | 
| 62 | 
            +
                        t_span (torch.Tensor): n_timesteps interpolated
         | 
| 63 | 
            +
                            shape: (n_timesteps + 1,)
         | 
| 64 | 
            +
                        mu (torch.Tensor): semantic info of reference audio and altered audio
         | 
| 65 | 
            +
                            shape: (batch_size, mel_timesteps(795+1069), 512)
         | 
| 66 | 
            +
                        x_lens (torch.Tensor): mel frames output
         | 
| 67 | 
            +
                            shape: (batch_size, mel_timesteps)
         | 
| 68 | 
            +
                        prompt (torch.Tensor): reference mel
         | 
| 69 | 
            +
                            shape: (batch_size, 80, 795)
         | 
| 70 | 
            +
                        style (torch.Tensor): reference global style
         | 
| 71 | 
            +
                            shape: (batch_size, 192)
         | 
| 72 | 
            +
                    """
         | 
| 73 | 
            +
                    t, _, _ = t_span[0], t_span[-1], t_span[1] - t_span[0]
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    # I am storing this because I can later plot it by putting a debugger here and saving it to a file
         | 
| 76 | 
            +
                    # Or in future might add like a return_all_steps flag
         | 
| 77 | 
            +
                    sol = []
         | 
| 78 | 
            +
                    # apply prompt
         | 
| 79 | 
            +
                    prompt_len = prompt.size(-1)
         | 
| 80 | 
            +
                    prompt_x = torch.zeros_like(x)
         | 
| 81 | 
            +
                    prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
         | 
| 82 | 
            +
                    x[..., :prompt_len] = 0
         | 
| 83 | 
            +
                    if self.zero_prompt_speech_token:
         | 
| 84 | 
            +
                        mu[..., :prompt_len] = 0
         | 
| 85 | 
            +
                    for step in tqdm(range(1, len(t_span))):
         | 
| 86 | 
            +
                        dt = t_span[step] - t_span[step - 1]
         | 
| 87 | 
            +
                        if inference_cfg_rate > 0:
         | 
| 88 | 
            +
                            # Stack original and CFG (null) inputs for batched processing
         | 
| 89 | 
            +
                            stacked_prompt_x = torch.cat([prompt_x, torch.zeros_like(prompt_x)], dim=0)
         | 
| 90 | 
            +
                            stacked_style = torch.cat([style, torch.zeros_like(style)], dim=0)
         | 
| 91 | 
            +
                            stacked_mu = torch.cat([mu, torch.zeros_like(mu)], dim=0)
         | 
| 92 | 
            +
                            stacked_x = torch.cat([x, x], dim=0)
         | 
| 93 | 
            +
                            stacked_t = torch.cat([t.unsqueeze(0), t.unsqueeze(0)], dim=0)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                            # Perform a single forward pass for both original and CFG inputs
         | 
| 96 | 
            +
                            stacked_dphi_dt = self.estimator(
         | 
| 97 | 
            +
                                stacked_x, stacked_prompt_x, x_lens, stacked_t, stacked_style, stacked_mu,
         | 
| 98 | 
            +
                            )
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                            # Split the output back into the original and CFG components
         | 
| 101 | 
            +
                            dphi_dt, cfg_dphi_dt = stacked_dphi_dt.chunk(2, dim=0)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                            # Apply CFG formula
         | 
| 104 | 
            +
                            dphi_dt = (1.0 + inference_cfg_rate) * dphi_dt - inference_cfg_rate * cfg_dphi_dt
         | 
| 105 | 
            +
                        else:
         | 
| 106 | 
            +
                            dphi_dt = self.estimator(x, prompt_x, x_lens, t.unsqueeze(0), style, mu)
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                        x = x + dt * dphi_dt
         | 
| 109 | 
            +
                        t = t + dt
         | 
| 110 | 
            +
                        sol.append(x)
         | 
| 111 | 
            +
                        if step < len(t_span) - 1:
         | 
| 112 | 
            +
                            dt = t_span[step + 1] - t
         | 
| 113 | 
            +
                        x[:, :, :prompt_len] = 0
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    return sol[-1]
         | 
| 116 | 
            +
                def forward(self, x1, x_lens, prompt_lens, mu, style):
         | 
| 117 | 
            +
                    """Computes diffusion loss
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    Args:
         | 
| 120 | 
            +
                        mu (torch.Tensor): semantic info of reference audio and altered audio
         | 
| 121 | 
            +
                            shape: (batch_size, mel_timesteps(795+1069), 512)
         | 
| 122 | 
            +
                        x1: mel
         | 
| 123 | 
            +
                        x_lens (torch.Tensor): mel frames output
         | 
| 124 | 
            +
                            shape: (batch_size, mel_timesteps)
         | 
| 125 | 
            +
                        prompt (torch.Tensor): reference mel
         | 
| 126 | 
            +
                            shape: (batch_size, 80, 795)
         | 
| 127 | 
            +
                        style (torch.Tensor): reference global style
         | 
| 128 | 
            +
                            shape: (batch_size, 192)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    Returns:
         | 
| 131 | 
            +
                        loss: conditional flow matching loss
         | 
| 132 | 
            +
                        y: conditional flow
         | 
| 133 | 
            +
                            shape: (batch_size, n_feats, mel_timesteps)
         | 
| 134 | 
            +
                    """
         | 
| 135 | 
            +
                    b, _, t = x1.shape
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    # random timestep
         | 
| 138 | 
            +
                    t = torch.rand([b, 1, 1], device=mu.device, dtype=x1.dtype)
         | 
| 139 | 
            +
                    # sample noise p(x_0)
         | 
| 140 | 
            +
                    z = torch.randn_like(x1)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    y = (1 - (1 - self.sigma_min) * t) * z + t * x1
         | 
| 143 | 
            +
                    u = x1 - (1 - self.sigma_min) * z
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    prompt = torch.zeros_like(x1)
         | 
| 146 | 
            +
                    for bib in range(b):
         | 
| 147 | 
            +
                        prompt[bib, :, :prompt_lens[bib]] = x1[bib, :, :prompt_lens[bib]]
         | 
| 148 | 
            +
                        # range covered by prompt are set to 0
         | 
| 149 | 
            +
                        y[bib, :, :prompt_lens[bib]] = 0
         | 
| 150 | 
            +
                        if self.zero_prompt_speech_token:
         | 
| 151 | 
            +
                            mu[bib, :, :prompt_lens[bib]] = 0
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    estimator_out = self.estimator(y, prompt, x_lens, t.squeeze(1).squeeze(1), style, mu, prompt_lens)
         | 
| 154 | 
            +
                    loss = 0
         | 
| 155 | 
            +
                    for bib in range(b):
         | 
| 156 | 
            +
                        loss += self.criterion(estimator_out[bib, :, prompt_lens[bib]:x_lens[bib]], u[bib, :, prompt_lens[bib]:x_lens[bib]])
         | 
| 157 | 
            +
                    loss /= b
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    return loss, estimator_out + (1 - self.sigma_min) * z
         | 
| 160 | 
            +
             | 
| 161 | 
            +
             | 
| 162 | 
            +
             | 
| 163 | 
            +
            class CFM(BASECFM):
         | 
| 164 | 
            +
                def __init__(self, args):
         | 
| 165 | 
            +
                    super().__init__(
         | 
| 166 | 
            +
                        args
         | 
| 167 | 
            +
                    )
         | 
| 168 | 
            +
                    if args.dit_type == "DiT":
         | 
| 169 | 
            +
                        self.estimator = DiT(args)
         | 
| 170 | 
            +
                    else:
         | 
| 171 | 
            +
                        raise NotImplementedError(f"Unknown diffusion type {args.dit_type}")
         | 
    	
        indextts/s2mel/modules/.ipynb_checkpoints/length_regulator-checkpoint.py
    ADDED
    
    | @@ -0,0 +1,141 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Tuple
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import torch.nn as nn
         | 
| 4 | 
            +
            from torch.nn import functional as F
         | 
| 5 | 
            +
            from modules.commons import sequence_mask
         | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            from dac.nn.quantize import VectorQuantize
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            # f0_bin = 256
         | 
| 10 | 
            +
            f0_max = 1100.0
         | 
| 11 | 
            +
            f0_min = 50.0
         | 
| 12 | 
            +
            f0_mel_min = 1127 * np.log(1 + f0_min / 700)
         | 
| 13 | 
            +
            f0_mel_max = 1127 * np.log(1 + f0_max / 700)
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            def f0_to_coarse(f0, f0_bin):
         | 
| 16 | 
            +
              f0_mel = 1127 * (1 + f0 / 700).log()
         | 
| 17 | 
            +
              a = (f0_bin - 2) / (f0_mel_max - f0_mel_min)
         | 
| 18 | 
            +
              b = f0_mel_min * a - 1.
         | 
| 19 | 
            +
              f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel)
         | 
| 20 | 
            +
              # torch.clip_(f0_mel, min=1., max=float(f0_bin - 1))
         | 
| 21 | 
            +
              f0_coarse = torch.round(f0_mel).long()
         | 
| 22 | 
            +
              f0_coarse = f0_coarse * (f0_coarse > 0)
         | 
| 23 | 
            +
              f0_coarse = f0_coarse + ((f0_coarse < 1) * 1)
         | 
| 24 | 
            +
              f0_coarse = f0_coarse * (f0_coarse < f0_bin)
         | 
| 25 | 
            +
              f0_coarse = f0_coarse + ((f0_coarse >= f0_bin) * (f0_bin - 1))
         | 
| 26 | 
            +
              return f0_coarse
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            class InterpolateRegulator(nn.Module):
         | 
| 29 | 
            +
                def __init__(
         | 
| 30 | 
            +
                        self,
         | 
| 31 | 
            +
                        channels: int,
         | 
| 32 | 
            +
                        sampling_ratios: Tuple,
         | 
| 33 | 
            +
                        is_discrete: bool = False,
         | 
| 34 | 
            +
                        in_channels: int = None,  # only applies to continuous input
         | 
| 35 | 
            +
                        vector_quantize: bool = False,  # whether to use vector quantization, only applies to continuous input
         | 
| 36 | 
            +
                        codebook_size: int = 1024, # for discrete only
         | 
| 37 | 
            +
                        out_channels: int = None,
         | 
| 38 | 
            +
                        groups: int = 1,
         | 
| 39 | 
            +
                        n_codebooks: int = 1,  # number of codebooks
         | 
| 40 | 
            +
                        quantizer_dropout: float = 0.0,  # dropout for quantizer
         | 
| 41 | 
            +
                        f0_condition: bool = False,
         | 
| 42 | 
            +
                        n_f0_bins: int = 512,
         | 
| 43 | 
            +
                ):
         | 
| 44 | 
            +
                    super().__init__()
         | 
| 45 | 
            +
                    self.sampling_ratios = sampling_ratios
         | 
| 46 | 
            +
                    out_channels = out_channels or channels
         | 
| 47 | 
            +
                    model = nn.ModuleList([])
         | 
| 48 | 
            +
                    if len(sampling_ratios) > 0:
         | 
| 49 | 
            +
                        self.interpolate = True
         | 
| 50 | 
            +
                        for _ in sampling_ratios:
         | 
| 51 | 
            +
                            module = nn.Conv1d(channels, channels, 3, 1, 1)
         | 
| 52 | 
            +
                            norm = nn.GroupNorm(groups, channels)
         | 
| 53 | 
            +
                            act = nn.Mish()
         | 
| 54 | 
            +
                            model.extend([module, norm, act])
         | 
| 55 | 
            +
                    else:
         | 
| 56 | 
            +
                        self.interpolate = False
         | 
| 57 | 
            +
                    model.append(
         | 
| 58 | 
            +
                        nn.Conv1d(channels, out_channels, 1, 1)
         | 
| 59 | 
            +
                    )
         | 
| 60 | 
            +
                    self.model = nn.Sequential(*model)
         | 
| 61 | 
            +
                    self.embedding = nn.Embedding(codebook_size, channels)
         | 
| 62 | 
            +
                    self.is_discrete = is_discrete
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    self.mask_token = nn.Parameter(torch.zeros(1, channels))
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    self.n_codebooks = n_codebooks
         | 
| 67 | 
            +
                    if n_codebooks > 1:
         | 
| 68 | 
            +
                        self.extra_codebooks = nn.ModuleList([
         | 
| 69 | 
            +
                            nn.Embedding(codebook_size, channels) for _ in range(n_codebooks - 1)
         | 
| 70 | 
            +
                        ])
         | 
| 71 | 
            +
                        self.extra_codebook_mask_tokens = nn.ParameterList([
         | 
| 72 | 
            +
                            nn.Parameter(torch.zeros(1, channels)) for _ in range(n_codebooks - 1)
         | 
| 73 | 
            +
                        ])
         | 
| 74 | 
            +
                    self.quantizer_dropout = quantizer_dropout
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    if f0_condition:
         | 
| 77 | 
            +
                        self.f0_embedding = nn.Embedding(n_f0_bins, channels)
         | 
| 78 | 
            +
                        self.f0_condition = f0_condition
         | 
| 79 | 
            +
                        self.n_f0_bins = n_f0_bins
         | 
| 80 | 
            +
                        self.f0_bins = torch.arange(2, 1024, 1024 // n_f0_bins)
         | 
| 81 | 
            +
                        self.f0_mask = nn.Parameter(torch.zeros(1, channels))
         | 
| 82 | 
            +
                    else:
         | 
| 83 | 
            +
                        self.f0_condition = False
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    if not is_discrete:
         | 
| 86 | 
            +
                        self.content_in_proj = nn.Linear(in_channels, channels)
         | 
| 87 | 
            +
                        if vector_quantize:
         | 
| 88 | 
            +
                            self.vq = VectorQuantize(channels, codebook_size, 8)
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                def forward(self, x, ylens=None, n_quantizers=None, f0=None):
         | 
| 91 | 
            +
                    # apply token drop
         | 
| 92 | 
            +
                    if self.training:
         | 
| 93 | 
            +
                        n_quantizers = torch.ones((x.shape[0],)) * self.n_codebooks
         | 
| 94 | 
            +
                        dropout = torch.randint(1, self.n_codebooks + 1, (x.shape[0],))
         | 
| 95 | 
            +
                        n_dropout = int(x.shape[0] * self.quantizer_dropout)
         | 
| 96 | 
            +
                        n_quantizers[:n_dropout] = dropout[:n_dropout]
         | 
| 97 | 
            +
                        n_quantizers = n_quantizers.to(x.device)
         | 
| 98 | 
            +
                        # decide whether to drop for each sample in batch
         | 
| 99 | 
            +
                    else:
         | 
| 100 | 
            +
                        n_quantizers = torch.ones((x.shape[0],), device=x.device) * (self.n_codebooks if n_quantizers is None else n_quantizers)
         | 
| 101 | 
            +
                    if self.is_discrete:
         | 
| 102 | 
            +
                        if self.n_codebooks > 1:
         | 
| 103 | 
            +
                            assert len(x.size()) == 3
         | 
| 104 | 
            +
                            x_emb = self.embedding(x[:, 0])
         | 
| 105 | 
            +
                            for i, emb in enumerate(self.extra_codebooks):
         | 
| 106 | 
            +
                                x_emb = x_emb + (n_quantizers > i+1)[..., None, None] * emb(x[:, i+1])
         | 
| 107 | 
            +
                                # add mask token if not using this codebook
         | 
| 108 | 
            +
                                # x_emb = x_emb + (n_quantizers <= i+1)[..., None, None] * self.extra_codebook_mask_tokens[i]
         | 
| 109 | 
            +
                            x = x_emb
         | 
| 110 | 
            +
                        elif self.n_codebooks == 1:
         | 
| 111 | 
            +
                            if len(x.size()) == 2:
         | 
| 112 | 
            +
                                x = self.embedding(x)
         | 
| 113 | 
            +
                            else:
         | 
| 114 | 
            +
                                x = self.embedding(x[:, 0])
         | 
| 115 | 
            +
                    else:
         | 
| 116 | 
            +
                        x = self.content_in_proj(x)
         | 
| 117 | 
            +
                    # x in (B, T, D)
         | 
| 118 | 
            +
                    mask = sequence_mask(ylens).unsqueeze(-1)
         | 
| 119 | 
            +
                    if self.interpolate:
         | 
| 120 | 
            +
                        x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
         | 
| 121 | 
            +
                    else:
         | 
| 122 | 
            +
                        x = x.transpose(1, 2).contiguous()
         | 
| 123 | 
            +
                        mask = mask[:, :x.size(2), :]
         | 
| 124 | 
            +
                        ylens = ylens.clamp(max=x.size(2)).long()
         | 
| 125 | 
            +
                    if self.f0_condition:
         | 
| 126 | 
            +
                        if f0 is None:
         | 
| 127 | 
            +
                            x = x + self.f0_mask.unsqueeze(-1)
         | 
| 128 | 
            +
                        else:
         | 
| 129 | 
            +
                            #quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device))  # (N, T)
         | 
| 130 | 
            +
                            quantized_f0 = f0_to_coarse(f0, self.n_f0_bins)
         | 
| 131 | 
            +
                            quantized_f0 = quantized_f0.clamp(0, self.n_f0_bins - 1).long()
         | 
| 132 | 
            +
                            f0_emb = self.f0_embedding(quantized_f0)
         | 
| 133 | 
            +
                            f0_emb = F.interpolate(f0_emb.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
         | 
| 134 | 
            +
                            x = x + f0_emb
         | 
| 135 | 
            +
                    out = self.model(x).transpose(1, 2).contiguous()
         | 
| 136 | 
            +
                    if hasattr(self, 'vq'):
         | 
| 137 | 
            +
                        out_q, commitment_loss, codebook_loss, codes, out,  = self.vq(out.transpose(1, 2))
         | 
| 138 | 
            +
                        out_q = out_q.transpose(1, 2)
         | 
| 139 | 
            +
                        return out_q * mask, ylens, codes, commitment_loss, codebook_loss
         | 
| 140 | 
            +
                    olens = ylens
         | 
| 141 | 
            +
                    return out * mask, olens, None, None, None
         | 
    	
        webui.py
    CHANGED
    
    | @@ -38,7 +38,9 @@ from modelscope.hub import api | |
| 38 |  | 
| 39 | 
             
            i18n = I18nAuto(language="Auto")
         | 
| 40 | 
             
            MODE = 'local'
         | 
| 41 | 
            -
            tts = IndexTTS2(model_dir=cmd_args.model_dir, | 
|  | |
|  | |
| 42 |  | 
| 43 | 
             
            # 支持的语言列表
         | 
| 44 | 
             
            LANGUAGES = {
         | 
|  | |
| 38 |  | 
| 39 | 
             
            i18n = I18nAuto(language="Auto")
         | 
| 40 | 
             
            MODE = 'local'
         | 
| 41 | 
            +
            tts = IndexTTS2(model_dir=cmd_args.model_dir,
         | 
| 42 | 
            +
                            cfg_path=os.path.join(cmd_args.model_dir, "config.yaml"),
         | 
| 43 | 
            +
                            is_fp16=False,use_cuda_kernel=False)
         | 
| 44 |  | 
| 45 | 
             
            # 支持的语言列表
         | 
| 46 | 
             
            LANGUAGES = {
         | 
