Spaces:
Running
on
Zero
Running
on
Zero
from __future__ import absolute_import | |
import math | |
import numpy as np | |
import sys | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from torch.nn import init | |
def conv3x3_block(in_planes, out_planes, stride=1): | |
"""3x3 convolution with padding""" | |
conv_layer = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, padding=1) | |
block = nn.Sequential( | |
conv_layer, | |
nn.BatchNorm2d(out_planes), | |
nn.ReLU(inplace=True), | |
) | |
return block | |
class STNHead(nn.Module): | |
def __init__(self, in_planes, num_ctrlpoints, activation='none'): | |
super(STNHead, self).__init__() | |
self.in_planes = in_planes | |
self.num_ctrlpoints = num_ctrlpoints | |
self.activation = activation | |
self.stn_convnet = nn.Sequential( | |
conv3x3_block(in_planes, 32), # 32*64 | |
nn.MaxPool2d(kernel_size=2, stride=2), | |
conv3x3_block(32, 64), # 16*32 | |
nn.MaxPool2d(kernel_size=2, stride=2), | |
conv3x3_block(64, 128), # 8*16 | |
nn.MaxPool2d(kernel_size=2, stride=2), | |
conv3x3_block(128, 256), # 4*8 | |
nn.MaxPool2d(kernel_size=2, stride=2), | |
conv3x3_block(256, 256), # 2*4, | |
nn.MaxPool2d(kernel_size=2, stride=2), | |
conv3x3_block(256, 256)) # 1*2 > 256*8*8 | |
self.stn_fc1 = nn.Sequential( | |
# nn.Linear(2*256, 512), | |
nn.Linear(8*8*256, 512), | |
nn.BatchNorm1d(512), | |
nn.ReLU(inplace=True)) | |
self.stn_fc2 = nn.Linear(512, num_ctrlpoints*2) | |
self.init_weights(self.stn_convnet) | |
self.init_weights(self.stn_fc1) | |
self.init_stn(self.stn_fc2) | |
def init_weights(self, module): | |
for m in module.modules(): | |
if isinstance(m, nn.Conv2d): | |
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |
m.weight.data.normal_(0, math.sqrt(2. / n)) | |
if m.bias is not None: | |
m.bias.data.zero_() | |
elif isinstance(m, nn.BatchNorm2d): | |
m.weight.data.fill_(1) | |
m.bias.data.zero_() | |
elif isinstance(m, nn.Linear): | |
m.weight.data.normal_(0, 0.001) | |
m.bias.data.zero_() | |
def init_stn(self, stn_fc2): | |
# margin = 0.01 | |
# sampling_num_per_side = int(self.num_ctrlpoints / 2) | |
# ctrl_pts_x = np.linspace(margin, 1.-margin, sampling_num_per_side) | |
# ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin | |
# ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1-margin) | |
# ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) | |
# ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) | |
# ctrl_points = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0).astype(np.float32) | |
margin_x, margin_y = 0.35,0.35 | |
# margin_x, margin_y = 0,0 | |
num_ctrl_pts_per_side = (self.num_ctrlpoints-4) // 4 +2 | |
ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side) | |
ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y | |
ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y) | |
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) | |
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) | |
ctrl_pts_x_left = np.ones(num_ctrl_pts_per_side) * margin_x | |
ctrl_pts_x_right = np.ones(num_ctrl_pts_per_side) * (1.0-margin_x) | |
ctrl_pts_left = np.stack([ctrl_pts_x_left[1:-1], ctrl_pts_x[1:-1]], axis=1) | |
ctrl_pts_right = np.stack([ctrl_pts_x_right[1:-1], ctrl_pts_x[1:-1]], axis=1) | |
ctrl_points = np.concatenate([ctrl_pts_top, ctrl_pts_bottom, ctrl_pts_left, ctrl_pts_right], axis=0).astype(np.float32) | |
if self.activation is 'none': | |
pass | |
elif self.activation == 'sigmoid': | |
ctrl_points = -np.log(1. / ctrl_points - 1.) | |
stn_fc2.weight.data.zero_() | |
stn_fc2.bias.data = torch.Tensor(ctrl_points).view(-1) | |
def forward(self, x): | |
x = self.stn_convnet(x) | |
batch_size, _, h, w = x.size() | |
x = x.view(batch_size, -1) | |
img_feat = self.stn_fc1(x) | |
x = self.stn_fc2(0.1 * img_feat) | |
if self.activation == 'sigmoid': | |
x = F.sigmoid(x) | |
x = x.view(-1, self.num_ctrlpoints, 2) | |
return img_feat, x | |
if __name__ == "__main__": | |
in_planes = 3 | |
num_ctrlpoints = 20 | |
activation='none' # 'sigmoid' | |
stn_head = STNHead(in_planes, num_ctrlpoints, activation) | |
input = torch.randn(10, 3, 32, 64) | |
control_points = stn_head(input) | |
print(control_points.size()) |