from typing import Any, Dict import torch from torch import nn import torch.distributed as dist from virtex.modules.label_smoothing import CrossEntropyLossWithLabelSmoothing from virtex.modules.textual_heads import TextualHead from virtex.modules.visual_backbones import VisualBackbone class ImageTextContrastiveModel(nn.Module): def __init__( self, visual: VisualBackbone, textual: TextualHead, label_smoothing: float = 0.0 ): super().__init__() self.visual = visual self.textual = textual self.padding_idx = self.textual.padding_idx self.visual_projection = nn.Linear( self.visual.visual_feature_size, self.textual.textual_feature_size, bias=False, ) self.logit_scale = nn.Parameter(torch.log(torch.tensor(1/0.07))) self.loss = CrossEntropyLossWithLabelSmoothing( label_smoothing, ignore_index=self.padding_idx ) def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, Any]: # Check if logit_scale needs to be clipped from last iteration. self.logit_scale.data = torch.clamp(self.logit_scale.data, 0, 3.912) # 50 times # shape: (batch_size, channels, height, width) visual_features = self.visual(batch["image"]) batch_size = visual_features.size(0) # shape: (batch_size, channels) visual_features = visual_features.mean(dim=[2, 3]).view(batch_size, -1) # shape: (batch_size, textual_feature_size) visual_features = self.visual_projection(visual_features) caption_tokens = batch["caption_tokens"] caption_lengths = batch["caption_lengths"] # shape: (batch_size, max_caption_length, hidden_size) textual_features = self.textual(caption_tokens, caption_lengths) # Take features from the first time-step (as BERT-* models do). # shape: (batch_size, hidden_size) textual_features = textual_features[:, 0, :] # Normalize visual and textual features. # shape: (batch_size, textual_feature_size) visual_features = visual_features / visual_features.norm(dim=-1, keepdim=True) textual_features = textual_features / textual_features.norm( dim=-1, keepdim=True ) # Gather textual features from all processes into one large tensor to # increase negative samples for contrastive learning. gathered_textual_features = [ torch.zeros_like(textual_features) for _ in range(dist.get_world_size()) ] dist.all_gather(gathered_textual_features, textual_features) # Shift features of current rank to zeroth index for easy implementation. gathered_textual_features[0], gathered_textual_features[dist.get_rank()] = ( gathered_textual_features[dist.get_rank()], gathered_textual_features[0], ) # shape: (batch_size * world_size, textual_feature_size) gathered_textual_features = torch.cat(gathered_textual_features, dim=0) # Calculate pairwise cosine similarity as logits. logit_scale = self.logit_scale.exp() visual_logits = logit_scale * visual_features @ gathered_textual_features.t() # Targets are an identity matrix (image [i] should match with caption [i]) visual_loss = self.loss( visual_logits, torch.arange(visual_logits.size(0)).to(visual_logits.device) ) # Do the same thing for visual features. gathered_visual_features = [ torch.zeros_like(visual_features) for _ in range(dist.get_world_size()) ] dist.all_gather(gathered_visual_features, visual_features) gathered_visual_features[0], gathered_visual_features[dist.get_rank()] = ( gathered_visual_features[dist.get_rank()], gathered_visual_features[0], ) # shape: (batch_size * world_size, textual_feature_size) gathered_visual_features = torch.cat(gathered_visual_features, dim=0) # Calculate pairwise cosine similarity as logits. logit_scale = self.logit_scale.exp() textual_logits = logit_scale * textual_features @ gathered_visual_features.t() # Targets are an identity matrix (image [i] should match with caption [i]) textual_loss = self.loss( textual_logits, torch.arange(textual_logits.size(0)).to(textual_logits.device), ) loss = 0.5 * (visual_loss + textual_loss) output_dict: Dict[str, Any] = { "loss": loss, # Single scalar per batch for logging in training script. "loss_components": {"contrastive": loss.clone().detach()}, } return output_dict