Spaces:
Runtime error
Runtime error
from __future__ import division, absolute_import | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
__all__ = ['HACNN'] | |
class ConvBlock(nn.Module): | |
"""Basic convolutional block. | |
convolution + batch normalization + relu. | |
Args: | |
in_c (int): number of input channels. | |
out_c (int): number of output channels. | |
k (int or tuple): kernel size. | |
s (int or tuple): stride. | |
p (int or tuple): padding. | |
""" | |
def __init__(self, in_c, out_c, k, s=1, p=0): | |
super(ConvBlock, self).__init__() | |
self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p) | |
self.bn = nn.BatchNorm2d(out_c) | |
def forward(self, x): | |
return F.relu(self.bn(self.conv(x))) | |
class InceptionA(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super(InceptionA, self).__init__() | |
mid_channels = out_channels // 4 | |
self.stream1 = nn.Sequential( | |
ConvBlock(in_channels, mid_channels, 1), | |
ConvBlock(mid_channels, mid_channels, 3, p=1), | |
) | |
self.stream2 = nn.Sequential( | |
ConvBlock(in_channels, mid_channels, 1), | |
ConvBlock(mid_channels, mid_channels, 3, p=1), | |
) | |
self.stream3 = nn.Sequential( | |
ConvBlock(in_channels, mid_channels, 1), | |
ConvBlock(mid_channels, mid_channels, 3, p=1), | |
) | |
self.stream4 = nn.Sequential( | |
nn.AvgPool2d(3, stride=1, padding=1), | |
ConvBlock(in_channels, mid_channels, 1), | |
) | |
def forward(self, x): | |
s1 = self.stream1(x) | |
s2 = self.stream2(x) | |
s3 = self.stream3(x) | |
s4 = self.stream4(x) | |
y = torch.cat([s1, s2, s3, s4], dim=1) | |
return y | |
class InceptionB(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super(InceptionB, self).__init__() | |
mid_channels = out_channels // 4 | |
self.stream1 = nn.Sequential( | |
ConvBlock(in_channels, mid_channels, 1), | |
ConvBlock(mid_channels, mid_channels, 3, s=2, p=1), | |
) | |
self.stream2 = nn.Sequential( | |
ConvBlock(in_channels, mid_channels, 1), | |
ConvBlock(mid_channels, mid_channels, 3, p=1), | |
ConvBlock(mid_channels, mid_channels, 3, s=2, p=1), | |
) | |
self.stream3 = nn.Sequential( | |
nn.MaxPool2d(3, stride=2, padding=1), | |
ConvBlock(in_channels, mid_channels * 2, 1), | |
) | |
def forward(self, x): | |
s1 = self.stream1(x) | |
s2 = self.stream2(x) | |
s3 = self.stream3(x) | |
y = torch.cat([s1, s2, s3], dim=1) | |
return y | |
class SpatialAttn(nn.Module): | |
"""Spatial Attention (Sec. 3.1.I.1)""" | |
def __init__(self): | |
super(SpatialAttn, self).__init__() | |
self.conv1 = ConvBlock(1, 1, 3, s=2, p=1) | |
self.conv2 = ConvBlock(1, 1, 1) | |
def forward(self, x): | |
# global cross-channel averaging | |
x = x.mean(1, keepdim=True) | |
# 3-by-3 conv | |
x = self.conv1(x) | |
# bilinear resizing | |
x = F.upsample( | |
x, (x.size(2) * 2, x.size(3) * 2), | |
mode='bilinear', | |
align_corners=True | |
) | |
# scaling conv | |
x = self.conv2(x) | |
return x | |
class ChannelAttn(nn.Module): | |
"""Channel Attention (Sec. 3.1.I.2)""" | |
def __init__(self, in_channels, reduction_rate=16): | |
super(ChannelAttn, self).__init__() | |
assert in_channels % reduction_rate == 0 | |
self.conv1 = ConvBlock(in_channels, in_channels // reduction_rate, 1) | |
self.conv2 = ConvBlock(in_channels // reduction_rate, in_channels, 1) | |
def forward(self, x): | |
# squeeze operation (global average pooling) | |
x = F.avg_pool2d(x, x.size()[2:]) | |
# excitation operation (2 conv layers) | |
x = self.conv1(x) | |
x = self.conv2(x) | |
return x | |
class SoftAttn(nn.Module): | |
"""Soft Attention (Sec. 3.1.I) | |
Aim: Spatial Attention + Channel Attention | |
Output: attention maps with shape identical to input. | |
""" | |
def __init__(self, in_channels): | |
super(SoftAttn, self).__init__() | |
self.spatial_attn = SpatialAttn() | |
self.channel_attn = ChannelAttn(in_channels) | |
self.conv = ConvBlock(in_channels, in_channels, 1) | |
def forward(self, x): | |
y_spatial = self.spatial_attn(x) | |
y_channel = self.channel_attn(x) | |
y = y_spatial * y_channel | |
y = torch.sigmoid(self.conv(y)) | |
return y | |
class HardAttn(nn.Module): | |
"""Hard Attention (Sec. 3.1.II)""" | |
def __init__(self, in_channels): | |
super(HardAttn, self).__init__() | |
self.fc = nn.Linear(in_channels, 4 * 2) | |
self.init_params() | |
def init_params(self): | |
self.fc.weight.data.zero_() | |
self.fc.bias.data.copy_( | |
torch.tensor( | |
[0, -0.75, 0, -0.25, 0, 0.25, 0, 0.75], dtype=torch.float | |
) | |
) | |
def forward(self, x): | |
# squeeze operation (global average pooling) | |
x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), x.size(1)) | |
# predict transformation parameters | |
theta = torch.tanh(self.fc(x)) | |
theta = theta.view(-1, 4, 2) | |
return theta | |
class HarmAttn(nn.Module): | |
"""Harmonious Attention (Sec. 3.1)""" | |
def __init__(self, in_channels): | |
super(HarmAttn, self).__init__() | |
self.soft_attn = SoftAttn(in_channels) | |
self.hard_attn = HardAttn(in_channels) | |
def forward(self, x): | |
y_soft_attn = self.soft_attn(x) | |
theta = self.hard_attn(x) | |
return y_soft_attn, theta | |
class HACNN(nn.Module): | |
"""Harmonious Attention Convolutional Neural Network. | |
Reference: | |
Li et al. Harmonious Attention Network for Person Re-identification. CVPR 2018. | |
Public keys: | |
- ``hacnn``: HACNN. | |
""" | |
# Args: | |
# num_classes (int): number of classes to predict | |
# nchannels (list): number of channels AFTER concatenation | |
# feat_dim (int): feature dimension for a single stream | |
# learn_region (bool): whether to learn region features (i.e. local branch) | |
def __init__( | |
self, | |
num_classes, | |
loss='softmax', | |
nchannels=[128, 256, 384], | |
feat_dim=512, | |
learn_region=True, | |
use_gpu=True, | |
**kwargs | |
): | |
super(HACNN, self).__init__() | |
self.loss = loss | |
self.learn_region = learn_region | |
self.use_gpu = use_gpu | |
self.conv = ConvBlock(3, 32, 3, s=2, p=1) | |
# Construct Inception + HarmAttn blocks | |
# ============== Block 1 ============== | |
self.inception1 = nn.Sequential( | |
InceptionA(32, nchannels[0]), | |
InceptionB(nchannels[0], nchannels[0]), | |
) | |
self.ha1 = HarmAttn(nchannels[0]) | |
# ============== Block 2 ============== | |
self.inception2 = nn.Sequential( | |
InceptionA(nchannels[0], nchannels[1]), | |
InceptionB(nchannels[1], nchannels[1]), | |
) | |
self.ha2 = HarmAttn(nchannels[1]) | |
# ============== Block 3 ============== | |
self.inception3 = nn.Sequential( | |
InceptionA(nchannels[1], nchannels[2]), | |
InceptionB(nchannels[2], nchannels[2]), | |
) | |
self.ha3 = HarmAttn(nchannels[2]) | |
self.fc_global = nn.Sequential( | |
nn.Linear(nchannels[2], feat_dim), | |
nn.BatchNorm1d(feat_dim), | |
nn.ReLU(), | |
) | |
self.classifier_global = nn.Linear(feat_dim, num_classes) | |
if self.learn_region: | |
self.init_scale_factors() | |
self.local_conv1 = InceptionB(32, nchannels[0]) | |
self.local_conv2 = InceptionB(nchannels[0], nchannels[1]) | |
self.local_conv3 = InceptionB(nchannels[1], nchannels[2]) | |
self.fc_local = nn.Sequential( | |
nn.Linear(nchannels[2] * 4, feat_dim), | |
nn.BatchNorm1d(feat_dim), | |
nn.ReLU(), | |
) | |
self.classifier_local = nn.Linear(feat_dim, num_classes) | |
self.feat_dim = feat_dim * 2 | |
else: | |
self.feat_dim = feat_dim | |
def init_scale_factors(self): | |
# initialize scale factors (s_w, s_h) for four regions | |
self.scale_factors = [] | |
self.scale_factors.append( | |
torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float) | |
) | |
self.scale_factors.append( | |
torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float) | |
) | |
self.scale_factors.append( | |
torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float) | |
) | |
self.scale_factors.append( | |
torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float) | |
) | |
def stn(self, x, theta): | |
"""Performs spatial transform | |
x: (batch, channel, height, width) | |
theta: (batch, 2, 3) | |
""" | |
grid = F.affine_grid(theta, x.size()) | |
x = F.grid_sample(x, grid) | |
return x | |
def transform_theta(self, theta_i, region_idx): | |
"""Transforms theta to include (s_w, s_h), resulting in (batch, 2, 3)""" | |
scale_factors = self.scale_factors[region_idx] | |
theta = torch.zeros(theta_i.size(0), 2, 3) | |
theta[:, :, :2] = scale_factors | |
theta[:, :, -1] = theta_i | |
if self.use_gpu: | |
theta = theta.cuda() | |
return theta | |
def forward(self, x): | |
assert x.size(2) == 160 and x.size(3) == 64, \ | |
'Input size does not match, expected (160, 64) but got ({}, {})'.format(x.size(2), x.size(3)) | |
x = self.conv(x) | |
# ============== Block 1 ============== | |
# global branch | |
x1 = self.inception1(x) | |
x1_attn, x1_theta = self.ha1(x1) | |
x1_out = x1 * x1_attn | |
# local branch | |
if self.learn_region: | |
x1_local_list = [] | |
for region_idx in range(4): | |
x1_theta_i = x1_theta[:, region_idx, :] | |
x1_theta_i = self.transform_theta(x1_theta_i, region_idx) | |
x1_trans_i = self.stn(x, x1_theta_i) | |
x1_trans_i = F.upsample( | |
x1_trans_i, (24, 28), mode='bilinear', align_corners=True | |
) | |
x1_local_i = self.local_conv1(x1_trans_i) | |
x1_local_list.append(x1_local_i) | |
# ============== Block 2 ============== | |
# Block 2 | |
# global branch | |
x2 = self.inception2(x1_out) | |
x2_attn, x2_theta = self.ha2(x2) | |
x2_out = x2 * x2_attn | |
# local branch | |
if self.learn_region: | |
x2_local_list = [] | |
for region_idx in range(4): | |
x2_theta_i = x2_theta[:, region_idx, :] | |
x2_theta_i = self.transform_theta(x2_theta_i, region_idx) | |
x2_trans_i = self.stn(x1_out, x2_theta_i) | |
x2_trans_i = F.upsample( | |
x2_trans_i, (12, 14), mode='bilinear', align_corners=True | |
) | |
x2_local_i = x2_trans_i + x1_local_list[region_idx] | |
x2_local_i = self.local_conv2(x2_local_i) | |
x2_local_list.append(x2_local_i) | |
# ============== Block 3 ============== | |
# Block 3 | |
# global branch | |
x3 = self.inception3(x2_out) | |
x3_attn, x3_theta = self.ha3(x3) | |
x3_out = x3 * x3_attn | |
# local branch | |
if self.learn_region: | |
x3_local_list = [] | |
for region_idx in range(4): | |
x3_theta_i = x3_theta[:, region_idx, :] | |
x3_theta_i = self.transform_theta(x3_theta_i, region_idx) | |
x3_trans_i = self.stn(x2_out, x3_theta_i) | |
x3_trans_i = F.upsample( | |
x3_trans_i, (6, 7), mode='bilinear', align_corners=True | |
) | |
x3_local_i = x3_trans_i + x2_local_list[region_idx] | |
x3_local_i = self.local_conv3(x3_local_i) | |
x3_local_list.append(x3_local_i) | |
# ============== Feature generation ============== | |
# global branch | |
x_global = F.avg_pool2d(x3_out, | |
x3_out.size()[2:] | |
).view(x3_out.size(0), x3_out.size(1)) | |
x_global = self.fc_global(x_global) | |
# local branch | |
if self.learn_region: | |
x_local_list = [] | |
for region_idx in range(4): | |
x_local_i = x3_local_list[region_idx] | |
x_local_i = F.avg_pool2d(x_local_i, | |
x_local_i.size()[2:] | |
).view(x_local_i.size(0), -1) | |
x_local_list.append(x_local_i) | |
x_local = torch.cat(x_local_list, 1) | |
x_local = self.fc_local(x_local) | |
if not self.training: | |
# l2 normalization before concatenation | |
if self.learn_region: | |
x_global = x_global / x_global.norm(p=2, dim=1, keepdim=True) | |
x_local = x_local / x_local.norm(p=2, dim=1, keepdim=True) | |
return torch.cat([x_global, x_local], 1) | |
else: | |
return x_global | |
prelogits_global = self.classifier_global(x_global) | |
if self.learn_region: | |
prelogits_local = self.classifier_local(x_local) | |
if self.loss == 'softmax': | |
if self.learn_region: | |
return (prelogits_global, prelogits_local) | |
else: | |
return prelogits_global | |
elif self.loss == 'triplet': | |
if self.learn_region: | |
return (prelogits_global, prelogits_local), (x_global, x_local) | |
else: | |
return prelogits_global, x_global | |
else: | |
raise KeyError("Unsupported loss: {}".format(self.loss)) | |