Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,517 Bytes
91d9343 246dd82 91d9343 06894c7 246dd82 91d9343 246dd82 91d9343 06894c7 91d9343 246dd82 91d9343 06894c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
from tqdm import tqdm
import torch
import torch.optim as optim
import torch.nn.functional as F
def _gram_matrix(feature):
batch_size, n_feature_maps, height, width = feature.size()
new_feature = feature.view(batch_size * n_feature_maps, height * width)
return torch.mm(new_feature, new_feature.t())
def _compute_loss(generated_features, content_features, style_features, alpha, beta):
content_loss = 0
style_loss = 0
w_l = 1 / len(generated_features)
for gf, cf, sf in zip(generated_features, content_features, style_features):
content_loss += F.mse_loss(gf, cf)
G = _gram_matrix(gf)
A = _gram_matrix(sf)
style_loss += w_l * F.mse_loss(G, A)
return alpha * content_loss + beta * style_loss
def inference(
*,
model,
content_image,
style_features,
lr,
iterations=3,
optim_caller=optim.LBFGS,
alpha=1,
beta=1
):
generated_image = content_image.clone().requires_grad_(True)
optimizer = optim_caller([generated_image], lr=lr)
with torch.no_grad():
content_features = model(content_image)
def closure():
optimizer.zero_grad()
generated_features = model(generated_image)
total_loss = _compute_loss(generated_features, content_features, style_features, alpha, beta)
total_loss.backward()
return total_loss
for _ in tqdm(range(iterations), desc='The magic is happening ✨'):
optimizer.step(closure)
return generated_image |