Update discriminator.py
Browse files- discriminator.py +18 -244
    	
        discriminator.py
    CHANGED
    
    | @@ -1,231 +1,3 @@ | |
| 1 | 
            -
            # import torch
         | 
| 2 | 
            -
            # import torch.nn as nn
         | 
| 3 | 
            -
            # import torch.nn.functional as F
         | 
| 4 | 
            -
            # from audiotools import AudioSignal
         | 
| 5 | 
            -
            # from audiotools import ml
         | 
| 6 | 
            -
            # from audiotools import STFTParams
         | 
| 7 | 
            -
            # from einops import rearrange
         | 
| 8 | 
            -
            # from torch.nn.utils import weight_norm
         | 
| 9 | 
            -
             | 
| 10 | 
            -
             | 
| 11 | 
            -
            # def WNConv1d(*args, **kwargs):
         | 
| 12 | 
            -
            #     act = kwargs.pop("act", True)
         | 
| 13 | 
            -
            #     conv = weight_norm(nn.Conv1d(*args, **kwargs))
         | 
| 14 | 
            -
            #     if not act:
         | 
| 15 | 
            -
            #         return conv
         | 
| 16 | 
            -
            #     return nn.Sequential(conv, nn.LeakyReLU(0.1))
         | 
| 17 | 
            -
             | 
| 18 | 
            -
             | 
| 19 | 
            -
            # def WNConv2d(*args, **kwargs):
         | 
| 20 | 
            -
            #     act = kwargs.pop("act", True)
         | 
| 21 | 
            -
            #     conv = weight_norm(nn.Conv2d(*args, **kwargs))
         | 
| 22 | 
            -
            #     if not act:
         | 
| 23 | 
            -
            #         return conv
         | 
| 24 | 
            -
            #     return nn.Sequential(conv, nn.LeakyReLU(0.1))
         | 
| 25 | 
            -
             | 
| 26 | 
            -
             | 
| 27 | 
            -
            # class MPD(nn.Module):
         | 
| 28 | 
            -
            #     def __init__(self, period):
         | 
| 29 | 
            -
            #         super().__init__()
         | 
| 30 | 
            -
            #         self.period = period
         | 
| 31 | 
            -
            #         self.convs = nn.ModuleList(
         | 
| 32 | 
            -
            #             [
         | 
| 33 | 
            -
            #                 WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
         | 
| 34 | 
            -
            #                 WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
         | 
| 35 | 
            -
            #                 WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
         | 
| 36 | 
            -
            #                 WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
         | 
| 37 | 
            -
            #                 WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
         | 
| 38 | 
            -
            #             ]
         | 
| 39 | 
            -
            #         )
         | 
| 40 | 
            -
            #         self.conv_post = WNConv2d(
         | 
| 41 | 
            -
            #             1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
         | 
| 42 | 
            -
            #         )
         | 
| 43 | 
            -
             | 
| 44 | 
            -
            #     def pad_to_period(self, x):
         | 
| 45 | 
            -
            #         t = x.shape[-1]
         | 
| 46 | 
            -
            #         x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
         | 
| 47 | 
            -
            #         return x
         | 
| 48 | 
            -
             | 
| 49 | 
            -
            #     def forward(self, x):
         | 
| 50 | 
            -
            #         fmap = []
         | 
| 51 | 
            -
             | 
| 52 | 
            -
            #         x = self.pad_to_period(x)
         | 
| 53 | 
            -
            #         x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
         | 
| 54 | 
            -
             | 
| 55 | 
            -
            #         for layer in self.convs:
         | 
| 56 | 
            -
            #             x = layer(x)
         | 
| 57 | 
            -
            #             fmap.append(x)
         | 
| 58 | 
            -
             | 
| 59 | 
            -
            #         x = self.conv_post(x)
         | 
| 60 | 
            -
            #         fmap.append(x)
         | 
| 61 | 
            -
             | 
| 62 | 
            -
            #         return fmap
         | 
| 63 | 
            -
             | 
| 64 | 
            -
             | 
| 65 | 
            -
            # class MSD(nn.Module):
         | 
| 66 | 
            -
            #     def __init__(self, rate: int = 1, sample_rate: int = 44100):
         | 
| 67 | 
            -
            #         super().__init__()
         | 
| 68 | 
            -
            #         self.convs = nn.ModuleList(
         | 
| 69 | 
            -
            #             [
         | 
