import torch import torch.nn as nn import torch.nn.functional as F class ContentLoss(nn.Module): def __init__(self, target,): super().__init__() self.target = target.detach() def forward(self, input): self.loss = F.mse_loss(input, self.target) return input class StyleLoss(nn.Module): def __init__(self, target_feature): super().__init__() self.target = self.gram_matrix(target_feature).detach() def gram_matrix(self,input): a, b, c, d = input.size() features = input.view(a * b, c * d) G = torch.mm(features, features.t()) return G.div(a * b * c * d) def forward(self, input): G = self.gram_matrix(input) self.loss = F.mse_loss(G, self.target) return input class Normalization(nn.Module): def __init__(self, mean, std): super().__init__() self.mean = torch.tensor(mean).view(-1, 1, 1) self.std = torch.tensor(std).view(-1, 1, 1) def forward(self, img): return (img - self.mean) / self.std