Update modules/length_regulator.py
Browse files- modules/length_regulator.py +141 -118
 
    	
        modules/length_regulator.py
    CHANGED
    
    | 
         @@ -1,118 +1,141 @@ 
     | 
|
| 1 | 
         
            -
            from typing import Tuple
         
     | 
| 2 | 
         
            -
            import torch
         
     | 
| 3 | 
         
            -
            import torch.nn as nn
         
     | 
| 4 | 
         
            -
            from torch.nn import functional as F
         
     | 
| 5 | 
         
            -
            from modules.commons import sequence_mask
         
     | 
| 6 | 
         
            -
            import numpy as np
         
     | 
| 7 | 
         
            -
             
     | 
| 8 | 
         
            -
             
     | 
| 9 | 
         
            -
             
     | 
| 10 | 
         
            -
             
     | 
| 11 | 
         
            -
             
     | 
| 12 | 
         
            -
             
     | 
| 13 | 
         
            -
             
     | 
| 14 | 
         
            -
             
     | 
| 15 | 
         
            -
             
     | 
| 16 | 
         
            -
               
     | 
| 17 | 
         
            -
               
     | 
| 18 | 
         
            -
               
     | 
| 19 | 
         
            -
               
     | 
| 20 | 
         
            -
               
     | 
| 21 | 
         
            -
              f0_coarse =  
     | 
| 22 | 
         
            -
              f0_coarse = f0_coarse  
     | 
| 23 | 
         
            -
              f0_coarse = f0_coarse  
     | 
| 24 | 
         
            -
              f0_coarse = f0_coarse  
     | 
| 25 | 
         
            -
               
     | 
| 26 | 
         
            -
             
     | 
| 27 | 
         
            -
             
     | 
| 28 | 
         
            -
             
     | 
| 29 | 
         
            -
             
     | 
| 30 | 
         
            -
                         
     | 
| 31 | 
         
            -
                         
     | 
| 32 | 
         
            -
                         
     | 
| 33 | 
         
            -
                         
     | 
| 34 | 
         
            -
                         
     | 
| 35 | 
         
            -
                         
     | 
| 36 | 
         
            -
                         
     | 
| 37 | 
         
            -
                         
     | 
| 38 | 
         
            -
                         
     | 
| 39 | 
         
            -
                         
     | 
| 40 | 
         
            -
                         
     | 
| 41 | 
         
            -
                         
     | 
| 42 | 
         
            -
             
     | 
| 43 | 
         
            -
             
     | 
| 44 | 
         
            -
                     
     | 
| 45 | 
         
            -
                     
     | 
| 46 | 
         
            -
                     
     | 
| 47 | 
         
            -
                     
     | 
| 48 | 
         
            -
             
     | 
| 49 | 
         
            -
             
     | 
| 50 | 
         
            -
             
     | 
| 51 | 
         
            -
                             
     | 
| 52 | 
         
            -
                             
     | 
| 53 | 
         
            -
             
     | 
| 54 | 
         
            -
             
     | 
| 55 | 
         
            -
                     
     | 
| 56 | 
         
            -
             
     | 
| 57 | 
         
            -
                     
     | 
| 58 | 
         
            -
             
     | 
| 59 | 
         
            -
             
     | 
| 60 | 
         
            -
                    self. 
     | 
| 61 | 
         
            -
             
     | 
| 62 | 
         
            -
                    self. 
     | 
| 63 | 
         
            -
             
     | 
| 64 | 
         
            -
             
     | 
| 65 | 
         
            -
             
     | 
| 66 | 
         
            -
             
     | 
| 67 | 
         
            -
                     
     | 
| 68 | 
         
            -
             
     | 
| 69 | 
         
            -
             
     | 
| 70 | 
         
            -
             
     | 
| 71 | 
         
            -
             
     | 
| 72 | 
         
            -
             
     | 
| 73 | 
         
            -
                         
     | 
| 74 | 
         
            -
             
     | 
| 75 | 
         
            -
             
     | 
| 76 | 
         
            -
             
     | 
| 77 | 
         
            -
             
     | 
| 78 | 
         
            -
                        self.f0_condition =  
     | 
| 79 | 
         
            -
             
     | 
| 80 | 
         
            -
             
     | 
| 81 | 
         
            -
             
     | 
| 82 | 
         
            -
                     
     | 
| 83 | 
         
            -
                         
     | 
| 84 | 
         
            -
             
     | 
| 85 | 
         
            -
             
     | 
