""" 07-21 StyleLoss to encourage style statistics to be consistent within each cluster. """ import torch import torchvision import torch.nn as nn import torch.nn.functional as F # VGG architecter, used for the perceptual loss using a pretrained VGG network class VGG19(torch.nn.Module): def __init__(self, requires_grad=False, device = torch.device(f'cuda:0')): super().__init__() vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() for x in range(2): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(2, 7): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(7, 12): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(12, 21): self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(21, 30): self.slice5.add_module(str(x), vgg_pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device) self.normalization = Normalization(cnn_normalization_mean, cnn_normalization_std) def forward(self, X): #X = self.normalization(X) h_relu1 = self.slice1(X) h_relu2 = self.slice2(h_relu1) h_relu3 = self.slice3(h_relu2) h_relu4 = self.slice4(h_relu3) h_relu5 = self.slice5(h_relu4) out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] return out # create a module to normalize input image so we can easily put it in a # nn.Sequential class Normalization(nn.Module): def __init__(self, mean, std): super(Normalization, self).__init__() # .view the mean and std to make them [C x 1 x 1] so that they can # directly work with image Tensor of shape [B x C x H x W]. # B is batch size. C is number of channels. H is height and W is width. self.mean = torch.tensor(mean).view(-1, 1, 1) self.std = torch.tensor(std).view(-1, 1, 1) def forward(self, img): # normalize img return (img - self.mean) / self.std class GramMatrix(nn.Module): def forward(self,input): b, c, h, w = input.size() f = input.view(b,c,h*w) # bxcx(hxw) # torch.bmm(batch1, batch2, out=None) # # batch1: bxmxp, batch2: bxpxn -> bxmxn # G = torch.bmm(f,f.transpose(1,2)) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc return G.div_(c*h*w) class StyleLoss(nn.Module): """ Version 1. Compare mean and variance cluster-wise. """ def __init__(self, style_layers = 'relu3, relu4, relu5', device = torch.device(f'cuda:0'), style_mode = 'gram'): super().__init__() self.vgg = VGG19() self.style_layers = [] for style_layer in style_layers.split(','): self.style_layers.append(int(style_layer[-1]) - 1) self.style_mode = style_mode cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device) self.normalization = Normalization(cnn_normalization_mean, cnn_normalization_std) def forward(self, pred, gt): """ INPUTS: - pred: (B, 3, H, W) - gt: (B, 3, H, W) - seg: (B, H, W) """ # extract features for images B, _, H, W = pred.shape pred = self.normalization(pred) gt = self.normalization(gt) pred_feats = self.vgg(pred) gt_feats = self.vgg(gt) loss = 0 for style_layer in self.style_layers: pred_feat = pred_feats[style_layer] gt_feat = gt_feats[style_layer] pred_gram = GramMatrix()(pred_feat) gt_gram = GramMatrix()(gt_feat) loss += torch.sum((pred_gram - gt_gram) ** 2) / B return loss class styleLossMask(nn.Module): """ Version 1. Compare mean and variance cluster-wise. """ def __init__(self, style_layers = 'relu3, relu4, relu5', device = torch.device(f'cuda:0'), style_mode = 'gram'): super().__init__() self.vgg = VGG19() self.style_layers = [] for style_layer in style_layers.split(','): self.style_layers.append(int(style_layer[-1]) - 1) self.style_mode = style_mode #cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) #cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device) #self.normalization = Normalization(cnn_normalization_mean, cnn_normalization_std) def forward(self, input, target, mask): B, _, H, W = input.shape #pred = self.normalization(input) #target = self.normalization(target) pred_feats = self.vgg(input) gt_feats = self.vgg(target) loss = 0 mb, mc, mh, mw = mask.shape for style_layer in self.style_layers: pred_feat = pred_feats[style_layer] gt_feat = gt_feats[style_layer] ib,ic,ih,iw = pred_feat.size() iF = pred_feat.view(ib,ic,-1) tb,tc,th,tw = gt_feat.size() tF = gt_feat.view(tb,tc,-1) for i in range(mb): # resize mask to have the same size of the feature maski = F.interpolate(mask[i:i+1], size = (ih, iw), mode = 'nearest') mask_flat_i = maski.view(mc, -1) maskt = F.interpolate(mask[i:i+1], size = (th, tw), mode = 'nearest') mask_flat_t = maskt.view(mc, -1) for j in range(mc): # get features for each part idx = torch.nonzero(mask_flat_i[j]).squeeze() if len(idx.shape) == 0 or idx.shape[0] == 0: continue ipart = torch.index_select(iF, 2, idx) idx = torch.nonzero(mask_flat_t[j]).squeeze() if len(idx.shape) == 0 or idx.shape[0] == 0: continue tpart = torch.index_select(tF, 2, idx) iMean = torch.mean(ipart,dim=2) iGram = torch.bmm(ipart, ipart.transpose(1,2)).div_(ic*ih*iw) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc tMean = torch.mean(tpart,dim=2) tGram = torch.bmm(tpart, tpart.transpose(1,2)).div_(tc*th*tw) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc loss_j = nn.MSELoss()(iMean,tMean) + nn.MSELoss()(iGram,tGram) loss += loss_j return loss/tb # Perceptual loss that uses a pretrained VGG network class VGGLoss(nn.Module): def __init__(self, weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0], device = torch.device(f'cuda:0')): super(VGGLoss, self).__init__() self.vgg = VGG19(device = device) self.criterion = nn.L1Loss() self.weights = weights def forward(self, x, y): x_vgg, y_vgg = self.vgg(x), self.vgg(y) loss = 0 for i in range(len(x_vgg)): loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) return loss class styleLossMaskv2(nn.Module): def __init__(self, style_layers = 'relu3, relu4, relu5', device = torch.device(f'cuda:0')): super().__init__() self.vgg = VGG19(device = device) self.style_layers = [] for style_layer in style_layers.split(','): self.style_layers.append(int(style_layer[-1]) - 1) cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device) self.normalization = Normalization(cnn_normalization_mean, cnn_normalization_std) def forward(self, input, target, mask_input, mask_target): B, _, H, W = input.shape input = self.normalization(input) target = self.normalization(target) pred_feats = self.vgg(input) gt_feats = self.vgg(target) loss = 0 mb, mc, mh, mw = mask_input.shape for style_layer in self.style_layers: pred_feat = pred_feats[style_layer] gt_feat = gt_feats[style_layer] ib,ic,ih,iw = pred_feat.size() iF = pred_feat.view(ib,ic,-1) tb,tc,th,tw = gt_feat.size() tF = gt_feat.view(tb,tc,-1) for i in range(mb): # resize mask to have the same size of the feature maski = F.interpolate(mask_input[i:i+1], size = (ih, iw), mode = 'nearest') mask_flat_i = maski.view(mc, -1) maskt = F.interpolate(mask_target[i:i+1], size = (th, tw), mode = 'nearest') mask_flat_t = maskt.view(mc, -1) for j in range(mc): # get features for each part idx = torch.nonzero(mask_flat_i[j]).squeeze() if len(idx.shape) == 0 or idx.shape[0] == 0: continue ipart = torch.index_select(iF[i:i+1], 2, idx) idx = torch.nonzero(mask_flat_t[j]).squeeze() if len(idx.shape) == 0 or idx.shape[0] == 0: continue tpart = torch.index_select(tF[i:i+1], 2, idx) iGram = torch.bmm(ipart, ipart.transpose(1,2)).div_(ipart.shape[1] * ipart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc tGram = torch.bmm(tpart, tpart.transpose(1,2)).div_(tpart.shape[1] * tpart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc loss += torch.sum((iGram - tGram) ** 2) return loss/tb class styleLossMaskv3(nn.Module): def __init__(self, style_layers = 'relu1, relu2, relu3, relu4, relu5', device = torch.device(f'cuda:0')): super().__init__() self.vgg = VGG19(device = device) self.style_layers = [] for style_layer in style_layers.split(','): self.style_layers.append(int(style_layer[-1]) - 1) cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device) self.normalization = Normalization(cnn_normalization_mean, cnn_normalization_std) def forward_img_img(self, input, target, mask_input, mask_target): B, _, H, W = input.shape input = self.normalization(input) target = self.normalization(target) pred_feats = self.vgg(input) gt_feats = self.vgg(target) loss = 0 mb, mc, mh, mw = mask_input.shape for style_layer in self.style_layers: pred_feat = pred_feats[style_layer] gt_feat = gt_feats[style_layer] ib,ic,ih,iw = pred_feat.size() iF = pred_feat.view(ib,ic,-1) tb,tc,th,tw = gt_feat.size() tF = gt_feat.view(tb,tc,-1) for i in range(mb): # resize mask to have the same size of the feature maski = F.interpolate(mask_input[i:i+1], size = (ih, iw), mode = 'nearest') mask_flat_i = maski.view(mc, -1) maskt = F.interpolate(mask_target[i:i+1], size = (th, tw), mode = 'nearest') mask_flat_t = maskt.view(mc, -1) for j in range(mc): # get features for each part idx = torch.nonzero(mask_flat_i[j]).squeeze() if len(idx.shape) == 0 or idx.shape[0] == 0: continue ipart = torch.index_select(iF[i:i+1], 2, idx) idx = torch.nonzero(mask_flat_t[j]).squeeze() if len(idx.shape) == 0 or idx.shape[0] == 0: continue tpart = torch.index_select(tF[i:i+1], 2, idx) iGram = torch.bmm(ipart, ipart.transpose(1,2)).div_(ipart.shape[1] * ipart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc tGram = torch.bmm(tpart, tpart.transpose(1,2)).div_(tpart.shape[1] * tpart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc #loss += torch.sum((iGram - tGram) ** 2) loss += F.mse_loss(iGram, tGram) #return loss/tb return loss * 100000 / tb def forward_patch_img(self, input, target, mask_target): input = self.normalization(input) target = self.normalization(target) pred_feats = self.vgg(input) gt_feats = self.vgg(target) patch_num = input.shape[0] // target.shape[0] loss = 0 for style_layer in self.style_layers: pred_feat = pred_feats[style_layer] gt_feat = gt_feats[style_layer] ib,ic,ih,iw = pred_feat.size() iF = pred_feat.view(ib,ic,-1) tb,tc,th,tw = gt_feat.size() tF = gt_feat.view(tb,tc,-1) for i in range(tb): # resize mask to have the same size of the feature maskt = F.interpolate(mask_target[i:i+1], size = (th, tw), mode = 'nearest') mask_flat_t = maskt.view(-1) idx = torch.nonzero(mask_flat_t).squeeze() if len(idx.shape) == 0 or idx.shape[0] == 0: continue tpart = torch.index_select(tF[i:i+1], 2, idx) tGram = torch.bmm(tpart, tpart.transpose(1,2)).div_(tpart.shape[1] * tpart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc ipart = iF[i * patch_num: (i + 1) * patch_num] iGram = torch.bmm(ipart, ipart.transpose(1,2)).div_(ipart.shape[1] * ipart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc #loss += torch.sum((iGram - tGram.repeat(patch_num, 1, 1)) ** 2) loss += F.mse_loss(iGram, tGram.repeat(patch_num, 1, 1)) return loss/ib * 100000 class KLDLoss(nn.Module): def forward(self, mu, logvar): return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) class LPIPSorGramMatch(nn.Module): """ Version 1. Compare mean and variance cluster-wise. """ def __init__(self, style_layers = 'relu3, relu4, relu5', device = torch.device(f'cuda:0'), weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0],): super().__init__() self.vgg = VGG19(device = device) self.style_layers = [] for style_layer in style_layers.split(','): self.style_layers.append(int(style_layer[-1]) - 1) cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device) self.normalization = Normalization(cnn_normalization_mean, cnn_normalization_std) self.criterion = nn.L1Loss() self.weights = weights def forward(self, pred, gt, mode = 'lpips'): """ INPUTS: - pred: (B, 3, H, W) - gt: (B, 3, H, W) - seg: (B, H, W) """ # extract features for images B, _, H, W = pred.shape pred = self.normalization(pred) gt = self.normalization(gt) pred_feats = self.vgg(pred) gt_feats = self.vgg(gt) if mode == 'lpips': lpips_loss = 0 for i in range(len(pred_feats)): lpips_loss += self.weights[i] * self.criterion(pred_feats[i], gt_feats[i].detach()) return lpips_loss elif mode == 'gram_match': gram_match_loss = 0 for style_layer in self.style_layers: pred_feat = pred_feats[style_layer] gt_feat = gt_feats[style_layer] pred_gram = GramMatrix()(pred_feat) gt_gram = GramMatrix()(gt_feat) gram_match_loss += torch.sum((pred_gram - gt_gram) ** 2) / B return gram_match_loss else: raise ValueError("Only computes lpips or gram match loss.") class styleLossMaskv4(nn.Module): def __init__(self, style_layers = 'relu3, relu4, relu5', device = torch.device(f'cuda:0')): super().__init__() self.vgg = VGG19(device = device) self.style_layers = [] for style_layer in style_layers.split(','): self.style_layers.append(int(style_layer[-1]) - 1) cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device) self.normalization = Normalization(cnn_normalization_mean, cnn_normalization_std) def forward_img_img(self, input, target, mask_input, mask_target): B, _, H, W = input.shape input = self.normalization(input) target = self.normalization(target) pred_feats = self.vgg(input) gt_feats = self.vgg(target) loss = 0 mb, mc, mh, mw = mask_input.shape for style_layer in self.style_layers: pred_feat = pred_feats[style_layer] gt_feat = gt_feats[style_layer] ib,ic,ih,iw = pred_feat.size() iF = pred_feat.view(ib,ic,-1) tb,tc,th,tw = gt_feat.size() tF = gt_feat.view(tb,tc,-1) for i in range(mb): # resize mask to have the same size of the feature maski = F.interpolate(mask_input[i:i+1], size = (ih, iw), mode = 'nearest') mask_flat_i = maski.view(mc, -1) maskt = F.interpolate(mask_target[i:i+1], size = (th, tw), mode = 'nearest') mask_flat_t = maskt.view(mc, -1) for j in range(mc): # get features for each part idx = torch.nonzero(mask_flat_i[j]).squeeze() if len(idx.shape) == 0 or idx.shape[0] == 0: continue ipart = torch.index_select(iF[i:i+1], 2, idx) idx = torch.nonzero(mask_flat_t[j]).squeeze() if len(idx.shape) == 0 or idx.shape[0] == 0: continue tpart = torch.index_select(tF[i:i+1], 2, idx) iGram = torch.bmm(ipart, ipart.transpose(1,2)).div_(ipart.shape[1] * ipart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc iMean = torch.mean(ipart, dim=2) tGram = torch.bmm(tpart, tpart.transpose(1,2)).div_(tpart.shape[1] * tpart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc tMean = torch.mean(tpart, dim=2) loss += torch.sum((iGram - tGram) ** 2) + torch.sum((iMean - tMean) ** 2) * 0.01 return loss/tb def forward_patch_img(self, input, target, mask_target): input = self.normalization(input) target = self.normalization(target) pred_feats = self.vgg(input) gt_feats = self.vgg(target) patch_num = input.shape[0] // target.shape[0] loss = 0 for style_layer in self.style_layers: pred_feat = pred_feats[style_layer] gt_feat = gt_feats[style_layer] ib,ic,ih,iw = pred_feat.size() iF = pred_feat.view(ib,ic,-1) tb,tc,th,tw = gt_feat.size() tF = gt_feat.view(tb,tc,-1) for i in range(tb): # resize mask to have the same size of the feature maskt = F.interpolate(mask_target[i:i+1], size = (th, tw), mode = 'nearest') mask_flat_t = maskt.view(-1) idx = torch.nonzero(mask_flat_t).squeeze() if len(idx.shape) == 0 or idx.shape[0] == 0: continue tpart = torch.index_select(tF[i:i+1], 2, idx) tGram = torch.bmm(tpart, tpart.transpose(1,2)).div_(tpart.shape[1] * tpart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc tMean = torch.mean(tpart, dim=2) ipart = iF[i * patch_num: (i + 1) * patch_num] iMean = torch.mean(ipart, dim=2) iGram = torch.bmm(ipart, ipart.transpose(1,2)).div_(ipart.shape[1] * ipart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc loss += torch.sum((iGram - tGram.repeat(patch_num, 1, 1)) ** 2) loss += torch.sum((iMean - tMean.repeat(patch_num, 1)) ** 2) * 0.01 return loss/ib