| 70 | 
            -
            #                 WNConv1d(1, 16, 15, 1, padding=7),
         | 
| 71 | 
            -
            #                 WNConv1d(16, 64, 41, 4, groups=4, padding=20),
         | 
| 72 | 
            -
            #                 WNConv1d(64, 256, 41, 4, groups=16, padding=20),
         | 
| 73 | 
            -
            #                 WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
         | 
| 74 | 
            -
            #                 WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
         | 
| 75 | 
            -
            #                 WNConv1d(1024, 1024, 5, 1, padding=2),
         | 
| 76 | 
            -
            #             ]
         | 
| 77 | 
            -
            #         )
         | 
| 78 | 
            -
            #         self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
         | 
| 79 | 
            -
            #         self.sample_rate = sample_rate
         | 
| 80 | 
            -
            #         self.rate = rate
         | 
| 81 | 
            -
             | 
| 82 | 
            -
            #     def forward(self, x):
         | 
| 83 | 
            -
            #         x = AudioSignal(x, self.sample_rate)
         | 
| 84 | 
            -
            #         x.resample(self.sample_rate // self.rate)
         | 
| 85 | 
            -
            #         x = x.audio_data
         | 
| 86 | 
            -
             | 
| 87 | 
            -
            #         fmap = []
         | 
| 88 | 
            -
             | 
| 89 | 
            -
            #         for l in self.convs:
         | 
| 90 | 
            -
            #             x = l(x)
         | 
| 91 | 
            -
            #             fmap.append(x)
         | 
| 92 | 
            -
            #         x = self.conv_post(x)
         | 
| 93 | 
            -
            #         fmap.append(x)
         | 
| 94 | 
            -
             | 
| 95 | 
            -
            #         return fmap
         | 
| 96 | 
            -
             | 
| 97 | 
            -
             | 
| 98 | 
            -
            # BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
         | 
| 99 | 
            -
             | 
| 100 | 
            -
             | 
| 101 | 
            -
            # class MRD(nn.Module):
         | 
| 102 | 
            -
            #     def __init__(
         | 
| 103 | 
            -
            #         self,
         | 
| 104 | 
            -
            #         window_length: int,
         | 
| 105 | 
            -
            #         hop_factor: float = 0.25,
         | 
| 106 | 
            -
            #         sample_rate: int = 44100,
         | 
| 107 | 
            -
            #         bands: list = BANDS,
         | 
| 108 | 
            -
            #     ):
         | 
| 109 | 
            -
            #         """Complex multi-band spectrogram discriminator.
         | 
| 110 | 
            -
            #         Parameters
         | 
| 111 | 
            -
            #         ----------
         | 
| 112 | 
            -
            #         window_length : int
         | 
| 113 | 
            -
            #             Window length of STFT.
         | 
| 114 | 
            -
            #         hop_factor : float, optional
         | 
| 115 | 
            -
            #             Hop factor of the STFT, defaults to ``0.25 * window_length``.
         | 
| 116 | 
            -
            #         sample_rate : int, optional
         | 
| 117 | 
            -
            #             Sampling rate of audio in Hz, by default 44100
         | 
| 118 | 
            -
            #         bands : list, optional
         | 
| 119 | 
            -
            #             Bands to run discriminator over.
         | 
| 120 | 
            -
            #         """
         | 
| 121 | 
            -
            #         super().__init__()
         | 
| 122 | 
            -
             | 
| 123 | 
            -
            #         self.window_length = window_length
         | 
| 124 | 
            -
            #         self.hop_factor = hop_factor
         | 
| 125 | 
            -
            #         self.sample_rate = sample_rate
         | 
| 126 | 
            -
            #         self.stft_params = STFTParams(
         | 
| 127 | 
            -
            #             window_length=window_length,
         | 
| 128 | 
            -
            #             hop_length=int(window_length * hop_factor),
         | 
| 129 | 
            -
            #             match_stride=True,
         | 
| 130 | 
            -
            #         )
         | 
| 131 | 
            -
             | 
| 132 | 
            -
            #         n_fft = window_length // 2 + 1
         | 
| 133 | 
            -
            #         bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
         | 
| 134 | 
            -
            #         self.bands = bands
         | 
| 135 | 
            -
             | 
| 136 | 
            -
            #         ch = 32
         | 
| 137 | 
            -
            #         convs = lambda: nn.ModuleList(
         | 
| 138 | 
            -
            #             [
         | 
| 139 | 
            -
            #                 WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
         | 
