endo-yuki-t
initial commit
d7dbcdd
raw
history blame
No virus
12.2 kB
import pdb
import torch.nn as nn
import math
import torch
from torch.nn.parameter import Parameter
import torch.nn.functional as F
from torch.nn import Module
from torch.nn.modules.conv import _ConvNd
from torch.nn.modules.utils import _quadruple
from torch.autograd import Variable
from torch.nn import Conv2d
def conv4d(data,filters,bias=None,permute_filters=True,use_half=False):
"""
This is done by stacking results of multiple 3D convolutions, and is very slow.
Taken from https://github.com/ignacio-rocco/ncnet
"""
b,c,h,w,d,t=data.size()
data=data.permute(2,0,1,3,4,5).contiguous() # permute to avoid making contiguous inside loop
# Same permutation is done with filters, unless already provided with permutation
if permute_filters:
filters=filters.permute(2,0,1,3,4,5).contiguous() # permute to avoid making contiguous inside loop
c_out=filters.size(1)
if use_half:
output = Variable(torch.HalfTensor(h,b,c_out,w,d,t),requires_grad=data.requires_grad)
else:
output = Variable(torch.zeros(h,b,c_out,w,d,t),requires_grad=data.requires_grad)
padding=filters.size(0)//2
if use_half:
Z=Variable(torch.zeros(padding,b,c,w,d,t).half())
else:
Z=Variable(torch.zeros(padding,b,c,w,d,t))
if data.is_cuda:
Z=Z.cuda(data.get_device())
output=output.cuda(data.get_device())
data_padded = torch.cat((Z,data,Z),0)
for i in range(output.size(0)): # loop on first feature dimension
# convolve with center channel of filter (at position=padding)
output[i,:,:,:,:,:]=F.conv3d(data_padded[i+padding,:,:,:,:,:],
filters[padding,:,:,:,:,:], bias=bias, stride=1, padding=padding)
# convolve with upper/lower channels of filter (at postions [:padding] [padding+1:])
for p in range(1,padding+1):
output[i,:,:,:,:,:]=output[i,:,:,:,:,:]+F.conv3d(data_padded[i+padding-p,:,:,:,:,:],
filters[padding-p,:,:,:,:,:], bias=None, stride=1, padding=padding)
output[i,:,:,:,:,:]=output[i,:,:,:,:,:]+F.conv3d(data_padded[i+padding+p,:,:,:,:,:],
filters[padding+p,:,:,:,:,:], bias=None, stride=1, padding=padding)
output=output.permute(1,2,0,3,4,5).contiguous()
return output
class Conv4d(_ConvNd):
"""Applies a 4D convolution over an input signal composed of several input
planes.
"""
def __init__(self, in_channels, out_channels, kernel_size, bias=True, pre_permuted_filters=True):
# stride, dilation and groups !=1 functionality not tested
stride=1
dilation=1
groups=1
# zero padding is added automatically in conv4d function to preserve tensor size
padding = 0
kernel_size = _quadruple(kernel_size)
stride = _quadruple(stride)
padding = _quadruple(padding)
dilation = _quadruple(dilation)
super(Conv4d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _quadruple(0), groups, bias)
# weights will be sliced along one dimension during convolution loop
# make the looping dimension to be the first one in the tensor,
# so that we don't need to call contiguous() inside the loop
self.pre_permuted_filters=pre_permuted_filters
if self.pre_permuted_filters:
self.weight.data=self.weight.data.permute(2,0,1,3,4,5).contiguous()
self.use_half=False
# self.isbias = bias
# if not self.isbias:
# self.bn = torch.nn.BatchNorm1d(out_channels)
def forward(self, input):
out = conv4d(input, self.weight, bias=self.bias,permute_filters=not self.pre_permuted_filters,use_half=self.use_half) # filters pre-permuted in constructor
# if not self.isbias:
# b,c,u,v,h,w = out.shape
# out = self.bn(out.view(b,c,-1)).view(b,c,u,v,h,w)
return out
class fullConv4d(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, bias=True, pre_permuted_filters=True):
super(fullConv4d, self).__init__()
self.conv = Conv4d(in_channels, out_channels, kernel_size, bias=bias, pre_permuted_filters=pre_permuted_filters)
self.isbias = bias
if not self.isbias:
self.bn = torch.nn.BatchNorm1d(out_channels)
def forward(self, input):
out = self.conv(input)
if not self.isbias:
b,c,u,v,h,w = out.shape
out = self.bn(out.view(b,c,-1)).view(b,c,u,v,h,w)
return out
class butterfly4D(torch.nn.Module):
'''
butterfly 4d
'''
def __init__(self, fdima, fdimb, withbn=True, full=True,groups=1):
super(butterfly4D, self).__init__()
self.proj = nn.Sequential(projfeat4d(fdima, fdimb, 1, with_bn=withbn,groups=groups),
nn.ReLU(inplace=True),)
self.conva1 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(2,1,1),full=full,groups=groups)
self.conva2 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(2,1,1),full=full,groups=groups)
self.convb3 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(1,1,1),full=full,groups=groups)
self.convb2 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(1,1,1),full=full,groups=groups)
self.convb1 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(1,1,1),full=full,groups=groups)
#@profile
def forward(self,x):
out = self.proj(x)
b,c,u,v,h,w = out.shape # 9x9
out1 = self.conva1(out) # 5x5, 3
_,c1,u1,v1,h1,w1 = out1.shape
out2 = self.conva2(out1) # 3x3, 9
_,c2,u2,v2,h2,w2 = out2.shape
out2 = self.convb3(out2) # 3x3, 9
tout1 = F.upsample(out2.view(b,c,u2,v2,-1),(u1,v1,h2*w2),mode='trilinear').view(b,c,u1,v1,h2,w2) # 5x5
tout1 = F.upsample(tout1.view(b,c,-1,h2,w2),(u1*v1,h1,w1),mode='trilinear').view(b,c,u1,v1,h1,w1) # 5x5
out1 = tout1 + out1
out1 = self.convb2(out1)
tout = F.upsample(out1.view(b,c,u1,v1,-1),(u,v,h1*w1),mode='trilinear').view(b,c,u,v,h1,w1)
tout = F.upsample(tout.view(b,c,-1,h1,w1),(u*v,h,w),mode='trilinear').view(b,c,u,v,h,w)
out = tout + out
out = self.convb1(out)
return out
class projfeat4d(torch.nn.Module):
'''
Turn 3d projection into 2d projection
'''
def __init__(self, in_planes, out_planes, stride, with_bn=True,groups=1):
super(projfeat4d, self).__init__()
self.with_bn = with_bn
self.stride = stride
self.conv1 = nn.Conv3d(in_planes, out_planes, 1, (stride,stride,1), padding=0,bias=not with_bn,groups=groups)
self.bn = nn.BatchNorm3d(out_planes)
def forward(self,x):
b,c,u,v,h,w = x.size()
x = self.conv1(x.view(b,c,u,v,h*w))
if self.with_bn:
x = self.bn(x)
_,c,u,v,_ = x.shape
x = x.view(b,c,u,v,h,w)
return x
class sepConv4d(torch.nn.Module):
'''
Separable 4d convolution block as 2 3D convolutions
'''
def __init__(self, in_planes, out_planes, stride=(1,1,1), with_bn=True, ksize=3, full=True,groups=1):
super(sepConv4d, self).__init__()
bias = not with_bn
self.isproj = False
self.stride = stride[0]
expand = 1
if with_bn:
if in_planes != out_planes:
self.isproj = True
self.proj = nn.Sequential(nn.Conv2d(in_planes, out_planes, 1, bias=bias, padding=0,groups=groups),
nn.BatchNorm2d(out_planes))
if full:
self.conv1 = nn.Sequential(nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=(1,self.stride,self.stride), bias=bias, padding=(0,ksize//2,ksize//2),groups=groups),
nn.BatchNorm3d(in_planes))
else:
self.conv1 = nn.Sequential(nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=1, bias=bias, padding=(0,ksize//2,ksize//2),groups=groups),
nn.BatchNorm3d(in_planes))
self.conv2 = nn.Sequential(nn.Conv3d(in_planes, in_planes*expand, (ksize,ksize,1), stride=(self.stride,self.stride,1), bias=bias, padding=(ksize//2,ksize//2,0),groups=groups),
nn.BatchNorm3d(in_planes*expand))
else:
if in_planes != out_planes:
self.isproj = True
self.proj = nn.Conv2d(in_planes, out_planes, 1, bias=bias, padding=0,groups=groups)
if full:
self.conv1 = nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=(1,self.stride,self.stride), bias=bias, padding=(0,ksize//2,ksize//2),groups=groups)
else:
self.conv1 = nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=1, bias=bias, padding=(0,ksize//2,ksize//2),groups=groups)
self.conv2 = nn.Conv3d(in_planes, in_planes*expand, (ksize,ksize,1), stride=(self.stride,self.stride,1), bias=bias, padding=(ksize//2,ksize//2,0),groups=groups)
self.relu = nn.ReLU(inplace=True)
#@profile
def forward(self,x):
b,c,u,v,h,w = x.shape
x = self.conv2(x.view(b,c,u,v,-1))
b,c,u,v,_ = x.shape
x = self.relu(x)
x = self.conv1(x.view(b,c,-1,h,w))
b,c,_,h,w = x.shape
if self.isproj:
x = self.proj(x.view(b,c,-1,w))
x = x.view(b,-1,u,v,h,w)
return x
class sepConv4dBlock(torch.nn.Module):
'''
Separable 4d convolution block as 2 2D convolutions and a projection
layer
'''
def __init__(self, in_planes, out_planes, stride=(1,1,1), with_bn=True, full=True,groups=1):
super(sepConv4dBlock, self).__init__()
if in_planes == out_planes and stride==(1,1,1):
self.downsample = None
else:
if full:
self.downsample = sepConv4d(in_planes, out_planes, stride, with_bn=with_bn,ksize=1, full=full,groups=groups)
else:
self.downsample = projfeat4d(in_planes, out_planes,stride[0], with_bn=with_bn,groups=groups)
self.conv1 = sepConv4d(in_planes, out_planes, stride, with_bn=with_bn, full=full ,groups=groups)
self.conv2 = sepConv4d(out_planes, out_planes,(1,1,1), with_bn=with_bn, full=full,groups=groups)
self.relu1 = nn.ReLU(inplace=True)
self.relu2 = nn.ReLU(inplace=True)
#@profile
def forward(self,x):
out = self.relu1(self.conv1(x))
if self.downsample:
x = self.downsample(x)
out = self.relu2(x + self.conv2(out))
return out
##import torch.backends.cudnn as cudnn
##cudnn.benchmark = True
#import time
##im = torch.randn(9,64,9,160,224).cuda()
##net = torch.nn.Conv3d(64, 64, 3).cuda()
##net = Conv4d(1,1,3,bias=True,pre_permuted_filters=True).cuda()
##net = sepConv4dBlock(2,2,stride=(1,1,1)).cuda()
#
##im = torch.randn(1,16,9,9,96,320).cuda()
##net = sepConv4d(16,16,with_bn=False).cuda()
#
##im = torch.randn(1,16,81,96,320).cuda()
##net = torch.nn.Conv3d(16,16,(1,3,3),padding=(0,1,1)).cuda()
#
##im = torch.randn(1,16,9,9,96*320).cuda()
##net = torch.nn.Conv3d(16,16,(3,3,1),padding=(1,1,0)).cuda()
#
##im = torch.randn(10000,10,9,9).cuda()
##net = torch.nn.Conv2d(10,10,3,padding=1).cuda()
#
##im = torch.randn(81,16,96,320).cuda()
##net = torch.nn.Conv2d(16,16,3,padding=1).cuda()
#c= int(16 *1)
#cp = int(16 *1)
#h=int(96 *4)
#w=int(320 *4)
#k=3
#im = torch.randn(1,c,h,w).cuda()
#net = torch.nn.Conv2d(c,cp,k,padding=k//2).cuda()
#
#im2 = torch.randn(cp,k*k*c).cuda()
#im1 = F.unfold(im, (k,k), padding=k//2)[0]
#
#
#net(im)
#net(im)
#torch.mm(im2,im1)
#torch.mm(im2,im1)
#torch.cuda.synchronize()
#beg = time.time()
#for i in range(100):
# net(im)
# #im1 = F.unfold(im, (k,k), padding=k//2)[0]
# torch.mm(im2,im1)
#torch.cuda.synchronize()
#print('%f'%((time.time()-beg)*10.))