Spaces:
Running
on
Zero
Running
on
Zero
"""This file contains some base implementation for discrminators. | |
Copyright (2024) Bytedance Ltd. and/or its affiliates | |
Licensed under the Apache License, Version 2.0 (the "License"); | |
you may not use this file except in compliance with the License. | |
You may obtain a copy of the License at | |
http://www.apache.org/licenses/LICENSE-2.0 | |
Unless required by applicable law or agreed to in writing, software | |
distributed under the License is distributed on an "AS IS" BASIS, | |
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
See the License for the specific language governing permissions and | |
limitations under the License. | |
TODO: Add reference to Mark Weber's tech report on the improved discriminator architecture. | |
""" | |
import functools | |
import math | |
from typing import Tuple | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from .maskgit_vqgan import Conv2dSame | |
class BlurBlock(torch.nn.Module): | |
def __init__(self, | |
kernel: Tuple[int] = (1, 3, 3, 1) | |
): | |
super().__init__() | |
kernel = torch.tensor(kernel, dtype=torch.float32, requires_grad=False) | |
kernel = kernel[None, :] * kernel[:, None] | |
kernel /= kernel.sum() | |
kernel = kernel.unsqueeze(0).unsqueeze(0) | |
self.register_buffer("kernel", kernel) | |
def calc_same_pad(self, i: int, k: int, s: int) -> int: | |
return max((math.ceil(i / s) - 1) * s + (k - 1) + 1 - i, 0) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
ic, ih, iw = x.size()[-3:] | |
pad_h = self.calc_same_pad(i=ih, k=4, s=2) | |
pad_w = self.calc_same_pad(i=iw, k=4, s=2) | |
if pad_h > 0 or pad_w > 0: | |
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) | |
weight = self.kernel.expand(ic, -1, -1, -1) | |
out = F.conv2d(input=x, weight=weight, stride=2, groups=x.shape[1]) | |
return out | |
class NLayerDiscriminator(torch.nn.Module): | |
def __init__( | |
self, | |
num_channels: int = 3, | |
hidden_channels: int = 128, | |
num_stages: int = 3, | |
blur_resample: bool = True, | |
blur_kernel_size: int = 4 | |
): | |
""" Initializes the NLayerDiscriminator. | |
Args: | |
num_channels -> int: The number of input channels. | |
hidden_channels -> int: The number of hidden channels. | |
num_stages -> int: The number of stages. | |
blur_resample -> bool: Whether to use blur resampling. | |
blur_kernel_size -> int: The blur kernel size. | |
""" | |
super().__init__() | |
assert num_stages > 0, "Discriminator cannot have 0 stages" | |
assert (not blur_resample) or (blur_kernel_size >= 3 and blur_kernel_size <= 5), "Blur kernel size must be in [3,5] when sampling]" | |
in_channel_mult = (1,) + tuple(map(lambda t: 2**t, range(num_stages))) | |
init_kernel_size = 5 | |
activation = functools.partial(torch.nn.LeakyReLU, negative_slope=0.1) | |
self.block_in = torch.nn.Sequential( | |
Conv2dSame( | |
num_channels, | |
hidden_channels, | |
kernel_size=init_kernel_size | |
), | |
activation(), | |
) | |
BLUR_KERNEL_MAP = { | |
3: (1,2,1), | |
4: (1,3,3,1), | |
5: (1,4,6,4,1), | |
} | |
discriminator_blocks = [] | |
for i_level in range(num_stages): | |
in_channels = hidden_channels * in_channel_mult[i_level] | |
out_channels = hidden_channels * in_channel_mult[i_level + 1] | |
block = torch.nn.Sequential( | |
Conv2dSame( | |
in_channels, | |
out_channels, | |
kernel_size=3, | |
), | |
torch.nn.AvgPool2d(kernel_size=2, stride=2) if not blur_resample else BlurBlock(BLUR_KERNEL_MAP[blur_kernel_size]), | |
torch.nn.GroupNorm(32, out_channels), | |
activation(), | |
) | |
discriminator_blocks.append(block) | |
self.blocks = torch.nn.ModuleList(discriminator_blocks) | |
self.pool = torch.nn.AdaptiveMaxPool2d((16, 16)) | |
self.to_logits = torch.nn.Sequential( | |
Conv2dSame(out_channels, out_channels, 1), | |
activation(), | |
Conv2dSame(out_channels, 1, kernel_size=5) | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
""" Forward pass. | |
Args: | |
x -> torch.Tensor: The input tensor. | |
Returns: | |
output -> torch.Tensor: The output tensor. | |
""" | |
hidden_states = self.block_in(x) | |
for block in self.blocks: | |
hidden_states = block(hidden_states) | |
hidden_states = self.pool(hidden_states) | |
return self.to_logits(hidden_states) | |