| 140 | 
            -
            #                 WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
         | 
| 141 | 
            -
            #                 WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
         | 
| 142 | 
            -
            #                 WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
         | 
| 143 | 
            -
            #                 WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
         | 
| 144 | 
            -
            #             ]
         | 
| 145 | 
            -
            #         )
         | 
| 146 | 
            -
            #         self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
         | 
| 147 | 
            -
            #         self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
         | 
| 148 | 
            -
             | 
| 149 | 
            -
            #     def spectrogram(self, x):
         | 
| 150 | 
            -
            #         x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
         | 
| 151 | 
            -
            #         x = torch.view_as_real(x.stft())
         | 
| 152 | 
            -
            #         x = rearrange(x, "b 1 f t c -> (b 1) c t f")
         | 
| 153 | 
            -
            #         # Split into bands
         | 
| 154 | 
            -
            #         x_bands = [x[..., b[0] : b[1]] for b in self.bands]
         | 
| 155 | 
            -
            #         return x_bands
         | 
| 156 | 
            -
             | 
| 157 | 
            -
            #     def forward(self, x):
         | 
| 158 | 
            -
            #         x_bands = self.spectrogram(x)
         | 
| 159 | 
            -
            #         fmap = []
         | 
| 160 | 
            -
             | 
| 161 | 
            -
            #         x = []
         | 
| 162 | 
            -
            #         for band, stack in zip(x_bands, self.band_convs):
         | 
| 163 | 
            -
            #             for layer in stack:
         | 
| 164 | 
            -
            #                 band = layer(band)
         | 
| 165 | 
            -
            #                 fmap.append(band)
         | 
| 166 | 
            -
            #             x.append(band)
         | 
| 167 | 
            -
             | 
| 168 | 
            -
            #         x = torch.cat(x, dim=-1)
         | 
| 169 | 
            -
            #         x = self.conv_post(x)
         | 
| 170 | 
            -
            #         fmap.append(x)
         | 
| 171 | 
            -
             | 
| 172 | 
            -
            #         return fmap
         | 
| 173 | 
            -
             | 
| 174 | 
            -
             | 
| 175 | 
            -
            # class Discriminator(ml.BaseModel):
         | 
| 176 | 
            -
            #     def __init__(
         | 
| 177 | 
            -
            #         self,
         | 
| 178 | 
            -
            #         rates: list = [],
         | 
| 179 | 
            -
            #         periods: list = [2, 3, 5, 7, 11],
         | 
| 180 | 
            -
            #         fft_sizes: list = [2048, 1024, 512],
         | 
| 181 | 
            -
            #         sample_rate: int = 44100,
         | 
| 182 | 
            -
            #         bands: list = BANDS,
         | 
| 183 | 
            -
            #     ):
         | 
| 184 | 
            -
            #         """Discriminator that combines multiple discriminators.
         | 
| 185 | 
            -
             | 
| 186 | 
            -
            #         Parameters
         | 
| 187 | 
            -
            #         ----------
         | 
| 188 | 
            -
            #         rates : list, optional
         | 
| 189 | 
            -
            #             sampling rates (in Hz) to run MSD at, by default []
         | 
| 190 | 
            -
            #             If empty, MSD is not used.
         | 
| 191 | 
            -
            #         periods : list, optional
         | 
| 192 | 
            -
            #             periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
         | 
| 193 | 
            -
            #         fft_sizes : list, optional
         | 
| 194 | 
            -
            #             Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
         | 
| 195 | 
            -
            #         sample_rate : int, optional
         | 
| 196 | 
            -
            #             Sampling rate of audio in Hz, by default 44100
         | 
| 197 | 
            -
            #         bands : list, optional
         | 
| 198 | 
            -
            #             Bands to run MRD at, by default `BANDS`
         | 
| 199 | 
            -
            #         """
         | 
| 200 | 
            -
            #         super().__init__()
         | 
| 201 | 
            -
            #         discs = []
         | 
| 202 | 
            -
            #         discs += [MPD(p) for p in periods]
         | 
| 203 | 
            -
            #         discs += [MSD(r, sample_rate=sample_rate) for r in rates]
         | 
| 204 | 
            -
            #         discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes]
         | 
| 205 | 
            -
            #         self.discriminators = nn.ModuleList(discs)
         | 
| 206 | 
            -
             | 
| 207 | 
            -
            #     def preprocess(self, y):
         | 
| 208 | 
            -
            #         # Remove DC offset
         | 
