Spaces:
Runtime error
Runtime error
# encoding: utf-8 | |
import math | |
import torch | |
import itertools | |
import numpy as np | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from grid_sample import grid_sample | |
from torch.autograd import Variable | |
from tps_grid_gen import TPSGridGen | |
class CNN(nn.Module): | |
def __init__(self, num_output): | |
super(CNN, self).__init__() | |
self.conv1 = nn.Conv2d(1, 10, kernel_size=5) | |
self.conv2 = nn.Conv2d(10, 20, kernel_size=5) | |
self.conv2_drop = nn.Dropout2d() | |
self.fc1 = nn.Linear(320, 50) | |
self.fc2 = nn.Linear(50, num_output) | |
def forward(self, x): | |
x = F.relu(F.max_pool2d(self.conv1(x), 2)) | |
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) | |
x = x.view(-1, 320) | |
x = F.relu(self.fc1(x)) | |
x = F.dropout(x, training=self.training) | |
x = self.fc2(x) | |
return x | |
class ClsNet(nn.Module): | |
def __init__(self): | |
super(ClsNet, self).__init__() | |
self.cnn = CNN(10) | |
def forward(self, x): | |
return F.log_softmax(self.cnn(x)) | |
class BoundedGridLocNet(nn.Module): | |
def __init__(self, grid_height, grid_width, target_control_points): | |
super(BoundedGridLocNet, self).__init__() | |
self.cnn = CNN(grid_height * grid_width * 2) | |
bias = torch.from_numpy(np.arctanh(target_control_points.numpy())) | |
bias = bias.view(-1) | |
self.cnn.fc2.bias.data.copy_(bias) | |
self.cnn.fc2.weight.data.zero_() | |
def forward(self, x): | |
batch_size = x.size(0) | |
points = F.tanh(self.cnn(x)) | |
return points.view(batch_size, -1, 2) | |
class UnBoundedGridLocNet(nn.Module): | |
def __init__(self, grid_height, grid_width, target_control_points): | |
super(UnBoundedGridLocNet, self).__init__() | |
self.cnn = CNN(grid_height * grid_width * 2) | |
bias = target_control_points.view(-1) | |
self.cnn.fc2.bias.data.copy_(bias) | |
self.cnn.fc2.weight.data.zero_() | |
def forward(self, x): | |
batch_size = x.size(0) | |
points = self.cnn(x) | |
return points.view(batch_size, -1, 2) | |
class STNClsNet(nn.Module): | |
def __init__(self, args): | |
super(STNClsNet, self).__init__() | |
self.args = args | |
r1 = args.span_range_height | |
r2 = args.span_range_width | |
assert r1 < 1 and r2 < 1 # if >= 1, arctanh will cause error in BoundedGridLocNet | |
target_control_points = torch.Tensor(list(itertools.product( | |
np.arange(-r1, r1 + 0.00001, 2.0 * r1 / (args.grid_height - 1)), | |
np.arange(-r2, r2 + 0.00001, 2.0 * r2 / (args.grid_width - 1)), | |
))) | |
Y, X = target_control_points.split(1, dim=1) | |
target_control_points = torch.cat([X, Y], dim=1) | |
GridLocNet = { | |
'unbounded_stn': UnBoundedGridLocNet, | |
'bounded_stn': BoundedGridLocNet, | |
}[args.model] | |
self.loc_net = GridLocNet( | |
args.grid_height, args.grid_width, target_control_points) | |
self.tps = TPSGridGen( | |
args.image_height, args.image_width, target_control_points) | |
self.cls_net = ClsNet() | |
def forward(self, x): | |
batch_size = x.size(0) | |
source_control_points = self.loc_net(x) | |
source_coordinate = self.tps(source_control_points) | |
grid = source_coordinate.view( | |
batch_size, self.args.image_height, self.args.image_width, 2) | |
transformed_x = grid_sample(x, grid) | |
logit = self.cls_net(transformed_x) | |
return logit | |
def get_model(args): | |
if args.model == 'no_stn': | |
print('create model without STN') | |
model = ClsNet() | |
else: | |
print('create model with STN') | |
model = STNClsNet(args) | |
return model | |