taim-gan / src /models /losses.py
Dmmc's picture
three-model version
c8ddb9b
"""Module containing the loss functions for the GANs."""
from typing import Any, Dict
import torch
from torch import nn
# pylint: disable=too-many-arguments
# pylint: disable=too-many-locals
def generator_loss(
logits: Dict[str, Dict[str, torch.Tensor]],
local_fake_incept_feat: torch.Tensor,
global_fake_incept_feat: torch.Tensor,
real_labels: torch.Tensor,
words_emb: torch.Tensor,
sent_emb: torch.Tensor,
match_labels: torch.Tensor,
cap_lens: torch.Tensor,
class_ids: torch.Tensor,
real_vgg_feat: torch.Tensor,
fake_vgg_feat: torch.Tensor,
const_dict: Dict[str, float],
) -> Any:
"""Calculate the loss for the generator.
Args:
logits: Dictionary with fake/real and word-level/uncond/cond logits
local_fake_incept_feat: The local inception features for the fake images.
global_fake_incept_feat: The global inception features for the fake images.
real_labels: Label for "real" image as predicted by discriminator,
this is a tensor of ones. [shape: (batch_size, 1)].
word_labels: POS tagged word labels for the captions. [shape: (batch_size, L)]
words_emb: The embeddings for all the words in the captions.
shape: (batch_size, embedding_size, max_caption_length)
sent_emb: The embeddings for the sentences.
shape: (batch_size, embedding_size)
match_labels: Tensor of shape: (batch_size, 1).
This is of the form torch.tensor([0, 1, 2, ..., batch-1])
cap_lens: The length of the 'actual' captions in the batch [without padding]
shape: (batch_size, 1)
class_ids: The class ids for the instance. shape: (batch_size, 1)
real_vgg_feat: The vgg features for the real images. shape: (batch_size, 128, 128, 128)
fake_vgg_feat: The vgg features for the fake images. shape: (batch_size, 128, 128, 128)
const_dict: The dictionary containing the constants.
"""
lambda1 = const_dict["lambda1"]
total_error_g = 0.0
cond_logits = logits["fake"]["cond"]
cond_err_g = nn.BCEWithLogitsLoss()(cond_logits, real_labels)
uncond_logits = logits["fake"]["uncond"]
uncond_err_g = nn.BCEWithLogitsLoss()(uncond_logits, real_labels)
# add up the conditional and unconditional losses
loss_g = cond_err_g + uncond_err_g
total_error_g += loss_g
# DAMSM Loss from attnGAN.
loss_damsm = damsm_loss(
local_fake_incept_feat,
global_fake_incept_feat,
words_emb,
sent_emb,
match_labels,
cap_lens,
class_ids,
const_dict,
)
total_error_g += loss_damsm
loss_per = 0.5 * nn.MSELoss()(real_vgg_feat, fake_vgg_feat) # perceptual loss
total_error_g += lambda1 * loss_per
return total_error_g
def damsm_loss(
local_incept_feat: torch.Tensor,
global_incept_feat: torch.Tensor,
words_emb: torch.Tensor,
sent_emb: torch.Tensor,
match_labels: torch.Tensor,
cap_lens: torch.Tensor,
class_ids: torch.Tensor,
const_dict: Dict[str, float],
) -> Any:
"""Calculate the DAMSM loss from the attnGAN paper.
Args:
local_incept_feat: The local inception features. [shape: (batch, D, 17, 17)]
global_incept_feat: The global inception features. [shape: (batch, D)]
words_emb: The embeddings for all the words in the captions.
shape: (batch, D, max_caption_length)
sent_emb: The embeddings for the sentences. shape: (batch_size, D)
match_labels: Tensor of shape: (batch_size, 1).
This is of the form torch.tensor([0, 1, 2, ..., batch-1])
cap_lens: The length of the 'actual' captions in the batch [without padding]
shape: (batch_size, 1)
class_ids: The class ids for the instance. shape: (batch, 1)
const_dict: The dictionary containing the constants.
"""
batch_size = match_labels.size(0)
# Mask mis-match samples, that come from the same class as the real sample
masks = []
match_scores = []
gamma1 = const_dict["gamma1"]
gamma2 = const_dict["gamma2"]
gamma3 = const_dict["gamma3"]
lambda3 = const_dict["lambda3"]
for i in range(batch_size):
mask = (class_ids == class_ids[i]).int()
# This ensures that "correct class" index is not included in the mask.
mask[i] = 0
masks.append(mask.reshape(1, -1)) # shape: (1, batch)
numb_words = int(cap_lens[i])
# shape: (1, D, L), this picks the caption at ith batch index.
query_words = words_emb[i, :, :numb_words].unsqueeze(0)
# shape: (batch, D, L), this expands the same caption for all batch indices.
query_words = query_words.repeat(batch_size, 1, 1)
c_i = compute_region_context_vector(
local_incept_feat, query_words, gamma1
) # Taken from attnGAN paper. shape: (batch, D, L)
query_words = query_words.transpose(1, 2) # shape: (batch, L, D)
c_i = c_i.transpose(1, 2) # shape: (batch, L, D)
query_words = query_words.reshape(
batch_size * numb_words, -1
) # shape: (batch * L, D)
c_i = c_i.reshape(batch_size * numb_words, -1) # shape: (batch * L, D)
r_i = compute_relevance(
c_i, query_words
) # cosine similarity, or R(c_i, e_i) from attnGAN paper. shape: (batch * L, 1)
r_i = r_i.view(batch_size, numb_words) # shape: (batch, L)
r_i = torch.exp(r_i * gamma2) # shape: (batch, L)
r_i = r_i.sum(dim=1, keepdim=True) # shape: (batch, 1)
r_i = torch.log(
r_i
) # This is image-text matching score b/w whole image and caption, shape: (batch, 1)
match_scores.append(r_i)
masks = torch.cat(masks, dim=0).bool() # type: ignore
match_scores = torch.cat(match_scores, dim=1) # type: ignore
# This corresponds to P(D|Q) from attnGAN.
match_scores = gamma3 * match_scores # type: ignore
match_scores.data.masked_fill_( # type: ignore
masks, -float("inf")
) # mask out the scores for mis-matched samples
match_scores_t = match_scores.transpose( # type: ignore
0, 1
) # This corresponds to P(Q|D) from attnGAN.
# This corresponds to L1_w from attnGAN.
l1_w = nn.CrossEntropyLoss()(match_scores, match_labels)
# This corresponds to L2_w from attnGAN.
l2_w = nn.CrossEntropyLoss()(match_scores_t, match_labels)
incept_feat_norm = torch.linalg.norm(global_incept_feat, dim=1)
sent_emb_norm = torch.linalg.norm(sent_emb, dim=1)
# shape: (batch, batch)
global_match_score = global_incept_feat @ (sent_emb.T)
global_match_score = (
global_match_score / torch.outer(incept_feat_norm, sent_emb_norm)
).clamp(min=1e-8)
global_match_score = gamma3 * global_match_score
# mask out the scores for mis-matched samples
global_match_score.data.masked_fill_(masks, -float("inf")) # type: ignore
global_match_t = global_match_score.T # shape: (batch, batch)
# This corresponds to L1_s from attnGAN.
l1_s = nn.CrossEntropyLoss()(global_match_score, match_labels)
# This corresponds to L2_s from attnGAN.
l2_s = nn.CrossEntropyLoss()(global_match_t, match_labels)
loss_damsm = lambda3 * (l1_w + l2_w + l1_s + l2_s)
return loss_damsm
def compute_relevance(c_i: torch.Tensor, query_words: torch.Tensor) -> Any:
"""Computes the cosine similarity between the region context vector and the query words.
Args:
c_i: The region context vector. shape: (batch * L, D)
query_words: The query words. shape: (batch * L, D)
"""
prod = c_i * query_words # shape: (batch * L, D)
numr = torch.sum(prod, dim=1) # shape: (batch * L, 1)
norm_c = torch.linalg.norm(c_i, ord=2, dim=1)
norm_q = torch.linalg.norm(query_words, ord=2, dim=1)
denr = norm_c * norm_q
r_i = (numr / denr).clamp(min=1e-8).squeeze() # shape: (batch * L, 1)
return r_i
def compute_region_context_vector(
local_incept_feat: torch.Tensor, query_words: torch.Tensor, gamma1: float
) -> Any:
"""Compute the region context vector (c_i) from attnGAN paper.
Args:
local_incept_feat: The local inception features. [shape: (batch, D, 17, 17)]
query_words: The embeddings for all the words in the captions. shape: (batch, D, L)
gamma1: The gamma1 value from attnGAN paper.
"""
batch, L = query_words.size(0), query_words.size(2) # pylint: disable=invalid-name
feat_height, feat_width = local_incept_feat.size(2), local_incept_feat.size(3)
N = feat_height * feat_width # pylint: disable=invalid-name
# Reshape the local inception features to (batch, D, N)
local_incept_feat = local_incept_feat.view(batch, -1, N)
# shape: (batch, N, D)
incept_feat_t = local_incept_feat.transpose(1, 2)
sim_matrix = incept_feat_t @ query_words # shape: (batch, N, L)
sim_matrix = sim_matrix.view(batch * N, L) # shape: (batch * N, L)
sim_matrix = nn.Softmax(dim=1)(sim_matrix) # shape: (batch * N, L)
sim_matrix = sim_matrix.view(batch, N, L) # shape: (batch, N, L)
sim_matrix = torch.transpose(sim_matrix, 1, 2) # shape: (batch, L, N)
sim_matrix = sim_matrix.reshape(batch * L, N) # shape: (batch * L, N)
alpha_j = gamma1 * sim_matrix # shape: (batch * L, N)
alpha_j = nn.Softmax(dim=1)(alpha_j) # shape: (batch * L, N)
alpha_j = alpha_j.view(batch, L, N) # shape: (batch, L, N)
alpha_j_t = torch.transpose(alpha_j, 1, 2) # shape: (batch, N, L)
c_i = (
local_incept_feat @ alpha_j_t
) # shape: (batch, D, L) [summing over N dimension in paper, so we multiply like this]
return c_i
def discriminator_loss(
logits: Dict[str, Dict[str, torch.Tensor]],
labels: Dict[str, Dict[str, torch.Tensor]],
) -> Any:
"""
Calculate discriminator objective
:param dict[str, dict[str, torch.Tensor]] logits:
Dictionary with fake/real and word-level/uncond/cond logits
Example:
logits = {
"fake": {
"word_level": torch.Tensor (BxL)
"uncond": torch.Tensor (Bx1)
"cond": torch.Tensor (Bx1)
},
"real": {
"word_level": torch.Tensor (BxL)
"uncond": torch.Tensor (Bx1)
"cond": torch.Tensor (Bx1)
},
}
:param dict[str, dict[str, torch.Tensor]] labels:
Dictionary with fake/real and word-level/image labels
Example:
labels = {
"fake": {
"word_level": torch.Tensor (BxL)
"image": torch.Tensor (Bx1)
},
"real": {
"word_level": torch.Tensor (BxL)
"image": torch.Tensor (Bx1)
},
}
:param float lambda_4: Hyperparameter for word loss in paper
:return: Discriminator objective loss
:rtype: Any
"""
# define main loss functions for logit losses
tot_loss = 0.0
bce_logits = nn.BCEWithLogitsLoss()
bce = nn.BCELoss()
# calculate word-level loss
word_loss = bce(logits["real"]["word_level"], labels["real"]["word_level"])
# calculate unconditional adversarial loss
uncond_loss = bce_logits(logits["real"]["uncond"], labels["real"]["image"])
# calculate conditional adversarial loss
cond_loss = bce_logits(logits["real"]["cond"], labels["real"]["image"])
tot_loss = (uncond_loss + cond_loss) / 2.0
fake_uncond_loss = bce_logits(logits["fake"]["uncond"], labels["fake"]["image"])
fake_cond_loss = bce_logits(logits["fake"]["cond"], labels["fake"]["image"])
tot_loss += (fake_uncond_loss + fake_cond_loss) / 3.0
tot_loss += word_loss
return tot_loss
def kl_loss(mu_tensor: torch.Tensor, logvar: torch.Tensor) -> Any:
"""
Calculate KL loss
:param torch.Tensor mu_tensor: Mean of latent distribution
:param torch.Tensor logvar: Log variance of latent distribution
:return: KL loss [-0.5 * (1 + log(sigma) - mu^2 - sigma^2)]
:rtype: Any
"""
return torch.mean(-0.5 * (1 + 0.5 * logvar - mu_tensor.pow(2) - torch.exp(logvar)))