Bhaskar Saranga
Added tracker
e215925
raw
history blame contribute delete
No virus
13.8 kB
from __future__ import division, absolute_import
import torch
from torch import nn
from torch.nn import functional as F
__all__ = ['HACNN']
class ConvBlock(nn.Module):
"""Basic convolutional block.
convolution + batch normalization + relu.
Args:
in_c (int): number of input channels.
out_c (int): number of output channels.
k (int or tuple): kernel size.
s (int or tuple): stride.
p (int or tuple): padding.
"""
def __init__(self, in_c, out_c, k, s=1, p=0):
super(ConvBlock, self).__init__()
self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p)
self.bn = nn.BatchNorm2d(out_c)
def forward(self, x):
return F.relu(self.bn(self.conv(x)))
class InceptionA(nn.Module):
def __init__(self, in_channels, out_channels):
super(InceptionA, self).__init__()
mid_channels = out_channels // 4
self.stream1 = nn.Sequential(
ConvBlock(in_channels, mid_channels, 1),
ConvBlock(mid_channels, mid_channels, 3, p=1),
)
self.stream2 = nn.Sequential(
ConvBlock(in_channels, mid_channels, 1),
ConvBlock(mid_channels, mid_channels, 3, p=1),
)
self.stream3 = nn.Sequential(
ConvBlock(in_channels, mid_channels, 1),
ConvBlock(mid_channels, mid_channels, 3, p=1),
)
self.stream4 = nn.Sequential(
nn.AvgPool2d(3, stride=1, padding=1),
ConvBlock(in_channels, mid_channels, 1),
)
def forward(self, x):
s1 = self.stream1(x)
s2 = self.stream2(x)
s3 = self.stream3(x)
s4 = self.stream4(x)
y = torch.cat([s1, s2, s3, s4], dim=1)
return y
class InceptionB(nn.Module):
def __init__(self, in_channels, out_channels):
super(InceptionB, self).__init__()
mid_channels = out_channels // 4
self.stream1 = nn.Sequential(
ConvBlock(in_channels, mid_channels, 1),
ConvBlock(mid_channels, mid_channels, 3, s=2, p=1),
)
self.stream2 = nn.Sequential(
ConvBlock(in_channels, mid_channels, 1),
ConvBlock(mid_channels, mid_channels, 3, p=1),
ConvBlock(mid_channels, mid_channels, 3, s=2, p=1),
)
self.stream3 = nn.Sequential(
nn.MaxPool2d(3, stride=2, padding=1),
ConvBlock(in_channels, mid_channels * 2, 1),
)
def forward(self, x):
s1 = self.stream1(x)
s2 = self.stream2(x)
s3 = self.stream3(x)
y = torch.cat([s1, s2, s3], dim=1)
return y
class SpatialAttn(nn.Module):
"""Spatial Attention (Sec. 3.1.I.1)"""
def __init__(self):
super(SpatialAttn, self).__init__()
self.conv1 = ConvBlock(1, 1, 3, s=2, p=1)
self.conv2 = ConvBlock(1, 1, 1)
def forward(self, x):
# global cross-channel averaging
x = x.mean(1, keepdim=True)
# 3-by-3 conv
x = self.conv1(x)
# bilinear resizing
x = F.upsample(
x, (x.size(2) * 2, x.size(3) * 2),
mode='bilinear',
align_corners=True
)
# scaling conv
x = self.conv2(x)
return x
class ChannelAttn(nn.Module):
"""Channel Attention (Sec. 3.1.I.2)"""
def __init__(self, in_channels, reduction_rate=16):
super(ChannelAttn, self).__init__()
assert in_channels % reduction_rate == 0
self.conv1 = ConvBlock(in_channels, in_channels // reduction_rate, 1)
self.conv2 = ConvBlock(in_channels // reduction_rate, in_channels, 1)
def forward(self, x):
# squeeze operation (global average pooling)
x = F.avg_pool2d(x, x.size()[2:])
# excitation operation (2 conv layers)
x = self.conv1(x)
x = self.conv2(x)
return x
class SoftAttn(nn.Module):
"""Soft Attention (Sec. 3.1.I)
Aim: Spatial Attention + Channel Attention
Output: attention maps with shape identical to input.
"""
def __init__(self, in_channels):
super(SoftAttn, self).__init__()
self.spatial_attn = SpatialAttn()
self.channel_attn = ChannelAttn(in_channels)
self.conv = ConvBlock(in_channels, in_channels, 1)
def forward(self, x):
y_spatial = self.spatial_attn(x)
y_channel = self.channel_attn(x)
y = y_spatial * y_channel
y = torch.sigmoid(self.conv(y))
return y
class HardAttn(nn.Module):
"""Hard Attention (Sec. 3.1.II)"""
def __init__(self, in_channels):
super(HardAttn, self).__init__()
self.fc = nn.Linear(in_channels, 4 * 2)
self.init_params()
def init_params(self):
self.fc.weight.data.zero_()
self.fc.bias.data.copy_(
torch.tensor(
[0, -0.75, 0, -0.25, 0, 0.25, 0, 0.75], dtype=torch.float
)
)
def forward(self, x):
# squeeze operation (global average pooling)
x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), x.size(1))
# predict transformation parameters
theta = torch.tanh(self.fc(x))
theta = theta.view(-1, 4, 2)
return theta
class HarmAttn(nn.Module):
"""Harmonious Attention (Sec. 3.1)"""
def __init__(self, in_channels):
super(HarmAttn, self).__init__()
self.soft_attn = SoftAttn(in_channels)
self.hard_attn = HardAttn(in_channels)
def forward(self, x):
y_soft_attn = self.soft_attn(x)
theta = self.hard_attn(x)
return y_soft_attn, theta
class HACNN(nn.Module):
"""Harmonious Attention Convolutional Neural Network.
Reference:
Li et al. Harmonious Attention Network for Person Re-identification. CVPR 2018.
Public keys:
- ``hacnn``: HACNN.
"""
# Args:
# num_classes (int): number of classes to predict
# nchannels (list): number of channels AFTER concatenation
# feat_dim (int): feature dimension for a single stream
# learn_region (bool): whether to learn region features (i.e. local branch)
def __init__(
self,
num_classes,
loss='softmax',
nchannels=[128, 256, 384],
feat_dim=512,
learn_region=True,
use_gpu=True,
**kwargs
):
super(HACNN, self).__init__()
self.loss = loss
self.learn_region = learn_region
self.use_gpu = use_gpu
self.conv = ConvBlock(3, 32, 3, s=2, p=1)
# Construct Inception + HarmAttn blocks
# ============== Block 1 ==============
self.inception1 = nn.Sequential(
InceptionA(32, nchannels[0]),
InceptionB(nchannels[0], nchannels[0]),
)
self.ha1 = HarmAttn(nchannels[0])
# ============== Block 2 ==============
self.inception2 = nn.Sequential(
InceptionA(nchannels[0], nchannels[1]),
InceptionB(nchannels[1], nchannels[1]),
)
self.ha2 = HarmAttn(nchannels[1])
# ============== Block 3 ==============
self.inception3 = nn.Sequential(
InceptionA(nchannels[1], nchannels[2]),
InceptionB(nchannels[2], nchannels[2]),
)
self.ha3 = HarmAttn(nchannels[2])
self.fc_global = nn.Sequential(
nn.Linear(nchannels[2], feat_dim),
nn.BatchNorm1d(feat_dim),
nn.ReLU(),
)
self.classifier_global = nn.Linear(feat_dim, num_classes)
if self.learn_region:
self.init_scale_factors()
self.local_conv1 = InceptionB(32, nchannels[0])
self.local_conv2 = InceptionB(nchannels[0], nchannels[1])
self.local_conv3 = InceptionB(nchannels[1], nchannels[2])
self.fc_local = nn.Sequential(
nn.Linear(nchannels[2] * 4, feat_dim),
nn.BatchNorm1d(feat_dim),
nn.ReLU(),
)
self.classifier_local = nn.Linear(feat_dim, num_classes)
self.feat_dim = feat_dim * 2
else:
self.feat_dim = feat_dim
def init_scale_factors(self):
# initialize scale factors (s_w, s_h) for four regions
self.scale_factors = []
self.scale_factors.append(
torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float)
)
self.scale_factors.append(
torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float)
)
self.scale_factors.append(
torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float)
)
self.scale_factors.append(
torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float)
)
def stn(self, x, theta):
"""Performs spatial transform
x: (batch, channel, height, width)
theta: (batch, 2, 3)
"""
grid = F.affine_grid(theta, x.size())
x = F.grid_sample(x, grid)
return x
def transform_theta(self, theta_i, region_idx):
"""Transforms theta to include (s_w, s_h), resulting in (batch, 2, 3)"""
scale_factors = self.scale_factors[region_idx]
theta = torch.zeros(theta_i.size(0), 2, 3)
theta[:, :, :2] = scale_factors
theta[:, :, -1] = theta_i
if self.use_gpu:
theta = theta.cuda()
return theta
def forward(self, x):
assert x.size(2) == 160 and x.size(3) == 64, \
'Input size does not match, expected (160, 64) but got ({}, {})'.format(x.size(2), x.size(3))
x = self.conv(x)
# ============== Block 1 ==============
# global branch
x1 = self.inception1(x)
x1_attn, x1_theta = self.ha1(x1)
x1_out = x1 * x1_attn
# local branch
if self.learn_region:
x1_local_list = []
for region_idx in range(4):
x1_theta_i = x1_theta[:, region_idx, :]
x1_theta_i = self.transform_theta(x1_theta_i, region_idx)
x1_trans_i = self.stn(x, x1_theta_i)
x1_trans_i = F.upsample(
x1_trans_i, (24, 28), mode='bilinear', align_corners=True
)
x1_local_i = self.local_conv1(x1_trans_i)
x1_local_list.append(x1_local_i)
# ============== Block 2 ==============
# Block 2
# global branch
x2 = self.inception2(x1_out)
x2_attn, x2_theta = self.ha2(x2)
x2_out = x2 * x2_attn
# local branch
if self.learn_region:
x2_local_list = []
for region_idx in range(4):
x2_theta_i = x2_theta[:, region_idx, :]
x2_theta_i = self.transform_theta(x2_theta_i, region_idx)
x2_trans_i = self.stn(x1_out, x2_theta_i)
x2_trans_i = F.upsample(
x2_trans_i, (12, 14), mode='bilinear', align_corners=True
)
x2_local_i = x2_trans_i + x1_local_list[region_idx]
x2_local_i = self.local_conv2(x2_local_i)
x2_local_list.append(x2_local_i)
# ============== Block 3 ==============
# Block 3
# global branch
x3 = self.inception3(x2_out)
x3_attn, x3_theta = self.ha3(x3)
x3_out = x3 * x3_attn
# local branch
if self.learn_region:
x3_local_list = []
for region_idx in range(4):
x3_theta_i = x3_theta[:, region_idx, :]
x3_theta_i = self.transform_theta(x3_theta_i, region_idx)
x3_trans_i = self.stn(x2_out, x3_theta_i)
x3_trans_i = F.upsample(
x3_trans_i, (6, 7), mode='bilinear', align_corners=True
)
x3_local_i = x3_trans_i + x2_local_list[region_idx]
x3_local_i = self.local_conv3(x3_local_i)
x3_local_list.append(x3_local_i)
# ============== Feature generation ==============
# global branch
x_global = F.avg_pool2d(x3_out,
x3_out.size()[2:]
).view(x3_out.size(0), x3_out.size(1))
x_global = self.fc_global(x_global)
# local branch
if self.learn_region:
x_local_list = []
for region_idx in range(4):
x_local_i = x3_local_list[region_idx]
x_local_i = F.avg_pool2d(x_local_i,
x_local_i.size()[2:]
).view(x_local_i.size(0), -1)
x_local_list.append(x_local_i)
x_local = torch.cat(x_local_list, 1)
x_local = self.fc_local(x_local)
if not self.training:
# l2 normalization before concatenation
if self.learn_region:
x_global = x_global / x_global.norm(p=2, dim=1, keepdim=True)
x_local = x_local / x_local.norm(p=2, dim=1, keepdim=True)
return torch.cat([x_global, x_local], 1)
else:
return x_global
prelogits_global = self.classifier_global(x_global)
if self.learn_region:
prelogits_local = self.classifier_local(x_local)
if self.loss == 'softmax':
if self.learn_region:
return (prelogits_global, prelogits_local)
else:
return prelogits_global
elif self.loss == 'triplet':
if self.learn_region:
return (prelogits_global, prelogits_local), (x_global, x_local)
else:
return prelogits_global, x_global
else:
raise KeyError("Unsupported loss: {}".format(self.loss))