Spaces:
Runtime error
Runtime error
""" | |
07-21 | |
StyleLoss to encourage style statistics to be consistent within each cluster. | |
""" | |
import torch | |
import torchvision | |
import torch.nn as nn | |
import torch.nn.functional as F | |
# VGG architecter, used for the perceptual loss using a pretrained VGG network | |
class VGG19(torch.nn.Module): | |
def __init__(self, requires_grad=False, device = torch.device(f'cuda:0')): | |
super().__init__() | |
vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features | |
self.slice1 = torch.nn.Sequential() | |
self.slice2 = torch.nn.Sequential() | |
self.slice3 = torch.nn.Sequential() | |
self.slice4 = torch.nn.Sequential() | |
self.slice5 = torch.nn.Sequential() | |
for x in range(2): | |
self.slice1.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(2, 7): | |
self.slice2.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(7, 12): | |
self.slice3.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(12, 21): | |
self.slice4.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(21, 30): | |
self.slice5.add_module(str(x), vgg_pretrained_features[x]) | |
if not requires_grad: | |
for param in self.parameters(): | |
param.requires_grad = False | |
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) | |
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device) | |
self.normalization = Normalization(cnn_normalization_mean, cnn_normalization_std) | |
def forward(self, X): | |
#X = self.normalization(X) | |
h_relu1 = self.slice1(X) | |
h_relu2 = self.slice2(h_relu1) | |
h_relu3 = self.slice3(h_relu2) | |
h_relu4 = self.slice4(h_relu3) | |
h_relu5 = self.slice5(h_relu4) | |
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] | |
return out | |
# create a module to normalize input image so we can easily put it in a | |
# nn.Sequential | |
class Normalization(nn.Module): | |
def __init__(self, mean, std): | |
super(Normalization, self).__init__() | |
# .view the mean and std to make them [C x 1 x 1] so that they can | |
# directly work with image Tensor of shape [B x C x H x W]. | |
# B is batch size. C is number of channels. H is height and W is width. | |
self.mean = torch.tensor(mean).view(-1, 1, 1) | |
self.std = torch.tensor(std).view(-1, 1, 1) | |
def forward(self, img): | |
# normalize img | |
return (img - self.mean) / self.std | |
class GramMatrix(nn.Module): | |
def forward(self,input): | |
b, c, h, w = input.size() | |
f = input.view(b,c,h*w) # bxcx(hxw) | |
# torch.bmm(batch1, batch2, out=None) # | |
# batch1: bxmxp, batch2: bxpxn -> bxmxn # | |
G = torch.bmm(f,f.transpose(1,2)) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc | |
return G.div_(c*h*w) | |
class StyleLoss(nn.Module): | |
""" | |
Version 1. Compare mean and variance cluster-wise. | |
""" | |
def __init__(self, style_layers = 'relu3, relu4, relu5', device = torch.device(f'cuda:0'), style_mode = 'gram'): | |
super().__init__() | |
self.vgg = VGG19() | |
self.style_layers = [] | |
for style_layer in style_layers.split(','): | |
self.style_layers.append(int(style_layer[-1]) - 1) | |
self.style_mode = style_mode | |
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) | |
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device) | |
self.normalization = Normalization(cnn_normalization_mean, cnn_normalization_std) | |
def forward(self, pred, gt): | |
""" | |
INPUTS: | |
- pred: (B, 3, H, W) | |
- gt: (B, 3, H, W) | |
- seg: (B, H, W) | |
""" | |
# extract features for images | |
B, _, H, W = pred.shape | |
pred = self.normalization(pred) | |
gt = self.normalization(gt) | |
pred_feats = self.vgg(pred) | |
gt_feats = self.vgg(gt) | |
loss = 0 | |
for style_layer in self.style_layers: | |
pred_feat = pred_feats[style_layer] | |
gt_feat = gt_feats[style_layer] | |
pred_gram = GramMatrix()(pred_feat) | |
gt_gram = GramMatrix()(gt_feat) | |
loss += torch.sum((pred_gram - gt_gram) ** 2) / B | |
return loss | |
class styleLossMask(nn.Module): | |
""" | |
Version 1. Compare mean and variance cluster-wise. | |
""" | |
def __init__(self, style_layers = 'relu3, relu4, relu5', device = torch.device(f'cuda:0'), style_mode = 'gram'): | |
super().__init__() | |
self.vgg = VGG19() | |
self.style_layers = [] | |
for style_layer in style_layers.split(','): | |
self.style_layers.append(int(style_layer[-1]) - 1) | |
self.style_mode = style_mode | |
#cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) | |
#cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device) | |
#self.normalization = Normalization(cnn_normalization_mean, cnn_normalization_std) | |
def forward(self, input, target, mask): | |
B, _, H, W = input.shape | |
#pred = self.normalization(input) | |
#target = self.normalization(target) | |
pred_feats = self.vgg(input) | |
gt_feats = self.vgg(target) | |
loss = 0 | |
mb, mc, mh, mw = mask.shape | |
for style_layer in self.style_layers: | |
pred_feat = pred_feats[style_layer] | |
gt_feat = gt_feats[style_layer] | |
ib,ic,ih,iw = pred_feat.size() | |
iF = pred_feat.view(ib,ic,-1) | |
tb,tc,th,tw = gt_feat.size() | |
tF = gt_feat.view(tb,tc,-1) | |
for i in range(mb): | |
# resize mask to have the same size of the feature | |
maski = F.interpolate(mask[i:i+1], size = (ih, iw), mode = 'nearest') | |
mask_flat_i = maski.view(mc, -1) | |
maskt = F.interpolate(mask[i:i+1], size = (th, tw), mode = 'nearest') | |
mask_flat_t = maskt.view(mc, -1) | |
for j in range(mc): | |
# get features for each part | |
idx = torch.nonzero(mask_flat_i[j]).squeeze() | |
if len(idx.shape) == 0 or idx.shape[0] == 0: | |
continue | |
ipart = torch.index_select(iF, 2, idx) | |
idx = torch.nonzero(mask_flat_t[j]).squeeze() | |
if len(idx.shape) == 0 or idx.shape[0] == 0: | |
continue | |
tpart = torch.index_select(tF, 2, idx) | |
iMean = torch.mean(ipart,dim=2) | |
iGram = torch.bmm(ipart, ipart.transpose(1,2)).div_(ic*ih*iw) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc | |
tMean = torch.mean(tpart,dim=2) | |
tGram = torch.bmm(tpart, tpart.transpose(1,2)).div_(tc*th*tw) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc | |
loss_j = nn.MSELoss()(iMean,tMean) + nn.MSELoss()(iGram,tGram) | |
loss += loss_j | |
return loss/tb | |
# Perceptual loss that uses a pretrained VGG network | |
class VGGLoss(nn.Module): | |
def __init__(self, weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0], device = torch.device(f'cuda:0')): | |
super(VGGLoss, self).__init__() | |
self.vgg = VGG19(device = device) | |
self.criterion = nn.L1Loss() | |
self.weights = weights | |
def forward(self, x, y): | |
x_vgg, y_vgg = self.vgg(x), self.vgg(y) | |
loss = 0 | |
for i in range(len(x_vgg)): | |
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) | |
return loss | |
class styleLossMaskv2(nn.Module): | |
def __init__(self, style_layers = 'relu3, relu4, relu5', device = torch.device(f'cuda:0')): | |
super().__init__() | |
self.vgg = VGG19(device = device) | |
self.style_layers = [] | |
for style_layer in style_layers.split(','): | |
self.style_layers.append(int(style_layer[-1]) - 1) | |
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) | |
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device) | |
self.normalization = Normalization(cnn_normalization_mean, cnn_normalization_std) | |
def forward(self, input, target, mask_input, mask_target): | |
B, _, H, W = input.shape | |
input = self.normalization(input) | |
target = self.normalization(target) | |
pred_feats = self.vgg(input) | |
gt_feats = self.vgg(target) | |
loss = 0 | |
mb, mc, mh, mw = mask_input.shape | |
for style_layer in self.style_layers: | |
pred_feat = pred_feats[style_layer] | |
gt_feat = gt_feats[style_layer] | |
ib,ic,ih,iw = pred_feat.size() | |
iF = pred_feat.view(ib,ic,-1) | |
tb,tc,th,tw = gt_feat.size() | |
tF = gt_feat.view(tb,tc,-1) | |
for i in range(mb): | |
# resize mask to have the same size of the feature | |
maski = F.interpolate(mask_input[i:i+1], size = (ih, iw), mode = 'nearest') | |
mask_flat_i = maski.view(mc, -1) | |
maskt = F.interpolate(mask_target[i:i+1], size = (th, tw), mode = 'nearest') | |
mask_flat_t = maskt.view(mc, -1) | |
for j in range(mc): | |
# get features for each part | |
idx = torch.nonzero(mask_flat_i[j]).squeeze() | |
if len(idx.shape) == 0 or idx.shape[0] == 0: | |
continue | |
ipart = torch.index_select(iF[i:i+1], 2, idx) | |
idx = torch.nonzero(mask_flat_t[j]).squeeze() | |
if len(idx.shape) == 0 or idx.shape[0] == 0: | |
continue | |
tpart = torch.index_select(tF[i:i+1], 2, idx) | |
iGram = torch.bmm(ipart, ipart.transpose(1,2)).div_(ipart.shape[1] * ipart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc | |
tGram = torch.bmm(tpart, tpart.transpose(1,2)).div_(tpart.shape[1] * tpart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc | |
loss += torch.sum((iGram - tGram) ** 2) | |
return loss/tb | |
class styleLossMaskv3(nn.Module): | |
def __init__(self, style_layers = 'relu1, relu2, relu3, relu4, relu5', device = torch.device(f'cuda:0')): | |
super().__init__() | |
self.vgg = VGG19(device = device) | |
self.style_layers = [] | |
for style_layer in style_layers.split(','): | |
self.style_layers.append(int(style_layer[-1]) - 1) | |
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) | |
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device) | |
self.normalization = Normalization(cnn_normalization_mean, cnn_normalization_std) | |
def forward_img_img(self, input, target, mask_input, mask_target): | |
B, _, H, W = input.shape | |
input = self.normalization(input) | |
target = self.normalization(target) | |
pred_feats = self.vgg(input) | |
gt_feats = self.vgg(target) | |
loss = 0 | |
mb, mc, mh, mw = mask_input.shape | |
for style_layer in self.style_layers: | |
pred_feat = pred_feats[style_layer] | |
gt_feat = gt_feats[style_layer] | |
ib,ic,ih,iw = pred_feat.size() | |
iF = pred_feat.view(ib,ic,-1) | |
tb,tc,th,tw = gt_feat.size() | |
tF = gt_feat.view(tb,tc,-1) | |
for i in range(mb): | |
# resize mask to have the same size of the feature | |
maski = F.interpolate(mask_input[i:i+1], size = (ih, iw), mode = 'nearest') | |
mask_flat_i = maski.view(mc, -1) | |
maskt = F.interpolate(mask_target[i:i+1], size = (th, tw), mode = 'nearest') | |
mask_flat_t = maskt.view(mc, -1) | |
for j in range(mc): | |
# get features for each part | |
idx = torch.nonzero(mask_flat_i[j]).squeeze() | |
if len(idx.shape) == 0 or idx.shape[0] == 0: | |
continue | |
ipart = torch.index_select(iF[i:i+1], 2, idx) | |
idx = torch.nonzero(mask_flat_t[j]).squeeze() | |
if len(idx.shape) == 0 or idx.shape[0] == 0: | |
continue | |
tpart = torch.index_select(tF[i:i+1], 2, idx) | |
iGram = torch.bmm(ipart, ipart.transpose(1,2)).div_(ipart.shape[1] * ipart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc | |
tGram = torch.bmm(tpart, tpart.transpose(1,2)).div_(tpart.shape[1] * tpart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc | |
#loss += torch.sum((iGram - tGram) ** 2) | |
loss += F.mse_loss(iGram, tGram) | |
#return loss/tb | |
return loss * 100000 / tb | |
def forward_patch_img(self, input, target, mask_target): | |
input = self.normalization(input) | |
target = self.normalization(target) | |
pred_feats = self.vgg(input) | |
gt_feats = self.vgg(target) | |
patch_num = input.shape[0] // target.shape[0] | |
loss = 0 | |
for style_layer in self.style_layers: | |
pred_feat = pred_feats[style_layer] | |
gt_feat = gt_feats[style_layer] | |
ib,ic,ih,iw = pred_feat.size() | |
iF = pred_feat.view(ib,ic,-1) | |
tb,tc,th,tw = gt_feat.size() | |
tF = gt_feat.view(tb,tc,-1) | |
for i in range(tb): | |
# resize mask to have the same size of the feature | |
maskt = F.interpolate(mask_target[i:i+1], size = (th, tw), mode = 'nearest') | |
mask_flat_t = maskt.view(-1) | |
idx = torch.nonzero(mask_flat_t).squeeze() | |
if len(idx.shape) == 0 or idx.shape[0] == 0: | |
continue | |
tpart = torch.index_select(tF[i:i+1], 2, idx) | |
tGram = torch.bmm(tpart, tpart.transpose(1,2)).div_(tpart.shape[1] * tpart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc | |
ipart = iF[i * patch_num: (i + 1) * patch_num] | |
iGram = torch.bmm(ipart, ipart.transpose(1,2)).div_(ipart.shape[1] * ipart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc | |
#loss += torch.sum((iGram - tGram.repeat(patch_num, 1, 1)) ** 2) | |
loss += F.mse_loss(iGram, tGram.repeat(patch_num, 1, 1)) | |
return loss/ib * 100000 | |
class KLDLoss(nn.Module): | |
def forward(self, mu, logvar): | |
return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) | |
class LPIPSorGramMatch(nn.Module): | |
""" | |
Version 1. Compare mean and variance cluster-wise. | |
""" | |
def __init__(self, style_layers = 'relu3, relu4, relu5', device = torch.device(f'cuda:0'), | |
weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0],): | |
super().__init__() | |
self.vgg = VGG19(device = device) | |
self.style_layers = [] | |
for style_layer in style_layers.split(','): | |
self.style_layers.append(int(style_layer[-1]) - 1) | |
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) | |
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device) | |
self.normalization = Normalization(cnn_normalization_mean, cnn_normalization_std) | |
self.criterion = nn.L1Loss() | |
self.weights = weights | |
def forward(self, pred, gt, mode = 'lpips'): | |
""" | |
INPUTS: | |
- pred: (B, 3, H, W) | |
- gt: (B, 3, H, W) | |
- seg: (B, H, W) | |
""" | |
# extract features for images | |
B, _, H, W = pred.shape | |
pred = self.normalization(pred) | |
gt = self.normalization(gt) | |
pred_feats = self.vgg(pred) | |
gt_feats = self.vgg(gt) | |
if mode == 'lpips': | |
lpips_loss = 0 | |
for i in range(len(pred_feats)): | |
lpips_loss += self.weights[i] * self.criterion(pred_feats[i], gt_feats[i].detach()) | |
return lpips_loss | |
elif mode == 'gram_match': | |
gram_match_loss = 0 | |
for style_layer in self.style_layers: | |
pred_feat = pred_feats[style_layer] | |
gt_feat = gt_feats[style_layer] | |
pred_gram = GramMatrix()(pred_feat) | |
gt_gram = GramMatrix()(gt_feat) | |
gram_match_loss += torch.sum((pred_gram - gt_gram) ** 2) / B | |
return gram_match_loss | |
else: | |
raise ValueError("Only computes lpips or gram match loss.") | |
class styleLossMaskv4(nn.Module): | |
def __init__(self, style_layers = 'relu3, relu4, relu5', device = torch.device(f'cuda:0')): | |
super().__init__() | |
self.vgg = VGG19(device = device) | |
self.style_layers = [] | |
for style_layer in style_layers.split(','): | |
self.style_layers.append(int(style_layer[-1]) - 1) | |
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) | |
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device) | |
self.normalization = Normalization(cnn_normalization_mean, cnn_normalization_std) | |
def forward_img_img(self, input, target, mask_input, mask_target): | |
B, _, H, W = input.shape | |
input = self.normalization(input) | |
target = self.normalization(target) | |
pred_feats = self.vgg(input) | |
gt_feats = self.vgg(target) | |
loss = 0 | |
mb, mc, mh, mw = mask_input.shape | |
for style_layer in self.style_layers: | |
pred_feat = pred_feats[style_layer] | |
gt_feat = gt_feats[style_layer] | |
ib,ic,ih,iw = pred_feat.size() | |
iF = pred_feat.view(ib,ic,-1) | |
tb,tc,th,tw = gt_feat.size() | |
tF = gt_feat.view(tb,tc,-1) | |
for i in range(mb): | |
# resize mask to have the same size of the feature | |
maski = F.interpolate(mask_input[i:i+1], size = (ih, iw), mode = 'nearest') | |
mask_flat_i = maski.view(mc, -1) | |
maskt = F.interpolate(mask_target[i:i+1], size = (th, tw), mode = 'nearest') | |
mask_flat_t = maskt.view(mc, -1) | |
for j in range(mc): | |
# get features for each part | |
idx = torch.nonzero(mask_flat_i[j]).squeeze() | |
if len(idx.shape) == 0 or idx.shape[0] == 0: | |
continue | |
ipart = torch.index_select(iF[i:i+1], 2, idx) | |
idx = torch.nonzero(mask_flat_t[j]).squeeze() | |
if len(idx.shape) == 0 or idx.shape[0] == 0: | |
continue | |
tpart = torch.index_select(tF[i:i+1], 2, idx) | |
iGram = torch.bmm(ipart, ipart.transpose(1,2)).div_(ipart.shape[1] * ipart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc | |
iMean = torch.mean(ipart, dim=2) | |
tGram = torch.bmm(tpart, tpart.transpose(1,2)).div_(tpart.shape[1] * tpart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc | |
tMean = torch.mean(tpart, dim=2) | |
loss += torch.sum((iGram - tGram) ** 2) + torch.sum((iMean - tMean) ** 2) * 0.01 | |
return loss/tb | |
def forward_patch_img(self, input, target, mask_target): | |
input = self.normalization(input) | |
target = self.normalization(target) | |
pred_feats = self.vgg(input) | |
gt_feats = self.vgg(target) | |
patch_num = input.shape[0] // target.shape[0] | |
loss = 0 | |
for style_layer in self.style_layers: | |
pred_feat = pred_feats[style_layer] | |
gt_feat = gt_feats[style_layer] | |
ib,ic,ih,iw = pred_feat.size() | |
iF = pred_feat.view(ib,ic,-1) | |
tb,tc,th,tw = gt_feat.size() | |
tF = gt_feat.view(tb,tc,-1) | |
for i in range(tb): | |
# resize mask to have the same size of the feature | |
maskt = F.interpolate(mask_target[i:i+1], size = (th, tw), mode = 'nearest') | |
mask_flat_t = maskt.view(-1) | |
idx = torch.nonzero(mask_flat_t).squeeze() | |
if len(idx.shape) == 0 or idx.shape[0] == 0: | |
continue | |
tpart = torch.index_select(tF[i:i+1], 2, idx) | |
tGram = torch.bmm(tpart, tpart.transpose(1,2)).div_(tpart.shape[1] * tpart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc | |
tMean = torch.mean(tpart, dim=2) | |
ipart = iF[i * patch_num: (i + 1) * patch_num] | |
iMean = torch.mean(ipart, dim=2) | |
iGram = torch.bmm(ipart, ipart.transpose(1,2)).div_(ipart.shape[1] * ipart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc | |
loss += torch.sum((iGram - tGram.repeat(patch_num, 1, 1)) ** 2) | |
loss += torch.sum((iMean - tMean.repeat(patch_num, 1)) ** 2) * 0.01 | |
return loss/ib | |