Spaces:
Sleeping
Sleeping
import torch.nn as nn | |
import torch | |
def gram_matrix(input): | |
a, b, c, d = input.size() | |
features = input.view(a * b, c * d) | |
G = torch.mm(features, features.t()) | |
return G.div(a * b * c * d) | |
class ContentLoss(nn.Module): | |
def __init__(self, target): | |
super().__init__() | |
self.target = target.detach() | |
def forward(self, input): | |
self.loss = nn.functional.mse_loss(input, self.target) | |
return input | |
class StyleLoss(nn.Module): | |
def __init__(self, target_feature): | |
super().__init__() | |
self.target = gram_matrix(target_feature).detach() | |
def forward(self, input): | |
G = gram_matrix(input) | |
self.loss = nn.functional.mse_loss(G, self.target) | |
return input | |
class Normalization(nn.Module): | |
def __init__(self, mean, std): | |
super().__init__() | |
self.mean = mean.view(-1, 1, 1) | |
self.std = std.view(-1, 1, 1) | |
def forward(self, img): | |
return (img - self.mean) / self.std |