Spaces:
Runtime error
Runtime error
File size: 5,954 Bytes
cbcb207 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import os
import src.utils.utils as utils
from src.utils.video_utils import create_video_from_intermediate_results
import torch
from torch import nn
from torch.optim import Adam, LBFGS
from torch.autograd import Variable
class ContentLoss(nn.Module):
def __init__(self, target):
super(ContentLoss, self).__init__()
self.target = target.detach()
def forward(self, current):
return nn.MSELoss(reduction='mean')(self.target, current)
class StyleLoss(nn.Module):
def __init__(self):
super(StyleLoss, self).__init__()
self.loss = 0.0
def forward(self, x, y):
for gram_gt, gram_hat in zip(x, y):
self.loss += torch.nn.MSELoss(reduction='sum')(gram_gt[0], gram_hat[0])
self.loss /= len(x)
return self.loss
class Build(nn.Module):
def __init__(
self,
config,
target_content_representation,
target_style_representation,
):
super(Build, self).__init__()
self.current_set_of_feature_maps = None
self.current_content_representation = None
self.current_Style_representation = None
self.config = config
self.target_content_representation = target_content_representation
self.target_style_representation = target_style_representation
def forward(self, model, x):
self.current_set_of_feature_maps = model(x)
self.current_content_representation = self.current_set_of_feature_maps[
self.config.content_feature_maps_index].squeeze(axis=0)
self.current_style_representation = [
utils.gram_matrix(x)
for cnt, x in enumerate(self.current_set_of_feature_maps)
if cnt in self.config.style_feature_maps_indices
]
content_loss = ContentLoss(self.target_content_representation)(
self.current_content_representation)
style_loss = StyleLoss()(
self.target_style_representation,
self.current_style_representation)
tv_loss = TotalVariationLoss(x)()
return Loss()(content_loss, style_loss, tv_loss)
class TotalVariationLoss(nn.Module):
def __init__(self, y):
super(TotalVariationLoss, self).__init__()
self.first = torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:]))
self.second = torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :]))
def forward(self):
return self.first + self.second
class Loss(nn.Module):
def __init__(self):
super(Loss, self).__init__()
def forward(self, x, y, z):
return utils.yamlGet("contentWeight") * x + utils.yamlGet("styleWeight") * y + utils.yamlGet("totalVariationWeight") * z
def neural_style_transfer():
dump_path = os.path.join(os.path.dirname(__file__), "data/transfer")
config = utils.Config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
content_img, style_img, init_img = utils.Images().getImages(device)
optimizing_img = Variable(init_img, requires_grad=True)
output = list(utils.prepare_model(device))
neural_net = output[0]
content_feature_maps_index_name = output[1]
style_feature_maps_indices_names = output[2]
config.content_feature_maps_index = content_feature_maps_index_name[0]
config.style_feature_maps_indices = style_feature_maps_indices_names[0]
content_img_set_of_feature_maps = neural_net(content_img)
style_img_set_of_feature_maps = neural_net(style_img)
target_content_representation = content_img_set_of_feature_maps[
config.content_feature_maps_index].squeeze(axis=0)
target_style_representation = [
utils.gram_matrix(x)
for cnt, x in enumerate(style_img_set_of_feature_maps)
if cnt in config.style_feature_maps_indices
]
if utils.yamlGet('optimizer') == 'Adam':
optimizer = Adam((optimizing_img, ), lr=utils.yamlGet('learning_rate'))
for cnt in range(utils.yamlGet("iterations")):
total_loss = Build(config, target_content_representation,
target_style_representation)(neural_net,
optimizing_img)
total_loss.backward()
optimizer.step()
optimizer.zero_grad()
with torch.no_grad():
utils.save_optimizing_image(optimizing_img, dump_path, cnt)
elif utils.yamlGet('optimizer') == 'LBFGS':
optimizer = LBFGS((optimizing_img, ),
max_iter=utils.yamlGet('iterations'),
line_search_fn='strong_wolfe')
def closure():
total_loss, _, _, _ = build_loss(
neural_net, optimizing_img, target_content_representation,
target_style_representation, config)
total_loss.backward()
optimizer.zero_grad()
with torch.no_grad():
utils.save_optimizing_image(optimizing_img, dump_path, cnt)
return total_loss
for cnt in range(utils.yamlGet("iterations")):
optimizer.step(closure)
create_video_from_intermediate_results(dump_path)
# some values of weights that worked for figures.jpg, vg_starry_night.jpg
# (starting point for finding good images)
# once you understand what each one does it gets really easy -> also see
# README.md
# lbfgs, content init -> (cw, sw, tv) = (1e5, 3e4, 1e0)
# lbfgs, style init -> (cw, sw, tv) = (1e5, 1e1, 1e-1)
# lbfgs, random init -> (cw, sw, tv) = (1e5, 1e3, 1e0)
# adam, content init -> (cw, sw, tv, lr) = (1e5, 1e5, 1e-1, 1e1)
# adam, style init -> (cw, sw, tv, lr) = (1e5, 1e2, 1e-1, 1e1)
# adam, random init -> (cw, sw, tv, lr) = (1e5, 1e2, 1e-1, 1e1)
# original NST Neural Style Transfer) algorithm (Gatys et al.)
# results_path = neural_style_transfer()
# create_video_from_intermediate_results(results_path)
|