Spaces:
Runtime error
Runtime error
File size: 1,755 Bytes
6672bfb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
"""
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
# norm_nc: the #channels of the normalized activations, hence the output dim of SPADE
# label_nc: the #channels of the input semantic map, hence the input dim of SPADE
# label_nc: also equivalent to the # of input label classes
class SPADE(nn.Module):
def __init__(self, opt, norm_nc):
super().__init__()
self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False)
# number of internal filters for generating scale/bias
nhidden = 128
# size of kernels
kernal_size = 3
# padding size
padding = kernal_size // 2
self.mlp_shared = nn.Sequential(
nn.Conv2d(opt['label_nc'], nhidden, kernel_size=kernal_size, padding=padding),
nn.ReLU()
)
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=kernal_size, padding=padding)
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=kernal_size, padding=padding)
def forward(self, x, segmap):
# Part 1. generate parameter-free normalized activations
normalized = self.param_free_norm(x)
# Part 2. produce scaling and bias conditioned on semantic map
# resize input segmentation map to match x.size() using nearest interpolation
# N, C, H, W = x.size()
segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
actv = self.mlp_shared(segmap)
gamma = self.mlp_gamma(actv)
beta = self.mlp_beta(actv)
# apply scale and bias
out = normalized * (1 + gamma) + beta
return out
|