| 209 | 
            -
            #         y = y - y.mean(dim=-1, keepdims=True)
         | 
| 210 | 
            -
            #         # Peak normalize the volume of input audio
         | 
| 211 | 
            -
            #         y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
         | 
| 212 | 
            -
            #         return y
         | 
| 213 | 
            -
             | 
| 214 | 
            -
            #     def forward(self, x):
         | 
| 215 | 
            -
            #         x = self.preprocess(x)
         | 
| 216 | 
            -
            #         fmaps = [d(x) for d in self.discriminators]
         | 
| 217 | 
            -
            #         return fmaps
         | 
| 218 | 
            -
             | 
| 219 | 
            -
             | 
| 220 | 
            -
            # if __name__ == "__main__":
         | 
| 221 | 
            -
            #     disc = Discriminator()
         | 
| 222 | 
            -
            #     x = torch.zeros(1, 1, 44100)
         | 
| 223 | 
            -
            #     results = disc(x)
         | 
| 224 | 
            -
            #     for i, result in enumerate(results):
         | 
| 225 | 
            -
            #         print(f"disc{i}")
         | 
| 226 | 
            -
            #         for i, r in enumerate(result):
         | 
| 227 | 
            -
            #             print(r.shape, r.mean(), r.min(), r.max())
         | 
| 228 | 
            -
            #         print()
         | 
| 229 | 
             
            import torch
         | 
| 230 | 
             
            import torch.nn as nn
         | 
| 231 | 
             
            import torch.nn.functional as F
         | 
| @@ -313,7 +85,7 @@ class MPD(nn.Module): | |
| 313 |  | 
| 314 |  | 
| 315 | 
             
            class MSD(nn.Module):
         | 
