import numpy as np import torch import torch.nn as nn from torch.nn import Upsample as NearestUpsample import torch.nn.functional as F from functools import partial import sys sys.path.append(".") from imaginaire.layers import Conv2dBlock, LinearBlock, Res2dBlock class StyleMLP(nn.Module): r"""MLP converting style code to intermediate style representation.""" def __init__(self, style_dim, out_dim, hidden_channels=256, leaky_relu=True, num_layers=5, normalize_input=True, output_act=True): super(StyleMLP, self).__init__() self.normalize_input = normalize_input self.output_act = output_act fc_layers = [] fc_layers.append(nn.Linear(style_dim, hidden_channels, bias=True)) for i in range(num_layers-1): fc_layers.append(nn.Linear(hidden_channels, hidden_channels, bias=True)) self.fc_layers = nn.ModuleList(fc_layers) self.fc_out = nn.Linear(hidden_channels, out_dim, bias=True) if leaky_relu: self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True) else: self.act = partial(F.relu, inplace=True) def forward(self, z): r""" Forward network Args: z (N x style_dim tensor): Style codes. """ if self.normalize_input: z = F.normalize(z, p=2, dim=-1,eps=1e-6) for fc_layer in self.fc_layers: z = self.act(fc_layer(z)) z = self.fc_out(z) if self.output_act: z = self.act(z) return z class histo_process(nn.Module): r"""Histo process to replace Style Encoder constructor. Args: style_enc_cfg (obj): Style encoder definition file. """ def __init__(self,style_enc_cfg): super().__init__() # if style_enc_cfg.histo.mode in ['RGB','rgb']: input_channel=270 # else: # input_channel=90 style_dims = style_enc_cfg.style_dims self.no_vae = getattr(style_enc_cfg, 'no_vae', False) num_filters = getattr(style_enc_cfg, 'num_filters', 180) self.process_model = nn.ModuleList() self.layer1 = LinearBlock(input_channel,num_filters) self.layer4 = LinearBlock(num_filters, num_filters) self.fc_mu = LinearBlock(num_filters, style_dims,nonlinearity='tanh') if not self.no_vae: self.fc_var = LinearBlock(num_filters, style_dims,nonlinearity='tanh') def forward(self,histo): x = self.layer1(histo) x = self.layer4(x) mu = self.fc_mu(x) #[-1,1] if not self.no_vae: logvar = self.fc_var(x) # [-1,1] std = torch.exp(0.5 * logvar) # [0.607,1.624] eps = torch.randn_like(std) z = eps.mul(std) + mu else: z = mu logvar = torch.zeros_like(mu) return mu, logvar, z class StyleEncoder(nn.Module): r"""Style Encoder constructor. Args: style_enc_cfg (obj): Style encoder definition file. """ def __init__(self, style_enc_cfg): super(StyleEncoder, self).__init__() input_image_channels = style_enc_cfg.input_image_channels num_filters = style_enc_cfg.num_filters kernel_size = style_enc_cfg.kernel_size padding = int(np.ceil((kernel_size - 1.0) / 2)) style_dims = style_enc_cfg.style_dims weight_norm_type = style_enc_cfg.weight_norm_type self.no_vae = getattr(style_enc_cfg, 'no_vae', False) activation_norm_type = 'none' nonlinearity = 'leakyrelu' base_conv2d_block = \ partial(Conv2dBlock, kernel_size=kernel_size, stride=2, padding=padding, weight_norm_type=weight_norm_type, activation_norm_type=activation_norm_type, # inplace_nonlinearity=True, nonlinearity=nonlinearity) self.layer1 = base_conv2d_block(input_image_channels, num_filters) self.layer2 = base_conv2d_block(num_filters * 1, num_filters * 2) self.layer3 = base_conv2d_block(num_filters * 2, num_filters * 4) self.layer4 = base_conv2d_block(num_filters * 4, num_filters * 8) self.layer5 = base_conv2d_block(num_filters * 8, num_filters * 8) self.layer6 = base_conv2d_block(num_filters * 8, num_filters * 8) self.fc_mu = LinearBlock(num_filters * 8 * 4 * 4, style_dims,nonlinearity='tanh') if not self.no_vae: self.fc_var = LinearBlock(num_filters * 8 * 4 * 4, style_dims,nonlinearity='tanh') def forward(self, input_x): r"""SPADE Style Encoder forward. Args: input_x (N x 3 x H x W tensor): input images. Returns: mu (N x C tensor): Mean vectors. logvar (N x C tensor): Log-variance vectors. z (N x C tensor): Style code vectors. """ if input_x.size(2) != 256 or input_x.size(3) != 256: input_x = F.interpolate(input_x, size=(256, 256), mode='bilinear') x = self.layer1(input_x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.layer5(x) x = self.layer6(x) x = x.view(x.size(0), -1) mu = self.fc_mu(x) if not self.no_vae: logvar = self.fc_var(x) std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) z = eps.mul(std) + mu else: z = mu logvar = torch.zeros_like(mu) return mu, logvar, z class RenderCNN(nn.Module): r"""CNN converting intermediate feature map to final image.""" def __init__(self, in_channels, style_dim, hidden_channels=256, leaky_relu=True): super(RenderCNN, self).__init__() self.fc_z_cond = nn.Linear(style_dim, 2 * 2 * hidden_channels) self.conv1 = nn.Conv2d(in_channels, hidden_channels, 1, stride=1, padding=0) self.conv2a = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1) self.conv2b = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1, bias=False) self.conv3a = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1) self.conv3b = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1, bias=False) self.conv4a = nn.Conv2d(hidden_channels, hidden_channels, 1, stride=1, padding=0) self.conv4b = nn.Conv2d(hidden_channels, hidden_channels, 1, stride=1, padding=0) self.conv4 = nn.Conv2d(hidden_channels, 3, 1, stride=1, padding=0) if leaky_relu: self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True) else: self.act = partial(F.relu, inplace=True) def modulate(self, x, w_, b_): w_ = w_[..., None, None] b_ = b_[..., None, None] return x * (w_+1) + b_ +1e-9 def forward(self, x, z): r"""Forward network. Args: x (N x in_channels x H x W tensor): Intermediate feature map z (N x style_dim tensor): Style codes. """ z = self.fc_z_cond(z) adapt = torch.chunk(z, 2 * 2, dim=-1) y = self.act(self.conv1(x)) y = y + self.conv2b(self.act(self.conv2a(y))) y = self.act(self.modulate(y, adapt[0], adapt[1])) y = y + self.conv3b(self.act(self.conv3a(y))) y = self.act(self.modulate(y, adapt[2], adapt[3])) y = y + self.conv4b(self.act(self.conv4a(y))) y = self.act(y) y = self.conv4(y) y = torch.sigmoid(y) return y class inner_Generator(nn.Module): r"""Pix2pixHD coarse-to-fine generator constructor. Args: gen_cfg (obj): Generator definition part of the yaml config file. data_cfg (obj): Data definition part of the yaml config file. last_act: ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``,default is 'relu'. """ def __init__(self, gen_cfg,inner_cfg, data_cfg,num_input_channels=3,last_act='relu'): super().__init__() assert last_act in ['none', 'relu', 'leakyrelu', 'prelu', 'tanh' , 'sigmoid' , 'softmax'] # pix2pixHD has a global generator. global_gen_cfg = inner_cfg # By default, pix2pixHD using instance normalization. activation_norm_type = getattr(gen_cfg, 'activation_norm_type', 'instance') activation_norm_params = getattr(gen_cfg, 'activation_norm_params', None) weight_norm_type = getattr(gen_cfg, 'weight_norm_type', '') padding_mode = getattr(gen_cfg, 'padding_mode', 'reflect') base_conv_block = partial(Conv2dBlock, padding_mode=padding_mode, weight_norm_type=weight_norm_type, activation_norm_type=activation_norm_type, activation_norm_params=activation_norm_params, nonlinearity='relu') base_res_block = partial(Res2dBlock, padding_mode=padding_mode, weight_norm_type=weight_norm_type, activation_norm_type=activation_norm_type, activation_norm_params=activation_norm_params, nonlinearity='relu', order='CNACN') # Know what is the number of available segmentation labels. # Global generator model. global_model = GlobalGenerator(global_gen_cfg, data_cfg, num_input_channels, padding_mode, base_conv_block, base_res_block,last_act=last_act) self.global_model = global_model def forward(self, input, random_style=False): r"""Coarse-to-fine generator forward. Args: data (dict) : Dictionary of input data. random_style (bool): Always set to false for the pix2pixHD model. Returns: output (dict) : Dictionary of output data. """ return self.global_model(input) def load_pretrained_network(self, pretrained_dict): r"""Load a pretrained network.""" # print(pretrained_dict.keys()) model_dict = self.state_dict() print('Pretrained network has fewer layers; The following are ' 'not initialized:') not_initialized = set() for k, v in model_dict.items(): kp = 'module.' + k.replace('global_model.', 'global_model.model.') if kp in pretrained_dict and v.size() == pretrained_dict[kp].size(): model_dict[k] = pretrained_dict[kp] else: not_initialized.add('.'.join(k.split('.')[:2])) print(sorted(not_initialized)) self.load_state_dict(model_dict) def inference(self, data, **kwargs): r"""Generator inference. Args: data (dict) : Dictionary of input data. Returns: fake_images (tensor): Output fake images. file_names (str): Data file name. """ output = self.forward(data, **kwargs) return output['fake_images'], data['key']['seg_maps'][0] class GlobalGenerator(nn.Module): r"""Coarse generator constructor. This is the main generator in the pix2pixHD architecture. Args: gen_cfg (obj): Generator definition part of the yaml config file. data_cfg (obj): Data definition part of the yaml config file. num_input_channels (int): Number of segmentation labels. padding_mode (str): zero | reflect | ... base_conv_block (obj): Conv block with preset attributes. base_res_block (obj): Residual block with preset attributes. last_act (str, optional, default='relu'): Type of nonlinear activation function. ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. """ def __init__(self, gen_cfg, data_cfg, num_input_channels, padding_mode, base_conv_block, base_res_block,last_act='relu'): super(GlobalGenerator, self).__init__() # num_img_channels = get_paired_input_image_channel_number(data_cfg) num_out_put_channels = getattr(gen_cfg, 'output_nc', 64) num_filters = getattr(gen_cfg, 'num_filters', 64) num_downsamples = getattr(gen_cfg, 'num_downsamples', 4) num_res_blocks = getattr(gen_cfg, 'num_res_blocks', 9) # First layer. model = [base_conv_block(num_input_channels, num_filters, kernel_size=7, padding=3)] # Downsample. for i in range(num_downsamples): ch = num_filters * (2 ** i) model += [base_conv_block(ch, ch * 2, 3, padding=1, stride=2)] # ResNet blocks. ch = num_filters * (2 ** num_downsamples) for i in range(num_res_blocks): model += [base_res_block(ch, ch, 3, padding=1)] # Upsample. num_upsamples = num_downsamples for i in reversed(range(num_upsamples)): ch = num_filters * (2 ** i) model += \ [NearestUpsample(scale_factor=2), base_conv_block(ch * 2, ch, 3, padding=1)] model += [Conv2dBlock(num_filters, num_out_put_channels, 7, padding=3, padding_mode=padding_mode, nonlinearity=last_act)] self.model = nn.Sequential(*model) def forward(self, input): r"""Coarse-to-fine generator forward. Args: input (4D tensor) : Input semantic representations. Returns: output (4D tensor) : Synthesized image by generator. """ return self.model(input) class inner_Generator_split(nn.Module): r"""Pix2pixHD coarse-to-fine generator constructor. Args: gen_cfg (obj): Generator definition part of the yaml config file. data_cfg (obj): Data definition part of the yaml config file. last_act: ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``,default is 'relu'. """ def __init__(self, gen_cfg,inner_cfg, data_cfg,num_input_channels=3,last_act='relu'): super().__init__() assert last_act in ['none', 'relu', 'leakyrelu', 'prelu', 'tanh' , 'sigmoid' , 'softmax'] # pix2pixHD has a global generator. # By default, pix2pixHD using instance normalization. print(inner_cfg) style_dim = gen_cfg.style_enc_cfg.interm_style_dims activation_norm_type = getattr(gen_cfg, 'activation_norm_type', 'instance') activation_norm_params = getattr(gen_cfg, 'activation_norm_params', None) weight_norm_type = getattr(gen_cfg, 'weight_norm_type', '') padding_mode = getattr(gen_cfg, 'padding_mode', 'reflect') # num_input_channels = get_paired_input_label_channel_number(data_cfg) # num_input_channels = 3 base_conv_block = partial(Conv2dBlock, padding_mode=padding_mode, weight_norm_type=weight_norm_type, activation_norm_type=activation_norm_type, activation_norm_params=activation_norm_params, ) base_res_block = partial(Res2dBlock, padding_mode=padding_mode, weight_norm_type=weight_norm_type, activation_norm_type=activation_norm_type, activation_norm_params=activation_norm_params, nonlinearity='relu', order='CNACN') # Know what is the number of available segmentation labels. # Global generator model. num_out_put_channels = getattr(inner_cfg, 'output_nc', 64) num_filters = getattr(inner_cfg, 'num_filters', 64) num_downsamples = 4 num_res_blocks = getattr(inner_cfg, 'num_res_blocks', 9) # First layer. model = [base_conv_block(num_input_channels, num_filters, kernel_size=7, padding=3)] model += [nn.PReLU()] # Downsample. for i in range(num_downsamples): ch = num_filters * (2 ** i) model += [base_conv_block(ch, ch * 2, 3, padding=1, stride=2)] model += [nn.PReLU()] # ResNet blocks. ch = num_filters * (2 ** num_downsamples) for i in range(num_res_blocks): model += [base_res_block(ch, ch, 3, padding=1)] self.model = nn.Sequential(*model) # Upsample. assert num_downsamples == 4 if not (inner_cfg.name =='render' and gen_cfg.style_inject): list = [16,8,4,2] else: list = [16,6,6,6] self.up0_a = NearestUpsample(scale_factor=2) self.up0_b = base_conv_block(num_filters * list[0], num_filters*list[1], 3, padding=1) self.up1_a = NearestUpsample(scale_factor=2) self.up1_b = base_conv_block(num_filters * list[1], num_filters*list[2], 3, padding=1) self.up2_a = NearestUpsample(scale_factor=2) self.up2_b = base_conv_block(num_filters * list[2], num_filters*list[3], 3, padding=1) self.up3_a = NearestUpsample(scale_factor=2) self.up3_b = base_conv_block(num_filters * list[3], num_filters, 3, padding=1) self.up_end = Conv2dBlock(num_filters, num_out_put_channels, 7, padding=3, padding_mode=padding_mode, nonlinearity=last_act) if inner_cfg.name =='render' and gen_cfg.style_inject: self.fc_z_cond = nn.Linear(style_dim, 4* list[-1] * num_filters) def modulate(self, x, w, b): w = w[..., None, None] b = b[..., None, None] return x * (w+1) + b def forward(self, input,z=None): r"""Coarse-to-fine generator forward. Args: input (4D tensor) : Input semantic representations. Returns: output (4D tensor) : Synthesized image by generator. """ if z is not None: z = self.fc_z_cond(z) adapt = torch.chunk(z, 2 * 2, dim=-1) input = self.model(input) input = self.up0_a(input) input = self.up0_b(input) input = F.leaky_relu(input,negative_slope=0.2, inplace=True) input = self.up1_a(input) input = self.up1_b(input) if z is not None: input = self.modulate(input, adapt[0], adapt[1]) input = F.leaky_relu(input,negative_slope=0.2, inplace=True) input = self.up2_a(input) input = self.up2_b(input) if z is not None: input = self.modulate(input, adapt[2], adapt[3]) input = F.leaky_relu(input,negative_slope=0.2, inplace=True) input = self.up3_a(input) input = self.up3_b(input) input = F.leaky_relu(input,negative_slope=0.2, inplace=True) input = self.up_end(input) return input if __name__=='__main__': from easydict import EasyDict as edict style_enc_cfg = edict() style_enc_cfg.histo.mode = 'RGB' style_enc_cfg.histo.num_filters = 180 model = histo_process(style_enc_cfg)