Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
class AddCoordsTh(nn.Module): | |
def __init__(self, x_dim=64, y_dim=64, with_r=False, with_boundary=False): | |
super(AddCoordsTh, self).__init__() | |
self.x_dim = x_dim | |
self.y_dim = y_dim | |
self.with_r = with_r | |
self.with_boundary = with_boundary | |
def forward(self, input_tensor, heatmap=None): | |
""" | |
input_tensor: (batch, c, x_dim, y_dim) | |
""" | |
batch_size_tensor = input_tensor.shape[0] | |
xx_ones = torch.ones([1, self.y_dim], dtype=torch.int32).to(input_tensor.device) | |
xx_ones = xx_ones.unsqueeze(-1) | |
xx_range = torch.arange(self.x_dim, dtype=torch.int32).unsqueeze(0).to(input_tensor.device) | |
xx_range = xx_range.unsqueeze(1) | |
xx_channel = torch.matmul(xx_ones.float(), xx_range.float()) | |
xx_channel = xx_channel.unsqueeze(-1) | |
yy_ones = torch.ones([1, self.x_dim], dtype=torch.int32).to(input_tensor.device) | |
yy_ones = yy_ones.unsqueeze(1) | |
yy_range = torch.arange(self.y_dim, dtype=torch.int32).unsqueeze(0).to(input_tensor.device) | |
yy_range = yy_range.unsqueeze(-1) | |
yy_channel = torch.matmul(yy_range.float(), yy_ones.float()) | |
yy_channel = yy_channel.unsqueeze(-1) | |
xx_channel = xx_channel.permute(0, 3, 2, 1) | |
yy_channel = yy_channel.permute(0, 3, 2, 1) | |
xx_channel = xx_channel / (self.x_dim - 1) | |
yy_channel = yy_channel / (self.y_dim - 1) | |
xx_channel = xx_channel * 2 - 1 | |
yy_channel = yy_channel * 2 - 1 | |
xx_channel = xx_channel.repeat(batch_size_tensor, 1, 1, 1) | |
yy_channel = yy_channel.repeat(batch_size_tensor, 1, 1, 1) | |
if self.with_boundary and type(heatmap) != type(None): | |
boundary_channel = torch.clamp(heatmap[:, -1:, :, :], 0.0, 1.0) | |
zero_tensor = torch.zeros_like(xx_channel) | |
xx_boundary_channel = torch.where(boundary_channel > 0.05, xx_channel, zero_tensor) | |
yy_boundary_channel = torch.where(boundary_channel > 0.05, yy_channel, zero_tensor) | |
if self.with_boundary and type(heatmap) != type(None): | |
xx_boundary_channel = xx_boundary_channel.to(input_tensor.device) | |
yy_boundary_channel = yy_boundary_channel.to(input_tensor.device) | |
ret = torch.cat([input_tensor, xx_channel, yy_channel], dim=1) | |
if self.with_r: | |
rr = torch.sqrt(torch.pow(xx_channel, 2) + torch.pow(yy_channel, 2)) | |
rr = rr / torch.max(rr) | |
ret = torch.cat([ret, rr], dim=1) | |
if self.with_boundary and type(heatmap) != type(None): | |
ret = torch.cat([ret, xx_boundary_channel, yy_boundary_channel], dim=1) | |
return ret | |
class CoordConvTh(nn.Module): | |
"""CoordConv layer as in the paper.""" | |
def __init__(self, x_dim, y_dim, with_r, with_boundary, in_channels, first_one=False, *args, **kwargs): | |
super(CoordConvTh, self).__init__() | |
self.addcoords = AddCoordsTh(x_dim=x_dim, y_dim=y_dim, with_r=with_r, with_boundary=with_boundary) | |
in_channels += 2 | |
if with_r: | |
in_channels += 1 | |
if with_boundary and not first_one: | |
in_channels += 2 | |
self.conv = nn.Conv2d(in_channels=in_channels, *args, **kwargs) | |
def forward(self, input_tensor, heatmap=None): | |
ret = self.addcoords(input_tensor, heatmap) | |
last_channel = ret[:, -2:, :, :] | |
ret = self.conv(ret) | |
return ret, last_channel | |
""" | |
An alternative implementation for PyTorch with auto-infering the x-y dimensions. | |
""" | |
class AddCoords(nn.Module): | |
def __init__(self, with_r=False): | |
super().__init__() | |
self.with_r = with_r | |
def forward(self, input_tensor): | |
""" | |
Args: | |
input_tensor: shape(batch, channel, x_dim, y_dim) | |
""" | |
batch_size, _, x_dim, y_dim = input_tensor.size() | |
xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1) | |
yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2) | |
xx_channel = xx_channel / (x_dim - 1) | |
yy_channel = yy_channel / (y_dim - 1) | |
xx_channel = xx_channel * 2 - 1 | |
yy_channel = yy_channel * 2 - 1 | |
xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) | |
yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) | |
if input_tensor.is_cuda: | |
xx_channel = xx_channel.to(input_tensor.device) | |
yy_channel = yy_channel.to(input_tensor.device) | |
ret = torch.cat([input_tensor, xx_channel.type_as(input_tensor), yy_channel.type_as(input_tensor)], dim=1) | |
if self.with_r: | |
rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2)) | |
if input_tensor.is_cuda: | |
rr = rr.to(input_tensor.device) | |
ret = torch.cat([ret, rr], dim=1) | |
return ret | |
class CoordConv(nn.Module): | |
def __init__(self, in_channels, out_channels, with_r=False, **kwargs): | |
super().__init__() | |
self.addcoords = AddCoords(with_r=with_r) | |
self.conv = nn.Conv2d(in_channels + 2, out_channels, **kwargs) | |
def forward(self, x): | |
ret = self.addcoords(x) | |
ret = self.conv(ret) | |
return ret | |