Neural-Style-Transfer / loss_functions.py
FrozenWolf's picture
Functional Processing Loop
4acfd41
import torch
import torch.nn as nn
import torch.nn.functional as F
class ContentLoss(nn.Module):
def __init__(self, target,):
super().__init__()
self.target = target.detach()
def forward(self, input):
self.loss = F.mse_loss(input, self.target)
return input
class StyleLoss(nn.Module):
def __init__(self, target_feature):
super().__init__()
self.target = self.gram_matrix(target_feature).detach()
def gram_matrix(self,input):
a, b, c, d = input.size()
features = input.view(a * b, c * d)
G = torch.mm(features, features.t())
return G.div(a * b * c * d)
def forward(self, input):
G = self.gram_matrix(input)
self.loss = F.mse_loss(G, self.target)
return input
class Normalization(nn.Module):
def __init__(self, mean, std):
super().__init__()
self.mean = torch.tensor(mean).view(-1, 1, 1)
self.std = torch.tensor(std).view(-1, 1, 1)
def forward(self, img):
return (img - self.mean) / self.std