| 86 | 
         
            -
                         
     | 
| 87 | 
         
            -
                         
     | 
| 88 | 
         
            -
             
     | 
| 89 | 
         
            -
             
     | 
| 90 | 
         
            -
             
     | 
| 91 | 
         
            -
                     
     | 
| 92 | 
         
            -
             
     | 
| 93 | 
         
            -
             
     | 
| 94 | 
         
            -
             
     | 
| 95 | 
         
            -
             
     | 
| 96 | 
         
            -
             
     | 
| 97 | 
         
            -
             
     | 
| 98 | 
         
            -
                         
     | 
| 99 | 
         
            -
             
     | 
| 100 | 
         
            -
             
     | 
| 101 | 
         
            -
             
     | 
| 102 | 
         
            -
             
     | 
| 103 | 
         
            -
             
     | 
| 104 | 
         
            -
             
     | 
| 105 | 
         
            -
             
     | 
| 106 | 
         
            -
             
     | 
| 107 | 
         
            -
             
     | 
| 108 | 
         
            -
             
     | 
| 109 | 
         
            -
             
     | 
| 110 | 
         
            -
             
     | 
| 111 | 
         
            -
                             
     | 
| 112 | 
         
            -
             
     | 
| 113 | 
         
            -
                             
     | 
| 114 | 
         
            -
             
     | 
| 115 | 
         
            -
             
     | 
| 116 | 
         
            -
             
     | 
| 117 | 
         
            -
                     
     | 
| 118 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Tuple
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            import torch.nn as nn
         
     | 
| 4 | 
         
            +
            from torch.nn import functional as F
         
     | 
| 5 | 
         
            +
            from modules.commons import sequence_mask
         
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
            +
            from dac.nn.quantize import VectorQuantize
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            # f0_bin = 256
         
     | 
| 10 | 
         
            +
            f0_max = 1100.0
         
     | 
| 11 | 
         
            +
            f0_min = 50.0
         
     | 
| 12 | 
         
            +
            f0_mel_min = 1127 * np.log(1 + f0_min / 700)
         
     | 
| 13 | 
         
            +
            f0_mel_max = 1127 * np.log(1 + f0_max / 700)
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            def f0_to_coarse(f0, f0_bin):
         
     | 
| 16 | 
         
            +
              f0_mel = 1127 * (1 + f0 / 700).log()
         
     | 
| 17 | 
         
            +
              a = (f0_bin - 2) / (f0_mel_max - f0_mel_min)
         
     | 
| 18 | 
         
            +
              b = f0_mel_min * a - 1.
         
     | 
| 19 | 
         
            +
              f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel)
         
     | 
| 20 | 
         
            +
              # torch.clip_(f0_mel, min=1., max=float(f0_bin - 1))
         
     | 
| 21 | 
         
            +
              f0_coarse = torch.round(f0_mel).long()
         
     | 
| 22 | 
         
            +
              f0_coarse = f0_coarse * (f0_coarse > 0)
         
     | 
| 23 | 
         
            +
              f0_coarse = f0_coarse + ((f0_coarse < 1) * 1)
         
     | 
| 24 | 
         
            +
              f0_coarse = f0_coarse * (f0_coarse < f0_bin)
         
     | 
| 25 | 
         
            +
              f0_coarse = f0_coarse + ((f0_coarse >= f0_bin) * (f0_bin - 1))
         
     | 
| 26 | 
         
            +
              return f0_coarse
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            class InterpolateRegulator(nn.Module):
         
     | 
| 29 | 
         
            +
                def __init__(
         
     | 
| 30 | 
         
            +
                        self,
         
     | 
| 31 | 
         
            +
                        channels: int,
         
     | 
| 32 | 
         
            +
                        sampling_ratios: Tuple,
         
     | 
| 33 | 
         
            +
                        is_discrete: bool = False,
         
     | 
| 34 | 
         
            +
                        in_channels: int = None,  # only applies to continuous input
         
     | 
| 35 | 
         
            +
                        vector_quantize: bool = False,  # whether to use vector quantization, only applies to continuous input
         
     | 
| 36 | 
         
            +
                        codebook_size: int = 1024, # for discrete only
         
     | 
| 37 | 
         
            +
                        out_channels: int = None,
         
     | 
| 38 | 
         
            +
                        groups: int = 1,
         
     | 
| 39 | 
         
            +
                        n_codebooks: int = 1,  # number of codebooks
         
     | 
| 40 | 
         
            +
                        quantizer_dropout: float = 0.0,  # dropout for quantizer
         
     | 
| 41 | 
         
            +
                        f0_condition: bool = False,
         
     | 
| 42 | 
         
            +
                        n_f0_bins: int = 512,
         
     | 
| 43 | 
         
            +
                ):
         
     | 
