File size: 3,988 Bytes
cc9f69c
 
 
91d9343
 
 
a9077eb
8d1740c
cc9f69c
 
 
91d9343
 
 
 
 
a9077eb
91d9343
 
 
a9077eb
 
91d9343
a9077eb
 
 
 
 
 
 
 
 
 
91d9343
a9077eb
349bdfb
 
91d9343
 
 
 
bbcd902
91d9343
 
a9077eb
91d9343
ce6dca2
d3ca146
91d9343
a9077eb
91d9343
cc9f69c
 
 
06894c7
d3ca146
c5d8238
91d9343
 
 
a9077eb
e21f7c8
 
bbcd902
a9077eb
 
 
8d1740c
e21f7c8
 
 
 
a9077eb
 
 
 
 
246dd82
b9f6209
91d9343
06894c7
349bdfb
a9077eb
 
91d9343
349bdfb
 
cc9f69c
 
 
 
 
 
b9f6209
349bdfb
246dd82
 
cc9f69c
fa762f9
a9077eb
 
 
 
 
 
cc9f69c
 
 
e21f7c8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
from tqdm import tqdm

import torch
import torch.optim as optim
import torch.nn.functional as F
from torchvision.transforms.functional import gaussian_blur

DEV_MODE = os.environ.get('DEV_MODE', None)
print('DEV MODE:', True if DEV_MODE else False)

def _gram_matrix(feature):
    batch_size, n_feature_maps, height, width = feature.size()
    new_feature = feature.view(batch_size * n_feature_maps, height * width)
    return torch.mm(new_feature, new_feature.t())

def _compute_loss(generated_features, content_features, style_features, resized_bg_masks, alpha, beta):
    content_loss = 0
    style_loss = 0
    w_l = 1 / len(generated_features)
    
    for i, (gf, cf, sf) in enumerate(zip(generated_features, content_features, style_features)):
        content_loss += F.mse_loss(gf, cf)
        
        if resized_bg_masks:
            blurred_bg_mask = gaussian_blur(resized_bg_masks[i], kernel_size=5)
            masked_gf = gf * blurred_bg_mask
            masked_sf = sf * blurred_bg_mask
            G = _gram_matrix(masked_gf)
            A = _gram_matrix(masked_sf)
        else:
            G = _gram_matrix(gf)
            A = _gram_matrix(sf)
        style_loss += w_l * F.mse_loss(G, A)
        
    total_loss = alpha * content_loss + beta * style_loss
    return content_loss, style_loss, total_loss

def inference(
    *,
    model,
    segmentation_model,
    content_image,
    style_features,
    apply_to_background,
    lr,
    iterations=101,
    optim_caller=optim.AdamW,
    alpha=1,
    beta=1,
):
    if DEV_MODE: 
        from torch.utils.tensorboard import SummaryWriter
        writer = SummaryWriter()
    generated_image = content_image.clone().requires_grad_(True)
    optimizer = optim_caller([generated_image], lr=lr)
    min_losses = [float('inf')] * iterations

    with torch.no_grad():
        content_features = model(content_image)

        resized_bg_masks = []
        background_ratio = None
        if apply_to_background:            
            segmentation_output = segmentation_model(content_image)['out']
            segmentation_mask = segmentation_output.argmax(dim=1)
            background_mask = (segmentation_mask == 0).float()
            foreground_mask = 1 - background_mask
            
            background_pixel_count = background_mask.sum().item()
            total_pixel_count = segmentation_mask.numel()
            background_ratio = background_pixel_count / total_pixel_count

            for cf in content_features:
                _, _, h_i, w_i = cf.shape
                bg_mask = F.interpolate(background_mask.unsqueeze(1), size=(h_i, w_i), mode='bilinear', align_corners=False)
                resized_bg_masks.append(bg_mask)
        
    def closure(iter):
        optimizer.zero_grad()
        generated_features = model(generated_image)
        content_loss, style_loss, total_loss = _compute_loss(
            generated_features, content_features, style_features, resized_bg_masks, alpha, beta
        )
        total_loss.backward()
        
        # log loss
        if DEV_MODE:
            writer.add_scalars(f'style-{"background" if apply_to_background else "image"}', {
                'Loss/content': content_loss.item(),
                'Loss/style': style_loss.item(),
                'Loss/total': total_loss.item()
            }, iter)
        min_losses[iter] = min(min_losses[iter], total_loss.item())
        
        return total_loss
    
    for iter in tqdm(range(iterations)):
        optimizer.step(lambda: closure(iter))

        if apply_to_background:
            with torch.no_grad():
                foreground_mask_resized = F.interpolate(foreground_mask.unsqueeze(1), size=generated_image.shape[2:], mode='nearest')
                generated_image.data = generated_image.data * (1 - foreground_mask_resized) + content_image.data * foreground_mask_resized

    if DEV_MODE:
        writer.flush()
        writer.close()
    return generated_image, background_ratio