maxmax20160403's picture
final ver
c24b656
raw
history blame
1.05 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
class ScaleDiscriminator(torch.nn.Module):
def __init__(self):
super(ScaleDiscriminator, self).__init__()
self.convs = nn.ModuleList([
weight_norm(nn.Conv1d(1, 16, 15, 1, padding=7)),
weight_norm(nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)),
weight_norm(nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)),
weight_norm(nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
weight_norm(nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
weight_norm(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
])
self.conv_post = weight_norm(nn.Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x):
fmap = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, 0.1)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return [(fmap, x)]