| 44 | 
         
            +
                    super().__init__()
         
     | 
| 45 | 
         
            +
                    self.sampling_ratios = sampling_ratios
         
     | 
| 46 | 
         
            +
                    out_channels = out_channels or channels
         
     | 
| 47 | 
         
            +
                    model = nn.ModuleList([])
         
     | 
| 48 | 
         
            +
                    if len(sampling_ratios) > 0:
         
     | 
| 49 | 
         
            +
                        self.interpolate = True
         
     | 
| 50 | 
         
            +
                        for _ in sampling_ratios:
         
     | 
| 51 | 
         
            +
                            module = nn.Conv1d(channels, channels, 3, 1, 1)
         
     | 
| 52 | 
         
            +
                            norm = nn.GroupNorm(groups, channels)
         
     | 
| 53 | 
         
            +
                            act = nn.Mish()
         
     | 
| 54 | 
         
            +
                            model.extend([module, norm, act])
         
     | 
| 55 | 
         
            +
                    else:
         
     | 
| 56 | 
         
            +
                        self.interpolate = False
         
     | 
| 57 | 
         
            +
                    model.append(
         
     | 
| 58 | 
         
            +
                        nn.Conv1d(channels, out_channels, 1, 1)
         
     | 
| 59 | 
         
            +
                    )
         
     | 
| 60 | 
         
            +
                    self.model = nn.Sequential(*model)
         
     | 
| 61 | 
         
            +
                    self.embedding = nn.Embedding(codebook_size, channels)
         
     | 
| 62 | 
         
            +
                    self.is_discrete = is_discrete
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                    self.mask_token = nn.Parameter(torch.zeros(1, channels))
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                    self.n_codebooks = n_codebooks
         
     | 
| 67 | 
         
            +
                    if n_codebooks > 1:
         
     | 
| 68 | 
         
            +
                        self.extra_codebooks = nn.ModuleList([
         
     | 
| 69 | 
         
            +
                            nn.Embedding(codebook_size, channels) for _ in range(n_codebooks - 1)
         
     | 
| 70 | 
         
            +
                        ])
         
     | 
| 71 | 
         
            +
                        self.extra_codebook_mask_tokens = nn.ParameterList([
         
     | 
| 72 | 
         
            +
                            nn.Parameter(torch.zeros(1, channels)) for _ in range(n_codebooks - 1)
         
     | 
| 73 | 
         
            +
                        ])
         
     | 
| 74 | 
         
            +
                    self.quantizer_dropout = quantizer_dropout
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    if f0_condition:
         
     | 
| 77 | 
         
            +
                        self.f0_embedding = nn.Embedding(n_f0_bins, channels)
         
     | 
| 78 | 
         
            +
                        self.f0_condition = f0_condition
         
     | 
| 79 | 
         
            +
                        self.n_f0_bins = n_f0_bins
         
     | 
