Spaces:
Paused
Paused
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
def is_dist_initialized(): | |
return torch.distributed.is_initialized() | |
def get_world_size(): | |
if is_dist_initialized(): | |
return torch.distributed.get_world_size() | |
return 1 | |
def all_gather_grad(x): | |
if get_world_size() > 1: | |
all_x = [torch.zeros_like(x) for _ in range(get_world_size())] | |
torch.distributed.all_gather(all_x, x) | |
all_x[torch.distributed.get_rank()] = x | |
x = torch.cat(all_x, dim=0) | |
return x | |
def all_gather_nograd(tensor): | |
# from albef | |
""" | |
Performs all_gather operation on the provided tensors. | |
*** Warning ***: torch.distributed.all_gather has no gradient. | |
""" | |
if get_world_size() > 1: | |
tensors_gather = [torch.ones_like(tensor) | |
for _ in range(torch.distributed.get_world_size())] | |
torch.distributed.all_gather(tensors_gather, tensor, async_op=False) | |
tensor = torch.cat(tensors_gather, dim=0) | |
return tensor | |
def image_text_contrastive_loss(image_feat, text_feat, temperature, image_id=None, text_id=None): | |
# add the following 4 lines | |
image_feat = all_gather_grad(image_feat) | |
text_feat = all_gather_grad(text_feat) | |
logits = torch.matmul(image_feat, text_feat.t()) | |
logits /= temperature | |
if image_id is None and text_id is None: | |
gt = torch.arange(logits.shape[0], device=logits.device) | |
loss1 = F.cross_entropy(logits, gt) | |
loss2 = F.cross_entropy(logits.t(), gt) | |
else: | |
image_id = all_gather_grad(image_id) | |
text_id = all_gather_grad(text_id) | |
gt_image = image_id.reshape((-1, 1)) == image_id.reshape((1, -1)) | |
gt_text = text_id.reshape((-1, 1)) == text_id.reshape((1, -1)) | |
gt = torch.logical_or(gt_image, gt_text) | |
loss1 = -torch.sum(gt * F.log_softmax(logits, dim=1)) / gt.sum() | |
loss2 = -torch.sum(gt.t() * F.log_softmax(logits.t(), dim=1)) / gt.sum() | |
return (loss1 + loss2) / 2 * get_world_size() # scale it up by the number of GPUs | |