Spaces:
Runtime error
Runtime error
import pdb | |
import torch.nn as nn | |
import math | |
import torch | |
from torch.nn.parameter import Parameter | |
import torch.nn.functional as F | |
from torch.nn import Module | |
from torch.nn.modules.conv import _ConvNd | |
from torch.nn.modules.utils import _quadruple | |
from torch.autograd import Variable | |
from torch.nn import Conv2d | |
def conv4d(data,filters,bias=None,permute_filters=True,use_half=False): | |
""" | |
This is done by stacking results of multiple 3D convolutions, and is very slow. | |
Taken from https://github.com/ignacio-rocco/ncnet | |
""" | |
b,c,h,w,d,t=data.size() | |
data=data.permute(2,0,1,3,4,5).contiguous() # permute to avoid making contiguous inside loop | |
# Same permutation is done with filters, unless already provided with permutation | |
if permute_filters: | |
filters=filters.permute(2,0,1,3,4,5).contiguous() # permute to avoid making contiguous inside loop | |
c_out=filters.size(1) | |
if use_half: | |
output = Variable(torch.HalfTensor(h,b,c_out,w,d,t),requires_grad=data.requires_grad) | |
else: | |
output = Variable(torch.zeros(h,b,c_out,w,d,t),requires_grad=data.requires_grad) | |
padding=filters.size(0)//2 | |
if use_half: | |
Z=Variable(torch.zeros(padding,b,c,w,d,t).half()) | |
else: | |
Z=Variable(torch.zeros(padding,b,c,w,d,t)) | |
if data.is_cuda: | |
Z=Z.cuda(data.get_device()) | |
output=output.cuda(data.get_device()) | |
data_padded = torch.cat((Z,data,Z),0) | |
for i in range(output.size(0)): # loop on first feature dimension | |
# convolve with center channel of filter (at position=padding) | |
output[i,:,:,:,:,:]=F.conv3d(data_padded[i+padding,:,:,:,:,:], | |
filters[padding,:,:,:,:,:], bias=bias, stride=1, padding=padding) | |
# convolve with upper/lower channels of filter (at postions [:padding] [padding+1:]) | |
for p in range(1,padding+1): | |
output[i,:,:,:,:,:]=output[i,:,:,:,:,:]+F.conv3d(data_padded[i+padding-p,:,:,:,:,:], | |
filters[padding-p,:,:,:,:,:], bias=None, stride=1, padding=padding) | |
output[i,:,:,:,:,:]=output[i,:,:,:,:,:]+F.conv3d(data_padded[i+padding+p,:,:,:,:,:], | |
filters[padding+p,:,:,:,:,:], bias=None, stride=1, padding=padding) | |
output=output.permute(1,2,0,3,4,5).contiguous() | |
return output | |
class Conv4d(_ConvNd): | |
"""Applies a 4D convolution over an input signal composed of several input | |
planes. | |
""" | |
def __init__(self, in_channels, out_channels, kernel_size, bias=True, pre_permuted_filters=True): | |
# stride, dilation and groups !=1 functionality not tested | |
stride=1 | |
dilation=1 | |
groups=1 | |
# zero padding is added automatically in conv4d function to preserve tensor size | |
padding = 0 | |
kernel_size = _quadruple(kernel_size) | |
stride = _quadruple(stride) | |
padding = _quadruple(padding) | |
dilation = _quadruple(dilation) | |
super(Conv4d, self).__init__( | |
in_channels, out_channels, kernel_size, stride, padding, dilation, | |
False, _quadruple(0), groups, bias) | |
# weights will be sliced along one dimension during convolution loop | |
# make the looping dimension to be the first one in the tensor, | |
# so that we don't need to call contiguous() inside the loop | |
self.pre_permuted_filters=pre_permuted_filters | |
if self.pre_permuted_filters: | |
self.weight.data=self.weight.data.permute(2,0,1,3,4,5).contiguous() | |
self.use_half=False | |
# self.isbias = bias | |
# if not self.isbias: | |
# self.bn = torch.nn.BatchNorm1d(out_channels) | |
def forward(self, input): | |
out = conv4d(input, self.weight, bias=self.bias,permute_filters=not self.pre_permuted_filters,use_half=self.use_half) # filters pre-permuted in constructor | |
# if not self.isbias: | |
# b,c,u,v,h,w = out.shape | |
# out = self.bn(out.view(b,c,-1)).view(b,c,u,v,h,w) | |
return out | |
class fullConv4d(torch.nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size, bias=True, pre_permuted_filters=True): | |
super(fullConv4d, self).__init__() | |
self.conv = Conv4d(in_channels, out_channels, kernel_size, bias=bias, pre_permuted_filters=pre_permuted_filters) | |
self.isbias = bias | |
if not self.isbias: | |
self.bn = torch.nn.BatchNorm1d(out_channels) | |
def forward(self, input): | |
out = self.conv(input) | |
if not self.isbias: | |
b,c,u,v,h,w = out.shape | |
out = self.bn(out.view(b,c,-1)).view(b,c,u,v,h,w) | |
return out | |
class butterfly4D(torch.nn.Module): | |
''' | |
butterfly 4d | |
''' | |
def __init__(self, fdima, fdimb, withbn=True, full=True,groups=1): | |
super(butterfly4D, self).__init__() | |
self.proj = nn.Sequential(projfeat4d(fdima, fdimb, 1, with_bn=withbn,groups=groups), | |
nn.ReLU(inplace=True),) | |
self.conva1 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(2,1,1),full=full,groups=groups) | |
self.conva2 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(2,1,1),full=full,groups=groups) | |
self.convb3 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(1,1,1),full=full,groups=groups) | |
self.convb2 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(1,1,1),full=full,groups=groups) | |
self.convb1 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(1,1,1),full=full,groups=groups) | |
#@profile | |
def forward(self,x): | |
out = self.proj(x) | |
b,c,u,v,h,w = out.shape # 9x9 | |
out1 = self.conva1(out) # 5x5, 3 | |
_,c1,u1,v1,h1,w1 = out1.shape | |
out2 = self.conva2(out1) # 3x3, 9 | |
_,c2,u2,v2,h2,w2 = out2.shape | |
out2 = self.convb3(out2) # 3x3, 9 | |
tout1 = F.upsample(out2.view(b,c,u2,v2,-1),(u1,v1,h2*w2),mode='trilinear').view(b,c,u1,v1,h2,w2) # 5x5 | |
tout1 = F.upsample(tout1.view(b,c,-1,h2,w2),(u1*v1,h1,w1),mode='trilinear').view(b,c,u1,v1,h1,w1) # 5x5 | |
out1 = tout1 + out1 | |
out1 = self.convb2(out1) | |
tout = F.upsample(out1.view(b,c,u1,v1,-1),(u,v,h1*w1),mode='trilinear').view(b,c,u,v,h1,w1) | |
tout = F.upsample(tout.view(b,c,-1,h1,w1),(u*v,h,w),mode='trilinear').view(b,c,u,v,h,w) | |
out = tout + out | |
out = self.convb1(out) | |
return out | |
class projfeat4d(torch.nn.Module): | |
''' | |
Turn 3d projection into 2d projection | |
''' | |
def __init__(self, in_planes, out_planes, stride, with_bn=True,groups=1): | |
super(projfeat4d, self).__init__() | |
self.with_bn = with_bn | |
self.stride = stride | |
self.conv1 = nn.Conv3d(in_planes, out_planes, 1, (stride,stride,1), padding=0,bias=not with_bn,groups=groups) | |
self.bn = nn.BatchNorm3d(out_planes) | |
def forward(self,x): | |
b,c,u,v,h,w = x.size() | |
x = self.conv1(x.view(b,c,u,v,h*w)) | |
if self.with_bn: | |
x = self.bn(x) | |
_,c,u,v,_ = x.shape | |
x = x.view(b,c,u,v,h,w) | |
return x | |
class sepConv4d(torch.nn.Module): | |
''' | |
Separable 4d convolution block as 2 3D convolutions | |
''' | |
def __init__(self, in_planes, out_planes, stride=(1,1,1), with_bn=True, ksize=3, full=True,groups=1): | |
super(sepConv4d, self).__init__() | |
bias = not with_bn | |
self.isproj = False | |
self.stride = stride[0] | |
expand = 1 | |
if with_bn: | |
if in_planes != out_planes: | |
self.isproj = True | |
self.proj = nn.Sequential(nn.Conv2d(in_planes, out_planes, 1, bias=bias, padding=0,groups=groups), | |
nn.BatchNorm2d(out_planes)) | |
if full: | |
self.conv1 = nn.Sequential(nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=(1,self.stride,self.stride), bias=bias, padding=(0,ksize//2,ksize//2),groups=groups), | |
nn.BatchNorm3d(in_planes)) | |
else: | |
self.conv1 = nn.Sequential(nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=1, bias=bias, padding=(0,ksize//2,ksize//2),groups=groups), | |
nn.BatchNorm3d(in_planes)) | |
self.conv2 = nn.Sequential(nn.Conv3d(in_planes, in_planes*expand, (ksize,ksize,1), stride=(self.stride,self.stride,1), bias=bias, padding=(ksize//2,ksize//2,0),groups=groups), | |
nn.BatchNorm3d(in_planes*expand)) | |
else: | |
if in_planes != out_planes: | |
self.isproj = True | |
self.proj = nn.Conv2d(in_planes, out_planes, 1, bias=bias, padding=0,groups=groups) | |
if full: | |
self.conv1 = nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=(1,self.stride,self.stride), bias=bias, padding=(0,ksize//2,ksize//2),groups=groups) | |
else: | |
self.conv1 = nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=1, bias=bias, padding=(0,ksize//2,ksize//2),groups=groups) | |
self.conv2 = nn.Conv3d(in_planes, in_planes*expand, (ksize,ksize,1), stride=(self.stride,self.stride,1), bias=bias, padding=(ksize//2,ksize//2,0),groups=groups) | |
self.relu = nn.ReLU(inplace=True) | |
#@profile | |
def forward(self,x): | |
b,c,u,v,h,w = x.shape | |
x = self.conv2(x.view(b,c,u,v,-1)) | |
b,c,u,v,_ = x.shape | |
x = self.relu(x) | |
x = self.conv1(x.view(b,c,-1,h,w)) | |
b,c,_,h,w = x.shape | |
if self.isproj: | |
x = self.proj(x.view(b,c,-1,w)) | |
x = x.view(b,-1,u,v,h,w) | |
return x | |
class sepConv4dBlock(torch.nn.Module): | |
''' | |
Separable 4d convolution block as 2 2D convolutions and a projection | |
layer | |
''' | |
def __init__(self, in_planes, out_planes, stride=(1,1,1), with_bn=True, full=True,groups=1): | |
super(sepConv4dBlock, self).__init__() | |
if in_planes == out_planes and stride==(1,1,1): | |
self.downsample = None | |
else: | |
if full: | |
self.downsample = sepConv4d(in_planes, out_planes, stride, with_bn=with_bn,ksize=1, full=full,groups=groups) | |
else: | |
self.downsample = projfeat4d(in_planes, out_planes,stride[0], with_bn=with_bn,groups=groups) | |
self.conv1 = sepConv4d(in_planes, out_planes, stride, with_bn=with_bn, full=full ,groups=groups) | |
self.conv2 = sepConv4d(out_planes, out_planes,(1,1,1), with_bn=with_bn, full=full,groups=groups) | |
self.relu1 = nn.ReLU(inplace=True) | |
self.relu2 = nn.ReLU(inplace=True) | |
#@profile | |
def forward(self,x): | |
out = self.relu1(self.conv1(x)) | |
if self.downsample: | |
x = self.downsample(x) | |
out = self.relu2(x + self.conv2(out)) | |
return out | |
##import torch.backends.cudnn as cudnn | |
##cudnn.benchmark = True | |
#import time | |
##im = torch.randn(9,64,9,160,224).cuda() | |
##net = torch.nn.Conv3d(64, 64, 3).cuda() | |
##net = Conv4d(1,1,3,bias=True,pre_permuted_filters=True).cuda() | |
##net = sepConv4dBlock(2,2,stride=(1,1,1)).cuda() | |
# | |
##im = torch.randn(1,16,9,9,96,320).cuda() | |
##net = sepConv4d(16,16,with_bn=False).cuda() | |
# | |
##im = torch.randn(1,16,81,96,320).cuda() | |
##net = torch.nn.Conv3d(16,16,(1,3,3),padding=(0,1,1)).cuda() | |
# | |
##im = torch.randn(1,16,9,9,96*320).cuda() | |
##net = torch.nn.Conv3d(16,16,(3,3,1),padding=(1,1,0)).cuda() | |
# | |
##im = torch.randn(10000,10,9,9).cuda() | |
##net = torch.nn.Conv2d(10,10,3,padding=1).cuda() | |
# | |
##im = torch.randn(81,16,96,320).cuda() | |
##net = torch.nn.Conv2d(16,16,3,padding=1).cuda() | |
#c= int(16 *1) | |
#cp = int(16 *1) | |
#h=int(96 *4) | |
#w=int(320 *4) | |
#k=3 | |
#im = torch.randn(1,c,h,w).cuda() | |
#net = torch.nn.Conv2d(c,cp,k,padding=k//2).cuda() | |
# | |
#im2 = torch.randn(cp,k*k*c).cuda() | |
#im1 = F.unfold(im, (k,k), padding=k//2)[0] | |
# | |
# | |
#net(im) | |
#net(im) | |
#torch.mm(im2,im1) | |
#torch.mm(im2,im1) | |
#torch.cuda.synchronize() | |
#beg = time.time() | |
#for i in range(100): | |
# net(im) | |
# #im1 = F.unfold(im, (k,k), padding=k//2)[0] | |
# torch.mm(im2,im1) | |
#torch.cuda.synchronize() | |
#print('%f'%((time.time()-beg)*10.)) | |