import torch import numpy as np import torchvision.transforms as T from models.clip_relevancy import ClipRelevancy from util.aug_utils import RandomSizeCrop from util.util import get_screen_template, get_text_criterion, get_augmentations_template class LossG(torch.nn.Module): def __init__(self, cfg, clip_extractor): super().__init__() self.cfg = cfg # calculate target text embeddings template = get_augmentations_template() self.src_e = clip_extractor.get_text_embedding(cfg["src_text"], template) self.target_comp_e = clip_extractor.get_text_embedding(cfg["comp_text"], template) self.target_greenscreen_e = clip_extractor.get_text_embedding(cfg["screen_text"], get_screen_template()) self.clip_extractor = clip_extractor self.text_criterion = get_text_criterion(cfg) if cfg["bootstrap_epoch"] > 0 and cfg["lambda_bootstrap"] > 0: self.relevancy_extractor = ClipRelevancy(cfg) self.relevancy_criterion = torch.nn.MSELoss() self.lambda_bootstrap = cfg["lambda_bootstrap"] def forward(self, outputs, inputs): losses = {} loss_G = 0 all_outputs_composite = [] all_outputs_greenscreen = [] all_outputs_edit = [] all_outputs_alpha = [] all_inputs = [] for out, ins in zip(["output_crop", "output_image"], ["input_crop", "input_image"]): if out not in outputs: continue all_outputs_composite += outputs[out]["composite"] all_outputs_greenscreen += outputs[out]["edit_on_greenscreen"] all_outputs_edit += outputs[out]["edit"] all_outputs_alpha += outputs[out]["alpha"] all_inputs += inputs[ins] # calculate alpha bootstrapping loss if inputs["step"] < self.cfg["bootstrap_epoch"] and self.cfg["lambda_bootstrap"] > 0: losses["loss_bootstrap"] = self.calculate_relevancy_loss(all_outputs_alpha, all_inputs) if self.cfg["bootstrap_scheduler"] == "linear": lambda_bootstrap = self.cfg["lambda_bootstrap"] * ( 1 - (inputs["step"] + 1) / self.cfg["bootstrap_epoch"] ) elif self.cfg["bootstrap_scheduler"] == "exponential": lambda_bootstrap = self.lambda_bootstrap * 0.99 self.lambda_bootstrap = lambda_bootstrap elif self.cfg["bootstrap_scheduler"] == "none": lambda_bootstrap = self.lambda_bootstrap else: raise ValueError("Unknown bootstrap scheduler") lambda_bootstrap = max(lambda_bootstrap, self.cfg["lambda_bootstrap_min"]) loss_G += losses["loss_bootstrap"] * lambda_bootstrap # calculate structure loss if self.cfg["lambda_structure"] > 0: losses["loss_structure"] = self.calculate_structure_loss(all_outputs_composite, all_inputs) loss_G += losses["loss_structure"] * self.cfg["lambda_structure"] # calculate composition loss if self.cfg["lambda_composition"] > 0: losses["loss_comp_clip"] = self.calculate_clip_loss(all_outputs_composite, self.target_comp_e) losses["loss_comp_dir"] = self.calculate_clip_dir_loss( all_inputs, all_outputs_composite, self.target_comp_e ) loss_G += (losses["loss_comp_clip"] + losses["loss_comp_dir"]) * self.cfg["lambda_composition"] # calculate sparsity loss if self.cfg["lambda_sparsity"] > 0: total, l0, l1 = self.calculate_alpha_reg(all_outputs_alpha) losses["loss_sparsity"] = total losses["loss_sparsity_l0"] = l0 losses["loss_sparsity_l1"] = l1 loss_G += losses["loss_sparsity"] * self.cfg["lambda_sparsity"] # calculate screen loss if self.cfg["lambda_screen"] > 0: losses["loss_screen"] = self.calculate_clip_loss(all_outputs_greenscreen, self.target_greenscreen_e) loss_G += losses["loss_screen"] * self.cfg["lambda_screen"] losses["loss"] = loss_G return losses def calculate_alpha_reg(self, prediction): """ Calculate the alpha sparsity term: linear combination between L1 and pseudo L0 penalties """ l1_loss = 0.0 for el in prediction: l1_loss += el.mean() l1_loss = l1_loss / len(prediction) loss = self.cfg["lambda_alpha_l1"] * l1_loss # Pseudo L0 loss using a squished sigmoid curve. l0_loss = 0.0 for el in prediction: l0_loss += torch.mean((torch.sigmoid(el * 5.0) - 0.5) * 2.0) l0_loss = l0_loss / len(prediction) loss += self.cfg["lambda_alpha_l0"] * l0_loss return loss, l0_loss, l1_loss def calculate_clip_loss(self, outputs, target_embeddings): # randomly select embeddings n_embeddings = np.random.randint(1, len(target_embeddings) + 1) target_embeddings = target_embeddings[torch.randint(len(target_embeddings), (n_embeddings,))] loss = 0.0 for img in outputs: # avoid memory limitations img_e = self.clip_extractor.get_image_embedding(img.unsqueeze(0)) for target_embedding in target_embeddings: loss += self.text_criterion(img_e, target_embedding.unsqueeze(0)) loss /= len(outputs) * len(target_embeddings) return loss def calculate_clip_dir_loss(self, inputs, outputs, target_embeddings): # randomly select embeddings n_embeddings = np.random.randint(1, min(len(self.src_e), len(target_embeddings)) + 1) idx = torch.randint(min(len(self.src_e), len(target_embeddings)), (n_embeddings,)) src_embeddings = self.src_e[idx] target_embeddings = target_embeddings[idx] target_dirs = target_embeddings - src_embeddings loss = 0.0 for in_img, out_img in zip(inputs, outputs): # avoid memory limitations in_e = self.clip_extractor.get_image_embedding(in_img.unsqueeze(0)) out_e = self.clip_extractor.get_image_embedding(out_img.unsqueeze(0)) for target_dir in target_dirs: loss += 1 - torch.nn.CosineSimilarity()(out_e - in_e, target_dir.unsqueeze(0)).mean() loss /= len(outputs) * len(target_dirs) return loss def calculate_structure_loss(self, outputs, inputs): loss = 0.0 for input, output in zip(inputs, outputs): with torch.no_grad(): target_self_sim = self.clip_extractor.get_self_sim(input.unsqueeze(0)) current_self_sim = self.clip_extractor.get_self_sim(output.unsqueeze(0)) loss = loss + torch.nn.MSELoss()(current_self_sim, target_self_sim) loss = loss / len(outputs) return loss def calculate_relevancy_loss(self, alpha, input_img): positive_relevance_loss = 0.0 for curr_alpha, curr_img in zip(alpha, input_img): x = torch.stack([curr_alpha, curr_img], dim=0) # [2, 3, H, W] x = T.Compose( [ RandomSizeCrop(min_cover=self.cfg["bootstrapping_min_cover"]), T.Resize((224, 224)), ] )(x) curr_alpha, curr_img = x[0].unsqueeze(0), x[1].unsqueeze(0) positive_relevance = self.relevancy_extractor(curr_img) positive_relevance_loss = self.relevancy_criterion(curr_alpha[0], positive_relevance.repeat(3, 1, 1)) if self.cfg["use_negative_bootstrap"]: negative_relevance = self.relevancy_extractor(curr_img, negative=True) relevant_values = negative_relevance > self.cfg["bootstrap_negative_map_threshold"] negative_alpha_local = (1 - curr_alpha) * relevant_values.unsqueeze(1) negative_relevance_local = negative_relevance * relevant_values negative_relevance_loss = self.relevancy_criterion( negative_alpha_local, negative_relevance_local.unsqueeze(1).repeat(1, 3, 1, 1), ) positive_relevance_loss += negative_relevance_loss positive_relevance_loss = positive_relevance_loss / len(alpha) return positive_relevance_loss