import torch.nn as nn import torch def gram_matrix(input): a, b, c, d = input.size() features = input.view(a * b, c * d) G = torch.mm(features, features.t()) return G.div(a * b * c * d) class ContentLoss(nn.Module): def __init__(self, target): super().__init__() self.target = target.detach() def forward(self, input): self.loss = nn.functional.mse_loss(input, self.target) return input class StyleLoss(nn.Module): def __init__(self, target_feature): super().__init__() self.target = gram_matrix(target_feature).detach() def forward(self, input): G = gram_matrix(input) self.loss = nn.functional.mse_loss(G, self.target) return input class Normalization(nn.Module): def __init__(self, mean, std): super().__init__() self.mean = mean.view(-1, 1, 1) self.std = std.view(-1, 1, 1) def forward(self, img): return (img - self.mean) / self.std