|
"""Discriminator providing word-level feedback""" |
|
from typing import Any |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from src.models.modules.conv_utils import conv1d, conv2d |
|
from src.models.modules.image_encoder import InceptionEncoder |
|
|
|
|
|
class WordLevelLogits(nn.Module): |
|
"""API for converting regional feature maps into logits for multi-class classification""" |
|
|
|
def __init__(self) -> None: |
|
""" |
|
Instantiate the module with softmax on channel dimension |
|
""" |
|
super().__init__() |
|
self.softmax = nn.Softmax(dim=1) |
|
|
|
self.flat = nn.Flatten(start_dim=2) |
|
|
|
self.chan_reduction = conv1d(256, 128) |
|
|
|
def forward( |
|
self, visual_features: torch.Tensor, word_embs: torch.Tensor, mask: torch.Tensor |
|
) -> Any: |
|
""" |
|
Fuse two types of features together to get output for feeding into the classification loss |
|
:param torch.Tensor visual_features: |
|
Feature maps of an image after being processed by Inception encoder. Bx128x17x17 |
|
:param torch.Tensor word_embs: |
|
Word-level embeddings from the text encoder Bx256xL |
|
:return: Logits for each word in the picture. BxL |
|
:rtype: Any |
|
""" |
|
|
|
word_embs = self.chan_reduction(word_embs) |
|
|
|
visual_features = self.flat(visual_features) |
|
word_embs = torch.transpose(word_embs, 1, 2) |
|
word_region_correlations = word_embs @ visual_features |
|
|
|
m_norm_l = nn.functional.normalize(word_region_correlations, dim=1) |
|
|
|
m_norm_hw = nn.functional.normalize(m_norm_l, dim=2) |
|
m_norm_hw = torch.transpose(m_norm_hw, 1, 2) |
|
weighted_img_feats = visual_features @ m_norm_hw |
|
weighted_img_feats = torch.sum(weighted_img_feats, dim=1) |
|
weighted_img_feats[mask] = -float("inf") |
|
deltas = self.softmax(weighted_img_feats) |
|
return deltas |
|
|
|
|
|
class UnconditionalLogits(nn.Module): |
|
"""Head for retrieving logits from an image""" |
|
|
|
def __init__(self) -> None: |
|
"""Initialize modules that reduce the features down to a set of logits""" |
|
super().__init__() |
|
self.conv = nn.Conv2d(128, 1, kernel_size=17) |
|
|
|
self.flat = nn.Flatten() |
|
|
|
def forward(self, visual_features: torch.Tensor) -> Any: |
|
""" |
|
Compute logits for unconditioned adversarial loss |
|
|
|
:param visual_features: Local features from Inception network. Bx128x17x17 |
|
:return: Logits for unconditioned adversarial loss. Bx1 |
|
:rtype: Any |
|
""" |
|
|
|
visual_features = self.conv(visual_features) |
|
|
|
logits = self.flat(visual_features) |
|
return logits |
|
|
|
|
|
class ConditionalLogits(nn.Module): |
|
"""Logits extractor for conditioned adversarial loss""" |
|
|
|
def __init__(self) -> None: |
|
super().__init__() |
|
|
|
self.text_to_fm = conv1d(256, 17 * 17) |
|
|
|
self.chan_aligner = conv2d(1, 128) |
|
|
|
self.joint_conv = nn.Conv2d(2 * 128, 1, kernel_size=17) |
|
|
|
self.flat = nn.Flatten() |
|
|
|
def forward(self, visual_features: torch.Tensor, sent_embs: torch.Tensor) -> Any: |
|
""" |
|
Compute logits for conditional adversarial loss |
|
|
|
:param torch.Tensor visual_features: Features from Inception encoder. Bx128x17x17 |
|
:param torch.Tensor sent_embs: Sentence embeddings from text encoder. Bx256 |
|
:return: Logits for conditional adversarial loss. BxL |
|
:rtype: Any |
|
""" |
|
|
|
|
|
sent_embs = sent_embs.view(-1, 256, 1) |
|
sent_embs = self.text_to_fm(sent_embs) |
|
|
|
|
|
sent_embs = sent_embs.view(-1, 1, 17, 17) |
|
|
|
|
|
sent_embs = self.chan_aligner(sent_embs) |
|
|
|
cross_features = torch.cat((visual_features, sent_embs), dim=1) |
|
|
|
cross_features = self.joint_conv(cross_features) |
|
|
|
logits = self.flat(cross_features) |
|
return logits |
|
|
|
|
|
class Discriminator(nn.Module): |
|
"""Simple CNN-based discriminator""" |
|
|
|
def __init__(self) -> None: |
|
"""Use a pretrained InceptionNet to extract features""" |
|
super().__init__() |
|
self.encoder = InceptionEncoder(D=128) |
|
|
|
self.logits_word_level = WordLevelLogits() |
|
self.logits_uncond = UnconditionalLogits() |
|
self.logits_cond = ConditionalLogits() |
|
|
|
def forward(self, images: torch.Tensor) -> Any: |
|
""" |
|
Retrieves image features encoded by the image encoder |
|
|
|
:param torch.Tensor images: Images to be analyzed. Bx3x256x256 |
|
:return: image features encoded by image encoder. Bx128x17x17 |
|
""" |
|
|
|
|
|
img_features, _ = self.encoder(images) |
|
return img_features |
|
|