# python3.7 """Contains the implementation of encoder used in GH-Feat (including IDInvert). ResNet is used as the backbone. GH-Feat paper: https://arxiv.org/pdf/2007.10379.pdf IDInvert paper: https://arxiv.org/pdf/2004.00049.pdf NOTE: Please use `latent_num` and `num_latents_per_head` to control the inversion space, such as Y-space used in GH-Feat and W-space used in IDInvert. In addition, IDInvert sets `use_fpn` and `use_sam` as `False` by default. """ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.distributed as dist __all__ = ['GHFeatEncoder'] # Resolutions allowed. _RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024] # pylint: disable=missing-function-docstring class BasicBlock(nn.Module): """Implementation of ResNet BasicBlock.""" expansion = 1 def __init__(self, inplanes, planes, base_width=64, stride=1, groups=1, dilation=1, norm_layer=None, downsample=None): super().__init__() if base_width != 64: raise ValueError(f'BasicBlock of ResNet only supports ' f'`base_width=64`, but {base_width} received!') if stride not in [1, 2]: raise ValueError(f'BasicBlock of ResNet only supports `stride=1` ' f'and `stride=2`, but {stride} received!') if groups != 1: raise ValueError(f'BasicBlock of ResNet only supports `groups=1`, ' f'but {groups} received!') if dilation != 1: raise ValueError(f'BasicBlock of ResNet only supports ' f'`dilation=1`, but {dilation} received!') assert self.expansion == 1 self.stride = stride if norm_layer is None: norm_layer = nn.BatchNorm2d self.conv1 = nn.Conv2d(in_channels=inplanes, out_channels=planes, kernel_size=3, stride=stride, padding=1, groups=1, dilation=1, bias=False) self.bn1 = norm_layer(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(in_channels=planes, out_channels=planes, kernel_size=3, stride=1, padding=1, groups=1, dilation=1, bias=False) self.bn2 = norm_layer(planes) self.downsample = downsample def forward(self, x): identity = self.downsample(x) if self.downsample is not None else x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out + identity) return out class Bottleneck(nn.Module): """Implementation of ResNet Bottleneck.""" expansion = 4 def __init__(self, inplanes, planes, base_width=64, stride=1, groups=1, dilation=1, norm_layer=None, downsample=None): super().__init__() if stride not in [1, 2]: raise ValueError(f'Bottleneck of ResNet only supports `stride=1` ' f'and `stride=2`, but {stride} received!') width = int(planes * (base_width / 64)) * groups self.stride = stride if norm_layer is None: norm_layer = nn.BatchNorm2d self.conv1 = nn.Conv2d(in_channels=inplanes, out_channels=width, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=False) self.bn1 = norm_layer(width) self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, kernel_size=3, stride=stride, padding=dilation, groups=groups, dilation=dilation, bias=False) self.bn2 = norm_layer(width) self.conv3 = nn.Conv2d(in_channels=width, out_channels=planes * self.expansion, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=False) self.bn3 = norm_layer(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample def forward(self, x): identity = self.downsample(x) if self.downsample is not None else x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) out = self.relu(out + identity) return out class GHFeatEncoder(nn.Module): """Define the ResNet-based encoder network for GAN inversion. On top of the backbone, there are several task-heads to produce inverted codes. Please use `latent_dim` and `num_latents_per_head` to define the structure. For example, `latent_dim = [512] * 14` and `num_latents_per_head = [4, 4, 6]` can be used for StyleGAN inversion with 14-layer latent codes, where 3 task heads (corresponding to 4, 4, 6 layers, respectively) are used. Settings for the encoder network: (1) resolution: The resolution of the output image. (2) latent_dim: Dimension of the latent space. A number (one code will be produced), or a list of numbers regarding layer-wise latent codes. (3) num_latents_per_head: Number of latents that is produced by each head. (4) image_channels: Number of channels of the output image. (default: 3) (5) final_res: Final resolution of the convolutional layers. (default: 4) ResNet-related settings: (1) network_depth: Depth of the network, like 18 for ResNet18. (default: 18) (2) inplanes: Number of channels of the first convolutional layer. (default: 64) (3) groups: Groups of the convolution, used in ResNet. (default: 1) (4) width_per_group: Number of channels per group, used in ResNet. (default: 64) (5) replace_stride_with_dilation: Whether to replace stride with dilation, used in ResNet. (default: None) (6) norm_layer: Normalization layer used in the encoder. If set as `None`, `nn.BatchNorm2d` will be used. Also, please NOTE that when using batch normalization, the batch size is required to be larger than one for training. (default: nn.BatchNorm2d) (7) max_channels: Maximum number of channels in each layer. (default: 512) Task-head related settings: (1) use_fpn: Whether to use Feature Pyramid Network (FPN) before outputting the latent code. (default: True) (2) fpn_channels: Number of channels used in FPN. (default: 512) (3) use_sam: Whether to use Spatial Alignment Module (SAM) before outputting the latent code. (default: True) (4) sam_channels: Number of channels used in SAM. (default: 512) """ arch_settings = { 18: (BasicBlock, [2, 2, 2, 2]), 34: (BasicBlock, [3, 4, 6, 3]), 50: (Bottleneck, [3, 4, 6, 3]), 101: (Bottleneck, [3, 4, 23, 3]), 152: (Bottleneck, [3, 8, 36, 3]) } def __init__(self, resolution, latent_dim, num_latents_per_head, image_channels=3, final_res=4, network_depth=18, inplanes=64, groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=nn.BatchNorm2d, max_channels=512, use_fpn=True, fpn_channels=512, use_sam=True, sam_channels=512): super().__init__() if resolution not in _RESOLUTIONS_ALLOWED: raise ValueError(f'Invalid resolution: `{resolution}`!\n' f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.') if network_depth not in self.arch_settings: raise ValueError(f'Invalid network depth: `{network_depth}`!\n' f'Options allowed: ' f'{list(self.arch_settings.keys())}.') if isinstance(latent_dim, int): latent_dim = [latent_dim] assert isinstance(latent_dim, (list, tuple)) assert isinstance(num_latents_per_head, (list, tuple)) assert sum(num_latents_per_head) == len(latent_dim) self.resolution = resolution self.latent_dim = latent_dim self.num_latents_per_head = num_latents_per_head self.num_heads = len(self.num_latents_per_head) self.image_channels = image_channels self.final_res = final_res self.inplanes = inplanes self.network_depth = network_depth self.groups = groups self.dilation = 1 self.base_width = width_per_group self.replace_stride_with_dilation = replace_stride_with_dilation if norm_layer is None: norm_layer = nn.BatchNorm2d if norm_layer == nn.BatchNorm2d and dist.is_initialized(): norm_layer = nn.SyncBatchNorm self.norm_layer = norm_layer self.max_channels = max_channels self.use_fpn = use_fpn self.fpn_channels = fpn_channels self.use_sam = use_sam self.sam_channels = sam_channels block_fn, num_blocks_per_stage = self.arch_settings[network_depth] self.num_stages = int(np.log2(resolution // final_res)) - 1 # Add one block for additional stages. for i in range(len(num_blocks_per_stage), self.num_stages): num_blocks_per_stage.append(1) if replace_stride_with_dilation is None: replace_stride_with_dilation = [False] * self.num_stages # Backbone. self.conv1 = nn.Conv2d(in_channels=self.image_channels, out_channels=self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.stage_channels = [self.inplanes] self.stages = nn.ModuleList() for i in range(self.num_stages): inplanes = self.inplanes if i == 0 else planes * block_fn.expansion planes = min(self.max_channels, self.inplanes * (2 ** i)) num_blocks = num_blocks_per_stage[i] stride = 1 if i == 0 else 2 dilate = replace_stride_with_dilation[i] self.stages.append(self._make_stage(block_fn=block_fn, inplanes=inplanes, planes=planes, num_blocks=num_blocks, stride=stride, dilate=dilate)) self.stage_channels.append(planes * block_fn.expansion) if self.num_heads > len(self.stage_channels): raise ValueError('Number of task heads is larger than number of ' 'stages! Please reduce the number of heads.') # Task-head. if self.num_heads == 1: self.use_fpn = False self.use_sam = False if self.use_fpn: fpn_pyramid_channels = self.stage_channels[-self.num_heads:] self.fpn = FPN(pyramid_channels=fpn_pyramid_channels, out_channels=self.fpn_channels) if self.use_sam: if self.use_fpn: sam_pyramid_channels = [self.fpn_channels] * self.num_heads else: sam_pyramid_channels = self.stage_channels[-self.num_heads:] self.sam = SAM(pyramid_channels=sam_pyramid_channels, out_channels=self.sam_channels) self.heads = nn.ModuleList() for head_idx in range(self.num_heads): # Parse in_channels. if self.use_sam: in_channels = self.sam_channels elif self.use_fpn: in_channels = self.fpn_channels else: in_channels = self.stage_channels[head_idx - self.num_heads] in_channels = in_channels * final_res * final_res # Parse out_channels. start_latent_idx = sum(self.num_latents_per_head[:head_idx]) end_latent_idx = sum(self.num_latents_per_head[:head_idx + 1]) out_channels = sum(self.latent_dim[start_latent_idx:end_latent_idx]) self.heads.append(CodeHead(in_channels=in_channels, out_channels=out_channels, norm_layer=self.norm_layer)) def _make_stage(self, block_fn, inplanes, planes, num_blocks, stride, dilate): norm_layer = self.norm_layer downsample = None previous_dilation = self.dilation if dilate: self.dilation *= stride stride = 1 if stride != 1 or inplanes != planes * block_fn.expansion: downsample = nn.Sequential( nn.Conv2d(in_channels=inplanes, out_channels=planes * block_fn.expansion, kernel_size=1, stride=stride, padding=0, dilation=1, groups=1, bias=False), norm_layer(planes * block_fn.expansion), ) blocks = [] blocks.append(block_fn(inplanes=inplanes, planes=planes, base_width=self.base_width, stride=stride, groups=self.groups, dilation=previous_dilation, norm_layer=norm_layer, downsample=downsample)) for _ in range(1, num_blocks): blocks.append(block_fn(inplanes=planes * block_fn.expansion, planes=planes, base_width=self.base_width, stride=1, groups=self.groups, dilation=self.dilation, norm_layer=norm_layer, downsample=None)) return nn.Sequential(*blocks) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) features = [x] for i in range(self.num_stages): x = self.stages[i](x) features.append(x) features = features[-self.num_heads:] if self.use_fpn: features = self.fpn(features) if self.use_sam: features = self.sam(features) else: final_size = features[-1].shape[2:] for i in range(self.num_heads - 1): features[i] = F.adaptive_avg_pool2d(features[i], final_size) outputs = [] for head_idx in range(self.num_heads): codes = self.heads[head_idx](features[head_idx]) start_latent_idx = sum(self.num_latents_per_head[:head_idx]) end_latent_idx = sum(self.num_latents_per_head[:head_idx + 1]) split_size = self.latent_dim[start_latent_idx:end_latent_idx] outputs.extend(torch.split(codes, split_size, dim=1)) max_dim = max(self.latent_dim) for i, dim in enumerate(self.latent_dim): if dim < max_dim: outputs[i] = F.pad(outputs[i], (0, max_dim - dim)) outputs[i] = outputs[i].unsqueeze(1) return torch.cat(outputs, dim=1) class FPN(nn.Module): """Implementation of Feature Pyramid Network (FPN). The input of this module is a pyramid of features with reducing resolutions. Then, this module fuses these multi-level features from `top_level` to `bottom_level`. In particular, starting from the `top_level`, each feature is convoluted, upsampled, and fused into its previous feature (which is also convoluted). Args: pyramid_channels: A list of integers, each of which indicates the number of channels of the feature from a particular level. out_channels: Number of channels for each output. Returns: A list of feature maps, each of which has `out_channels` channels. """ def __init__(self, pyramid_channels, out_channels): super().__init__() assert isinstance(pyramid_channels, (list, tuple)) self.num_levels = len(pyramid_channels) self.lateral_layers = nn.ModuleList() self.feature_layers = nn.ModuleList() for i in range(self.num_levels): in_channels = pyramid_channels[i] self.lateral_layers.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=True)) self.feature_layers.append(nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=True)) def forward(self, inputs): if len(inputs) != self.num_levels: raise ValueError('Number of inputs and `num_levels` mismatch!') # Project all related features to `out_channels`. laterals = [] for i in range(self.num_levels): laterals.append(self.lateral_layers[i](inputs[i])) # Fusion, starting from `top_level`. for i in range(self.num_levels - 1, 0, -1): scale_factor = laterals[i - 1].shape[2] // laterals[i].shape[2] laterals[i - 1] = (laterals[i - 1] + F.interpolate(laterals[i], mode='nearest', scale_factor=scale_factor)) # Get outputs. outputs = [] for i, lateral in enumerate(laterals): outputs.append(self.feature_layers[i](lateral)) return outputs class SAM(nn.Module): """Implementation of Spatial Alignment Module (SAM). The input of this module is a pyramid of features with reducing resolutions. Then this module downsamples all levels of feature to the minimum resolution and fuses it with the smallest feature map. Args: pyramid_channels: A list of integers, each of which indicates the number of channels of the feature from a particular level. out_channels: Number of channels for each output. Returns: A list of feature maps, each of which has `out_channels` channels. """ def __init__(self, pyramid_channels, out_channels): super().__init__() assert isinstance(pyramid_channels, (list, tuple)) self.num_levels = len(pyramid_channels) self.fusion_layers = nn.ModuleList() for i in range(self.num_levels): in_channels = pyramid_channels[i] self.fusion_layers.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=True)) def forward(self, inputs): if len(inputs) != self.num_levels: raise ValueError('Number of inputs and `num_levels` mismatch!') output_res = inputs[-1].shape[2:] for i in range(self.num_levels - 1, -1, -1): if i != self.num_levels - 1: inputs[i] = F.adaptive_avg_pool2d(inputs[i], output_res) inputs[i] = self.fusion_layers[i](inputs[i]) if i != self.num_levels - 1: inputs[i] = inputs[i] + inputs[-1] return inputs class CodeHead(nn.Module): """Implementation of the task-head to produce inverted codes.""" def __init__(self, in_channels, out_channels, norm_layer): super().__init__() self.fc = nn.Linear(in_channels, out_channels, bias=True) if norm_layer is None: self.norm = nn.Identity() else: self.norm = norm_layer(out_channels) def forward(self, x): if x.ndim > 2: x = x.flatten(start_dim=1) latent = self.fc(x) latent = latent.unsqueeze(2).unsqueeze(3) latent = self.norm(latent) return latent.flatten(start_dim=1) # pylint: enable=missing-function-docstring