File size: 8,352 Bytes
16d007c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import torch
import numpy as np
import torchvision.transforms as T

from models.clip_relevancy import ClipRelevancy
from util.aug_utils import RandomSizeCrop
from util.util import get_screen_template, get_text_criterion, get_augmentations_template


class LossG(torch.nn.Module):
    def __init__(self, cfg, clip_extractor):
        super().__init__()

        self.cfg = cfg

        # calculate target text embeddings
        template = get_augmentations_template()
        self.src_e = clip_extractor.get_text_embedding(cfg["src_text"], template)
        self.target_comp_e = clip_extractor.get_text_embedding(cfg["comp_text"], template)
        self.target_greenscreen_e = clip_extractor.get_text_embedding(cfg["screen_text"], get_screen_template())

        self.clip_extractor = clip_extractor
        self.text_criterion = get_text_criterion(cfg)

        if cfg["bootstrap_epoch"] > 0 and cfg["lambda_bootstrap"] > 0:
            self.relevancy_extractor = ClipRelevancy(cfg)
            self.relevancy_criterion = torch.nn.MSELoss()
            self.lambda_bootstrap = cfg["lambda_bootstrap"]

    def forward(self, outputs, inputs):
        losses = {}
        loss_G = 0

        all_outputs_composite = []
        all_outputs_greenscreen = []
        all_outputs_edit = []
        all_outputs_alpha = []
        all_inputs = []
        for out, ins in zip(["output_crop", "output_image"], ["input_crop", "input_image"]):
            if out not in outputs:
                continue
            all_outputs_composite += outputs[out]["composite"]
            all_outputs_greenscreen += outputs[out]["edit_on_greenscreen"]
            all_outputs_edit += outputs[out]["edit"]
            all_outputs_alpha += outputs[out]["alpha"]
            all_inputs += inputs[ins]

        # calculate alpha bootstrapping loss
        if inputs["step"] < self.cfg["bootstrap_epoch"] and self.cfg["lambda_bootstrap"] > 0:
            losses["loss_bootstrap"] = self.calculate_relevancy_loss(all_outputs_alpha, all_inputs)

            if self.cfg["bootstrap_scheduler"] == "linear":
                lambda_bootstrap = self.cfg["lambda_bootstrap"] * (
                    1 - (inputs["step"] + 1) / self.cfg["bootstrap_epoch"]
                )
            elif self.cfg["bootstrap_scheduler"] == "exponential":
                lambda_bootstrap = self.lambda_bootstrap * 0.99
                self.lambda_bootstrap = lambda_bootstrap
            elif self.cfg["bootstrap_scheduler"] == "none":
                lambda_bootstrap = self.lambda_bootstrap
            else:
                raise ValueError("Unknown bootstrap scheduler")
            lambda_bootstrap = max(lambda_bootstrap, self.cfg["lambda_bootstrap_min"])
            loss_G += losses["loss_bootstrap"] * lambda_bootstrap

        # calculate structure loss
        if self.cfg["lambda_structure"] > 0:
            losses["loss_structure"] = self.calculate_structure_loss(all_outputs_composite, all_inputs)
            loss_G += losses["loss_structure"] * self.cfg["lambda_structure"]

        # calculate composition loss
        if self.cfg["lambda_composition"] > 0:
            losses["loss_comp_clip"] = self.calculate_clip_loss(all_outputs_composite, self.target_comp_e)

            losses["loss_comp_dir"] = self.calculate_clip_dir_loss(
                all_inputs, all_outputs_composite, self.target_comp_e
            )

            loss_G += (losses["loss_comp_clip"] + losses["loss_comp_dir"]) * self.cfg["lambda_composition"]

        # calculate sparsity loss
        if self.cfg["lambda_sparsity"] > 0:
            total, l0, l1 = self.calculate_alpha_reg(all_outputs_alpha)
            losses["loss_sparsity"] = total
            losses["loss_sparsity_l0"] = l0
            losses["loss_sparsity_l1"] = l1

            loss_G += losses["loss_sparsity"] * self.cfg["lambda_sparsity"]

        # calculate screen loss
        if self.cfg["lambda_screen"] > 0:
            losses["loss_screen"] = self.calculate_clip_loss(all_outputs_greenscreen, self.target_greenscreen_e)
            loss_G += losses["loss_screen"] * self.cfg["lambda_screen"]

        losses["loss"] = loss_G
        return losses

    def calculate_alpha_reg(self, prediction):
        """
        Calculate the alpha sparsity term: linear combination between L1 and pseudo L0 penalties
        """
        l1_loss = 0.0
        for el in prediction:
            l1_loss += el.mean()
        l1_loss = l1_loss / len(prediction)
        loss = self.cfg["lambda_alpha_l1"] * l1_loss
        # Pseudo L0 loss using a squished sigmoid curve.
        l0_loss = 0.0
        for el in prediction:
            l0_loss += torch.mean((torch.sigmoid(el * 5.0) - 0.5) * 2.0)
        l0_loss = l0_loss / len(prediction)
        loss += self.cfg["lambda_alpha_l0"] * l0_loss
        return loss, l0_loss, l1_loss

    def calculate_clip_loss(self, outputs, target_embeddings):
        # randomly select embeddings
        n_embeddings = np.random.randint(1, len(target_embeddings) + 1)
        target_embeddings = target_embeddings[torch.randint(len(target_embeddings), (n_embeddings,))]

        loss = 0.0
        for img in outputs:  # avoid memory limitations
            img_e = self.clip_extractor.get_image_embedding(img.unsqueeze(0))
            for target_embedding in target_embeddings:
                loss += self.text_criterion(img_e, target_embedding.unsqueeze(0))

        loss /= len(outputs) * len(target_embeddings)
        return loss

    def calculate_clip_dir_loss(self, inputs, outputs, target_embeddings):
        # randomly select embeddings
        n_embeddings = np.random.randint(1, min(len(self.src_e), len(target_embeddings)) + 1)
        idx = torch.randint(min(len(self.src_e), len(target_embeddings)), (n_embeddings,))
        src_embeddings = self.src_e[idx]
        target_embeddings = target_embeddings[idx]
        target_dirs = target_embeddings - src_embeddings

        loss = 0.0
        for in_img, out_img in zip(inputs, outputs):  # avoid memory limitations
            in_e = self.clip_extractor.get_image_embedding(in_img.unsqueeze(0))
            out_e = self.clip_extractor.get_image_embedding(out_img.unsqueeze(0))
            for target_dir in target_dirs:
                loss += 1 - torch.nn.CosineSimilarity()(out_e - in_e, target_dir.unsqueeze(0)).mean()

        loss /= len(outputs) * len(target_dirs)
        return loss

    def calculate_structure_loss(self, outputs, inputs):
        loss = 0.0
        for input, output in zip(inputs, outputs):
            with torch.no_grad():
                target_self_sim = self.clip_extractor.get_self_sim(input.unsqueeze(0))
            current_self_sim = self.clip_extractor.get_self_sim(output.unsqueeze(0))
            loss = loss + torch.nn.MSELoss()(current_self_sim, target_self_sim)
        loss = loss / len(outputs)
        return loss

    def calculate_relevancy_loss(self, alpha, input_img):
        positive_relevance_loss = 0.0
        for curr_alpha, curr_img in zip(alpha, input_img):
            x = torch.stack([curr_alpha, curr_img], dim=0)  # [2, 3, H, W]
            x = T.Compose(
                [
                    RandomSizeCrop(min_cover=self.cfg["bootstrapping_min_cover"]),
                    T.Resize((224, 224)),
                ]
            )(x)
            curr_alpha, curr_img = x[0].unsqueeze(0), x[1].unsqueeze(0)
            positive_relevance = self.relevancy_extractor(curr_img)
            positive_relevance_loss = self.relevancy_criterion(curr_alpha[0], positive_relevance.repeat(3, 1, 1))
            if self.cfg["use_negative_bootstrap"]:
                negative_relevance = self.relevancy_extractor(curr_img, negative=True)
                relevant_values = negative_relevance > self.cfg["bootstrap_negative_map_threshold"]
                negative_alpha_local = (1 - curr_alpha) * relevant_values.unsqueeze(1)
                negative_relevance_local = negative_relevance * relevant_values
                negative_relevance_loss = self.relevancy_criterion(
                    negative_alpha_local,
                    negative_relevance_local.unsqueeze(1).repeat(1, 3, 1, 1),
                )
                positive_relevance_loss += negative_relevance_loss
        positive_relevance_loss = positive_relevance_loss / len(alpha)
        return positive_relevance_loss