SupermanxKiaski's picture
Upload 356 files
16d007c
raw
history blame
No virus
8.35 kB
import torch
import numpy as np
import torchvision.transforms as T
from models.clip_relevancy import ClipRelevancy
from util.aug_utils import RandomSizeCrop
from util.util import get_screen_template, get_text_criterion, get_augmentations_template
class LossG(torch.nn.Module):
def __init__(self, cfg, clip_extractor):
super().__init__()
self.cfg = cfg
# calculate target text embeddings
template = get_augmentations_template()
self.src_e = clip_extractor.get_text_embedding(cfg["src_text"], template)
self.target_comp_e = clip_extractor.get_text_embedding(cfg["comp_text"], template)
self.target_greenscreen_e = clip_extractor.get_text_embedding(cfg["screen_text"], get_screen_template())
self.clip_extractor = clip_extractor
self.text_criterion = get_text_criterion(cfg)
if cfg["bootstrap_epoch"] > 0 and cfg["lambda_bootstrap"] > 0:
self.relevancy_extractor = ClipRelevancy(cfg)
self.relevancy_criterion = torch.nn.MSELoss()
self.lambda_bootstrap = cfg["lambda_bootstrap"]
def forward(self, outputs, inputs):
losses = {}
loss_G = 0
all_outputs_composite = []
all_outputs_greenscreen = []
all_outputs_edit = []
all_outputs_alpha = []
all_inputs = []
for out, ins in zip(["output_crop", "output_image"], ["input_crop", "input_image"]):
if out not in outputs:
continue
all_outputs_composite += outputs[out]["composite"]
all_outputs_greenscreen += outputs[out]["edit_on_greenscreen"]
all_outputs_edit += outputs[out]["edit"]
all_outputs_alpha += outputs[out]["alpha"]
all_inputs += inputs[ins]
# calculate alpha bootstrapping loss
if inputs["step"] < self.cfg["bootstrap_epoch"] and self.cfg["lambda_bootstrap"] > 0:
losses["loss_bootstrap"] = self.calculate_relevancy_loss(all_outputs_alpha, all_inputs)
if self.cfg["bootstrap_scheduler"] == "linear":
lambda_bootstrap = self.cfg["lambda_bootstrap"] * (
1 - (inputs["step"] + 1) / self.cfg["bootstrap_epoch"]
)
elif self.cfg["bootstrap_scheduler"] == "exponential":
lambda_bootstrap = self.lambda_bootstrap * 0.99
self.lambda_bootstrap = lambda_bootstrap
elif self.cfg["bootstrap_scheduler"] == "none":
lambda_bootstrap = self.lambda_bootstrap
else:
raise ValueError("Unknown bootstrap scheduler")
lambda_bootstrap = max(lambda_bootstrap, self.cfg["lambda_bootstrap_min"])
loss_G += losses["loss_bootstrap"] * lambda_bootstrap
# calculate structure loss
if self.cfg["lambda_structure"] > 0:
losses["loss_structure"] = self.calculate_structure_loss(all_outputs_composite, all_inputs)
loss_G += losses["loss_structure"] * self.cfg["lambda_structure"]
# calculate composition loss
if self.cfg["lambda_composition"] > 0:
losses["loss_comp_clip"] = self.calculate_clip_loss(all_outputs_composite, self.target_comp_e)
losses["loss_comp_dir"] = self.calculate_clip_dir_loss(
all_inputs, all_outputs_composite, self.target_comp_e
)
loss_G += (losses["loss_comp_clip"] + losses["loss_comp_dir"]) * self.cfg["lambda_composition"]
# calculate sparsity loss
if self.cfg["lambda_sparsity"] > 0:
total, l0, l1 = self.calculate_alpha_reg(all_outputs_alpha)
losses["loss_sparsity"] = total
losses["loss_sparsity_l0"] = l0
losses["loss_sparsity_l1"] = l1
loss_G += losses["loss_sparsity"] * self.cfg["lambda_sparsity"]
# calculate screen loss
if self.cfg["lambda_screen"] > 0:
losses["loss_screen"] = self.calculate_clip_loss(all_outputs_greenscreen, self.target_greenscreen_e)
loss_G += losses["loss_screen"] * self.cfg["lambda_screen"]
losses["loss"] = loss_G
return losses
def calculate_alpha_reg(self, prediction):
"""
Calculate the alpha sparsity term: linear combination between L1 and pseudo L0 penalties
"""
l1_loss = 0.0
for el in prediction:
l1_loss += el.mean()
l1_loss = l1_loss / len(prediction)
loss = self.cfg["lambda_alpha_l1"] * l1_loss
# Pseudo L0 loss using a squished sigmoid curve.
l0_loss = 0.0
for el in prediction:
l0_loss += torch.mean((torch.sigmoid(el * 5.0) - 0.5) * 2.0)
l0_loss = l0_loss / len(prediction)
loss += self.cfg["lambda_alpha_l0"] * l0_loss
return loss, l0_loss, l1_loss
def calculate_clip_loss(self, outputs, target_embeddings):
# randomly select embeddings
n_embeddings = np.random.randint(1, len(target_embeddings) + 1)
target_embeddings = target_embeddings[torch.randint(len(target_embeddings), (n_embeddings,))]
loss = 0.0
for img in outputs: # avoid memory limitations
img_e = self.clip_extractor.get_image_embedding(img.unsqueeze(0))
for target_embedding in target_embeddings:
loss += self.text_criterion(img_e, target_embedding.unsqueeze(0))
loss /= len(outputs) * len(target_embeddings)
return loss
def calculate_clip_dir_loss(self, inputs, outputs, target_embeddings):
# randomly select embeddings
n_embeddings = np.random.randint(1, min(len(self.src_e), len(target_embeddings)) + 1)
idx = torch.randint(min(len(self.src_e), len(target_embeddings)), (n_embeddings,))
src_embeddings = self.src_e[idx]
target_embeddings = target_embeddings[idx]
target_dirs = target_embeddings - src_embeddings
loss = 0.0
for in_img, out_img in zip(inputs, outputs): # avoid memory limitations
in_e = self.clip_extractor.get_image_embedding(in_img.unsqueeze(0))
out_e = self.clip_extractor.get_image_embedding(out_img.unsqueeze(0))
for target_dir in target_dirs:
loss += 1 - torch.nn.CosineSimilarity()(out_e - in_e, target_dir.unsqueeze(0)).mean()
loss /= len(outputs) * len(target_dirs)
return loss
def calculate_structure_loss(self, outputs, inputs):
loss = 0.0
for input, output in zip(inputs, outputs):
with torch.no_grad():
target_self_sim = self.clip_extractor.get_self_sim(input.unsqueeze(0))
current_self_sim = self.clip_extractor.get_self_sim(output.unsqueeze(0))
loss = loss + torch.nn.MSELoss()(current_self_sim, target_self_sim)
loss = loss / len(outputs)
return loss
def calculate_relevancy_loss(self, alpha, input_img):
positive_relevance_loss = 0.0
for curr_alpha, curr_img in zip(alpha, input_img):
x = torch.stack([curr_alpha, curr_img], dim=0) # [2, 3, H, W]
x = T.Compose(
[
RandomSizeCrop(min_cover=self.cfg["bootstrapping_min_cover"]),
T.Resize((224, 224)),
]
)(x)
curr_alpha, curr_img = x[0].unsqueeze(0), x[1].unsqueeze(0)
positive_relevance = self.relevancy_extractor(curr_img)
positive_relevance_loss = self.relevancy_criterion(curr_alpha[0], positive_relevance.repeat(3, 1, 1))
if self.cfg["use_negative_bootstrap"]:
negative_relevance = self.relevancy_extractor(curr_img, negative=True)
relevant_values = negative_relevance > self.cfg["bootstrap_negative_map_threshold"]
negative_alpha_local = (1 - curr_alpha) * relevant_values.unsqueeze(1)
negative_relevance_local = negative_relevance * relevant_values
negative_relevance_loss = self.relevancy_criterion(
negative_alpha_local,
negative_relevance_local.unsqueeze(1).repeat(1, 3, 1, 1),
)
positive_relevance_loss += negative_relevance_loss
positive_relevance_loss = positive_relevance_loss / len(alpha)
return positive_relevance_loss