taim-gan / src /models /modules /discriminator.py
Dmmc's picture
three-model version
c8ddb9b
raw
history blame
No virus
5.87 kB
"""Discriminator providing word-level feedback"""
from typing import Any
import torch
from torch import nn
from src.models.modules.conv_utils import conv1d, conv2d
from src.models.modules.image_encoder import InceptionEncoder
class WordLevelLogits(nn.Module):
"""API for converting regional feature maps into logits for multi-class classification"""
def __init__(self) -> None:
"""
Instantiate the module with softmax on channel dimension
"""
super().__init__()
self.softmax = nn.Softmax(dim=1)
# layer for flattening the feature maps
self.flat = nn.Flatten(start_dim=2)
# change dism of of textual embs to correlate with chans of inception
self.chan_reduction = conv1d(256, 128)
def forward(
self, visual_features: torch.Tensor, word_embs: torch.Tensor, mask: torch.Tensor
) -> Any:
"""
Fuse two types of features together to get output for feeding into the classification loss
:param torch.Tensor visual_features:
Feature maps of an image after being processed by Inception encoder. Bx128x17x17
:param torch.Tensor word_embs:
Word-level embeddings from the text encoder Bx256xL
:return: Logits for each word in the picture. BxL
:rtype: Any
"""
# make textual and visual features have the same amount of channels
word_embs = self.chan_reduction(word_embs)
# flattening the feature maps
visual_features = self.flat(visual_features)
word_embs = torch.transpose(word_embs, 1, 2)
word_region_correlations = word_embs @ visual_features
# normalize across L dimension
m_norm_l = nn.functional.normalize(word_region_correlations, dim=1)
# normalize across H*W dimension
m_norm_hw = nn.functional.normalize(m_norm_l, dim=2)
m_norm_hw = torch.transpose(m_norm_hw, 1, 2)
weighted_img_feats = visual_features @ m_norm_hw
weighted_img_feats = torch.sum(weighted_img_feats, dim=1)
weighted_img_feats[mask] = -float("inf")
deltas = self.softmax(weighted_img_feats)
return deltas
class UnconditionalLogits(nn.Module):
"""Head for retrieving logits from an image"""
def __init__(self) -> None:
"""Initialize modules that reduce the features down to a set of logits"""
super().__init__()
self.conv = nn.Conv2d(128, 1, kernel_size=17)
# flattening BxLx1x1 into Bx1
self.flat = nn.Flatten()
def forward(self, visual_features: torch.Tensor) -> Any:
"""
Compute logits for unconditioned adversarial loss
:param visual_features: Local features from Inception network. Bx128x17x17
:return: Logits for unconditioned adversarial loss. Bx1
:rtype: Any
"""
# reduce channels and feature maps for visual features
visual_features = self.conv(visual_features)
# flatten Bx1x1x1 into Bx1
logits = self.flat(visual_features)
return logits
class ConditionalLogits(nn.Module):
"""Logits extractor for conditioned adversarial loss"""
def __init__(self) -> None:
super().__init__()
# layer for forming the feature maps out of textual info
self.text_to_fm = conv1d(256, 17 * 17)
# fitting the size of text channels to the size of visual channels
self.chan_aligner = conv2d(1, 128)
# for reduced textual + visual features down to 1x1 feature map
self.joint_conv = nn.Conv2d(2 * 128, 1, kernel_size=17)
# converting Bx1x1x1 into Bx1
self.flat = nn.Flatten()
def forward(self, visual_features: torch.Tensor, sent_embs: torch.Tensor) -> Any:
"""
Compute logits for conditional adversarial loss
:param torch.Tensor visual_features: Features from Inception encoder. Bx128x17x17
:param torch.Tensor sent_embs: Sentence embeddings from text encoder. Bx256
:return: Logits for conditional adversarial loss. BxL
:rtype: Any
"""
# make text and visual features have the same sizes of feature maps
# Bx256 -> Bx256x1 -> Bx289x1
sent_embs = sent_embs.view(-1, 256, 1)
sent_embs = self.text_to_fm(sent_embs)
# transform textual info into shape of visual feature maps
# Bx289x1 -> Bx1x17x17
sent_embs = sent_embs.view(-1, 1, 17, 17)
# propagate text embs through 1d conv to
# align dims with visual feature maps
sent_embs = self.chan_aligner(sent_embs)
# unite textual and visual features across the dim of channels
cross_features = torch.cat((visual_features, sent_embs), dim=1)
# reduce dims down to length of caption and form raw logits
cross_features = self.joint_conv(cross_features)
# form logits from Bx1x1x1 into Bx1
logits = self.flat(cross_features)
return logits
class Discriminator(nn.Module):
"""Simple CNN-based discriminator"""
def __init__(self) -> None:
"""Use a pretrained InceptionNet to extract features"""
super().__init__()
self.encoder = InceptionEncoder(D=128)
# define different logit extractors for different losses
self.logits_word_level = WordLevelLogits()
self.logits_uncond = UnconditionalLogits()
self.logits_cond = ConditionalLogits()
def forward(self, images: torch.Tensor) -> Any:
"""
Retrieves image features encoded by the image encoder
:param torch.Tensor images: Images to be analyzed. Bx3x256x256
:return: image features encoded by image encoder. Bx128x17x17
"""
# only taking the local features from inception
# Bx3x256x256 -> Bx128x17x17
img_features, _ = self.encoder(images)
return img_features