Spaces:
Running
on
Zero
Running
on
Zero
from tqdm import tqdm | |
import torch | |
import torch.optim as optim | |
import torch.nn.functional as F | |
def _gram_matrix(feature): | |
batch_size, n_feature_maps, height, width = feature.size() | |
new_feature = feature.view(batch_size * n_feature_maps, height * width) | |
return torch.mm(new_feature, new_feature.t()) | |
def _compute_loss(generated_features, content_features, style_features, alpha, beta): | |
content_loss = 0 | |
style_loss = 0 | |
w_l = 1 / len(generated_features) | |
for gf, cf, sf in zip(generated_features, content_features, style_features): | |
content_loss += F.mse_loss(gf, cf) | |
G = _gram_matrix(gf) | |
A = _gram_matrix(sf) | |
style_loss += w_l * F.mse_loss(G, A) | |
return alpha * content_loss + beta * style_loss | |
def inference( | |
*, | |
model, | |
content_image, | |
style_features, | |
lr, | |
iterations=100, | |
optim_caller=optim.AdamW, | |
alpha=1, | |
beta=1 | |
): | |
generated_image = content_image.clone().requires_grad_(True) | |
optimizer = optim_caller([generated_image], lr=lr) | |
min_losses = [float('inf')] * iterations | |
with torch.no_grad(): | |
content_features = model(content_image) | |
def closure(iter): | |
optimizer.zero_grad() | |
generated_features = model(generated_image) | |
total_loss = _compute_loss(generated_features, content_features, style_features, alpha, beta) | |
total_loss.backward() | |
min_losses[iter] = min(min_losses[iter], total_loss.item()) | |
return total_loss | |
for iter in tqdm(range(iterations), desc='The magic is happening ✨'): | |
optimizer.step(lambda: closure(iter)) | |
if iter % 10 == 0: print(f'Loss ({iter}):', min_losses[iter]) | |
return generated_image |