tsm-net / tsmnet /modules.py
ernestchu's picture
update
b6ef12a
raw
history blame contribute delete
No virus
5.21 kB
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.nn.utils import weight_norm
import numpy as np
def weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find("BatchNorm2d") != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
class ResnetBlock(nn.Module):
def __init__(self, dim, dilation=1):
super().__init__()
self.block = nn.Sequential(
nn.Tanh(),
nn.ReflectionPad1d(dilation),
WNConv1d(dim, dim, kernel_size=3, dilation=dilation),
nn.Tanh(),
WNConv1d(dim, dim, kernel_size=1),
)
self.shortcut = WNConv1d(dim, dim, kernel_size=1)
def forward(self, x):
return self.shortcut(x) + self.block(x)
class Autoencoder(nn.Module):
def __init__(self, compress_ratios, ngf, n_residual_layers):
super().__init__()
self.encoder = self.makeEncoder(compress_ratios, ngf, n_residual_layers)
self.decoder = self.makeDecoder([r for r in reversed(compress_ratios)], ngf, n_residual_layers)
self.apply(weights_init)
def makeEncoder(self, ratios, ngf, n_residual_layers):
mult = 1
model = [
nn.ReflectionPad1d(3),
WNConv1d(1, ngf, kernel_size=7, padding=0),
nn.Tanh(),
]
# Downsample to neuralgram scale
for i, r in enumerate(ratios):
mult *= 2
for j in range(n_residual_layers-1, -1, -1):
model += [ResnetBlock(mult * ngf // 2, dilation=3 ** j)]
model += [
nn.Tanh(),
WNConv1d(
mult * ngf // 2,
mult * ngf,
kernel_size=r * 2,
stride=r,
padding=r // 2 + r % 2
),
]
model += [ nn.Tanh() ]
return nn.Sequential(*model)
def makeDecoder(self, ratios, ngf, n_residual_layers):
mult = int(2 ** len(ratios))
model = []
# Upsample to raw audio scale
for i, r in enumerate(ratios):
model += [
nn.Tanh(),
WNConvTranspose1d(
mult * ngf,
mult * ngf // 2,
kernel_size=r * 2,
stride=r,
padding=r // 2 + r % 2,
output_padding=r % 2
),
]
for j in range(n_residual_layers):
model += [ResnetBlock(mult * ngf // 2, dilation=3 ** j)]
mult //= 2
model += [
nn.Tanh(),
nn.ReflectionPad1d(3),
WNConv1d(ngf, 1, kernel_size=7, padding=0),
nn.Tanh(),
]
return nn.Sequential(*model)
def forward(self, x):
return self.decoder(self.encoder(x))
class NLayerDiscriminator(nn.Module):
def __init__(self, ndf, n_layers, downsampling_factor):
super().__init__()
model = nn.ModuleDict()
model["layer_0"] = nn.Sequential(
nn.ReflectionPad1d(7),
WNConv1d(1, ndf, kernel_size=15),
nn.Tanh(),
)
nf = ndf
stride = downsampling_factor
for n in range(1, n_layers + 1):
nf_prev = nf
nf = min(nf * stride, 1024)
model["layer_%d" % n] = nn.Sequential(
WNConv1d(
nf_prev,
nf,
kernel_size=stride * 10 + 1,
stride=stride,
padding=stride * 5,
groups=nf_prev // 4,
),
nn.Tanh(),
)
nf = min(nf * 2, 1024)
model["layer_%d" % (n_layers + 1)] = nn.Sequential(
WNConv1d(nf_prev, nf, kernel_size=5, stride=1, padding=2),
nn.Tanh(),
)
model["layer_%d" % (n_layers + 2)] = WNConv1d(
nf, 1, kernel_size=3, stride=1, padding=1
)
self.model = model
def forward(self, x):
results = []
for key, layer in self.model.items():
x = layer(x)
results.append(x)
return results
class Discriminator(nn.Module):
def __init__(self, num_D, ndf, n_layers, downsampling_factor):
super().__init__()
self.model = nn.ModuleDict()
for i in range(num_D):
self.model[f"disc_{i}"] = NLayerDiscriminator(
ndf, n_layers, downsampling_factor
)
self.downsample = nn.AvgPool1d(4, stride=2, padding=1, count_include_pad=False)
self.apply(weights_init)
def forward(self, x):
results = []
for key, disc in self.model.items():
results.append(disc(x))
x = self.downsample(x)
return results