Spaces:
Configuration error
Configuration error
import torch | |
import numpy as np | |
import torchvision.transforms as T | |
from models.clip_relevancy import ClipRelevancy | |
from util.aug_utils import RandomSizeCrop | |
from util.util import get_screen_template, get_text_criterion, get_augmentations_template | |
class LossG(torch.nn.Module): | |
def __init__(self, cfg, clip_extractor): | |
super().__init__() | |
self.cfg = cfg | |
# calculate target text embeddings | |
template = get_augmentations_template() | |
self.src_e = clip_extractor.get_text_embedding(cfg["src_text"], template) | |
self.target_comp_e = clip_extractor.get_text_embedding(cfg["comp_text"], template) | |
self.target_greenscreen_e = clip_extractor.get_text_embedding(cfg["screen_text"], get_screen_template()) | |
self.clip_extractor = clip_extractor | |
self.text_criterion = get_text_criterion(cfg) | |
if cfg["bootstrap_epoch"] > 0 and cfg["lambda_bootstrap"] > 0: | |
self.relevancy_extractor = ClipRelevancy(cfg) | |
self.relevancy_criterion = torch.nn.MSELoss() | |
self.lambda_bootstrap = cfg["lambda_bootstrap"] | |
def forward(self, outputs, inputs): | |
losses = {} | |
loss_G = 0 | |
all_outputs_composite = [] | |
all_outputs_greenscreen = [] | |
all_outputs_edit = [] | |
all_outputs_alpha = [] | |
all_inputs = [] | |
for out, ins in zip(["output_crop", "output_image"], ["input_crop", "input_image"]): | |
if out not in outputs: | |
continue | |
all_outputs_composite += outputs[out]["composite"] | |
all_outputs_greenscreen += outputs[out]["edit_on_greenscreen"] | |
all_outputs_edit += outputs[out]["edit"] | |
all_outputs_alpha += outputs[out]["alpha"] | |
all_inputs += inputs[ins] | |
# calculate alpha bootstrapping loss | |
if inputs["step"] < self.cfg["bootstrap_epoch"] and self.cfg["lambda_bootstrap"] > 0: | |
losses["loss_bootstrap"] = self.calculate_relevancy_loss(all_outputs_alpha, all_inputs) | |
if self.cfg["bootstrap_scheduler"] == "linear": | |
lambda_bootstrap = self.cfg["lambda_bootstrap"] * ( | |
1 - (inputs["step"] + 1) / self.cfg["bootstrap_epoch"] | |
) | |
elif self.cfg["bootstrap_scheduler"] == "exponential": | |
lambda_bootstrap = self.lambda_bootstrap * 0.99 | |
self.lambda_bootstrap = lambda_bootstrap | |
elif self.cfg["bootstrap_scheduler"] == "none": | |
lambda_bootstrap = self.lambda_bootstrap | |
else: | |
raise ValueError("Unknown bootstrap scheduler") | |
lambda_bootstrap = max(lambda_bootstrap, self.cfg["lambda_bootstrap_min"]) | |
loss_G += losses["loss_bootstrap"] * lambda_bootstrap | |
# calculate structure loss | |
if self.cfg["lambda_structure"] > 0: | |
losses["loss_structure"] = self.calculate_structure_loss(all_outputs_composite, all_inputs) | |
loss_G += losses["loss_structure"] * self.cfg["lambda_structure"] | |
# calculate composition loss | |
if self.cfg["lambda_composition"] > 0: | |
losses["loss_comp_clip"] = self.calculate_clip_loss(all_outputs_composite, self.target_comp_e) | |
losses["loss_comp_dir"] = self.calculate_clip_dir_loss( | |
all_inputs, all_outputs_composite, self.target_comp_e | |
) | |
loss_G += (losses["loss_comp_clip"] + losses["loss_comp_dir"]) * self.cfg["lambda_composition"] | |
# calculate sparsity loss | |
if self.cfg["lambda_sparsity"] > 0: | |
total, l0, l1 = self.calculate_alpha_reg(all_outputs_alpha) | |
losses["loss_sparsity"] = total | |
losses["loss_sparsity_l0"] = l0 | |
losses["loss_sparsity_l1"] = l1 | |
loss_G += losses["loss_sparsity"] * self.cfg["lambda_sparsity"] | |
# calculate screen loss | |
if self.cfg["lambda_screen"] > 0: | |
losses["loss_screen"] = self.calculate_clip_loss(all_outputs_greenscreen, self.target_greenscreen_e) | |
loss_G += losses["loss_screen"] * self.cfg["lambda_screen"] | |
losses["loss"] = loss_G | |
return losses | |
def calculate_alpha_reg(self, prediction): | |
""" | |
Calculate the alpha sparsity term: linear combination between L1 and pseudo L0 penalties | |
""" | |
l1_loss = 0.0 | |
for el in prediction: | |
l1_loss += el.mean() | |
l1_loss = l1_loss / len(prediction) | |
loss = self.cfg["lambda_alpha_l1"] * l1_loss | |
# Pseudo L0 loss using a squished sigmoid curve. | |
l0_loss = 0.0 | |
for el in prediction: | |
l0_loss += torch.mean((torch.sigmoid(el * 5.0) - 0.5) * 2.0) | |
l0_loss = l0_loss / len(prediction) | |
loss += self.cfg["lambda_alpha_l0"] * l0_loss | |
return loss, l0_loss, l1_loss | |
def calculate_clip_loss(self, outputs, target_embeddings): | |
# randomly select embeddings | |
n_embeddings = np.random.randint(1, len(target_embeddings) + 1) | |
target_embeddings = target_embeddings[torch.randint(len(target_embeddings), (n_embeddings,))] | |
loss = 0.0 | |
for img in outputs: # avoid memory limitations | |
img_e = self.clip_extractor.get_image_embedding(img.unsqueeze(0)) | |
for target_embedding in target_embeddings: | |
loss += self.text_criterion(img_e, target_embedding.unsqueeze(0)) | |
loss /= len(outputs) * len(target_embeddings) | |
return loss | |
def calculate_clip_dir_loss(self, inputs, outputs, target_embeddings): | |
# randomly select embeddings | |
n_embeddings = np.random.randint(1, min(len(self.src_e), len(target_embeddings)) + 1) | |
idx = torch.randint(min(len(self.src_e), len(target_embeddings)), (n_embeddings,)) | |
src_embeddings = self.src_e[idx] | |
target_embeddings = target_embeddings[idx] | |
target_dirs = target_embeddings - src_embeddings | |
loss = 0.0 | |
for in_img, out_img in zip(inputs, outputs): # avoid memory limitations | |
in_e = self.clip_extractor.get_image_embedding(in_img.unsqueeze(0)) | |
out_e = self.clip_extractor.get_image_embedding(out_img.unsqueeze(0)) | |
for target_dir in target_dirs: | |
loss += 1 - torch.nn.CosineSimilarity()(out_e - in_e, target_dir.unsqueeze(0)).mean() | |
loss /= len(outputs) * len(target_dirs) | |
return loss | |
def calculate_structure_loss(self, outputs, inputs): | |
loss = 0.0 | |
for input, output in zip(inputs, outputs): | |
with torch.no_grad(): | |
target_self_sim = self.clip_extractor.get_self_sim(input.unsqueeze(0)) | |
current_self_sim = self.clip_extractor.get_self_sim(output.unsqueeze(0)) | |
loss = loss + torch.nn.MSELoss()(current_self_sim, target_self_sim) | |
loss = loss / len(outputs) | |
return loss | |
def calculate_relevancy_loss(self, alpha, input_img): | |
positive_relevance_loss = 0.0 | |
for curr_alpha, curr_img in zip(alpha, input_img): | |
x = torch.stack([curr_alpha, curr_img], dim=0) # [2, 3, H, W] | |
x = T.Compose( | |
[ | |
RandomSizeCrop(min_cover=self.cfg["bootstrapping_min_cover"]), | |
T.Resize((224, 224)), | |
] | |
)(x) | |
curr_alpha, curr_img = x[0].unsqueeze(0), x[1].unsqueeze(0) | |
positive_relevance = self.relevancy_extractor(curr_img) | |
positive_relevance_loss = self.relevancy_criterion(curr_alpha[0], positive_relevance.repeat(3, 1, 1)) | |
if self.cfg["use_negative_bootstrap"]: | |
negative_relevance = self.relevancy_extractor(curr_img, negative=True) | |
relevant_values = negative_relevance > self.cfg["bootstrap_negative_map_threshold"] | |
negative_alpha_local = (1 - curr_alpha) * relevant_values.unsqueeze(1) | |
negative_relevance_local = negative_relevance * relevant_values | |
negative_relevance_loss = self.relevancy_criterion( | |
negative_alpha_local, | |
negative_relevance_local.unsqueeze(1).repeat(1, 3, 1, 1), | |
) | |
positive_relevance_loss += negative_relevance_loss | |
positive_relevance_loss = positive_relevance_loss / len(alpha) | |
return positive_relevance_loss | |