""" Copyright (C) 2019 NVIDIA Corporation. All rights reserved. Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). """ import os import torch from torch.nn import init from spade.generator import SPADEGenerator class Pix2PixModel(torch.nn.Module): def __init__(self, opt): super().__init__() self.opt = opt self.FloatTensor = torch.cuda.FloatTensor if opt['use_gpu'] \ else torch.FloatTensor self.netG = self.initialize_networks(opt) def forward(self, data, mode): input_semantics, real_image = self.preprocess_input(data) if mode == 'inference': with torch.no_grad(): fake_image = self.generate_fake(input_semantics) return fake_image else: raise ValueError("|mode| is invalid") def preprocess_input(self, data): data['label'] = data['label'].long() # move to GPU and change data types if self.opt['use_gpu']: data['label'] = data['label'].cuda() data['instance'] = data['instance'].cuda() data['image'] = data['image'].cuda() # create one-hot label map label_map = data['label'] bs, _, h, w = label_map.size() input_label = self.FloatTensor(bs, self.opt['label_nc'], h, w).zero_() # one whole label map -> to one label map per class input_semantics = input_label.scatter_(1, label_map, 1.0) return input_semantics, data['image'] def generate_fake(self, input_semantics): fake_image = self.netG(input_semantics) return fake_image def create_network(self, cls, opt): net = cls(opt) if self.opt['use_gpu']: net.cuda() gain = 0.02 def init_weights(m): classname = m.__class__.__name__ if classname.find('BatchNorm2d') != -1: if hasattr(m, 'weight') and m.weight is not None: init.normal_(m.weight.data, 1.0, gain) if hasattr(m, 'bias') and m.bias is not None: init.constant_(m.bias.data, 0.0) elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): init.xavier_normal_(m.weight.data, gain=gain) if hasattr(m, 'bias') and m.bias is not None: init.constant_(m.bias.data, 0.0) # Applies fn recursively to every submodule (as returned by .children()) as well as self net.apply(init_weights) return net def load_network(self, net, label, epoch, opt): save_filename = '%s_net_%s.pth' % (epoch, label) save_path = os.path.join( save_filename) weights = torch.load(save_path) net.load_state_dict(weights) return net def initialize_networks(self, opt): netG = self.create_network(SPADEGenerator, opt) if not opt['isTrain']: netG = self.load_network(netG, 'G', opt['which_epoch'], opt) # self.print_network(netG) return netG def print_network(self, net): num_params = 0 for param in net.parameters(): num_params += param.numel() print('Network [%s] was created. Total number of parameters: %.1f million. ' % (type(net).__name__, num_params / 1000000)) print(net)