endo-yuki-t
initial commit
d7dbcdd
raw
history blame
No virus
21.3 kB
from __future__ import print_function
import torch
import torch.nn as nn
import torch.utils.data
from torch.autograd import Variable
import torch.nn.functional as F
import math
import numpy as np
import pdb
class residualBlock(nn.Module):
expansion = 1
def __init__(self, in_channels, n_filters, stride=1, downsample=None,dilation=1,with_bn=True):
super(residualBlock, self).__init__()
if dilation > 1:
padding = dilation
else:
padding = 1
if with_bn:
self.convbnrelu1 = conv2DBatchNormRelu(in_channels, n_filters, 3, stride, padding, dilation=dilation)
self.convbn2 = conv2DBatchNorm(n_filters, n_filters, 3, 1, 1)
else:
self.convbnrelu1 = conv2DBatchNormRelu(in_channels, n_filters, 3, stride, padding, dilation=dilation,with_bn=False)
self.convbn2 = conv2DBatchNorm(n_filters, n_filters, 3, 1, 1, with_bn=False)
self.downsample = downsample
self.relu = nn.LeakyReLU(0.1, inplace=True)
def forward(self, x):
residual = x
out = self.convbnrelu1(x)
out = self.convbn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
return self.relu(out)
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
nn.BatchNorm2d(out_planes),
nn.LeakyReLU(0.1,inplace=True))
class conv2DBatchNorm(nn.Module):
def __init__(self, in_channels, n_filters, k_size, stride, padding, dilation=1, with_bn=True):
super(conv2DBatchNorm, self).__init__()
bias = not with_bn
if dilation > 1:
conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size,
padding=padding, stride=stride, bias=bias, dilation=dilation)
else:
conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size,
padding=padding, stride=stride, bias=bias, dilation=1)
if with_bn:
self.cb_unit = nn.Sequential(conv_mod,
nn.BatchNorm2d(int(n_filters)),)
else:
self.cb_unit = nn.Sequential(conv_mod,)
def forward(self, inputs):
outputs = self.cb_unit(inputs)
return outputs
class conv2DBatchNormRelu(nn.Module):
def __init__(self, in_channels, n_filters, k_size, stride, padding, dilation=1, with_bn=True):
super(conv2DBatchNormRelu, self).__init__()
bias = not with_bn
if dilation > 1:
conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size,
padding=padding, stride=stride, bias=bias, dilation=dilation)
else:
conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size,
padding=padding, stride=stride, bias=bias, dilation=1)
if with_bn:
self.cbr_unit = nn.Sequential(conv_mod,
nn.BatchNorm2d(int(n_filters)),
nn.LeakyReLU(0.1, inplace=True),)
else:
self.cbr_unit = nn.Sequential(conv_mod,
nn.LeakyReLU(0.1, inplace=True),)
def forward(self, inputs):
outputs = self.cbr_unit(inputs)
return outputs
class pyramidPooling(nn.Module):
def __init__(self, in_channels, with_bn=True, levels=4):
super(pyramidPooling, self).__init__()
self.levels = levels
self.paths = []
for i in range(levels):
self.paths.append(conv2DBatchNormRelu(in_channels, in_channels, 1, 1, 0, with_bn=with_bn))
self.path_module_list = nn.ModuleList(self.paths)
self.relu = nn.LeakyReLU(0.1, inplace=True)
def forward(self, x):
h, w = x.shape[2:]
k_sizes = []
strides = []
for pool_size in np.linspace(1,min(h,w)//2,self.levels,dtype=int):
k_sizes.append((int(h/pool_size), int(w/pool_size)))
strides.append((int(h/pool_size), int(w/pool_size)))
k_sizes = k_sizes[::-1]
strides = strides[::-1]
pp_sum = x
for i, module in enumerate(self.path_module_list):
out = F.avg_pool2d(x, k_sizes[i], stride=strides[i], padding=0)
out = module(out)
out = F.upsample(out, size=(h,w), mode='bilinear')
pp_sum = pp_sum + 1./self.levels*out
pp_sum = self.relu(pp_sum/2.)
return pp_sum
class pspnet(nn.Module):
"""
Modified PSPNet. https://github.com/meetshah1995/pytorch-semseg/blob/master/ptsemseg/models/pspnet.py
"""
def __init__(self, is_proj=True,groups=1):
super(pspnet, self).__init__()
self.inplanes = 32
self.is_proj = is_proj
# Encoder
self.convbnrelu1_1 = conv2DBatchNormRelu(in_channels=3, k_size=3, n_filters=16,
padding=1, stride=2)
self.convbnrelu1_2 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=16,
padding=1, stride=1)
self.convbnrelu1_3 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=32,
padding=1, stride=1)
# Vanilla Residual Blocks
self.res_block3 = self._make_layer(residualBlock,64,1,stride=2)
self.res_block5 = self._make_layer(residualBlock,128,1,stride=2)
self.res_block6 = self._make_layer(residualBlock,128,1,stride=2)
self.res_block7 = self._make_layer(residualBlock,128,1,stride=2)
self.pyramid_pooling = pyramidPooling(128, levels=3)
# Iconvs
self.upconv6 = nn.Sequential(nn.Upsample(scale_factor=2),
conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
padding=1, stride=1))
self.iconv5 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128,
padding=1, stride=1)
self.upconv5 = nn.Sequential(nn.Upsample(scale_factor=2),
conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
padding=1, stride=1))
self.iconv4 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128,
padding=1, stride=1)
self.upconv4 = nn.Sequential(nn.Upsample(scale_factor=2),
conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
padding=1, stride=1))
self.iconv3 = conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
padding=1, stride=1)
self.upconv3 = nn.Sequential(nn.Upsample(scale_factor=2),
conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=32,
padding=1, stride=1))
self.iconv2 = conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=64,
padding=1, stride=1)
if self.is_proj:
self.proj6 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1)
self.proj5 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1)
self.proj4 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1)
self.proj3 = conv2DBatchNormRelu(in_channels=64, k_size=1,n_filters=64//groups, padding=0,stride=1)
self.proj2 = conv2DBatchNormRelu(in_channels=64, k_size=1,n_filters=64//groups, padding=0,stride=1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if hasattr(m.bias,'data'):
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
# H, W -> H/2, W/2
conv1 = self.convbnrelu1_1(x)
conv1 = self.convbnrelu1_2(conv1)
conv1 = self.convbnrelu1_3(conv1)
## H/2, W/2 -> H/4, W/4
pool1 = F.max_pool2d(conv1, 3, 2, 1)
# H/4, W/4 -> H/16, W/16
rconv3 = self.res_block3(pool1)
conv4 = self.res_block5(rconv3)
conv5 = self.res_block6(conv4)
conv6 = self.res_block7(conv5)
conv6 = self.pyramid_pooling(conv6)
conv6x = F.upsample(conv6, [conv5.size()[2],conv5.size()[3]],mode='bilinear')
concat5 = torch.cat((conv5,self.upconv6[1](conv6x)),dim=1)
conv5 = self.iconv5(concat5)
conv5x = F.upsample(conv5, [conv4.size()[2],conv4.size()[3]],mode='bilinear')
concat4 = torch.cat((conv4,self.upconv5[1](conv5x)),dim=1)
conv4 = self.iconv4(concat4)
conv4x = F.upsample(conv4, [rconv3.size()[2],rconv3.size()[3]],mode='bilinear')
concat3 = torch.cat((rconv3,self.upconv4[1](conv4x)),dim=1)
conv3 = self.iconv3(concat3)
conv3x = F.upsample(conv3, [pool1.size()[2],pool1.size()[3]],mode='bilinear')
concat2 = torch.cat((pool1,self.upconv3[1](conv3x)),dim=1)
conv2 = self.iconv2(concat2)
if self.is_proj:
proj6 = self.proj6(conv6)
proj5 = self.proj5(conv5)
proj4 = self.proj4(conv4)
proj3 = self.proj3(conv3)
proj2 = self.proj2(conv2)
return proj6,proj5,proj4,proj3,proj2
else:
return conv6, conv5, conv4, conv3, conv2
class pspnet_s(nn.Module):
"""
Modified PSPNet. https://github.com/meetshah1995/pytorch-semseg/blob/master/ptsemseg/models/pspnet.py
"""
def __init__(self, is_proj=True,groups=1):
super(pspnet_s, self).__init__()
self.inplanes = 32
self.is_proj = is_proj
# Encoder
self.convbnrelu1_1 = conv2DBatchNormRelu(in_channels=3, k_size=3, n_filters=16,
padding=1, stride=2)
self.convbnrelu1_2 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=16,
padding=1, stride=1)
self.convbnrelu1_3 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=32,
padding=1, stride=1)
# Vanilla Residual Blocks
self.res_block3 = self._make_layer(residualBlock,64,1,stride=2)
self.res_block5 = self._make_layer(residualBlock,128,1,stride=2)
self.res_block6 = self._make_layer(residualBlock,128,1,stride=2)
self.res_block7 = self._make_layer(residualBlock,128,1,stride=2)
self.pyramid_pooling = pyramidPooling(128, levels=3)
# Iconvs
self.upconv6 = nn.Sequential(nn.Upsample(scale_factor=2),
conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
padding=1, stride=1))
self.iconv5 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128,
padding=1, stride=1)
self.upconv5 = nn.Sequential(nn.Upsample(scale_factor=2),
conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
padding=1, stride=1))
self.iconv4 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128,
padding=1, stride=1)
self.upconv4 = nn.Sequential(nn.Upsample(scale_factor=2),
conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
padding=1, stride=1))
self.iconv3 = conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
padding=1, stride=1)
#self.upconv3 = nn.Sequential(nn.Upsample(scale_factor=2),
# conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=32,
# padding=1, stride=1))
#self.iconv2 = conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=64,
# padding=1, stride=1)
if self.is_proj:
self.proj6 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1)
self.proj5 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1)
self.proj4 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1)
self.proj3 = conv2DBatchNormRelu(in_channels=64, k_size=1,n_filters=64//groups, padding=0,stride=1)
#self.proj2 = conv2DBatchNormRelu(in_channels=64, k_size=1,n_filters=64//groups, padding=0,stride=1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if hasattr(m.bias,'data'):
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
# H, W -> H/2, W/2
conv1 = self.convbnrelu1_1(x)
conv1 = self.convbnrelu1_2(conv1)
conv1 = self.convbnrelu1_3(conv1)
## H/2, W/2 -> H/4, W/4
pool1 = F.max_pool2d(conv1, 3, 2, 1)
# H/4, W/4 -> H/16, W/16
rconv3 = self.res_block3(pool1)
conv4 = self.res_block5(rconv3)
conv5 = self.res_block6(conv4)
conv6 = self.res_block7(conv5)
conv6 = self.pyramid_pooling(conv6)
conv6x = F.upsample(conv6, [conv5.size()[2],conv5.size()[3]],mode='bilinear')
concat5 = torch.cat((conv5,self.upconv6[1](conv6x)),dim=1)
conv5 = self.iconv5(concat5)
conv5x = F.upsample(conv5, [conv4.size()[2],conv4.size()[3]],mode='bilinear')
concat4 = torch.cat((conv4,self.upconv5[1](conv5x)),dim=1)
conv4 = self.iconv4(concat4)
conv4x = F.upsample(conv4, [rconv3.size()[2],rconv3.size()[3]],mode='bilinear')
concat3 = torch.cat((rconv3,self.upconv4[1](conv4x)),dim=1)
conv3 = self.iconv3(concat3)
#conv3x = F.upsample(conv3, [pool1.size()[2],pool1.size()[3]],mode='bilinear')
#concat2 = torch.cat((pool1,self.upconv3[1](conv3x)),dim=1)
#conv2 = self.iconv2(concat2)
if self.is_proj:
proj6 = self.proj6(conv6)
proj5 = self.proj5(conv5)
proj4 = self.proj4(conv4)
proj3 = self.proj3(conv3)
# proj2 = self.proj2(conv2)
# return proj6,proj5,proj4,proj3,proj2
return proj6,proj5,proj4,proj3
else:
# return conv6, conv5, conv4, conv3, conv2
return conv6, conv5, conv4, conv3
class bfmodule(nn.Module):
def __init__(self, inplanes, outplanes):
super(bfmodule, self).__init__()
self.proj = conv2DBatchNormRelu(in_channels=inplanes,k_size=1,n_filters=64,padding=0,stride=1)
self.inplanes = 64
# Vanilla Residual Blocks
self.res_block3 = self._make_layer(residualBlock,64,1,stride=2)
self.res_block5 = self._make_layer(residualBlock,64,1,stride=2)
self.res_block6 = self._make_layer(residualBlock,64,1,stride=2)
self.res_block7 = self._make_layer(residualBlock,128,1,stride=2)
self.pyramid_pooling = pyramidPooling(128, levels=3)
# Iconvs
self.upconv6 = conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
padding=1, stride=1)
self.upconv5 = conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=32,
padding=1, stride=1)
self.upconv4 = conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=32,
padding=1, stride=1)
self.upconv3 = conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=32,
padding=1, stride=1)
self.iconv5 = conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
padding=1, stride=1)
self.iconv4 = conv2DBatchNormRelu(in_channels=96, k_size=3, n_filters=64,
padding=1, stride=1)
self.iconv3 = conv2DBatchNormRelu(in_channels=96, k_size=3, n_filters=64,
padding=1, stride=1)
self.iconv2 = nn.Sequential(conv2DBatchNormRelu(in_channels=96, k_size=3, n_filters=64,
padding=1, stride=1),
nn.Conv2d(64, outplanes,kernel_size=3, stride=1, padding=1, bias=True))
self.proj6 = nn.Conv2d(128, outplanes,kernel_size=3, stride=1, padding=1, bias=True)
self.proj5 = nn.Conv2d(64, outplanes,kernel_size=3, stride=1, padding=1, bias=True)
self.proj4 = nn.Conv2d(64, outplanes,kernel_size=3, stride=1, padding=1, bias=True)
self.proj3 = nn.Conv2d(64, outplanes,kernel_size=3, stride=1, padding=1, bias=True)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if hasattr(m.bias,'data'):
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
proj = self.proj(x) # 4x
rconv3 = self.res_block3(proj) #8x
conv4 = self.res_block5(rconv3) #16x
conv5 = self.res_block6(conv4) #32x
conv6 = self.res_block7(conv5) #64x
conv6 = self.pyramid_pooling(conv6) #64x
pred6 = self.proj6(conv6)
conv6u = F.upsample(conv6, [conv5.size()[2],conv5.size()[3]], mode='bilinear')
concat5 = torch.cat((conv5,self.upconv6(conv6u)),dim=1)
conv5 = self.iconv5(concat5) #32x
pred5 = self.proj5(conv5)
conv5u = F.upsample(conv5, [conv4.size()[2],conv4.size()[3]], mode='bilinear')
concat4 = torch.cat((conv4,self.upconv5(conv5u)),dim=1)
conv4 = self.iconv4(concat4) #16x
pred4 = self.proj4(conv4)
conv4u = F.upsample(conv4, [rconv3.size()[2],rconv3.size()[3]], mode='bilinear')
concat3 = torch.cat((rconv3,self.upconv4(conv4u)),dim=1)
conv3 = self.iconv3(concat3) # 8x
pred3 = self.proj3(conv3)
conv3u = F.upsample(conv3, [x.size()[2],x.size()[3]], mode='bilinear')
concat2 = torch.cat((proj,self.upconv3(conv3u)),dim=1)
pred2 = self.iconv2(concat2) # 4x
return pred2, pred3, pred4, pred5, pred6