Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
class ContentLoss(nn.Module): | |
""" | |
Content Loss for the neural style transfer algorithm. | |
""" | |
def __init__(self, target: torch.Tensor, device: torch.device) -> None: | |
super(ContentLoss, self).__init__() | |
batch_size, channels, height, width = target.size() | |
target = target.view(batch_size * channels, height * width) | |
self.target = target.detach().to(device) | |
def __str__(self) -> str: | |
return "Content loss" | |
def forward(self, input: torch.Tensor) -> torch.Tensor: | |
batch_size, channels, height, width = input.size() | |
input = input.view(batch_size * channels, height * width) | |
return F.mse_loss(input, self.target) | |
class StyleLoss(nn.Module): | |
""" | |
Style loss for the neural style transfer algorithm. | |
""" | |
def __init__(self, target: torch.Tensor, device: torch.device) -> None: | |
super(StyleLoss, self).__init__() | |
self.target = self.compute_gram_matrix(target).detach().to(device) | |
def __str__(self) -> str: | |
return "Style loss" | |
def forward(self, input: torch.Tensor) -> torch.Tensor: | |
input = self.compute_gram_matrix(input) | |
return F.mse_loss(input, self.target) | |
def compute_gram_matrix(self, input: torch.Tensor) -> torch.Tensor: | |
batch_size, channels, height, width = input.size() | |
input = input.view(batch_size * channels, height * width) | |
return torch.matmul(input, input.T).div(batch_size * channels * height * width) |