Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torch.nn import Upsample as NearestUpsample | |
import torch.nn.functional as F | |
from functools import partial | |
import sys | |
sys.path.append(".") | |
from imaginaire.layers import Conv2dBlock, LinearBlock, Res2dBlock | |
class StyleMLP(nn.Module): | |
r"""MLP converting style code to intermediate style representation.""" | |
def __init__(self, style_dim, out_dim, hidden_channels=256, leaky_relu=True, num_layers=5, normalize_input=True, | |
output_act=True): | |
super(StyleMLP, self).__init__() | |
self.normalize_input = normalize_input | |
self.output_act = output_act | |
fc_layers = [] | |
fc_layers.append(nn.Linear(style_dim, hidden_channels, bias=True)) | |
for i in range(num_layers-1): | |
fc_layers.append(nn.Linear(hidden_channels, hidden_channels, bias=True)) | |
self.fc_layers = nn.ModuleList(fc_layers) | |
self.fc_out = nn.Linear(hidden_channels, out_dim, bias=True) | |
if leaky_relu: | |
self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
else: | |
self.act = partial(F.relu, inplace=True) | |
def forward(self, z): | |
r""" Forward network | |
Args: | |
z (N x style_dim tensor): Style codes. | |
""" | |
if self.normalize_input: | |
z = F.normalize(z, p=2, dim=-1,eps=1e-6) | |
for fc_layer in self.fc_layers: | |
z = self.act(fc_layer(z)) | |
z = self.fc_out(z) | |
if self.output_act: | |
z = self.act(z) | |
return z | |
class histo_process(nn.Module): | |
r"""Histo process to replace Style Encoder constructor. | |
Args: | |
style_enc_cfg (obj): Style encoder definition file. | |
""" | |
def __init__(self,style_enc_cfg): | |
super().__init__() | |
# if style_enc_cfg.histo.mode in ['RGB','rgb']: | |
input_channel=270 | |
# else: | |
# input_channel=90 | |
style_dims = style_enc_cfg.style_dims | |
self.no_vae = getattr(style_enc_cfg, 'no_vae', False) | |
num_filters = getattr(style_enc_cfg, 'num_filters', 180) | |
self.process_model = nn.ModuleList() | |
self.layer1 = LinearBlock(input_channel,num_filters) | |
self.layer4 = LinearBlock(num_filters, num_filters) | |
self.fc_mu = LinearBlock(num_filters, style_dims,nonlinearity='tanh') | |
if not self.no_vae: | |
self.fc_var = LinearBlock(num_filters, style_dims,nonlinearity='tanh') | |
def forward(self,histo): | |
x = self.layer1(histo) | |
x = self.layer4(x) | |
mu = self.fc_mu(x) #[-1,1] | |
if not self.no_vae: | |
logvar = self.fc_var(x) # [-1,1] | |
std = torch.exp(0.5 * logvar) # [0.607,1.624] | |
eps = torch.randn_like(std) | |
z = eps.mul(std) + mu | |
else: | |
z = mu | |
logvar = torch.zeros_like(mu) | |
return mu, logvar, z | |
class StyleEncoder(nn.Module): | |
r"""Style Encoder constructor. | |
Args: | |
style_enc_cfg (obj): Style encoder definition file. | |
""" | |
def __init__(self, style_enc_cfg): | |
super(StyleEncoder, self).__init__() | |
input_image_channels = style_enc_cfg.input_image_channels | |
num_filters = style_enc_cfg.num_filters | |
kernel_size = style_enc_cfg.kernel_size | |
padding = int(np.ceil((kernel_size - 1.0) / 2)) | |
style_dims = style_enc_cfg.style_dims | |
weight_norm_type = style_enc_cfg.weight_norm_type | |
self.no_vae = getattr(style_enc_cfg, 'no_vae', False) | |
activation_norm_type = 'none' | |
nonlinearity = 'leakyrelu' | |
base_conv2d_block = \ | |
partial(Conv2dBlock, | |
kernel_size=kernel_size, | |
stride=2, | |
padding=padding, | |
weight_norm_type=weight_norm_type, | |
activation_norm_type=activation_norm_type, | |
# inplace_nonlinearity=True, | |
nonlinearity=nonlinearity) | |
self.layer1 = base_conv2d_block(input_image_channels, num_filters) | |
self.layer2 = base_conv2d_block(num_filters * 1, num_filters * 2) | |
self.layer3 = base_conv2d_block(num_filters * 2, num_filters * 4) | |
self.layer4 = base_conv2d_block(num_filters * 4, num_filters * 8) | |
self.layer5 = base_conv2d_block(num_filters * 8, num_filters * 8) | |
self.layer6 = base_conv2d_block(num_filters * 8, num_filters * 8) | |
self.fc_mu = LinearBlock(num_filters * 8 * 4 * 4, style_dims,nonlinearity='tanh') | |
if not self.no_vae: | |
self.fc_var = LinearBlock(num_filters * 8 * 4 * 4, style_dims,nonlinearity='tanh') | |
def forward(self, input_x): | |
r"""SPADE Style Encoder forward. | |
Args: | |
input_x (N x 3 x H x W tensor): input images. | |
Returns: | |
mu (N x C tensor): Mean vectors. | |
logvar (N x C tensor): Log-variance vectors. | |
z (N x C tensor): Style code vectors. | |
""" | |
if input_x.size(2) != 256 or input_x.size(3) != 256: | |
input_x = F.interpolate(input_x, size=(256, 256), mode='bilinear') | |
x = self.layer1(input_x) | |
x = self.layer2(x) | |
x = self.layer3(x) | |
x = self.layer4(x) | |
x = self.layer5(x) | |
x = self.layer6(x) | |
x = x.view(x.size(0), -1) | |
mu = self.fc_mu(x) | |
if not self.no_vae: | |
logvar = self.fc_var(x) | |
std = torch.exp(0.5 * logvar) | |
eps = torch.randn_like(std) | |
z = eps.mul(std) + mu | |
else: | |
z = mu | |
logvar = torch.zeros_like(mu) | |
return mu, logvar, z | |
class RenderCNN(nn.Module): | |
r"""CNN converting intermediate feature map to final image.""" | |
def __init__(self, in_channels, style_dim, hidden_channels=256, | |
leaky_relu=True): | |
super(RenderCNN, self).__init__() | |
self.fc_z_cond = nn.Linear(style_dim, 2 * 2 * hidden_channels) | |
self.conv1 = nn.Conv2d(in_channels, hidden_channels, 1, stride=1, padding=0) | |
self.conv2a = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1) | |
self.conv2b = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1, bias=False) | |
self.conv3a = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1) | |
self.conv3b = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1, bias=False) | |
self.conv4a = nn.Conv2d(hidden_channels, hidden_channels, 1, stride=1, padding=0) | |
self.conv4b = nn.Conv2d(hidden_channels, hidden_channels, 1, stride=1, padding=0) | |
self.conv4 = nn.Conv2d(hidden_channels, 3, 1, stride=1, padding=0) | |
if leaky_relu: | |
self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
else: | |
self.act = partial(F.relu, inplace=True) | |
def modulate(self, x, w_, b_): | |
w_ = w_[..., None, None] | |
b_ = b_[..., None, None] | |
return x * (w_+1) + b_ +1e-9 | |
def forward(self, x, z): | |
r"""Forward network. | |
Args: | |
x (N x in_channels x H x W tensor): Intermediate feature map | |
z (N x style_dim tensor): Style codes. | |
""" | |
z = self.fc_z_cond(z) | |
adapt = torch.chunk(z, 2 * 2, dim=-1) | |
y = self.act(self.conv1(x)) | |
y = y + self.conv2b(self.act(self.conv2a(y))) | |
y = self.act(self.modulate(y, adapt[0], adapt[1])) | |
y = y + self.conv3b(self.act(self.conv3a(y))) | |
y = self.act(self.modulate(y, adapt[2], adapt[3])) | |
y = y + self.conv4b(self.act(self.conv4a(y))) | |
y = self.act(y) | |
y = self.conv4(y) | |
y = torch.sigmoid(y) | |
return y | |
class inner_Generator(nn.Module): | |
r"""Pix2pixHD coarse-to-fine generator constructor. | |
Args: | |
gen_cfg (obj): Generator definition part of the yaml config file. | |
data_cfg (obj): Data definition part of the yaml config file. | |
last_act: ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, | |
``'tanh'`` , ``'sigmoid'`` or ``'softmax'``,default is 'relu'. | |
""" | |
def __init__(self, gen_cfg,inner_cfg, data_cfg,num_input_channels=3,last_act='relu'): | |
super().__init__() | |
assert last_act in ['none', 'relu', 'leakyrelu', 'prelu', | |
'tanh' , 'sigmoid' , 'softmax'] | |
# pix2pixHD has a global generator. | |
global_gen_cfg = inner_cfg | |
# By default, pix2pixHD using instance normalization. | |
activation_norm_type = getattr(gen_cfg, 'activation_norm_type', | |
'instance') | |
activation_norm_params = getattr(gen_cfg, 'activation_norm_params', | |
None) | |
weight_norm_type = getattr(gen_cfg, 'weight_norm_type', '') | |
padding_mode = getattr(gen_cfg, 'padding_mode', 'reflect') | |
base_conv_block = partial(Conv2dBlock, | |
padding_mode=padding_mode, | |
weight_norm_type=weight_norm_type, | |
activation_norm_type=activation_norm_type, | |
activation_norm_params=activation_norm_params, | |
nonlinearity='relu') | |
base_res_block = partial(Res2dBlock, | |
padding_mode=padding_mode, | |
weight_norm_type=weight_norm_type, | |
activation_norm_type=activation_norm_type, | |
activation_norm_params=activation_norm_params, | |
nonlinearity='relu', order='CNACN') | |
# Know what is the number of available segmentation labels. | |
# Global generator model. | |
global_model = GlobalGenerator(global_gen_cfg, data_cfg, | |
num_input_channels, padding_mode, | |
base_conv_block, base_res_block,last_act=last_act) | |
self.global_model = global_model | |
def forward(self, input, random_style=False): | |
r"""Coarse-to-fine generator forward. | |
Args: | |
data (dict) : Dictionary of input data. | |
random_style (bool): Always set to false for the pix2pixHD model. | |
Returns: | |
output (dict) : Dictionary of output data. | |
""" | |
return self.global_model(input) | |
def load_pretrained_network(self, pretrained_dict): | |
r"""Load a pretrained network.""" | |
# print(pretrained_dict.keys()) | |
model_dict = self.state_dict() | |
print('Pretrained network has fewer layers; The following are ' | |
'not initialized:') | |
not_initialized = set() | |
for k, v in model_dict.items(): | |
kp = 'module.' + k.replace('global_model.', 'global_model.model.') | |
if kp in pretrained_dict and v.size() == pretrained_dict[kp].size(): | |
model_dict[k] = pretrained_dict[kp] | |
else: | |
not_initialized.add('.'.join(k.split('.')[:2])) | |
print(sorted(not_initialized)) | |
self.load_state_dict(model_dict) | |
def inference(self, data, **kwargs): | |
r"""Generator inference. | |
Args: | |
data (dict) : Dictionary of input data. | |
Returns: | |
fake_images (tensor): Output fake images. | |
file_names (str): Data file name. | |
""" | |
output = self.forward(data, **kwargs) | |
return output['fake_images'], data['key']['seg_maps'][0] | |
class GlobalGenerator(nn.Module): | |
r"""Coarse generator constructor. This is the main generator in the | |
pix2pixHD architecture. | |
Args: | |
gen_cfg (obj): Generator definition part of the yaml config file. | |
data_cfg (obj): Data definition part of the yaml config file. | |
num_input_channels (int): Number of segmentation labels. | |
padding_mode (str): zero | reflect | ... | |
base_conv_block (obj): Conv block with preset attributes. | |
base_res_block (obj): Residual block with preset attributes. | |
last_act (str, optional, default='relu'): | |
Type of nonlinear activation function. | |
``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, | |
``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. | |
""" | |
def __init__(self, gen_cfg, data_cfg, num_input_channels, padding_mode, | |
base_conv_block, base_res_block,last_act='relu'): | |
super(GlobalGenerator, self).__init__() | |
# num_img_channels = get_paired_input_image_channel_number(data_cfg) | |
num_out_put_channels = getattr(gen_cfg, 'output_nc', 64) | |
num_filters = getattr(gen_cfg, 'num_filters', 64) | |
num_downsamples = getattr(gen_cfg, 'num_downsamples', 4) | |
num_res_blocks = getattr(gen_cfg, 'num_res_blocks', 9) | |
# First layer. | |
model = [base_conv_block(num_input_channels, num_filters, | |
kernel_size=7, padding=3)] | |
# Downsample. | |
for i in range(num_downsamples): | |
ch = num_filters * (2 ** i) | |
model += [base_conv_block(ch, ch * 2, 3, padding=1, stride=2)] | |
# ResNet blocks. | |
ch = num_filters * (2 ** num_downsamples) | |
for i in range(num_res_blocks): | |
model += [base_res_block(ch, ch, 3, padding=1)] | |
# Upsample. | |
num_upsamples = num_downsamples | |
for i in reversed(range(num_upsamples)): | |
ch = num_filters * (2 ** i) | |
model += \ | |
[NearestUpsample(scale_factor=2), | |
base_conv_block(ch * 2, ch, 3, padding=1)] | |
model += [Conv2dBlock(num_filters, num_out_put_channels, 7, padding=3, | |
padding_mode=padding_mode, nonlinearity=last_act)] | |
self.model = nn.Sequential(*model) | |
def forward(self, input): | |
r"""Coarse-to-fine generator forward. | |
Args: | |
input (4D tensor) : Input semantic representations. | |
Returns: | |
output (4D tensor) : Synthesized image by generator. | |
""" | |
return self.model(input) | |
class inner_Generator_split(nn.Module): | |
r"""Pix2pixHD coarse-to-fine generator constructor. | |
Args: | |
gen_cfg (obj): Generator definition part of the yaml config file. | |
data_cfg (obj): Data definition part of the yaml config file. | |
last_act: ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, | |
``'tanh'`` , ``'sigmoid'`` or ``'softmax'``,default is 'relu'. | |
""" | |
def __init__(self, gen_cfg,inner_cfg, data_cfg,num_input_channels=3,last_act='relu'): | |
super().__init__() | |
assert last_act in ['none', 'relu', 'leakyrelu', 'prelu', | |
'tanh' , 'sigmoid' , 'softmax'] | |
# pix2pixHD has a global generator. | |
# By default, pix2pixHD using instance normalization. | |
print(inner_cfg) | |
style_dim = gen_cfg.style_enc_cfg.interm_style_dims | |
activation_norm_type = getattr(gen_cfg, 'activation_norm_type', | |
'instance') | |
activation_norm_params = getattr(gen_cfg, 'activation_norm_params', | |
None) | |
weight_norm_type = getattr(gen_cfg, 'weight_norm_type', '') | |
padding_mode = getattr(gen_cfg, 'padding_mode', 'reflect') | |
# num_input_channels = get_paired_input_label_channel_number(data_cfg) | |
# num_input_channels = 3 | |
base_conv_block = partial(Conv2dBlock, | |
padding_mode=padding_mode, | |
weight_norm_type=weight_norm_type, | |
activation_norm_type=activation_norm_type, | |
activation_norm_params=activation_norm_params, | |
) | |
base_res_block = partial(Res2dBlock, | |
padding_mode=padding_mode, | |
weight_norm_type=weight_norm_type, | |
activation_norm_type=activation_norm_type, | |
activation_norm_params=activation_norm_params, | |
nonlinearity='relu', order='CNACN') | |
# Know what is the number of available segmentation labels. | |
# Global generator model. | |
num_out_put_channels = getattr(inner_cfg, 'output_nc', 64) | |
num_filters = getattr(inner_cfg, 'num_filters', 64) | |
num_downsamples = 4 | |
num_res_blocks = getattr(inner_cfg, 'num_res_blocks', 9) | |
# First layer. | |
model = [base_conv_block(num_input_channels, num_filters, | |
kernel_size=7, padding=3)] | |
model += [nn.PReLU()] | |
# Downsample. | |
for i in range(num_downsamples): | |
ch = num_filters * (2 ** i) | |
model += [base_conv_block(ch, ch * 2, 3, padding=1, stride=2)] | |
model += [nn.PReLU()] | |
# ResNet blocks. | |
ch = num_filters * (2 ** num_downsamples) | |
for i in range(num_res_blocks): | |
model += [base_res_block(ch, ch, 3, padding=1)] | |
self.model = nn.Sequential(*model) | |
# Upsample. | |
assert num_downsamples == 4 | |
if not (inner_cfg.name =='render' and gen_cfg.style_inject): | |
list = [16,8,4,2] | |
else: | |
list = [16,6,6,6] | |
self.up0_a = NearestUpsample(scale_factor=2) | |
self.up0_b = base_conv_block(num_filters * list[0], num_filters*list[1], 3, padding=1) | |
self.up1_a = NearestUpsample(scale_factor=2) | |
self.up1_b = base_conv_block(num_filters * list[1], num_filters*list[2], 3, padding=1) | |
self.up2_a = NearestUpsample(scale_factor=2) | |
self.up2_b = base_conv_block(num_filters * list[2], num_filters*list[3], 3, padding=1) | |
self.up3_a = NearestUpsample(scale_factor=2) | |
self.up3_b = base_conv_block(num_filters * list[3], num_filters, 3, padding=1) | |
self.up_end = Conv2dBlock(num_filters, num_out_put_channels, 7, padding=3, | |
padding_mode=padding_mode, nonlinearity=last_act) | |
if inner_cfg.name =='render' and gen_cfg.style_inject: | |
self.fc_z_cond = nn.Linear(style_dim, 4* list[-1] * num_filters) | |
def modulate(self, x, w, b): | |
w = w[..., None, None] | |
b = b[..., None, None] | |
return x * (w+1) + b | |
def forward(self, input,z=None): | |
r"""Coarse-to-fine generator forward. | |
Args: | |
input (4D tensor) : Input semantic representations. | |
Returns: | |
output (4D tensor) : Synthesized image by generator. | |
""" | |
if z is not None: | |
z = self.fc_z_cond(z) | |
adapt = torch.chunk(z, 2 * 2, dim=-1) | |
input = self.model(input) | |
input = self.up0_a(input) | |
input = self.up0_b(input) | |
input = F.leaky_relu(input,negative_slope=0.2, inplace=True) | |
input = self.up1_a(input) | |
input = self.up1_b(input) | |
if z is not None: | |
input = self.modulate(input, adapt[0], adapt[1]) | |
input = F.leaky_relu(input,negative_slope=0.2, inplace=True) | |
input = self.up2_a(input) | |
input = self.up2_b(input) | |
if z is not None: | |
input = self.modulate(input, adapt[2], adapt[3]) | |
input = F.leaky_relu(input,negative_slope=0.2, inplace=True) | |
input = self.up3_a(input) | |
input = self.up3_b(input) | |
input = F.leaky_relu(input,negative_slope=0.2, inplace=True) | |
input = self.up_end(input) | |
return input | |
if __name__=='__main__': | |
from easydict import EasyDict as edict | |
style_enc_cfg = edict() | |
style_enc_cfg.histo.mode = 'RGB' | |
style_enc_cfg.histo.num_filters = 180 | |
model = histo_process(style_enc_cfg) |