sunshineatnoon
Add application file
1b2a9b1
"""
07-21
StyleLoss to encourage style statistics to be consistent within each cluster.
"""
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
# VGG architecter, used for the perceptual loss using a pretrained VGG network
class VGG19(torch.nn.Module):
def __init__(self, requires_grad=False, device = torch.device(f'cuda:0')):
super().__init__()
vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
self.normalization = Normalization(cnn_normalization_mean, cnn_normalization_std)
def forward(self, X):
#X = self.normalization(X)
h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
# create a module to normalize input image so we can easily put it in a
# nn.Sequential
class Normalization(nn.Module):
def __init__(self, mean, std):
super(Normalization, self).__init__()
# .view the mean and std to make them [C x 1 x 1] so that they can
# directly work with image Tensor of shape [B x C x H x W].
# B is batch size. C is number of channels. H is height and W is width.
self.mean = torch.tensor(mean).view(-1, 1, 1)
self.std = torch.tensor(std).view(-1, 1, 1)
def forward(self, img):
# normalize img
return (img - self.mean) / self.std
class GramMatrix(nn.Module):
def forward(self,input):
b, c, h, w = input.size()
f = input.view(b,c,h*w) # bxcx(hxw)
# torch.bmm(batch1, batch2, out=None) #
# batch1: bxmxp, batch2: bxpxn -> bxmxn #
G = torch.bmm(f,f.transpose(1,2)) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc
return G.div_(c*h*w)
class StyleLoss(nn.Module):
"""
Version 1. Compare mean and variance cluster-wise.
"""
def __init__(self, style_layers = 'relu3, relu4, relu5', device = torch.device(f'cuda:0'), style_mode = 'gram'):
super().__init__()
self.vgg = VGG19()
self.style_layers = []
for style_layer in style_layers.split(','):
self.style_layers.append(int(style_layer[-1]) - 1)
self.style_mode = style_mode
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
self.normalization = Normalization(cnn_normalization_mean, cnn_normalization_std)
def forward(self, pred, gt):
"""
INPUTS:
- pred: (B, 3, H, W)
- gt: (B, 3, H, W)
- seg: (B, H, W)
"""
# extract features for images
B, _, H, W = pred.shape
pred = self.normalization(pred)
gt = self.normalization(gt)
pred_feats = self.vgg(pred)
gt_feats = self.vgg(gt)
loss = 0
for style_layer in self.style_layers:
pred_feat = pred_feats[style_layer]
gt_feat = gt_feats[style_layer]
pred_gram = GramMatrix()(pred_feat)
gt_gram = GramMatrix()(gt_feat)
loss += torch.sum((pred_gram - gt_gram) ** 2) / B
return loss
class styleLossMask(nn.Module):
"""
Version 1. Compare mean and variance cluster-wise.
"""
def __init__(self, style_layers = 'relu3, relu4, relu5', device = torch.device(f'cuda:0'), style_mode = 'gram'):
super().__init__()
self.vgg = VGG19()
self.style_layers = []
for style_layer in style_layers.split(','):
self.style_layers.append(int(style_layer[-1]) - 1)
self.style_mode = style_mode
#cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
#cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
#self.normalization = Normalization(cnn_normalization_mean, cnn_normalization_std)
def forward(self, input, target, mask):
B, _, H, W = input.shape
#pred = self.normalization(input)
#target = self.normalization(target)
pred_feats = self.vgg(input)
gt_feats = self.vgg(target)
loss = 0
mb, mc, mh, mw = mask.shape
for style_layer in self.style_layers:
pred_feat = pred_feats[style_layer]
gt_feat = gt_feats[style_layer]
ib,ic,ih,iw = pred_feat.size()
iF = pred_feat.view(ib,ic,-1)
tb,tc,th,tw = gt_feat.size()
tF = gt_feat.view(tb,tc,-1)
for i in range(mb):
# resize mask to have the same size of the feature
maski = F.interpolate(mask[i:i+1], size = (ih, iw), mode = 'nearest')
mask_flat_i = maski.view(mc, -1)
maskt = F.interpolate(mask[i:i+1], size = (th, tw), mode = 'nearest')
mask_flat_t = maskt.view(mc, -1)
for j in range(mc):
# get features for each part
idx = torch.nonzero(mask_flat_i[j]).squeeze()
if len(idx.shape) == 0 or idx.shape[0] == 0:
continue
ipart = torch.index_select(iF, 2, idx)
idx = torch.nonzero(mask_flat_t[j]).squeeze()
if len(idx.shape) == 0 or idx.shape[0] == 0:
continue
tpart = torch.index_select(tF, 2, idx)
iMean = torch.mean(ipart,dim=2)
iGram = torch.bmm(ipart, ipart.transpose(1,2)).div_(ic*ih*iw) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc
tMean = torch.mean(tpart,dim=2)
tGram = torch.bmm(tpart, tpart.transpose(1,2)).div_(tc*th*tw) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc
loss_j = nn.MSELoss()(iMean,tMean) + nn.MSELoss()(iGram,tGram)
loss += loss_j
return loss/tb
# Perceptual loss that uses a pretrained VGG network
class VGGLoss(nn.Module):
def __init__(self, weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0], device = torch.device(f'cuda:0')):
super(VGGLoss, self).__init__()
self.vgg = VGG19(device = device)
self.criterion = nn.L1Loss()
self.weights = weights
def forward(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
loss = 0
for i in range(len(x_vgg)):
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
return loss
class styleLossMaskv2(nn.Module):
def __init__(self, style_layers = 'relu3, relu4, relu5', device = torch.device(f'cuda:0')):
super().__init__()
self.vgg = VGG19(device = device)
self.style_layers = []
for style_layer in style_layers.split(','):
self.style_layers.append(int(style_layer[-1]) - 1)
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
self.normalization = Normalization(cnn_normalization_mean, cnn_normalization_std)
def forward(self, input, target, mask_input, mask_target):
B, _, H, W = input.shape
input = self.normalization(input)
target = self.normalization(target)
pred_feats = self.vgg(input)
gt_feats = self.vgg(target)
loss = 0
mb, mc, mh, mw = mask_input.shape
for style_layer in self.style_layers:
pred_feat = pred_feats[style_layer]
gt_feat = gt_feats[style_layer]
ib,ic,ih,iw = pred_feat.size()
iF = pred_feat.view(ib,ic,-1)
tb,tc,th,tw = gt_feat.size()
tF = gt_feat.view(tb,tc,-1)
for i in range(mb):
# resize mask to have the same size of the feature
maski = F.interpolate(mask_input[i:i+1], size = (ih, iw), mode = 'nearest')
mask_flat_i = maski.view(mc, -1)
maskt = F.interpolate(mask_target[i:i+1], size = (th, tw), mode = 'nearest')
mask_flat_t = maskt.view(mc, -1)
for j in range(mc):
# get features for each part
idx = torch.nonzero(mask_flat_i[j]).squeeze()
if len(idx.shape) == 0 or idx.shape[0] == 0:
continue
ipart = torch.index_select(iF[i:i+1], 2, idx)
idx = torch.nonzero(mask_flat_t[j]).squeeze()
if len(idx.shape) == 0 or idx.shape[0] == 0:
continue
tpart = torch.index_select(tF[i:i+1], 2, idx)
iGram = torch.bmm(ipart, ipart.transpose(1,2)).div_(ipart.shape[1] * ipart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc
tGram = torch.bmm(tpart, tpart.transpose(1,2)).div_(tpart.shape[1] * tpart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc
loss += torch.sum((iGram - tGram) ** 2)
return loss/tb
class styleLossMaskv3(nn.Module):
def __init__(self, style_layers = 'relu1, relu2, relu3, relu4, relu5', device = torch.device(f'cuda:0')):
super().__init__()
self.vgg = VGG19(device = device)
self.style_layers = []
for style_layer in style_layers.split(','):
self.style_layers.append(int(style_layer[-1]) - 1)
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
self.normalization = Normalization(cnn_normalization_mean, cnn_normalization_std)
def forward_img_img(self, input, target, mask_input, mask_target):
B, _, H, W = input.shape
input = self.normalization(input)
target = self.normalization(target)
pred_feats = self.vgg(input)
gt_feats = self.vgg(target)
loss = 0
mb, mc, mh, mw = mask_input.shape
for style_layer in self.style_layers:
pred_feat = pred_feats[style_layer]
gt_feat = gt_feats[style_layer]
ib,ic,ih,iw = pred_feat.size()
iF = pred_feat.view(ib,ic,-1)
tb,tc,th,tw = gt_feat.size()
tF = gt_feat.view(tb,tc,-1)
for i in range(mb):
# resize mask to have the same size of the feature
maski = F.interpolate(mask_input[i:i+1], size = (ih, iw), mode = 'nearest')
mask_flat_i = maski.view(mc, -1)
maskt = F.interpolate(mask_target[i:i+1], size = (th, tw), mode = 'nearest')
mask_flat_t = maskt.view(mc, -1)
for j in range(mc):
# get features for each part
idx = torch.nonzero(mask_flat_i[j]).squeeze()
if len(idx.shape) == 0 or idx.shape[0] == 0:
continue
ipart = torch.index_select(iF[i:i+1], 2, idx)
idx = torch.nonzero(mask_flat_t[j]).squeeze()
if len(idx.shape) == 0 or idx.shape[0] == 0:
continue
tpart = torch.index_select(tF[i:i+1], 2, idx)
iGram = torch.bmm(ipart, ipart.transpose(1,2)).div_(ipart.shape[1] * ipart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc
tGram = torch.bmm(tpart, tpart.transpose(1,2)).div_(tpart.shape[1] * tpart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc
#loss += torch.sum((iGram - tGram) ** 2)
loss += F.mse_loss(iGram, tGram)
#return loss/tb
return loss * 100000 / tb
def forward_patch_img(self, input, target, mask_target):
input = self.normalization(input)
target = self.normalization(target)
pred_feats = self.vgg(input)
gt_feats = self.vgg(target)
patch_num = input.shape[0] // target.shape[0]
loss = 0
for style_layer in self.style_layers:
pred_feat = pred_feats[style_layer]
gt_feat = gt_feats[style_layer]
ib,ic,ih,iw = pred_feat.size()
iF = pred_feat.view(ib,ic,-1)
tb,tc,th,tw = gt_feat.size()
tF = gt_feat.view(tb,tc,-1)
for i in range(tb):
# resize mask to have the same size of the feature
maskt = F.interpolate(mask_target[i:i+1], size = (th, tw), mode = 'nearest')
mask_flat_t = maskt.view(-1)
idx = torch.nonzero(mask_flat_t).squeeze()
if len(idx.shape) == 0 or idx.shape[0] == 0:
continue
tpart = torch.index_select(tF[i:i+1], 2, idx)
tGram = torch.bmm(tpart, tpart.transpose(1,2)).div_(tpart.shape[1] * tpart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc
ipart = iF[i * patch_num: (i + 1) * patch_num]
iGram = torch.bmm(ipart, ipart.transpose(1,2)).div_(ipart.shape[1] * ipart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc
#loss += torch.sum((iGram - tGram.repeat(patch_num, 1, 1)) ** 2)
loss += F.mse_loss(iGram, tGram.repeat(patch_num, 1, 1))
return loss/ib * 100000
class KLDLoss(nn.Module):
def forward(self, mu, logvar):
return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
class LPIPSorGramMatch(nn.Module):
"""
Version 1. Compare mean and variance cluster-wise.
"""
def __init__(self, style_layers = 'relu3, relu4, relu5', device = torch.device(f'cuda:0'),
weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0],):
super().__init__()
self.vgg = VGG19(device = device)
self.style_layers = []
for style_layer in style_layers.split(','):
self.style_layers.append(int(style_layer[-1]) - 1)
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
self.normalization = Normalization(cnn_normalization_mean, cnn_normalization_std)
self.criterion = nn.L1Loss()
self.weights = weights
def forward(self, pred, gt, mode = 'lpips'):
"""
INPUTS:
- pred: (B, 3, H, W)
- gt: (B, 3, H, W)
- seg: (B, H, W)
"""
# extract features for images
B, _, H, W = pred.shape
pred = self.normalization(pred)
gt = self.normalization(gt)
pred_feats = self.vgg(pred)
gt_feats = self.vgg(gt)
if mode == 'lpips':
lpips_loss = 0
for i in range(len(pred_feats)):
lpips_loss += self.weights[i] * self.criterion(pred_feats[i], gt_feats[i].detach())
return lpips_loss
elif mode == 'gram_match':
gram_match_loss = 0
for style_layer in self.style_layers:
pred_feat = pred_feats[style_layer]
gt_feat = gt_feats[style_layer]
pred_gram = GramMatrix()(pred_feat)
gt_gram = GramMatrix()(gt_feat)
gram_match_loss += torch.sum((pred_gram - gt_gram) ** 2) / B
return gram_match_loss
else:
raise ValueError("Only computes lpips or gram match loss.")
class styleLossMaskv4(nn.Module):
def __init__(self, style_layers = 'relu3, relu4, relu5', device = torch.device(f'cuda:0')):
super().__init__()
self.vgg = VGG19(device = device)
self.style_layers = []
for style_layer in style_layers.split(','):
self.style_layers.append(int(style_layer[-1]) - 1)
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
self.normalization = Normalization(cnn_normalization_mean, cnn_normalization_std)
def forward_img_img(self, input, target, mask_input, mask_target):
B, _, H, W = input.shape
input = self.normalization(input)
target = self.normalization(target)
pred_feats = self.vgg(input)
gt_feats = self.vgg(target)
loss = 0
mb, mc, mh, mw = mask_input.shape
for style_layer in self.style_layers:
pred_feat = pred_feats[style_layer]
gt_feat = gt_feats[style_layer]
ib,ic,ih,iw = pred_feat.size()
iF = pred_feat.view(ib,ic,-1)
tb,tc,th,tw = gt_feat.size()
tF = gt_feat.view(tb,tc,-1)
for i in range(mb):
# resize mask to have the same size of the feature
maski = F.interpolate(mask_input[i:i+1], size = (ih, iw), mode = 'nearest')
mask_flat_i = maski.view(mc, -1)
maskt = F.interpolate(mask_target[i:i+1], size = (th, tw), mode = 'nearest')
mask_flat_t = maskt.view(mc, -1)
for j in range(mc):
# get features for each part
idx = torch.nonzero(mask_flat_i[j]).squeeze()
if len(idx.shape) == 0 or idx.shape[0] == 0:
continue
ipart = torch.index_select(iF[i:i+1], 2, idx)
idx = torch.nonzero(mask_flat_t[j]).squeeze()
if len(idx.shape) == 0 or idx.shape[0] == 0:
continue
tpart = torch.index_select(tF[i:i+1], 2, idx)
iGram = torch.bmm(ipart, ipart.transpose(1,2)).div_(ipart.shape[1] * ipart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc
iMean = torch.mean(ipart, dim=2)
tGram = torch.bmm(tpart, tpart.transpose(1,2)).div_(tpart.shape[1] * tpart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc
tMean = torch.mean(tpart, dim=2)
loss += torch.sum((iGram - tGram) ** 2) + torch.sum((iMean - tMean) ** 2) * 0.01
return loss/tb
def forward_patch_img(self, input, target, mask_target):
input = self.normalization(input)
target = self.normalization(target)
pred_feats = self.vgg(input)
gt_feats = self.vgg(target)
patch_num = input.shape[0] // target.shape[0]
loss = 0
for style_layer in self.style_layers:
pred_feat = pred_feats[style_layer]
gt_feat = gt_feats[style_layer]
ib,ic,ih,iw = pred_feat.size()
iF = pred_feat.view(ib,ic,-1)
tb,tc,th,tw = gt_feat.size()
tF = gt_feat.view(tb,tc,-1)
for i in range(tb):
# resize mask to have the same size of the feature
maskt = F.interpolate(mask_target[i:i+1], size = (th, tw), mode = 'nearest')
mask_flat_t = maskt.view(-1)
idx = torch.nonzero(mask_flat_t).squeeze()
if len(idx.shape) == 0 or idx.shape[0] == 0:
continue
tpart = torch.index_select(tF[i:i+1], 2, idx)
tGram = torch.bmm(tpart, tpart.transpose(1,2)).div_(tpart.shape[1] * tpart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc
tMean = torch.mean(tpart, dim=2)
ipart = iF[i * patch_num: (i + 1) * patch_num]
iMean = torch.mean(ipart, dim=2)
iGram = torch.bmm(ipart, ipart.transpose(1,2)).div_(ipart.shape[1] * ipart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc
loss += torch.sum((iGram - tGram.repeat(patch_num, 1, 1)) ** 2)
loss += torch.sum((iMean - tMean.repeat(patch_num, 1)) ** 2) * 0.01
return loss/ib