Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,223 Bytes
91d9343 a9077eb 91d9343 a9077eb 91d9343 a9077eb 91d9343 a9077eb 91d9343 a9077eb 91d9343 bbcd902 91d9343 a9077eb 91d9343 ce6dca2 d3ca146 91d9343 a9077eb 91d9343 06894c7 d3ca146 c5d8238 91d9343 a9077eb bbcd902 a9077eb 246dd82 b9f6209 91d9343 06894c7 a9077eb 91d9343 b9f6209 246dd82 b9f6209 fa762f9 a9077eb c59d0bc 91d9343 06894c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
from tqdm import tqdm
import torch
import torch.optim as optim
import torch.nn.functional as F
from torchvision.transforms.functional import gaussian_blur
from torchvision import models
def _gram_matrix(feature):
batch_size, n_feature_maps, height, width = feature.size()
new_feature = feature.view(batch_size * n_feature_maps, height * width)
return torch.mm(new_feature, new_feature.t())
def _compute_loss(generated_features, content_features, style_features, resized_bg_masks, alpha, beta):
content_loss = 0
style_loss = 0
w_l = 1 / len(generated_features)
for i, (gf, cf, sf) in enumerate(zip(generated_features, content_features, style_features)):
content_loss += F.mse_loss(gf, cf)
if resized_bg_masks:
blurred_bg_mask = gaussian_blur(resized_bg_masks[i], kernel_size=5)
masked_gf = gf * blurred_bg_mask
masked_sf = sf * blurred_bg_mask
G = _gram_matrix(masked_gf)
A = _gram_matrix(masked_sf)
else:
G = _gram_matrix(gf)
A = _gram_matrix(sf)
style_loss += w_l * F.mse_loss(G, A)
return alpha * content_loss + beta * style_loss
def inference(
*,
model,
segmentation_model,
content_image,
style_features,
apply_to_background,
lr,
iterations=101,
optim_caller=optim.AdamW,
alpha=1,
beta=1,
):
generated_image = content_image.clone().requires_grad_(True)
optimizer = optim_caller([generated_image], lr=lr)
min_losses = [float('inf')] * iterations
with torch.no_grad():
content_features = model(content_image)
resized_bg_masks = []
if apply_to_background:
segmentation_output = segmentation_model(content_image)['out']
segmentation_mask = segmentation_output.argmax(dim=1)
background_mask = (segmentation_mask == 0).float()
foreground_mask = (segmentation_mask != 0).float()
for cf in content_features:
_, _, h_i, w_i = cf.shape
bg_mask = F.interpolate(background_mask.unsqueeze(1), size=(h_i, w_i), mode='bilinear', align_corners=False)
resized_bg_masks.append(bg_mask)
def closure(iter):
optimizer.zero_grad()
generated_features = model(generated_image)
total_loss = _compute_loss(
generated_features, content_features, style_features, resized_bg_masks, alpha, beta
)
total_loss.backward()
min_losses[iter] = min(min_losses[iter], total_loss.item())
return total_loss
for iter in tqdm(range(iterations), desc='The magic is happening ✨'):
optimizer.step(lambda: closure(iter))
if apply_to_background:
with torch.no_grad():
foreground_mask_resized = F.interpolate(foreground_mask.unsqueeze(1), size=generated_image.shape[2:], mode='nearest')
generated_image.data = generated_image.data * (1 - foreground_mask_resized) + content_image.data * foreground_mask_resized
if iter % 10 == 0: print(f'Loss ({iter}):', min_losses[iter])
return generated_image |