Spaces:
Build error
Build error
""" | |
****************** COPYRIGHT AND CONFIDENTIALITY INFORMATION ****************** | |
Copyright (c) 2018 [Thomson Licensing] | |
All Rights Reserved | |
This program contains proprietary information which is a trade secret/business \ | |
secret of [Thomson Licensing] and is protected, even if unpublished, under \ | |
applicable Copyright laws (including French droit d'auteur) and/or may be \ | |
subject to one or more patent(s). | |
Recipient is to retain this program in confidence and is not permitted to use \ | |
or make copies thereof other than as permitted in a written agreement with \ | |
[Thomson Licensing] unless otherwise expressly allowed by applicable laws or \ | |
by [Thomson Licensing] under express agreement. | |
Thomson Licensing is a company of the group TECHNICOLOR | |
******************************************************************************* | |
This scripts permits one to reproduce training and experiments of: | |
Engilberge, M., Chevallier, L., Pérez, P., & Cord, M. (2018, April). | |
Finding beans in burgers: Deep semantic-visual embedding with localization. | |
In Proceedings of CVPR (pp. 3984-3993) | |
Author: Martin Engilberge | |
""" | |
import torch.nn as nn | |
import torch | |
class ContrastiveLoss(nn.Module): | |
def __init__(self, margin=0.2): | |
super(ContrastiveLoss, self).__init__() | |
self.margin = margin | |
def forward(self, imgs, caps): | |
scores = torch.mm(imgs, caps.t()) | |
diag = scores.diag() | |
cost_s = torch.clamp((self.margin - diag).expand_as(scores) + scores, min=0) | |
# compare every diagonal score to scores in its row (i.e, all | |
# contrastive sentences for each image) | |
cost_im = torch.clamp((self.margin - diag.view(-1, 1)).expand_as(scores) + scores, min=0) | |
# clear diagonals | |
diag_s = torch.diag(cost_s.diag()) | |
diag_im = torch.diag(cost_im.diag()) | |
cost_s = cost_s - diag_s | |
cost_im = cost_im - diag_im | |
return cost_s.sum() + cost_im.sum() | |
class HardNegativeContrastiveLoss(nn.Module): | |
def __init__(self, nmax=1, margin=0.2): | |
super(HardNegativeContrastiveLoss, self).__init__() | |
self.margin = margin | |
self.nmax = nmax | |
def forward(self, imgs, caps): | |
scores = torch.mm(imgs, caps.t()) | |
diag = scores.diag() | |
# Reducing the score on diagonal so there are not selected as hard negative | |
scores = (scores - 2 * torch.diag(scores.diag())) | |
sorted_cap, _ = torch.sort(scores, 0, descending=True) | |
sorted_img, _ = torch.sort(scores, 1, descending=True) | |
# Selecting the nmax hardest negative examples | |
max_c = sorted_cap[:self.nmax, :] | |
max_i = sorted_img[:, :self.nmax] | |
# Margin based loss with hard negative instead of random negative | |
neg_cap = torch.sum(torch.clamp(max_c + (self.margin - diag).view(1, -1).expand_as(max_c), min=0)) | |
neg_img = torch.sum(torch.clamp(max_i + (self.margin - diag).view(-1, 1).expand_as(max_i), min=0)) | |
loss = neg_cap + neg_img | |
return loss | |