# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn import torch.nn.functional as F import utils class CLIPLoss(nn.Module): def __init__(self): super().__init__() self.labels = None self.last_local_batch_size = None def forward(self, outputs): image_embed = outputs['image_embed'] text_embed = outputs['text_embed'] logit_scale = outputs['logit_scale'] local_batch_size = image_embed.size(0) if local_batch_size != self.last_local_batch_size: self.labels = local_batch_size * utils.get_rank() + torch.arange( local_batch_size, device=image_embed.device ) self.last_local_batch_size = local_batch_size # normalized features image_embed = F.normalize(image_embed, dim=-1, p=2) text_embed = F.normalize(text_embed, dim=-1, p=2) # gather features from all GPUs image_embed_all, text_embed_all = \ utils.all_gather_batch([image_embed, text_embed]) # cosine similarity as logits logits_per_image = logit_scale * image_embed @ text_embed_all.t() logits_per_text = logit_scale * text_embed @ image_embed_all.t() loss = (F.cross_entropy(logits_per_image, self.labels) + \ F.cross_entropy(logits_per_text, self.labels)) / 2 # compute accuracy with torch.no_grad(): pred = torch.argmax(logits_per_image, dim=-1) correct = pred.eq(self.labels).sum() acc = 100 * correct / local_batch_size return {'loss': loss, 'clip_loss': loss, 'clip_acc': acc} class SIMCLRLoss(nn.Module): """ This is the SimCLR loss in https://arxiv.org/abs/2002.05709 The embedding vectors are assumed to have size (2 x batch_size, embedding_dim) and the memory layout that can be reshaped into shape (2, batch_size, embedding_dim). This memory layout is consistent with the SimCLR collator in https://github.com/facebookresearch/vissl/blob/master/vissl/data/collators/simclr_collator.py Config params: temperature (float): the temperature to be applied on the logits """ def __init__(self, temperature=0.1): super().__init__() self.tau = temperature self.labels = None self.masks = None self.last_local_batch_size = None def forward(self, outputs): q_a = outputs['aug1_embed'] q_b = outputs['aug2_embed'] q_a = F.normalize(q_a, dim=-1, p=2) q_b = F.normalize(q_b, dim=-1, p=2) local_batch_size = q_a.size(0) k_a, k_b = utils.all_gather_batch_with_grad([q_a, q_b]) if local_batch_size != self.last_local_batch_size: self.labels = local_batch_size * utils.get_rank() + torch.arange( local_batch_size, device=q_a.device ) total_batch_size = local_batch_size * utils.get_world_size() self.masks = F.one_hot(self.labels, total_batch_size) * 1e9 self.last_local_batch_size = local_batch_size logits_aa = torch.matmul(q_a, k_a.transpose(0, 1)) / self.tau logits_aa = logits_aa - self.masks logits_bb = torch.matmul(q_b, k_b.transpose(0, 1)) / self.tau logits_bb = logits_bb - self.masks logits_ab = torch.matmul(q_a, k_b.transpose(0, 1)) / self.tau logits_ba = torch.matmul(q_b, k_a.transpose(0, 1)) / self.tau loss_a = F.cross_entropy(torch.cat([logits_ab, logits_aa], dim=1), self.labels) loss_b = F.cross_entropy(torch.cat([logits_ba, logits_bb], dim=1), self.labels) loss = (loss_a + loss_b) / 2 # divide by 2 to average over all samples # compute accuracy with torch.no_grad(): pred = torch.argmax(torch.cat([logits_ab, logits_aa], dim=1), dim=-1) correct = pred.eq(self.labels).sum() acc = 100 * correct / local_batch_size return {'loss': loss, 'ssl_loss': loss, 'ssl_acc': acc} class SLIPLoss(nn.Module): def __init__(self, ssl_loss, ssl_scale): super().__init__() self.clip_loss = CLIPLoss() self.ssl_loss = ssl_loss self.ssl_scale = ssl_scale def forward(self, outputs): clip_loss_dict = self.clip_loss(outputs) clip_loss = clip_loss_dict['clip_loss'] clip_acc = clip_loss_dict['clip_acc'] ssl_loss_dict = self.ssl_loss(outputs) ssl_loss = ssl_loss_dict['ssl_loss'] ssl_acc = ssl_loss_dict['ssl_acc'] return {'loss': clip_loss + self.ssl_scale * ssl_loss, 'clip_loss': clip_loss, 'clip_acc': clip_acc, 'ssl_loss': ssl_loss, 'ssl_acc': ssl_acc}