atticus's picture
completed
30a0ec5
"""
****************** COPYRIGHT AND CONFIDENTIALITY INFORMATION ******************
Copyright (c) 2018 [Thomson Licensing]
All Rights Reserved
This program contains proprietary information which is a trade secret/business \
secret of [Thomson Licensing] and is protected, even if unpublished, under \
applicable Copyright laws (including French droit d'auteur) and/or may be \
subject to one or more patent(s).
Recipient is to retain this program in confidence and is not permitted to use \
or make copies thereof other than as permitted in a written agreement with \
[Thomson Licensing] unless otherwise expressly allowed by applicable laws or \
by [Thomson Licensing] under express agreement.
Thomson Licensing is a company of the group TECHNICOLOR
*******************************************************************************
This scripts permits one to reproduce training and experiments of:
Engilberge, M., Chevallier, L., Pérez, P., & Cord, M. (2018, April).
Finding beans in burgers: Deep semantic-visual embedding with localization.
In Proceedings of CVPR (pp. 3984-3993)
Author: Martin Engilberge
"""
import torch
import torch.nn as nn
from misc.config import path
from misc.weldonModel import ResNet_weldon
from sru import SRU
class SruEmb(nn.Module):
def __init__(self, nb_layer, dim_in, dim_out, dropout=0.25):
super(SruEmb, self).__init__()
self.dim_out = dim_out
# SRU 作为文本特征提取
self.rnn = SRU(dim_in, dim_out, num_layers=nb_layer,
dropout=dropout, rnn_dropout=dropout,
use_tanh=True, has_skip_term=True,
v1=True, rescale=False)
def _select_last(self, x, lengths):
batch_size = x.size(0)
mask = x.data.new().resize_as_(x.data).fill_(0)
for i in range(batch_size):
mask[i][lengths[i] - 1].fill_(1)
x = x.mul(mask)
x = x.sum(1, keepdim=True).view(batch_size, self.dim_out)
return x
def _process_lengths(self, input):
max_length = input.size(1)
# 获取每段文本的长度
lengths = list(
max_length - input.data.eq(0).sum(1, keepdim=True).squeeze())
return lengths
def forward(self, input, lengths=None):
if lengths is None:
lengths = self._process_lengths(input)
x = input.permute(1, 0, 2)
# rnn
x, hn = self.rnn(x)
x = x.permute(1, 0, 2)
if lengths:
# 用mask抹除padding部分的权重
x = self._select_last(x, lengths)
return x
class img_embedding(nn.Module):
def __init__(self, args):
super(img_embedding, self).__init__()
# 图像backbone Resnet152
model_weldon2 = ResNet_weldon(args, pretrained=False, weldon_pretrained_path=path["WELDON_CLASSIF_PRETRAINED"])
self.base_layer = nn.Sequential(*list(model_weldon2.children())[:-1])
# 关掉图像侧梯度
for param in self.base_layer.parameters():
param.requires_grad = False
def forward(self, x):
x = self.base_layer(x)
x = x.view(x.size()[0], -1)
return x
# 图像激活图
def get_activation_map(self, x):
x = self.base_layer[0](x)
act_map = self.base_layer[1](x)
act = self.base_layer[2](act_map)
return act, act_map
class joint_embedding(nn.Module):
def __init__(self, args):
super(joint_embedding, self).__init__()
# 图像编码
self.img_emb = torch.nn.DataParallel(img_embedding(args))
# 描述编码
self.cap_emb = SruEmb(args.sru, 620, args.dimemb)
# 全连接
self.fc = torch.nn.DataParallel(nn.Linear(2400, args.dimemb, bias=True))
# dropout层
self.dropout = torch.nn.Dropout(p=0.5)
def forward(self, imgs, caps, lengths):
# 图像侧
if imgs is not None:
x_imgs = self.img_emb(imgs)
x_imgs = self.dropout(x_imgs)
x_imgs = self.fc(x_imgs)
x_imgs = x_imgs / torch.norm(x_imgs, 2, dim=1, keepdim=True).expand_as(x_imgs)
else:
x_imgs = None
# 描述侧
if caps is not None:
x_caps = self.cap_emb(caps, lengths=lengths)
x_caps = x_caps / torch.norm(x_caps, 2, dim=1, keepdim=True).expand_as(x_caps)
else:
x_caps = None
return x_imgs, x_caps