Spaces:
Runtime error
Runtime error
File size: 2,294 Bytes
9eae6e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from .vgg import VGG19, VGG16
class Perceptual16Loss(nn.Module):
def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]):
super(Perceptual16Loss, self).__init__()
self.vgg = VGG16()
self.criterion = torch.nn.L1Loss()
self.weights = weights
def calculate_pl(self, x, y):
feat_output = self.vgg(x)
feat_gt = self.vgg(y)
content_loss = 0.0
for i in range(3):
content_loss += self.criterion(feat_output[i], feat_gt[i])
return content_loss.to(device=x.device)
def compute_gram(self, x):
b, c, h, w = x.size()
f = x.view(b, c, w * h)
f_T = f.transpose(1, 2)
G = f.bmm(f_T) / (h * w * c)
return G
def calc_style(self, x, y):
feat_output = self.extractor(x)
feat_gt = self.extractor(y)
style_loss = 0.0
for i in range(3):
style_loss += self.criterion(
self.compute_gram(feat_output[i]), self.compute_gram(feat_gt[i]))
return style_loss
class Perceptual19Loss(nn.Module):
def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]):
super(Perceptual19Loss, self).__init__()
self.vgg = VGG19()
self.criterion = torch.nn.L1Loss()
self.weights = weights
def calculate_pl(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
content_loss = 0.0
prefix = [1, 2, 3, 4, 5]
for i in range(5):
content_loss += self.weights[i] * self.criterion(
x_vgg[f'relu{prefix[i]}_1'], y_vgg[f'relu{prefix[i]}_1'])
return content_loss.to(device=x.device)
def compute_gram(self, x):
b, c, h, w = x.size()
f = x.view(b, c, w * h)
f_T = f.transpose(1, 2)
G = f.bmm(f_T) / (h * w * c)
return G
def calc_style(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
style_loss = 0.0
prefix = [2, 3, 4, 5]
posfix = [2, 4, 4, 2]
for pre, pos in list(zip(prefix, posfix)):
style_loss += self.criterion(
self.compute_gram(x_vgg[f'relu{pre}_{pos}']), self.compute_gram(y_vgg[f'relu{pre}_{pos}']))
return style_loss
|