""" ****************** COPYRIGHT AND CONFIDENTIALITY INFORMATION ****************** Copyright (c) 2018 [Thomson Licensing] All Rights Reserved This program contains proprietary information which is a trade secret/business \ secret of [Thomson Licensing] and is protected, even if unpublished, under \ applicable Copyright laws (including French droit d'auteur) and/or may be \ subject to one or more patent(s). Recipient is to retain this program in confidence and is not permitted to use \ or make copies thereof other than as permitted in a written agreement with \ [Thomson Licensing] unless otherwise expressly allowed by applicable laws or \ by [Thomson Licensing] under express agreement. Thomson Licensing is a company of the group TECHNICOLOR ******************************************************************************* This scripts permits one to reproduce training and experiments of: Engilberge, M., Chevallier, L., Pérez, P., & Cord, M. (2018, April). Finding beans in burgers: Deep semantic-visual embedding with localization. In Proceedings of CVPR (pp. 3984-3993) Author: Martin Engilberge """ import torch import torch.nn as nn from misc.config import path from misc.weldonModel import ResNet_weldon from sru import SRU class SruEmb(nn.Module): def __init__(self, nb_layer, dim_in, dim_out, dropout=0.25): super(SruEmb, self).__init__() self.dim_out = dim_out # SRU 作为文本特征提取 self.rnn = SRU(dim_in, dim_out, num_layers=nb_layer, dropout=dropout, rnn_dropout=dropout, use_tanh=True, has_skip_term=True, v1=True, rescale=False) def _select_last(self, x, lengths): batch_size = x.size(0) mask = x.data.new().resize_as_(x.data).fill_(0) for i in range(batch_size): mask[i][lengths[i] - 1].fill_(1) x = x.mul(mask) x = x.sum(1, keepdim=True).view(batch_size, self.dim_out) return x def _process_lengths(self, input): max_length = input.size(1) # 获取每段文本的长度 lengths = list( max_length - input.data.eq(0).sum(1, keepdim=True).squeeze()) return lengths def forward(self, input, lengths=None): if lengths is None: lengths = self._process_lengths(input) x = input.permute(1, 0, 2) # rnn x, hn = self.rnn(x) x = x.permute(1, 0, 2) if lengths: # 用mask抹除padding部分的权重 x = self._select_last(x, lengths) return x class img_embedding(nn.Module): def __init__(self, args): super(img_embedding, self).__init__() # 图像backbone Resnet152 model_weldon2 = ResNet_weldon(args, pretrained=False, weldon_pretrained_path=path["WELDON_CLASSIF_PRETRAINED"]) self.base_layer = nn.Sequential(*list(model_weldon2.children())[:-1]) # 关掉图像侧梯度 for param in self.base_layer.parameters(): param.requires_grad = False def forward(self, x): x = self.base_layer(x) x = x.view(x.size()[0], -1) return x # 图像激活图 def get_activation_map(self, x): x = self.base_layer[0](x) act_map = self.base_layer[1](x) act = self.base_layer[2](act_map) return act, act_map class joint_embedding(nn.Module): def __init__(self, args): super(joint_embedding, self).__init__() # 图像编码 self.img_emb = torch.nn.DataParallel(img_embedding(args)) # 描述编码 self.cap_emb = SruEmb(args.sru, 620, args.dimemb) # 全连接 self.fc = torch.nn.DataParallel(nn.Linear(2400, args.dimemb, bias=True)) # dropout层 self.dropout = torch.nn.Dropout(p=0.5) def forward(self, imgs, caps, lengths): # 图像侧 if imgs is not None: x_imgs = self.img_emb(imgs) x_imgs = self.dropout(x_imgs) x_imgs = self.fc(x_imgs) x_imgs = x_imgs / torch.norm(x_imgs, 2, dim=1, keepdim=True).expand_as(x_imgs) else: x_imgs = None # 描述侧 if caps is not None: x_caps = self.cap_emb(caps, lengths=lengths) x_caps = x_caps / torch.norm(x_caps, 2, dim=1, keepdim=True).expand_as(x_caps) else: x_caps = None return x_imgs, x_caps