File size: 3,258 Bytes
cc9f69c
 
 
91d9343
 
 
a9077eb
8d1740c
91d9343
 
 
 
 
a9077eb
91d9343
 
 
a9077eb
 
91d9343
a9077eb
 
 
 
 
 
 
 
 
 
91d9343
a9077eb
349bdfb
 
91d9343
 
 
 
814e69a
91d9343
89e4ae0
91d9343
a9077eb
91d9343
ce6dca2
d3ca146
91d9343
a9077eb
91d9343
06894c7
d3ca146
c5d8238
91d9343
 
 
a9077eb
e21f7c8
814e69a
89e4ae0
fc92636
 
a9077eb
8d1740c
a9077eb
 
 
 
 
246dd82
b9f6209
91d9343
06894c7
349bdfb
a9077eb
 
91d9343
349bdfb
 
b9f6209
349bdfb
246dd82
 
cc9f69c
fa762f9
a9077eb
 
 
 
 
814e69a
3b42de6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
from tqdm import tqdm

import torch
import torch.optim as optim
import torch.nn.functional as F
from torchvision.transforms.functional import gaussian_blur

def _gram_matrix(feature):
    batch_size, n_feature_maps, height, width = feature.size()
    new_feature = feature.view(batch_size * n_feature_maps, height * width)
    return torch.mm(new_feature, new_feature.t())

def _compute_loss(generated_features, content_features, style_features, resized_bg_masks, alpha, beta):
    content_loss = 0
    style_loss = 0
    w_l = 1 / len(generated_features)
    
    for i, (gf, cf, sf) in enumerate(zip(generated_features, content_features, style_features)):
        content_loss += F.mse_loss(gf, cf)
        
        if resized_bg_masks:
            blurred_bg_mask = gaussian_blur(resized_bg_masks[i], kernel_size=5)
            masked_gf = gf * blurred_bg_mask
            masked_sf = sf * blurred_bg_mask
            G = _gram_matrix(masked_gf)
            A = _gram_matrix(masked_sf)
        else:
            G = _gram_matrix(gf)
            A = _gram_matrix(sf)
        style_loss += w_l * F.mse_loss(G, A)
        
    total_loss = alpha * content_loss + beta * style_loss
    return content_loss, style_loss, total_loss

def inference(
    *,
    model,
    sod_model,
    content_image,
    content_image_norm,
    style_features,
    apply_to_background,
    lr,
    iterations=101,
    optim_caller=optim.AdamW,
    alpha=1,
    beta=1,
):
    generated_image = content_image.clone().requires_grad_(True)
    optimizer = optim_caller([generated_image], lr=lr)
    min_losses = [float('inf')] * iterations

    with torch.no_grad():
        content_features = model(content_image)

        resized_bg_masks = []
        if apply_to_background:
            segmentation_output = sod_model(content_image_norm)[0]
            segmentation_output = torch.sigmoid(segmentation_output)
            segmentation_mask = (segmentation_output > 0.7).float()
            background_mask = (segmentation_mask == 0).float()
            foreground_mask = 1 - background_mask

            for cf in content_features:
                _, _, h_i, w_i = cf.shape
                bg_mask = F.interpolate(background_mask.unsqueeze(1), size=(h_i, w_i), mode='bilinear', align_corners=False)
                resized_bg_masks.append(bg_mask)
        
    def closure(iter):
        optimizer.zero_grad()
        generated_features = model(generated_image)
        content_loss, style_loss, total_loss = _compute_loss(
            generated_features, content_features, style_features, resized_bg_masks, alpha, beta
        )
        total_loss.backward()
        
        # log loss
        min_losses[iter] = min(min_losses[iter], total_loss.item())
        
        return total_loss
    
    for iter in tqdm(range(iterations)):
        optimizer.step(lambda: closure(iter))

        if apply_to_background:
            with torch.no_grad():
                foreground_mask_resized = F.interpolate(foreground_mask.unsqueeze(1), size=generated_image.shape[2:], mode='nearest')
                generated_image.data = generated_image.data * (1 - foreground_mask_resized) + content_image.data * foreground_mask_resized
                
    return generated_image