| 316 | 
            -
                def __init__(self, rate: int = 1, sample_rate: int =  | 
| 317 | 
             
                    super().__init__()
         | 
| 318 | 
             
                    self.convs = nn.ModuleList([
         | 
| 319 | 
             
                        WNConv1d(1, 16, 15, 1, padding=7),
         | 
| @@ -463,19 +235,19 @@ class DiscriminatorCQT(nn.Module): | |
| 463 |  | 
| 464 | 
             
            class MultiScaleSubbandCQT(nn.Module):
         | 
| 465 | 
             
                """CQT discriminator at multiple scales"""
         | 
| 466 | 
            -
                def __init__(self, sample_rate= | 
| 467 | 
             
                    super().__init__()
         | 
| 468 | 
             
                    cfg = Munch({
         | 
| 469 | 
            -
             | 
| 470 | 
            -
             | 
| 471 | 
            -
             | 
| 472 | 
            -
             | 
| 473 | 
            -
             | 
| 474 | 
            -
             | 
| 475 | 
            -
             | 
| 476 | 
            -
             | 
| 477 | 
            -
             | 
| 478 | 
            -
             | 
| 479 | 
             
                    })
         | 
| 480 | 
             
                    self.cfg = cfg
         | 
| 481 | 
             
                    self.discriminators = nn.ModuleList([
         | 
| @@ -499,7 +271,7 @@ BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)] | |
| 499 |  | 
| 500 | 
             
            class MRD(nn.Module):
         | 
| 501 | 
             
                def __init__(self, window_length: int, hop_factor: float = 0.25, 
         | 
| 502 | 
            -
                             sample_rate: int =  | 
| 503 | 
             
                    """Multi-resolution spectrogram discriminator."""
         | 
| 504 | 
             
                    super().__init__()
         | 
| 505 | 
             
                    self.window_length = window_length
         | 
| @@ -556,7 +328,7 @@ class Discriminator(ml.BaseModel): | |
| 556 | 
             
                    rates: list = [],
         | 
| 557 | 
             
                    periods: list = [2, 3, 5, 7, 11],
         | 
| 558 | 
             
                    fft_sizes: list = [2048, 1024, 512],
         | 
| 559 | 
            -
                    sample_rate: int =  | 
| 560 | 
             
                ):
         | 
| 561 | 
             
                    """Discriminator combining MPD, MSD, MRD and CQT.
         | 
| 562 |  | 
| @@ -569,7 +341,7 @@ class Discriminator(ml.BaseModel): | |
| 569 | 
             
                    fft_sizes : list, optional
         | 
| 570 | 
             
                        FFT sizes for MRD, by default [2048, 1024, 512]
         | 
| 571 | 
             
                    sample_rate : int, optional
         | 
| 572 | 
            -
                        Sampling rate of audio in Hz, by default  | 
| 573 | 
             
                    """
         | 
| 574 | 
             
                    super().__init__()
         | 
| 575 | 
             
                    discs = []
         | 
| @@ -593,4 +365,6 @@ class Discriminator(ml.BaseModel): | |
| 593 | 
             
                def forward(self, x):
         | 
| 594 | 
             
                    x = self.preprocess(x)
         | 
| 595 | 
             
                    fmaps = [d(x) for d in self.discriminators]
         | 
| 596 | 
            -
                    return fmaps
         | 
|  | |
|  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1 | 
             
            import torch
         | 
| 2 | 
             
            import torch.nn as nn
         | 
| 3 | 
             
            import torch.nn.functional as F
         | 
|  | |
| 85 |  | 
| 86 |  | 
| 87 | 
             
            class MSD(nn.Module):
         | 
| 88 | 
            +
                def __init__(self, rate: int = 1, sample_rate: int = 24000):
         | 
| 89 | 
             
                    super().__init__()
         | 
| 90 | 
             
                    self.convs = nn.ModuleList([
         | 
| 91 | 
             
                        WNConv1d(1, 16, 15, 1, padding=7),
         | 
|  | |
| 235 |  | 
| 236 | 
             
            class MultiScaleSubbandCQT(nn.Module):
         | 
| 237 | 
             
                """CQT discriminator at multiple scales"""
         | 
| 238 | 
            +
                def __init__(self, sample_rate=24000):
         | 
| 239 | 
             
                    super().__init__()
         | 
| 240 | 
             
                    cfg = Munch({
         | 
| 241 | 
            +
                            "hop_lengths": [512, 256, 256],
         | 
| 242 | 
            +
                            "sampling_rate": 24000,
         | 
| 243 | 
            +
                            "filters": 32,
         | 
| 244 | 
            +
                            "max_filters": 1024,
         | 
| 245 | 
            +
                            "filters_scale": 1,
         | 
| 246 | 
            +
                            "dilations": [1, 2, 4],
         | 
| 247 | 
            +
                            "in_channels": 1,
         | 
| 248 | 
            +
                            "out_channels": 1,
         | 
| 249 | 
            +
                            "n_octaves": [9, 9, 9],
         | 
| 250 | 
            +
                            "bins_per_octaves": [24, 36, 48],
         | 
| 251 | 
             
                    })
         | 
| 252 | 
             
                    self.cfg = cfg
         | 
| 253 | 
             
                    self.discriminators = nn.ModuleList([
         | 
|  | |
| 271 |  | 
| 272 | 
             
            class MRD(nn.Module):
         | 
| 273 | 
             
                def __init__(self, window_length: int, hop_factor: float = 0.25, 
         | 
| 274 | 
            +
                             sample_rate: int = 24000, bands: list = BANDS):
         | 
| 275 | 
             
                    """Multi-resolution spectrogram discriminator."""
         | 
| 276 | 
             
                    super().__init__()
         | 
| 277 | 
             
                    self.window_length = window_length
         | 
|  | |
| 328 | 
             
                    rates: list = [],
         | 
| 329 | 
             
                    periods: list = [2, 3, 5, 7, 11],
         | 
| 330 | 
             
                    fft_sizes: list = [2048, 1024, 512],
         | 
| 331 | 
            +
                    sample_rate: int = 24000,
         | 
| 332 | 
             
                ):
         | 
| 333 | 
             
                    """Discriminator combining MPD, MSD, MRD and CQT.
         | 
| 334 |  | 
|  | |
| 341 | 
             
                    fft_sizes : list, optional
         | 
| 342 | 
             
                        FFT sizes for MRD, by default [2048, 1024, 512]
         | 
| 343 | 
             
                    sample_rate : int, optional
         | 
| 344 | 
            +
                        Sampling rate of audio in Hz, by default 24000
         | 
| 345 | 
             
                    """
         | 
| 346 | 
             
                    super().__init__()
         | 
| 347 | 
             
                    discs = []
         | 
|  | |
| 365 | 
             
                def forward(self, x):
         | 
| 366 | 
             
                    x = self.preprocess(x)
         | 
| 367 | 
             
                    fmaps = [d(x) for d in self.discriminators]
         | 
| 368 | 
            +
                    return fmaps
         | 
| 369 | 
            +
             | 
| 370 | 
            +
                    
         | 
