File size: 3,988 Bytes
cc9f69c 91d9343 a9077eb 8d1740c cc9f69c 91d9343 a9077eb 91d9343 a9077eb 91d9343 a9077eb 91d9343 a9077eb 349bdfb 91d9343 bbcd902 91d9343 a9077eb 91d9343 ce6dca2 d3ca146 91d9343 a9077eb 91d9343 cc9f69c 06894c7 d3ca146 c5d8238 91d9343 a9077eb e21f7c8 bbcd902 a9077eb 8d1740c e21f7c8 a9077eb 246dd82 b9f6209 91d9343 06894c7 349bdfb a9077eb 91d9343 349bdfb cc9f69c b9f6209 349bdfb 246dd82 cc9f69c fa762f9 a9077eb cc9f69c e21f7c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import os
from tqdm import tqdm
import torch
import torch.optim as optim
import torch.nn.functional as F
from torchvision.transforms.functional import gaussian_blur
DEV_MODE = os.environ.get('DEV_MODE', None)
print('DEV MODE:', True if DEV_MODE else False)
def _gram_matrix(feature):
batch_size, n_feature_maps, height, width = feature.size()
new_feature = feature.view(batch_size * n_feature_maps, height * width)
return torch.mm(new_feature, new_feature.t())
def _compute_loss(generated_features, content_features, style_features, resized_bg_masks, alpha, beta):
content_loss = 0
style_loss = 0
w_l = 1 / len(generated_features)
for i, (gf, cf, sf) in enumerate(zip(generated_features, content_features, style_features)):
content_loss += F.mse_loss(gf, cf)
if resized_bg_masks:
blurred_bg_mask = gaussian_blur(resized_bg_masks[i], kernel_size=5)
masked_gf = gf * blurred_bg_mask
masked_sf = sf * blurred_bg_mask
G = _gram_matrix(masked_gf)
A = _gram_matrix(masked_sf)
else:
G = _gram_matrix(gf)
A = _gram_matrix(sf)
style_loss += w_l * F.mse_loss(G, A)
total_loss = alpha * content_loss + beta * style_loss
return content_loss, style_loss, total_loss
def inference(
*,
model,
segmentation_model,
content_image,
style_features,
apply_to_background,
lr,
iterations=101,
optim_caller=optim.AdamW,
alpha=1,
beta=1,
):
if DEV_MODE:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
generated_image = content_image.clone().requires_grad_(True)
optimizer = optim_caller([generated_image], lr=lr)
min_losses = [float('inf')] * iterations
with torch.no_grad():
content_features = model(content_image)
resized_bg_masks = []
background_ratio = None
if apply_to_background:
segmentation_output = segmentation_model(content_image)['out']
segmentation_mask = segmentation_output.argmax(dim=1)
background_mask = (segmentation_mask == 0).float()
foreground_mask = 1 - background_mask
background_pixel_count = background_mask.sum().item()
total_pixel_count = segmentation_mask.numel()
background_ratio = background_pixel_count / total_pixel_count
for cf in content_features:
_, _, h_i, w_i = cf.shape
bg_mask = F.interpolate(background_mask.unsqueeze(1), size=(h_i, w_i), mode='bilinear', align_corners=False)
resized_bg_masks.append(bg_mask)
def closure(iter):
optimizer.zero_grad()
generated_features = model(generated_image)
content_loss, style_loss, total_loss = _compute_loss(
generated_features, content_features, style_features, resized_bg_masks, alpha, beta
)
total_loss.backward()
# log loss
if DEV_MODE:
writer.add_scalars(f'style-{"background" if apply_to_background else "image"}', {
'Loss/content': content_loss.item(),
'Loss/style': style_loss.item(),
'Loss/total': total_loss.item()
}, iter)
min_losses[iter] = min(min_losses[iter], total_loss.item())
return total_loss
for iter in tqdm(range(iterations)):
optimizer.step(lambda: closure(iter))
if apply_to_background:
with torch.no_grad():
foreground_mask_resized = F.interpolate(foreground_mask.unsqueeze(1), size=generated_image.shape[2:], mode='nearest')
generated_image.data = generated_image.data * (1 - foreground_mask_resized) + content_image.data * foreground_mask_resized
if DEV_MODE:
writer.flush()
writer.close()
return generated_image, background_ratio
|