atticus's picture
completed
30a0ec5
"""
****************** COPYRIGHT AND CONFIDENTIALITY INFORMATION ******************
Copyright (c) 2018 [Thomson Licensing]
All Rights Reserved
This program contains proprietary information which is a trade secret/business \
secret of [Thomson Licensing] and is protected, even if unpublished, under \
applicable Copyright laws (including French droit d'auteur) and/or may be \
subject to one or more patent(s).
Recipient is to retain this program in confidence and is not permitted to use \
or make copies thereof other than as permitted in a written agreement with \
[Thomson Licensing] unless otherwise expressly allowed by applicable laws or \
by [Thomson Licensing] under express agreement.
Thomson Licensing is a company of the group TECHNICOLOR
*******************************************************************************
This scripts permits one to reproduce training and experiments of:
Engilberge, M., Chevallier, L., Pérez, P., & Cord, M. (2018, April).
Finding beans in burgers: Deep semantic-visual embedding with localization.
In Proceedings of CVPR (pp. 3984-3993)
Author: Martin Engilberge
"""
import torch
import torch.nn as nn
import torchvision.models as models
##########################################################
# translated from torch version: #
# https://github.com/durandtibo/weldon.resnet.pytorch #
##########################################################
"""
****************** COPYRIGHT AND CONFIDENTIALITY INFORMATION ******************
Copyright (c) 2018 [Thomson Licensing]
All Rights Reserved
This program contains proprietary information which is a trade secret/business \
secret of [Thomson Licensing] and is protected, even if unpublished, under \
applicable Copyright laws (including French droit d'auteur) and/or may be \
subject to one or more patent(s).
Recipient is to retain this program in confidence and is not permitted to use \
or make copies thereof other than as permitted in a written agreement with \
[Thomson Licensing] unless otherwise expressly allowed by applicable laws or \
by [Thomson Licensing] under express agreement.
Thomson Licensing is a company of the group TECHNICOLOR
*******************************************************************************
This scripts permits one to reproduce training and experiments of:
Engilberge, M., Chevallier, L., Pérez, P., & Cord, M. (2018, April).
Finding beans in burgers: Deep semantic-visual embedding with localization.
In Proceedings of CVPR (pp. 3984-3993)
Author: Martin Engilberge
"""
import torch
import torch.nn as nn
import torchvision.models as models
##########################################################
# translated from torch version: #
# https://github.com/durandtibo/weldon.resnet.pytorch #
##########################################################
class WeldonPooling(nn.Module): #
# Pytorch implementation of WELDON pooling
def __init__(self, nMax=1, nMin=None):
super(WeldonPooling, self).__init__()
self.nMax = nMax
if(nMin is None):
self.nMin = nMax
else:
self.nMin = nMin
self.input = torch.Tensor()
self.output = torch.Tensor()
self.indicesMax = torch.Tensor()
self.indicesMin = torch.Tensor()
def forward(self, input):
self.batchSize = 0
self.numChannels = 0
self.h = 0
self.w = 0
if input.dim() == 4:
self.batchSize = input.size(0)
self.numChannels = input.size(1)
self.h = input.size(2)
self.w = input.size(3)
elif input.dim() == 3:
self.batchSize = 1
self.numChannels = input.size(0)
self.h = input.size(1)
self.w = input.size(2)
else:
print('error in WeldonPooling:forward - incorrect input size')
self.input = input
nMax = self.nMax
if nMax <= 0:
nMax = 0
elif nMax < 1:
nMax = torch.clamp(torch.floor(nMax * self.h * self.w), min=1)
nMin = self.nMin
if nMin <= 0:
nMin = 0
elif nMin < 1:
nMin = torch.clamp(torch.floor(nMin * self.h * self.w), min=1)
x = input.view(self.batchSize, self.numChannels, self.h * self.w)
# sort scores by decreasing order
scoreSorted, indices = torch.sort(x, x.dim() - 1, True)
# compute top max
self.indicesMax = indices[:, :, 0:nMax]
self.output = torch.sum(scoreSorted[:, :, 0:nMax], dim=2, keepdim=True)
self.output = self.output.div(nMax)
# compute top min
if nMin > 0:
self.indicesMin = indices[
:, :, self.h * self.w - nMin:self.h * self.w]
yMin = torch.sum(
scoreSorted[:, :, self.h * self.w - nMin:self.h * self.w], 2, keepdim=True).div(nMin)
self.output = torch.add(self.output, yMin)
if input.dim() == 4:
self.output = self.output.view(
self.batchSize, self.numChannels, 1, 1)
elif input.dim() == 3:
self.output = self.output.view(self.numChannels, 1, 1)
return self.output
def backward(self, grad_output, _indices_grad=None):
nMax = self.nMax
if nMax <= 0:
nMax = 0
elif nMax < 1:
nMax = torch.clamp(torch.floor(nMax * self.h * self.w), min=1)
nMin = self.nMin
if nMin <= 0:
nMin = 0
elif nMin < 1:
nMin = torch.clamp(torch.floor(nMin * self.h * self.w), min=1)
yMax = grad_output.clone().view(self.batchSize, self.numChannels,
1).expand(self.batchSize, self.numChannels, nMax)
z = torch.zeros(self.batchSize, self.numChannels,
self.h * self.w).type_as(self.input)
z = z.scatter_(2, self.indicesMax, yMax).div(nMax)
if nMin > 0:
yMin = grad_output.clone().view(self.batchSize, self.numChannels, 1).div(
nMin).expand(self.batchSize, self.numChannels, nMin)
self.gradInput = z.scatter_(2, self.indicesMin, yMin).view(
self.batchSize, self.numChannels, self.h, self.w)
else:
self.gradInput = z.view(
self.batchSize, self.numChannels, self.h, self.w)
if self.input.dim() == 3:
self.gradInput = self.gradInput.view(
self.numChannels, self.h, self.w)
return self.gradInput
class ResNet_weldon(nn.Module):
def __init__(self, args, pretrained=True, weldon_pretrained_path=None):
super(ResNet_weldon, self).__init__()
resnet = models.resnet152(pretrained=pretrained)
self.base_layer = nn.Sequential(*list(resnet.children())[:-2])
self.spaConv = nn.Conv2d(2048, 2400, 1,)
# add spatial aggregation layer
self.wldPool = WeldonPooling(15)
# Linear layer for imagenet classification
self.fc = nn.Linear(2400, 1000)
# Loading pretrained weights of resnet weldon on imagenet classification
if pretrained:
try:
state_di = torch.load(
weldon_pretrained_path, map_location=lambda storage, loc: storage)['state_dict']
self.load_state_dict(state_di)
except Exception:
print("Error when loading pretrained resnet weldon")
def forward(self, x):
x = self.base_layer(x)
x = self.spaConv(x)
x = self.wldPool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
class DynamicPooling(nn.Module): #
# Pytorch implementation of WELDON pooling
def __init__(self, nMax=1, nMin=None):
super(DynamicPooling, self).__init__()
self.nMax = nMax
if(nMin is None):
self.nMin = nMax
else:
self.nMin = nMin
self.input = torch.Tensor()
self.output = torch.Tensor()
self.indicesMax = torch.Tensor()
self.indicesMin = torch.Tensor()
self.conv2d = nn.Conv2d(in_channels=2400, out_channels=2400, kernel_size=3, groups=2400)
self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)
self.act = nn.ReLU()
def fore_back_layer(self, x):
x_fore = self.conv2d(x)
x_back = self.conv2d(x)
x_fore = self.avgpool(x_fore)
x_back = self.avgpool(x_back)
x_fore = self.act(x_fore)
x_back = self.act(x_back)
return x_fore, x_back
def forward(self, input):
self.batchSize = 0
self.numChannels = 0
self.h = 0
self.w = 0
if input.dim() == 4:
self.batchSize = input.size(0)
self.numChannels = input.size(1)
self.h = input.size(2)
self.w = input.size(3)
elif input.dim() == 3:
self.batchSize = 1
self.numChannels = input.size(0)
self.h = input.size(1)
self.w = input.size(2)
else:
print('error in WeldonPooling:forward - incorrect input size')
self.input = input
nMax = self.nMax
if nMax <= 0:
nMax = 0
elif nMax < 1:
nMax = torch.clamp(torch.floor(nMax * self.h * self.w), min=1)
nMin = self.nMin
if nMin <= 0:
nMin = 0
elif nMin < 1:
nMin = torch.clamp(torch.floor(nMin * self.h * self.w), min=1)
# calculate the foreground coefficient
weight_fore, weight_back = self.fore_back_layer(input)
x = input.view(self.batchSize, self.numChannels, self.h * self.w)
# sort scores by decreasing order
scoreSorted, indices = torch.sort(x, x.dim() - 1, True)
# compute top max
self.indicesMax = indices[:, :, 0:nMax] # torch.Size([40, 2400, 15])
self.output = weight_fore.squeeze(dim=-1) * torch.sum(scoreSorted[:, :, 0:nMax], dim=2, keepdim=True)
self.output = self.output.div(nMax)
# compute top min
if nMin > 0:
self.indicesMin = indices[
:, :, self.h * self.w - nMin:self.h * self.w]
yMin = weight_back.squeeze(dim=-1) * torch.sum(
scoreSorted[:, :, self.h * self.w - nMin:self.h * self.w], 2, keepdim=True).div(nMin)
self.output = torch.add(self.output, yMin)
if input.dim() == 4:
self.output = self.output.view(
self.batchSize, self.numChannels, 1, 1)
elif input.dim() == 3:
self.output = self.output.view(self.numChannels, 1, 1)
return self.output
def backward(self, grad_output, _indices_grad=None):
nMax = self.nMax
if nMax <= 0:
nMax = 0
elif nMax < 1:
nMax = torch.clamp(torch.floor(nMax * self.h * self.w), min=1)
nMin = self.nMin
if nMin <= 0:
nMin = 0
elif nMin < 1:
nMin = torch.clamp(torch.floor(nMin * self.h * self.w), min=1)
yMax = grad_output.clone().view(self.batchSize, self.numChannels,
1).expand(self.batchSize, self.numChannels, nMax)
z = torch.zeros(self.batchSize, self.numChannels,
self.h * self.w).type_as(self.input)
z = z.scatter_(2, self.indicesMax, yMax).div(nMax)
if nMin > 0:
yMin = grad_output.clone().view(self.batchSize, self.numChannels, 1).div(
nMin).expand(self.batchSize, self.numChannels, nMin)
self.gradInput = z.scatter_(2, self.indicesMin, yMin).view(
self.batchSize, self.numChannels, self.h, self.w)
else:
self.gradInput = z.view(
self.batchSize, self.numChannels, self.h, self.w)
if self.input.dim() == 3:
self.gradInput = self.gradInput.view(
self.numChannels, self.h, self.w)
return self.gradInput