import numpy as np import torch import os from torch.autograd import Variable from util.image_pool import ImagePool import torch.nn as nn import cv2 from .base_model import BaseModel from . import networks import torch.nn.functional as F NC = 20 def generate_discrete_label(inputs, label_nc, onehot=True, encode=True): pred_batch = [] size = inputs.size() for input in inputs: input = input.view(1, label_nc, size[2], size[3]) pred = np.squeeze(input.data.max(1)[1].cpu().numpy(), axis=0) pred_batch.append(pred) pred_batch = np.array(pred_batch) pred_batch = torch.from_numpy(pred_batch) label_map = [] for p in pred_batch: p = p.view(1, 256, 192) label_map.append(p) label_map = torch.stack(label_map, 0) if not onehot: return label_map.float().cuda() size = label_map.size() oneHot_size = (size[0], label_nc, size[2], size[3]) input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0) return input_label def morpho(mask, iter, bigger=True): kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) new = [] for i in range(len(mask)): tem = mask[i].cpu().detach().numpy().squeeze().reshape(256, 192, 1)*255 tem = tem.astype(np.uint8) if bigger: tem = cv2.dilate(tem, kernel, iterations=iter) else: tem = cv2.erode(tem, kernel, iterations=iter) tem = tem.astype(np.float64) tem = tem.reshape(1, 256, 192) new.append(tem.astype(np.float64)/255.0) new = np.stack(new) new = torch.FloatTensor(new).cuda() return new def morpho_smaller(mask, iter, bigger=True): kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (1, 1)) new = [] for i in range(len(mask)): tem = mask[i].cpu().detach().numpy().squeeze().reshape(256, 192, 1)*255 tem = tem.astype(np.uint8) if bigger: tem = cv2.dilate(tem, kernel, iterations=iter) else: tem = cv2.erode(tem, kernel, iterations=iter) tem = tem.astype(np.float64) tem = tem.reshape(1, 256, 192) new.append(tem.astype(np.float64)/255.0) new = np.stack(new) new = torch.FloatTensor(new).cuda() return new def encode(label_map, size): label_nc = 14 oneHot_size = (size[0], label_nc, size[2], size[3]) input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0) return input_label class Pix2PixHDModel(BaseModel): def name(self): return 'Pix2PixHDModel' def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss): flags = (True, use_gan_feat_loss, use_vgg_loss, True, True) def loss_filter(g_gan, g_gan_feat, g_vgg, d_real, d_fake): return [l for (l, f) in zip((g_gan, g_gan_feat, g_vgg, d_real, d_fake), flags) if f] return loss_filter def get_G(self, in_C, out_c, n_blocks, opt, L=1, S=1): return networks.define_G(in_C, out_c, opt.ngf, opt.netG, L, S, opt.n_downsample_global, n_blocks, opt.n_local_enhancers, opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids) def get_D(self, inc, opt): netD = networks.define_D(inc, opt.ndf, opt.n_layers_D, opt.norm, opt.no_lsgan, opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids) return netD def cross_entropy2d(self, input, target, weight=None, size_average=True): n, c, h, w = input.size() nt, ht, wt = target.size() # Handle inconsistent size between input and target if h != ht or w != wt: input = F.interpolate(input, size=( ht, wt), mode="bilinear", align_corners=True) input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) target = target.view(-1) loss = F.cross_entropy( input, target, weight=weight, size_average=size_average, ignore_index=250 ) return loss def ger_average_color(self, mask, arms): color = torch.zeros(arms.shape).cuda() for i in range(arms.shape[0]): count = len(torch.nonzero(mask[i, :, :, :])) if count < 10: color[i, 0, :, :] = 0 color[i, 1, :, :] = 0 color[i, 2, :, :] = 0 else: color[i, 0, :, :] = arms[i, 0, :, :].sum() / count color[i, 1, :, :] = arms[i, 1, :, :].sum() / count color[i, 2, :, :] = arms[i, 2, :, :].sum() / count return color def initialize(self, opt): BaseModel.initialize(self, opt) if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM torch.backends.cudnn.benchmark = True self.isTrain = opt.isTrain input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc self.count = 0 # define networks # Generator network netG_input_nc = input_nc # Main Generator with torch.no_grad(): self.Unet = networks.define_UnetMask(4, self.gpu_ids).eval() self.G1 = networks.define_Refine_ResUnet(37, 14, self.gpu_ids).eval() self.G2 = networks.define_Refine(19+18, 1, self.gpu_ids).eval() self.G = networks.define_Refine(24, 3, self.gpu_ids).eval() self.tanh = nn.Tanh() self.sigmoid = nn.Sigmoid() self.BCE = torch.nn.BCEWithLogitsLoss() # Discriminator network if self.isTrain: use_sigmoid = opt.no_lsgan netD_input_nc = input_nc + opt.output_nc netB_input_nc = opt.output_nc * 2 # self.D1 = self.get_D(17, opt) # self.D2 = self.get_D(4, opt) # self.D3=self.get_D(7+3,opt) # self.D = self.get_D(20, opt) # self.netB = networks.define_B(netB_input_nc, opt.output_nc, 32, 3, 3, opt.norm, gpu_ids=self.gpu_ids) if self.opt.verbose: print('---------- Networks initialized -------------') # load networks if not self.isTrain or opt.continue_train or opt.load_pretrain: pretrained_path = '' if not self.isTrain else opt.load_pretrain self.load_network(self.Unet, 'U', opt.which_epoch, pretrained_path) self.load_network(self.G1, 'G1', opt.which_epoch, pretrained_path) self.load_network(self.G2, 'G2', opt.which_epoch, pretrained_path) self.load_network(self.G, 'G', opt.which_epoch, pretrained_path) # set loss functions and optimizers if self.isTrain: if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: raise NotImplementedError( "Fake Pool Not Implemented for MultiGPU") self.fake_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.loss_filter = self.init_loss_filter( not opt.no_ganFeat_loss, not opt.no_vgg_loss) self.criterionGAN = networks.GANLoss( use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionFeat = torch.nn.L1Loss() if not opt.no_vgg_loss: self.criterionVGG = networks.VGGLoss(self.gpu_ids) self.criterionStyle = networks.StyleLoss(self.gpu_ids) # Names so we can breakout loss self.loss_names = self.loss_filter( 'G_GAN', 'G_GAN_Feat', 'G_VGG', 'D_real', 'D_fake') # initialize optimizers # optimizer G if opt.niter_fix_global > 0: import sys if sys.version_info >= (3, 0): finetune_list = set() else: from sets import Set finetune_list = Set() params_dict = dict(self.netG.named_parameters()) params = [] for key, value in params_dict.items(): if key.startswith('model' + str(opt.n_local_enhancers)): params += [value] finetune_list.add(key.split('.')[0]) print( '------------- Only training the local enhancer ork (for %d epochs) ------------' % opt.niter_fix_global) print('The layers that are finetuned are ', sorted(finetune_list)) def encode_input(self, label_map, clothes_mask, all_clothes_label): size = label_map.size() oneHot_size = (size[0], 14, size[2], size[3]) input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() input_label = input_label.scatter_( 1, label_map.data.long().cuda(), 1.0) masked_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() masked_label = masked_label.scatter_( 1, (label_map * (1 - clothes_mask)).data.long().cuda(), 1.0) c_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() c_label = c_label.scatter_( 1, all_clothes_label.data.long().cuda(), 1.0) input_label = Variable(input_label) return input_label, masked_label, c_label def encode_input_test(self, label_map, label_map_ref, real_image_ref, infer=False): if self.opt.label_nc == 0: input_label = label_map.data.cuda() input_label_ref = label_map_ref.data.cuda() else: # create one-hot vector for label map size = label_map.size() oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) input_label = torch.cuda.FloatTensor( torch.Size(oneHot_size)).zero_() input_label = input_label.scatter_( 1, label_map.data.long().cuda(), 1.0) input_label_ref = torch.cuda.FloatTensor( torch.Size(oneHot_size)).zero_() input_label_ref = input_label_ref.scatter_( 1, label_map_ref.data.long().cuda(), 1.0) if self.opt.data_type == 16: input_label = input_label.half() input_label_ref = input_label_ref.half() input_label = Variable(input_label, volatile=infer) input_label_ref = Variable(input_label_ref, volatile=infer) real_image_ref = Variable(real_image_ref.data.cuda()) return input_label, input_label_ref, real_image_ref def discriminate(self, netD, input_label, test_image, use_pool=False): input_concat = torch.cat((input_label, test_image.detach()), dim=1) if use_pool: fake_query = self.fake_pool.query(input_concat) return netD.forward(fake_query) else: return netD.forward(input_concat) def gen_noise(self, shape): noise = np.zeros(shape, dtype=np.uint8) # noise noise = cv2.randn(noise, 0, 255) noise = np.asarray(noise / 255, dtype=np.uint8) noise = torch.tensor(noise, dtype=torch.float32) return noise.cuda() def multi_scale_blend(self, fake_img, fake_c, mask, number=4): alpha = [0, 0.1, 0.3, 0.6, 0.9] smaller = mask out = 0 for i in range(1, number+1): bigger = smaller smaller = morpho(smaller, 2, False) mid = bigger-smaller out += mid*(alpha[i]*fake_c+(1-alpha[i])*fake_img) out += smaller*fake_c out += (1-mask)*fake_img return out def forward(self, label, pre_clothes_mask, img_fore, clothes_mask, clothes, all_clothes_label, real_image, pose, grid, mask_fore): # Encode Inputs input_label, masked_label, all_clothes_label = self.encode_input( label, clothes_mask, all_clothes_label) arm1_mask = torch.FloatTensor( (label.cpu().numpy() == 11).astype(np.float)).cuda() arm2_mask = torch.FloatTensor( (label.cpu().numpy() == 13).astype(np.float)).cuda() pre_clothes_mask = torch.FloatTensor( (pre_clothes_mask.detach().cpu().numpy() > 0.5).astype(np.float)).cuda() clothes = clothes * pre_clothes_mask shape = pre_clothes_mask.shape G1_in = torch.cat([pre_clothes_mask, clothes, all_clothes_label, pose, self.gen_noise(shape)], dim=1) arm_label = self.G1.refine(G1_in) arm_label = self.sigmoid(arm_label) CE_loss = self.cross_entropy2d( arm_label, (label * (1 - clothes_mask)).transpose(0, 1)[0].long()) * 10 armlabel_map = generate_discrete_label(arm_label.detach(), 14, False) dis_label = generate_discrete_label(arm_label.detach(), 14) G2_in = torch.cat([pre_clothes_mask, clothes, dis_label, pose, self.gen_noise(shape)], 1) fake_cl = self.G2.refine(G2_in) fake_cl = self.sigmoid(fake_cl) CE_loss += self.BCE(fake_cl, clothes_mask) * 10 fake_cl_dis = torch.FloatTensor( (fake_cl.detach().cpu().numpy() > 0.5).astype(np.float)).cuda() fake_cl_dis = morpho(fake_cl_dis, 1, True) new_arm1_mask = torch.FloatTensor( (armlabel_map.cpu().numpy() == 11).astype(np.float)).cuda() new_arm2_mask = torch.FloatTensor( (armlabel_map.cpu().numpy() == 13).astype(np.float)).cuda() fake_cl_dis = fake_cl_dis*(1 - new_arm1_mask)*(1-new_arm2_mask) fake_cl_dis *= mask_fore arm1_occ = clothes_mask * new_arm1_mask arm2_occ = clothes_mask * new_arm2_mask bigger_arm1_occ = morpho(arm1_occ, 10) bigger_arm2_occ = morpho(arm2_occ, 10) arm1_full = arm1_occ + (1 - clothes_mask) * arm1_mask arm2_full = arm2_occ + (1 - clothes_mask) * arm2_mask armlabel_map *= (1 - new_arm1_mask) armlabel_map *= (1 - new_arm2_mask) armlabel_map = armlabel_map * (1 - arm1_full) + arm1_full * 11 armlabel_map = armlabel_map * (1 - arm2_full) + arm2_full * 13 armlabel_map *= (1-fake_cl_dis) dis_label = encode(armlabel_map, armlabel_map.shape) fake_c, warped, warped_mask, warped_grid = self.Unet( clothes, fake_cl_dis, pre_clothes_mask, grid) mask = fake_c[:, 3, :, :] mask = self.sigmoid(mask)*fake_cl_dis fake_c = self.tanh(fake_c[:, 0:3, :, :]) fake_c = fake_c*(1-mask)+mask*warped skin_color = self.ger_average_color((arm1_mask + arm2_mask - arm2_mask * arm1_mask), (arm1_mask + arm2_mask - arm2_mask * arm1_mask) * real_image) occlude = (1 - bigger_arm1_occ * (arm2_mask + arm1_mask+clothes_mask)) * \ (1 - bigger_arm2_occ * (arm2_mask + arm1_mask+clothes_mask)) img_hole_hand = img_fore * \ (1 - clothes_mask) * occlude * (1 - fake_cl_dis) G_in = torch.cat([img_hole_hand, dis_label, fake_c, skin_color, self.gen_noise(shape)], 1) fake_image = self.G.refine(G_in.detach()) fake_image = self.tanh(fake_image) loss_D_fake = 0 loss_D_real = 0 loss_G_GAN = 0 loss_G_VGG = 0 L1_loss = 0 style_loss = L1_loss return [self.loss_filter(loss_G_GAN, 0, loss_G_VGG, loss_D_real, loss_D_fake), fake_image, clothes, arm_label, L1_loss, style_loss, fake_cl, CE_loss, real_image, warped_grid] def inference(self, label, pre_clothes_mask, img_fore, clothes_mask, clothes, all_clothes_label, real_image, pose, grid, mask_fore): # Encode Inputs input_label, masked_label, all_clothes_label = self.encode_input( label, clothes_mask, all_clothes_label) arm1_mask = torch.FloatTensor( (label.cpu().numpy() == 11).astype(np.float)).cuda() arm2_mask = torch.FloatTensor( (label.cpu().numpy() == 13).astype(np.float)).cuda() pre_clothes_mask = torch.FloatTensor( (pre_clothes_mask.detach().cpu().numpy() > 0.5).astype(np.float)).cuda() clothes = clothes * pre_clothes_mask shape = pre_clothes_mask.shape G1_in = torch.cat([pre_clothes_mask, clothes, all_clothes_label, pose, self.gen_noise(shape)], dim=1) arm_label = self.G1.refine(G1_in) arm_label = self.sigmoid(arm_label) armlabel_map = generate_discrete_label(arm_label.detach(), 14, False) dis_label = generate_discrete_label(arm_label.detach(), 14) G2_in = torch.cat([pre_clothes_mask, clothes, dis_label, pose, self.gen_noise(shape)], 1) fake_cl = self.G2.refine(G2_in) fake_cl = self.sigmoid(fake_cl) fake_cl_dis = torch.FloatTensor( (fake_cl.detach().cpu().numpy() > 0.5).astype(np.float)).cuda() fake_cl_dis = morpho(fake_cl_dis, 1, True) new_arm1_mask = torch.FloatTensor( (armlabel_map.cpu().numpy() == 11).astype(np.float)).cuda() new_arm2_mask = torch.FloatTensor( (armlabel_map.cpu().numpy() == 13).astype(np.float)).cuda() fake_cl_dis = fake_cl_dis*(1 - new_arm1_mask)*(1-new_arm2_mask) fake_cl_dis *= mask_fore arm1_occ = clothes_mask * new_arm1_mask arm2_occ = clothes_mask * new_arm2_mask bigger_arm1_occ = morpho(arm1_occ, 10) bigger_arm2_occ = morpho(arm2_occ, 10) arm1_full = arm1_occ + (1 - clothes_mask) * arm1_mask arm2_full = arm2_occ + (1 - clothes_mask) * arm2_mask armlabel_map *= (1 - new_arm1_mask) armlabel_map *= (1 - new_arm2_mask) armlabel_map = armlabel_map * (1 - arm1_full) + arm1_full * 11 armlabel_map = armlabel_map * (1 - arm2_full) + arm2_full * 13 armlabel_map *= (1-fake_cl_dis) dis_label = encode(armlabel_map, armlabel_map.shape) fake_c, warped, warped_mask, warped_grid = self.Unet( clothes, fake_cl_dis, pre_clothes_mask, grid) mask = fake_c[:, 3, :, :] mask = self.sigmoid(mask)*fake_cl_dis fake_c = self.tanh(fake_c[:, 0:3, :, :]) fake_c = fake_c*(1-mask)+mask*warped skin_color = self.ger_average_color((arm1_mask + arm2_mask - arm2_mask * arm1_mask), (arm1_mask + arm2_mask - arm2_mask * arm1_mask) * real_image) occlude = (1 - bigger_arm1_occ * (arm2_mask + arm1_mask+clothes_mask)) * \ (1 - bigger_arm2_occ * (arm2_mask + arm1_mask+clothes_mask)) img_hole_hand = img_fore * \ (1 - clothes_mask) * occlude * (1 - fake_cl_dis) G_in = torch.cat([img_hole_hand, dis_label, fake_c, skin_color, self.gen_noise(shape)], 1) fake_image = self.G.refine(G_in.detach()) fake_image = self.tanh(fake_image) return [fake_image, warped, fake_c] def save(self, which_epoch): # self.save_network(self.Unet, 'U', which_epoch, self.gpu_ids) # self.save_network(self.G, 'G', which_epoch, self.gpu_ids) # self.save_network(self.G1, 'G1', which_epoch, self.gpu_ids) # self.save_network(self.G2, 'G2', which_epoch, self.gpu_ids) # # self.save_network(self.G3, 'G3', which_epoch, self.gpu_ids) # self.save_network(self.D, 'D', which_epoch, self.gpu_ids) # self.save_network(self.D1, 'D1', which_epoch, self.gpu_ids) # self.save_network(self.D2, 'D2', which_epoch, self.gpu_ids) # self.save_network(self.D3, 'D3', which_epoch, self.gpu_ids) pass # self.save_network(self.netB, 'B', which_epoch, self.gpu_ids) def update_fixed_params(self): # after fixing the global generator for a number of iterations, also start finetuning it params = list(self.netG.parameters()) if self.gen_features: params += list(self.netE.parameters()) self.optimizer_G = torch.optim.Adam( params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) if self.opt.verbose: print('------------ Now also finetuning global generator -----------') def update_learning_rate(self): lrd = self.opt.lr / self.opt.niter_decay lr = self.old_lr - lrd for param_group in self.optimizer_D.param_groups: param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr if self.opt.verbose: print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr class InferenceModel(Pix2PixHDModel): def forward(self, label, pre_clothes_mask, img_fore, clothes_mask, clothes, all_clothes_label, real_image, pose, grid, mask_fore): return self.inference(label, pre_clothes_mask, img_fore, clothes_mask, clothes, all_clothes_label, real_image, pose, grid, mask_fore)