Spaces:
Running
on
Zero
Running
on
Zero
| # References: https://github.com/yxlu-0102/MP-SENet/blob/main/models/discriminator.py | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from pesq import pesq | |
| from joblib import Parallel, delayed | |
| from models.lsigmoid import LearnableSigmoid1D | |
| def pesq_loss(clean, noisy, sr=16000): | |
| try: | |
| pesq_score = pesq(sr, clean, noisy, 'wb') | |
| except: | |
| # error can happen due to silent period | |
| pesq_score = -1 | |
| return pesq_score | |
| def batch_pesq(clean, noisy, cfg): | |
| num_worker = cfg['env_setting']['num_workers'] | |
| pesq_score = Parallel(n_jobs=num_worker)(delayed(pesq_loss)(c, n) for c, n in zip(clean, noisy)) | |
| pesq_score = np.array(pesq_score) | |
| if -1 in pesq_score: | |
| return None | |
| pesq_score = (pesq_score - 1) / 3.5 | |
| return torch.FloatTensor(pesq_score) | |
| class MetricDiscriminator(nn.Module): | |
| def __init__(self, dim=16, in_channel=2): | |
| super(MetricDiscriminator, self).__init__() | |
| self.layers = nn.Sequential( | |
| nn.utils.spectral_norm(nn.Conv2d(in_channel, dim, (4,4), (2,2), (1,1), bias=False)), | |
| nn.InstanceNorm2d(dim, affine=True), | |
| nn.PReLU(dim), | |
| nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)), | |
| nn.InstanceNorm2d(dim*2, affine=True), | |
| nn.PReLU(dim*2), | |
| nn.utils.spectral_norm(nn.Conv2d(dim*2, dim*4, (4,4), (2,2), (1,1), bias=False)), | |
| nn.InstanceNorm2d(dim*4, affine=True), | |
| nn.PReLU(dim*4), | |
| nn.utils.spectral_norm(nn.Conv2d(dim*4, dim*8, (4,4), (2,2), (1,1), bias=False)), | |
| nn.InstanceNorm2d(dim*8, affine=True), | |
| nn.PReLU(dim*8), | |
| nn.AdaptiveMaxPool2d(1), | |
| nn.Flatten(), | |
| nn.utils.spectral_norm(nn.Linear(dim*8, dim*4)), | |
| nn.Dropout(0.3), | |
| nn.PReLU(dim*4), | |
| nn.utils.spectral_norm(nn.Linear(dim*4, 1)), | |
| LearnableSigmoid1D(1) | |
| ) | |
| def forward(self, x, y): | |
| xy = torch.stack((x, y), dim=1) | |
| return self.layers(xy) | |