import torch import torch.nn as nn from .mpd import MultiPeriodDiscriminator from .mrd import MultiResolutionDiscriminator from omegaconf import OmegaConf class Discriminator(nn.Module): def __init__(self, hp): super(Discriminator, self).__init__() self.MRD = MultiResolutionDiscriminator(hp) self.MPD = MultiPeriodDiscriminator(hp) def forward(self, x): return self.MRD(x), self.MPD(x) if __name__ == '__main__': hp = OmegaConf.load('../config/default.yaml') model = Discriminator(hp) x = torch.randn(3, 1, 16384) print(x.shape) mrd_output, mpd_output = model(x) for features, score in mpd_output: for feat in features: print(feat.shape) print(score.shape) pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(pytorch_total_params)