|
"""Module containing the loss functions for the GANs.""" |
|
from typing import Any, Dict |
|
|
|
import torch |
|
from torch import nn |
|
|
|
|
|
|
|
|
|
|
|
def generator_loss( |
|
logits: Dict[str, Dict[str, torch.Tensor]], |
|
local_fake_incept_feat: torch.Tensor, |
|
global_fake_incept_feat: torch.Tensor, |
|
real_labels: torch.Tensor, |
|
words_emb: torch.Tensor, |
|
sent_emb: torch.Tensor, |
|
match_labels: torch.Tensor, |
|
cap_lens: torch.Tensor, |
|
class_ids: torch.Tensor, |
|
real_vgg_feat: torch.Tensor, |
|
fake_vgg_feat: torch.Tensor, |
|
const_dict: Dict[str, float], |
|
) -> Any: |
|
"""Calculate the loss for the generator. |
|
|
|
Args: |
|
logits: Dictionary with fake/real and word-level/uncond/cond logits |
|
|
|
local_fake_incept_feat: The local inception features for the fake images. |
|
|
|
global_fake_incept_feat: The global inception features for the fake images. |
|
|
|
real_labels: Label for "real" image as predicted by discriminator, |
|
this is a tensor of ones. [shape: (batch_size, 1)]. |
|
|
|
word_labels: POS tagged word labels for the captions. [shape: (batch_size, L)] |
|
|
|
words_emb: The embeddings for all the words in the captions. |
|
shape: (batch_size, embedding_size, max_caption_length) |
|
|
|
sent_emb: The embeddings for the sentences. |
|
shape: (batch_size, embedding_size) |
|
|
|
match_labels: Tensor of shape: (batch_size, 1). |
|
This is of the form torch.tensor([0, 1, 2, ..., batch-1]) |
|
|
|
cap_lens: The length of the 'actual' captions in the batch [without padding] |
|
shape: (batch_size, 1) |
|
|
|
class_ids: The class ids for the instance. shape: (batch_size, 1) |
|
|
|
real_vgg_feat: The vgg features for the real images. shape: (batch_size, 128, 128, 128) |
|
fake_vgg_feat: The vgg features for the fake images. shape: (batch_size, 128, 128, 128) |
|
|
|
const_dict: The dictionary containing the constants. |
|
""" |
|
lambda1 = const_dict["lambda1"] |
|
total_error_g = 0.0 |
|
|
|
cond_logits = logits["fake"]["cond"] |
|
cond_err_g = nn.BCEWithLogitsLoss()(cond_logits, real_labels) |
|
|
|
uncond_logits = logits["fake"]["uncond"] |
|
uncond_err_g = nn.BCEWithLogitsLoss()(uncond_logits, real_labels) |
|
|
|
|
|
loss_g = cond_err_g + uncond_err_g |
|
total_error_g += loss_g |
|
|
|
|
|
loss_damsm = damsm_loss( |
|
local_fake_incept_feat, |
|
global_fake_incept_feat, |
|
words_emb, |
|
sent_emb, |
|
match_labels, |
|
cap_lens, |
|
class_ids, |
|
const_dict, |
|
) |
|
|
|
total_error_g += loss_damsm |
|
|
|
loss_per = 0.5 * nn.MSELoss()(real_vgg_feat, fake_vgg_feat) |
|
|
|
total_error_g += lambda1 * loss_per |
|
|
|
return total_error_g |
|
|
|
|
|
def damsm_loss( |
|
local_incept_feat: torch.Tensor, |
|
global_incept_feat: torch.Tensor, |
|
words_emb: torch.Tensor, |
|
sent_emb: torch.Tensor, |
|
match_labels: torch.Tensor, |
|
cap_lens: torch.Tensor, |
|
class_ids: torch.Tensor, |
|
const_dict: Dict[str, float], |
|
) -> Any: |
|
"""Calculate the DAMSM loss from the attnGAN paper. |
|
|
|
Args: |
|
local_incept_feat: The local inception features. [shape: (batch, D, 17, 17)] |
|
|
|
global_incept_feat: The global inception features. [shape: (batch, D)] |
|
|
|
words_emb: The embeddings for all the words in the captions. |
|
|
|
shape: (batch, D, max_caption_length) |
|
|
|
sent_emb: The embeddings for the sentences. shape: (batch_size, D) |
|
|
|
match_labels: Tensor of shape: (batch_size, 1). |
|
This is of the form torch.tensor([0, 1, 2, ..., batch-1]) |
|
|
|
cap_lens: The length of the 'actual' captions in the batch [without padding] |
|
shape: (batch_size, 1) |
|
|
|
class_ids: The class ids for the instance. shape: (batch, 1) |
|
|
|
const_dict: The dictionary containing the constants. |
|
""" |
|
batch_size = match_labels.size(0) |
|
|
|
masks = [] |
|
|
|
match_scores = [] |
|
gamma1 = const_dict["gamma1"] |
|
gamma2 = const_dict["gamma2"] |
|
gamma3 = const_dict["gamma3"] |
|
lambda3 = const_dict["lambda3"] |
|
|
|
for i in range(batch_size): |
|
mask = (class_ids == class_ids[i]).int() |
|
|
|
mask[i] = 0 |
|
masks.append(mask.reshape(1, -1)) |
|
|
|
numb_words = int(cap_lens[i]) |
|
|
|
query_words = words_emb[i, :, :numb_words].unsqueeze(0) |
|
|
|
query_words = query_words.repeat(batch_size, 1, 1) |
|
|
|
c_i = compute_region_context_vector( |
|
local_incept_feat, query_words, gamma1 |
|
) |
|
|
|
query_words = query_words.transpose(1, 2) |
|
c_i = c_i.transpose(1, 2) |
|
query_words = query_words.reshape( |
|
batch_size * numb_words, -1 |
|
) |
|
c_i = c_i.reshape(batch_size * numb_words, -1) |
|
|
|
r_i = compute_relevance( |
|
c_i, query_words |
|
) |
|
r_i = r_i.view(batch_size, numb_words) |
|
r_i = torch.exp(r_i * gamma2) |
|
r_i = r_i.sum(dim=1, keepdim=True) |
|
r_i = torch.log( |
|
r_i |
|
) |
|
match_scores.append(r_i) |
|
|
|
masks = torch.cat(masks, dim=0).bool() |
|
match_scores = torch.cat(match_scores, dim=1) |
|
|
|
|
|
match_scores = gamma3 * match_scores |
|
match_scores.data.masked_fill_( |
|
masks, -float("inf") |
|
) |
|
|
|
match_scores_t = match_scores.transpose( |
|
0, 1 |
|
) |
|
|
|
|
|
l1_w = nn.CrossEntropyLoss()(match_scores, match_labels) |
|
|
|
l2_w = nn.CrossEntropyLoss()(match_scores_t, match_labels) |
|
|
|
incept_feat_norm = torch.linalg.norm(global_incept_feat, dim=1) |
|
sent_emb_norm = torch.linalg.norm(sent_emb, dim=1) |
|
|
|
|
|
global_match_score = global_incept_feat @ (sent_emb.T) |
|
|
|
global_match_score = ( |
|
global_match_score / torch.outer(incept_feat_norm, sent_emb_norm) |
|
).clamp(min=1e-8) |
|
global_match_score = gamma3 * global_match_score |
|
|
|
|
|
global_match_score.data.masked_fill_(masks, -float("inf")) |
|
|
|
global_match_t = global_match_score.T |
|
|
|
|
|
l1_s = nn.CrossEntropyLoss()(global_match_score, match_labels) |
|
|
|
l2_s = nn.CrossEntropyLoss()(global_match_t, match_labels) |
|
|
|
loss_damsm = lambda3 * (l1_w + l2_w + l1_s + l2_s) |
|
|
|
return loss_damsm |
|
|
|
|
|
def compute_relevance(c_i: torch.Tensor, query_words: torch.Tensor) -> Any: |
|
"""Computes the cosine similarity between the region context vector and the query words. |
|
|
|
Args: |
|
c_i: The region context vector. shape: (batch * L, D) |
|
query_words: The query words. shape: (batch * L, D) |
|
""" |
|
prod = c_i * query_words |
|
numr = torch.sum(prod, dim=1) |
|
norm_c = torch.linalg.norm(c_i, ord=2, dim=1) |
|
norm_q = torch.linalg.norm(query_words, ord=2, dim=1) |
|
denr = norm_c * norm_q |
|
r_i = (numr / denr).clamp(min=1e-8).squeeze() |
|
return r_i |
|
|
|
|
|
def compute_region_context_vector( |
|
local_incept_feat: torch.Tensor, query_words: torch.Tensor, gamma1: float |
|
) -> Any: |
|
"""Compute the region context vector (c_i) from attnGAN paper. |
|
|
|
Args: |
|
local_incept_feat: The local inception features. [shape: (batch, D, 17, 17)] |
|
query_words: The embeddings for all the words in the captions. shape: (batch, D, L) |
|
gamma1: The gamma1 value from attnGAN paper. |
|
""" |
|
batch, L = query_words.size(0), query_words.size(2) |
|
|
|
feat_height, feat_width = local_incept_feat.size(2), local_incept_feat.size(3) |
|
N = feat_height * feat_width |
|
|
|
|
|
local_incept_feat = local_incept_feat.view(batch, -1, N) |
|
|
|
incept_feat_t = local_incept_feat.transpose(1, 2) |
|
|
|
sim_matrix = incept_feat_t @ query_words |
|
sim_matrix = sim_matrix.view(batch * N, L) |
|
|
|
sim_matrix = nn.Softmax(dim=1)(sim_matrix) |
|
sim_matrix = sim_matrix.view(batch, N, L) |
|
|
|
sim_matrix = torch.transpose(sim_matrix, 1, 2) |
|
sim_matrix = sim_matrix.reshape(batch * L, N) |
|
|
|
alpha_j = gamma1 * sim_matrix |
|
alpha_j = nn.Softmax(dim=1)(alpha_j) |
|
alpha_j = alpha_j.view(batch, L, N) |
|
alpha_j_t = torch.transpose(alpha_j, 1, 2) |
|
|
|
c_i = ( |
|
local_incept_feat @ alpha_j_t |
|
) |
|
return c_i |
|
|
|
|
|
def discriminator_loss( |
|
logits: Dict[str, Dict[str, torch.Tensor]], |
|
labels: Dict[str, Dict[str, torch.Tensor]], |
|
) -> Any: |
|
""" |
|
Calculate discriminator objective |
|
|
|
:param dict[str, dict[str, torch.Tensor]] logits: |
|
Dictionary with fake/real and word-level/uncond/cond logits |
|
|
|
Example: |
|
|
|
logits = { |
|
"fake": { |
|
"word_level": torch.Tensor (BxL) |
|
"uncond": torch.Tensor (Bx1) |
|
"cond": torch.Tensor (Bx1) |
|
}, |
|
"real": { |
|
"word_level": torch.Tensor (BxL) |
|
"uncond": torch.Tensor (Bx1) |
|
"cond": torch.Tensor (Bx1) |
|
}, |
|
} |
|
:param dict[str, dict[str, torch.Tensor]] labels: |
|
Dictionary with fake/real and word-level/image labels |
|
|
|
Example: |
|
|
|
labels = { |
|
"fake": { |
|
"word_level": torch.Tensor (BxL) |
|
"image": torch.Tensor (Bx1) |
|
}, |
|
"real": { |
|
"word_level": torch.Tensor (BxL) |
|
"image": torch.Tensor (Bx1) |
|
}, |
|
} |
|
:param float lambda_4: Hyperparameter for word loss in paper |
|
:return: Discriminator objective loss |
|
:rtype: Any |
|
""" |
|
|
|
tot_loss = 0.0 |
|
bce_logits = nn.BCEWithLogitsLoss() |
|
bce = nn.BCELoss() |
|
|
|
word_loss = bce(logits["real"]["word_level"], labels["real"]["word_level"]) |
|
|
|
uncond_loss = bce_logits(logits["real"]["uncond"], labels["real"]["image"]) |
|
|
|
|
|
cond_loss = bce_logits(logits["real"]["cond"], labels["real"]["image"]) |
|
|
|
tot_loss = (uncond_loss + cond_loss) / 2.0 |
|
|
|
fake_uncond_loss = bce_logits(logits["fake"]["uncond"], labels["fake"]["image"]) |
|
fake_cond_loss = bce_logits(logits["fake"]["cond"], labels["fake"]["image"]) |
|
|
|
tot_loss += (fake_uncond_loss + fake_cond_loss) / 3.0 |
|
tot_loss += word_loss |
|
|
|
return tot_loss |
|
|
|
|
|
def kl_loss(mu_tensor: torch.Tensor, logvar: torch.Tensor) -> Any: |
|
""" |
|
Calculate KL loss |
|
|
|
:param torch.Tensor mu_tensor: Mean of latent distribution |
|
:param torch.Tensor logvar: Log variance of latent distribution |
|
:return: KL loss [-0.5 * (1 + log(sigma) - mu^2 - sigma^2)] |
|
:rtype: Any |
|
""" |
|
return torch.mean(-0.5 * (1 + 0.5 * logvar - mu_tensor.pow(2) - torch.exp(logvar))) |
|
|