| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | import math |
| | from functools import partial |
| | from math import log2 |
| | from typing import List |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from einops import rearrange |
| | from kornia.filters import filter2d |
| |
|
| | from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor |
| |
|
| |
|
| | class Blur(torch.nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | f = torch.Tensor([1, 2, 1]) |
| | self.register_buffer("f", f) |
| |
|
| | def forward(self, x): |
| | f = self.f |
| | f = f[None, None, :] * f[None, :, None] |
| | return filter2d(x, f, normalized=True) |
| |
|
| |
|
| | class EqualLinear(torch.nn.Module): |
| | def __init__(self, in_dim, out_dim, lr_mul=1, bias=True): |
| | super().__init__() |
| | self.weight = torch.nn.Parameter(torch.randn(out_dim, in_dim)) |
| | if bias: |
| | self.bias = torch.nn.Parameter(torch.zeros(out_dim)) |
| |
|
| | self.lr_mul = lr_mul |
| |
|
| | def forward(self, input): |
| | return F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul) |
| |
|
| |
|
| | class StyleMapping(torch.nn.Module): |
| | def __init__(self, emb, depth, lr_mul=0.1): |
| | super().__init__() |
| |
|
| | layers = [] |
| | for _ in range(depth): |
| | layers.extend([EqualLinear(emb, emb, lr_mul), torch.nn.LeakyReLU(0.2, inplace=True)]) |
| |
|
| | self.net = torch.nn.Sequential(*layers) |
| |
|
| | def forward(self, x): |
| | x = F.normalize(x, dim=1) |
| | return self.net(x) |
| |
|
| |
|
| | class RGBBlock(torch.nn.Module): |
| | def __init__(self, latent_dim, input_channel, upsample, channels=3): |
| | super().__init__() |
| | self.input_channel = input_channel |
| | self.to_style = torch.nn.Linear(latent_dim, input_channel) |
| |
|
| | out_filters = channels |
| | self.conv = Conv2DModulated(input_channel, out_filters, 1, demod=False) |
| |
|
| | self.upsample = ( |
| | torch.nn.Sequential(torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), Blur(),) |
| | if upsample |
| | else None |
| | ) |
| |
|
| | def forward(self, x, prev_rgb, istyle): |
| | style = self.to_style(istyle) |
| | x = self.conv(x, style) |
| |
|
| | if prev_rgb is not None: |
| | x = x + prev_rgb |
| |
|
| | if self.upsample is not None: |
| | x = self.upsample(x) |
| |
|
| | return x |
| |
|
| |
|
| | class Conv2DModulated(torch.nn.Module): |
| | """ |
| | Modulated convolution. |
| | For details refer to [1] |
| | [1] Karras et. al. - Analyzing and Improving the Image Quality of StyleGAN (https://arxiv.org/abs/1912.04958) |
| | """ |
| |
|
| | def __init__( |
| | self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, eps=1e-8, **kwargs, |
| | ): |
| | super().__init__() |
| | self.filters = out_chan |
| | self.demod = demod |
| | self.kernel = kernel |
| | self.stride = stride |
| | self.dilation = dilation |
| | self.weight = torch.nn.Parameter(torch.randn((out_chan, in_chan, kernel, kernel))) |
| | self.eps = eps |
| | torch.nn.init.kaiming_normal_(self.weight, a=0, mode="fan_in", nonlinearity="leaky_relu") |
| |
|
| | def _get_same_padding(self, size, kernel, dilation, stride): |
| | return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2 |
| |
|
| | def forward(self, x, y): |
| | b, c, h, w = x.shape |
| |
|
| | w1 = y[:, None, :, None, None] |
| | w2 = self.weight[None, :, :, :, :] |
| | weights = w2 * (w1 + 1) |
| |
|
| | if self.demod: |
| | d = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + self.eps) |
| | weights = weights * d |
| |
|
| | x = x.reshape(1, -1, h, w) |
| |
|
| | _, _, *ws = weights.shape |
| | weights = weights.reshape(b * self.filters, *ws) |
| |
|
| | padding = self._get_same_padding(h, self.kernel, self.dilation, self.stride) |
| | x = F.conv2d(x, weights, padding=padding, groups=b) |
| |
|
| | x = x.reshape(-1, self.filters, h, w) |
| | return x |
| |
|
| |
|
| | class GeneratorBlock(torch.nn.Module): |
| | def __init__( |
| | self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, channels=1, |
| | ): |
| | super().__init__() |
| | self.upsample = torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False) if upsample else None |
| |
|
| | self.to_style1 = torch.nn.Linear(latent_dim, input_channels) |
| | self.to_noise1 = torch.nn.Linear(1, filters) |
| | self.conv1 = Conv2DModulated(input_channels, filters, 3) |
| |
|
| | self.to_style2 = torch.nn.Linear(latent_dim, filters) |
| | self.to_noise2 = torch.nn.Linear(1, filters) |
| | self.conv2 = Conv2DModulated(filters, filters, 3) |
| |
|
| | self.activation = torch.nn.LeakyReLU(0.2, inplace=True) |
| | self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, channels) |
| |
|
| | def forward(self, x, prev_rgb, istyle, inoise): |
| | if self.upsample is not None: |
| | x = self.upsample(x) |
| |
|
| | inoise = inoise[:, : x.shape[2], : x.shape[3], :] |
| | noise1 = self.to_noise1(inoise).permute((0, 3, 1, 2)) |
| | noise2 = self.to_noise2(inoise).permute((0, 3, 1, 2)) |
| |
|
| | style1 = self.to_style1(istyle) |
| | x = self.conv1(x, style1) |
| | x = self.activation(x + noise1) |
| |
|
| | style2 = self.to_style2(istyle) |
| | x = self.conv2(x, style2) |
| | x = self.activation(x + noise2) |
| |
|
| | rgb = self.to_rgb(x, prev_rgb, istyle) |
| | return x, rgb |
| |
|
| |
|
| | class DiscriminatorBlock(torch.nn.Module): |
| | def __init__(self, input_channels, filters, downsample=True): |
| | super().__init__() |
| | self.conv_res = torch.nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1)) |
| |
|
| | self.net = torch.nn.Sequential( |
| | torch.nn.Conv2d(input_channels, filters, 3, padding=1), |
| | torch.nn.LeakyReLU(0.2, inplace=True), |
| | torch.nn.Conv2d(filters, filters, 3, padding=1), |
| | torch.nn.LeakyReLU(0.2, inplace=True), |
| | ) |
| |
|
| | self.downsample = ( |
| | torch.nn.Sequential(Blur(), torch.nn.Conv2d(filters, filters, 3, padding=1, stride=2)) |
| | if downsample |
| | else None |
| | ) |
| |
|
| | def forward(self, x): |
| | res = self.conv_res(x) |
| | x = self.net(x) |
| | if self.downsample is not None: |
| | x = self.downsample(x) |
| | x = (x + res) * (1 / math.sqrt(2)) |
| | return x |
| |
|
| |
|
| | class Generator(torch.nn.Module): |
| | def __init__( |
| | self, n_bands, latent_dim, style_depth, network_capacity=16, channels=1, fmap_max=512, start_from_zero=True |
| | ): |
| | super().__init__() |
| | self.image_size = n_bands |
| | self.latent_dim = latent_dim |
| | self.num_layers = int(log2(n_bands) - 1) |
| | self.style_depth = style_depth |
| |
|
| | self.style_mapping = StyleMapping(self.latent_dim, self.style_depth, lr_mul=0.1) |
| |
|
| | filters = [network_capacity * (2 ** (i + 1)) for i in range(self.num_layers)][::-1] |
| |
|
| | set_fmap_max = partial(min, fmap_max) |
| | filters = list(map(set_fmap_max, filters)) |
| | init_channels = filters[0] |
| | filters = [init_channels, *filters] |
| |
|
| | in_out_pairs = zip(filters[:-1], filters[1:]) |
| |
|
| | self.initial_conv = torch.nn.Conv2d(filters[0], filters[0], 3, padding=1) |
| | self.blocks = torch.nn.ModuleList([]) |
| |
|
| | for ind, (in_chan, out_chan) in enumerate(in_out_pairs): |
| | not_first = ind != 0 |
| | not_last = ind != (self.num_layers - 1) |
| |
|
| | block = GeneratorBlock( |
| | latent_dim, in_chan, out_chan, upsample=not_first, upsample_rgb=not_last, channels=channels, |
| | ) |
| | self.blocks.append(block) |
| |
|
| | for m in self.modules(): |
| | if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)): |
| | torch.nn.init.kaiming_normal_(m.weight, a=0, mode="fan_in", nonlinearity="leaky_relu") |
| | for block in self.blocks: |
| | torch.nn.init.zeros_(block.to_noise1.weight) |
| | torch.nn.init.zeros_(block.to_noise1.bias) |
| | torch.nn.init.zeros_(block.to_noise2.weight) |
| | torch.nn.init.zeros_(block.to_noise2.bias) |
| |
|
| | initial_block_size = n_bands // self.upsample_factor, 1 |
| | self.initial_block = torch.nn.Parameter( |
| | torch.randn((1, init_channels, *initial_block_size)), requires_grad=False |
| | ) |
| | if start_from_zero: |
| | self.initial_block.data.zero_() |
| |
|
| | def add_scaled_condition(self, target: torch.Tensor, condition: torch.Tensor, condition_lengths: torch.Tensor): |
| | *_, target_height, _ = target.shape |
| | *_, height, _ = condition.shape |
| |
|
| | scale = height // target_height |
| |
|
| | |
| | condition = F.interpolate(condition, size=target.shape[-2:], mode="bilinear") |
| |
|
| | |
| | result = (target + condition) / 2 |
| | result = mask_sequence_tensor(result, (condition_lengths / scale).ceil().long()) |
| |
|
| | return result |
| |
|
| | @property |
| | def upsample_factor(self): |
| | return 2 ** sum(1 for block in self.blocks if block.upsample) |
| |
|
| | def forward(self, condition: torch.Tensor, lengths: torch.Tensor, ws: List[torch.Tensor], noise: torch.Tensor): |
| | batch_size, _, _, max_length = condition.shape |
| |
|
| | x = self.initial_block.expand(batch_size, -1, -1, max_length // self.upsample_factor) |
| |
|
| | rgb = None |
| | x = self.initial_conv(x) |
| |
|
| | for style, block in zip(ws, self.blocks): |
| | x, rgb = block(x, rgb, style, noise) |
| |
|
| | x = self.add_scaled_condition(x, condition, lengths) |
| | rgb = self.add_scaled_condition(rgb, condition, lengths) |
| |
|
| | return rgb |
| |
|
| |
|
| | class Discriminator(torch.nn.Module): |
| | def __init__( |
| | self, n_bands, network_capacity=16, channels=1, fmap_max=512, |
| | ): |
| | super().__init__() |
| | num_layers = int(log2(n_bands) - 1) |
| | num_init_filters = channels |
| |
|
| | blocks = [] |
| | filters = [num_init_filters] + [(network_capacity * 4) * (2 ** i) for i in range(num_layers + 1)] |
| |
|
| | set_fmap_max = partial(min, fmap_max) |
| | filters = list(map(set_fmap_max, filters)) |
| | chan_in_out = list(zip(filters[:-1], filters[1:])) |
| |
|
| | blocks = [] |
| |
|
| | for ind, (in_chan, out_chan) in enumerate(chan_in_out): |
| | is_not_last = ind != (len(chan_in_out) - 1) |
| |
|
| | block = DiscriminatorBlock(in_chan, out_chan, downsample=is_not_last) |
| | blocks.append(block) |
| |
|
| | self.blocks = torch.nn.ModuleList(blocks) |
| |
|
| | channel_last = filters[-1] |
| | latent_dim = channel_last |
| |
|
| | self.final_conv = torch.nn.Conv2d(channel_last, channel_last, 3, padding=1) |
| | self.to_logit = torch.nn.Linear(latent_dim, 1) |
| |
|
| | for m in self.modules(): |
| | if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)): |
| | torch.nn.init.kaiming_normal_(m.weight, a=0, mode="fan_in", nonlinearity="leaky_relu") |
| |
|
| | def forward(self, x, condition: torch.Tensor, lengths: torch.Tensor): |
| | for block in self.blocks: |
| | x = block(x) |
| | scale = condition.shape[-1] // x.shape[-1] |
| | x = mask_sequence_tensor(x, (lengths / scale).ceil().long()) |
| |
|
| | x = self.final_conv(x) |
| |
|
| | scale = condition.shape[-1] // x.shape[-1] |
| | x = mask_sequence_tensor(x, (lengths / scale).ceil().long()) |
| |
|
| | x = x.mean(axis=-2) |
| | x = (x / rearrange(lengths / scale, "b -> b 1 1")).sum(axis=-1) |
| | x = self.to_logit(x) |
| | return x.squeeze() |
| |
|