| 80 | 
         
            +
                        self.f0_bins = torch.arange(2, 1024, 1024 // n_f0_bins)
         
     | 
| 81 | 
         
            +
                        self.f0_mask = nn.Parameter(torch.zeros(1, channels))
         
     | 
| 82 | 
         
            +
                    else:
         
     | 
| 83 | 
         
            +
                        self.f0_condition = False
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                    if not is_discrete:
         
     | 
| 86 | 
         
            +
                        self.content_in_proj = nn.Linear(in_channels, channels)
         
     | 
| 87 | 
         
            +
                        if vector_quantize:
         
     | 
| 88 | 
         
            +
                            self.vq = VectorQuantize(channels, codebook_size, 8)
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                def forward(self, x, ylens=None, n_quantizers=None, f0=None):
         
     | 
| 91 | 
         
            +
                    # apply token drop
         
     | 
| 92 | 
         
            +
                    if self.training:
         
     | 
| 93 | 
         
            +
                        n_quantizers = torch.ones((x.shape[0],)) * self.n_codebooks
         
     | 
| 94 | 
         
            +
                        dropout = torch.randint(1, self.n_codebooks + 1, (x.shape[0],))
         
     | 
| 95 | 
         
            +
                        n_dropout = int(x.shape[0] * self.quantizer_dropout)
         
     | 
| 96 | 
         
            +
                        n_quantizers[:n_dropout] = dropout[:n_dropout]
         
     | 
| 97 | 
         
            +
                        n_quantizers = n_quantizers.to(x.device)
         
     | 
| 98 | 
         
            +
                        # decide whether to drop for each sample in batch
         
     | 
| 99 | 
         
            +
                    else:
         
     | 
| 100 | 
         
            +
                        n_quantizers = torch.ones((x.shape[0],), device=x.device) * (self.n_codebooks if n_quantizers is None else n_quantizers)
         
     | 
| 101 | 
         
            +
                    if self.is_discrete:
         
     | 
| 102 | 
         
            +
                        if self.n_codebooks > 1:
         
     | 
| 103 | 
         
            +
                            assert len(x.size()) == 3
         
     | 
| 104 | 
         
            +
                            x_emb = self.embedding(x[:, 0])
         
     | 
| 105 | 
         
            +
                            for i, emb in enumerate(self.extra_codebooks):
         
     | 
| 106 | 
         
            +
                                x_emb = x_emb + (n_quantizers > i+1)[..., None, None] * emb(x[:, i+1])
         
     | 
| 107 | 
         
            +
                                # add mask token if not using this codebook
         
     | 
| 108 | 
         
            +
                                # x_emb = x_emb + (n_quantizers <= i+1)[..., None, None] * self.extra_codebook_mask_tokens[i]
         
     | 
| 109 | 
         
            +
                            x = x_emb
         
     | 
| 110 | 
         
            +
                        elif self.n_codebooks == 1:
         
     | 
| 111 | 
         
            +
                            if len(x.size()) == 2:
         
     | 
| 112 | 
         
            +
                                x = self.embedding(x)
         
     | 
| 113 | 
         
            +
                            else:
         
     | 
| 114 | 
         
            +
                                x = self.embedding(x[:, 0])
         
     | 
| 115 | 
         
            +
                    else:
         
     | 
| 116 | 
         
            +
                        x = self.content_in_proj(x)
         
     | 
| 117 | 
         
            +
                    # x in (B, T, D)
         
     | 
| 118 | 
         
            +
                    mask = sequence_mask(ylens).unsqueeze(-1)
         
     | 
| 119 | 
         
            +
                    if self.interpolate:
         
     | 
| 120 | 
         
            +
                        x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
         
     | 
| 121 | 
         
            +
                    else:
         
     | 
| 122 | 
         
            +
                        x = x.transpose(1, 2).contiguous()
         
     | 
| 123 | 
         
            +
                        mask = mask[:, :x.size(2), :]
         
     | 
| 124 | 
         
            +
                        ylens = ylens.clamp(max=x.size(2)).long()
         
     | 
| 125 | 
         
            +
                    if self.f0_condition:
         
     | 
| 126 | 
         
            +
                        if f0 is None:
         
     | 
| 127 | 
         
            +
                            x = x + self.f0_mask.unsqueeze(-1)
         
     | 
| 128 | 
         
            +
                        else:
         
     | 
| 129 | 
         
            +
                            quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device))  # (N, T)
         
     | 
| 130 | 
         
            +
                            #quantized_f0 = f0_to_coarse(f0, self.n_f0_bins)
         
     | 
| 131 | 
         
            +
                            #quantized_f0 = quantized_f0.clamp(0, self.n_f0_bins - 1).long()
         
     | 
| 132 | 
         
            +
                            f0_emb = self.f0_embedding(quantized_f0)
         
     | 
| 133 | 
         
            +
                            f0_emb = F.interpolate(f0_emb.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
         
     | 
| 134 | 
         
            +
                            x = x + f0_emb
         
     | 
| 135 | 
         
            +
                    out = self.model(x).transpose(1, 2).contiguous()
         
     | 
| 136 | 
         
            +
                    if hasattr(self, 'vq'):
         
     | 
| 137 | 
         
            +
                        out_q, commitment_loss, codebook_loss, codes, out,  = self.vq(out.transpose(1, 2))
         
     | 
| 138 | 
         
            +
                        out_q = out_q.transpose(1, 2)
         
     | 
| 139 | 
         
            +
                        return out_q * mask, ylens, codes, commitment_loss, codebook_loss
         
     | 
| 140 | 
         
            +
                    olens = ylens
         
     | 
| 141 | 
         
            +
                    return out * mask, olens, None, None, None
         
     |