incrl's picture
Initial Upload (attempt 2)
5b557cf verified
import torch
import torch.nn as nn
class styleLoss(nn.Module):
def forward(self,input,target):
ib,ic,ih,iw = input.size()
iF = input.view(ib,ic,-1)
iMean = torch.mean(iF,dim=2)
iCov = GramMatrix()(input)
tb,tc,th,tw = target.size()
tF = target.view(tb,tc,-1)
tMean = torch.mean(tF,dim=2)
tCov = GramMatrix()(target)
loss = nn.MSELoss(size_average=False)(iMean,tMean) + nn.MSELoss(size_average=False)(iCov,tCov)
return loss/tb
class GramMatrix(nn.Module):
def forward(self,input):
b, c, h, w = input.size()
f = input.view(b,c,h*w) # bxcx(hxw)
# torch.bmm(batch1, batch2, out=None) #
# batch1: bxmxp, batch2: bxpxn -> bxmxn #
G = torch.bmm(f,f.transpose(1,2)) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc
return G.div_(c*h*w)
class LossCriterion(nn.Module):
def __init__(self,style_layers,content_layers,style_weight,content_weight):
super(LossCriterion,self).__init__()
self.style_layers = style_layers
self.content_layers = content_layers
self.style_weight = style_weight
self.content_weight = content_weight
self.styleLosses = [styleLoss()] * len(style_layers)
self.contentLosses = [nn.MSELoss()] * len(content_layers)
def forward(self,tF,sF,cF):
# content loss
totalContentLoss = 0
for i,layer in enumerate(self.content_layers):
cf_i = cF[layer]
cf_i = cf_i.detach()
tf_i = tF[layer]
loss_i = self.contentLosses[i]
totalContentLoss += loss_i(tf_i,cf_i)
totalContentLoss = totalContentLoss * self.content_weight
# style loss
totalStyleLoss = 0
for i,layer in enumerate(self.style_layers):
sf_i = sF[layer]
sf_i = sf_i.detach()
tf_i = tF[layer]
loss_i = self.styleLosses[i]
totalStyleLoss += loss_i(tf_i,sf_i)
totalStyleLoss = totalStyleLoss * self.style_weight
loss = totalStyleLoss + totalContentLoss
return loss,totalStyleLoss,totalContentLoss