File size: 5,868 Bytes
c8ddb9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
"""Discriminator providing word-level feedback"""
from typing import Any
import torch
from torch import nn
from src.models.modules.conv_utils import conv1d, conv2d
from src.models.modules.image_encoder import InceptionEncoder
class WordLevelLogits(nn.Module):
"""API for converting regional feature maps into logits for multi-class classification"""
def __init__(self) -> None:
"""
Instantiate the module with softmax on channel dimension
"""
super().__init__()
self.softmax = nn.Softmax(dim=1)
# layer for flattening the feature maps
self.flat = nn.Flatten(start_dim=2)
# change dism of of textual embs to correlate with chans of inception
self.chan_reduction = conv1d(256, 128)
def forward(
self, visual_features: torch.Tensor, word_embs: torch.Tensor, mask: torch.Tensor
) -> Any:
"""
Fuse two types of features together to get output for feeding into the classification loss
:param torch.Tensor visual_features:
Feature maps of an image after being processed by Inception encoder. Bx128x17x17
:param torch.Tensor word_embs:
Word-level embeddings from the text encoder Bx256xL
:return: Logits for each word in the picture. BxL
:rtype: Any
"""
# make textual and visual features have the same amount of channels
word_embs = self.chan_reduction(word_embs)
# flattening the feature maps
visual_features = self.flat(visual_features)
word_embs = torch.transpose(word_embs, 1, 2)
word_region_correlations = word_embs @ visual_features
# normalize across L dimension
m_norm_l = nn.functional.normalize(word_region_correlations, dim=1)
# normalize across H*W dimension
m_norm_hw = nn.functional.normalize(m_norm_l, dim=2)
m_norm_hw = torch.transpose(m_norm_hw, 1, 2)
weighted_img_feats = visual_features @ m_norm_hw
weighted_img_feats = torch.sum(weighted_img_feats, dim=1)
weighted_img_feats[mask] = -float("inf")
deltas = self.softmax(weighted_img_feats)
return deltas
class UnconditionalLogits(nn.Module):
"""Head for retrieving logits from an image"""
def __init__(self) -> None:
"""Initialize modules that reduce the features down to a set of logits"""
super().__init__()
self.conv = nn.Conv2d(128, 1, kernel_size=17)
# flattening BxLx1x1 into Bx1
self.flat = nn.Flatten()
def forward(self, visual_features: torch.Tensor) -> Any:
"""
Compute logits for unconditioned adversarial loss
:param visual_features: Local features from Inception network. Bx128x17x17
:return: Logits for unconditioned adversarial loss. Bx1
:rtype: Any
"""
# reduce channels and feature maps for visual features
visual_features = self.conv(visual_features)
# flatten Bx1x1x1 into Bx1
logits = self.flat(visual_features)
return logits
class ConditionalLogits(nn.Module):
"""Logits extractor for conditioned adversarial loss"""
def __init__(self) -> None:
super().__init__()
# layer for forming the feature maps out of textual info
self.text_to_fm = conv1d(256, 17 * 17)
# fitting the size of text channels to the size of visual channels
self.chan_aligner = conv2d(1, 128)
# for reduced textual + visual features down to 1x1 feature map
self.joint_conv = nn.Conv2d(2 * 128, 1, kernel_size=17)
# converting Bx1x1x1 into Bx1
self.flat = nn.Flatten()
def forward(self, visual_features: torch.Tensor, sent_embs: torch.Tensor) -> Any:
"""
Compute logits for conditional adversarial loss
:param torch.Tensor visual_features: Features from Inception encoder. Bx128x17x17
:param torch.Tensor sent_embs: Sentence embeddings from text encoder. Bx256
:return: Logits for conditional adversarial loss. BxL
:rtype: Any
"""
# make text and visual features have the same sizes of feature maps
# Bx256 -> Bx256x1 -> Bx289x1
sent_embs = sent_embs.view(-1, 256, 1)
sent_embs = self.text_to_fm(sent_embs)
# transform textual info into shape of visual feature maps
# Bx289x1 -> Bx1x17x17
sent_embs = sent_embs.view(-1, 1, 17, 17)
# propagate text embs through 1d conv to
# align dims with visual feature maps
sent_embs = self.chan_aligner(sent_embs)
# unite textual and visual features across the dim of channels
cross_features = torch.cat((visual_features, sent_embs), dim=1)
# reduce dims down to length of caption and form raw logits
cross_features = self.joint_conv(cross_features)
# form logits from Bx1x1x1 into Bx1
logits = self.flat(cross_features)
return logits
class Discriminator(nn.Module):
"""Simple CNN-based discriminator"""
def __init__(self) -> None:
"""Use a pretrained InceptionNet to extract features"""
super().__init__()
self.encoder = InceptionEncoder(D=128)
# define different logit extractors for different losses
self.logits_word_level = WordLevelLogits()
self.logits_uncond = UnconditionalLogits()
self.logits_cond = ConditionalLogits()
def forward(self, images: torch.Tensor) -> Any:
"""
Retrieves image features encoded by the image encoder
:param torch.Tensor images: Images to be analyzed. Bx3x256x256
:return: image features encoded by image encoder. Bx128x17x17
"""
# only taking the local features from inception
# Bx3x256x256 -> Bx128x17x17
img_features, _ = self.encoder(images)
return img_features
|