File size: 3,900 Bytes
362a9f8 3120fa9 362a9f8 3120fa9 362a9f8 3120fa9 362a9f8 3120fa9 362a9f8 3120fa9 362a9f8 3120fa9 362a9f8 3120fa9 362a9f8 3120fa9 362a9f8 3120fa9 362a9f8 3120fa9 362a9f8 3120fa9 362a9f8 3120fa9 362a9f8 dacb584 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import chainer
import chainer.functions as F
import chainer.links as L
class UNet3D(chainer.Chain):
def __init__(self, num_of_label):
w = chainer.initializers.HeNormal()
super(UNet3D, self).__init__()
with self.init_scope():
# encoder pass
self.ce0 = L.ConvolutionND(ndim=3, in_channels=1, out_channels=16, ksize=3, pad=1, initialW=w)
self.bne0 = L.BatchNormalization(16)
self.ce1 = L.ConvolutionND(ndim=3, in_channels=16, out_channels=32, ksize=3, pad=1, initialW=w)
self.bne1 = L.BatchNormalization(32)
self.ce2 = L.ConvolutionND(ndim=3, in_channels=32, out_channels=32, ksize=3, pad=1, initialW=w)
self.bne2 = L.BatchNormalization(32)
self.ce3 = L.ConvolutionND(ndim=3, in_channels=32, out_channels=64, ksize=3, pad=1, initialW=w)
self.bne3 = L.BatchNormalization(64)
self.ce4 = L.ConvolutionND(ndim=3, in_channels=64, out_channels=64, ksize=3, pad=1, initialW=w)
self.bne4 = L.BatchNormalization(64)
# decoder pass
self.cd4 = L.ConvolutionND(ndim=3, in_channels=64, out_channels=128, ksize=3, pad=1, initialW=w)
self.bnd4 = L.BatchNormalization(128)
self.deconv2 = L.DeconvolutionND(ndim=3, in_channels=128, out_channels=128, ksize=2, stride=2, initialW=w,
nobias=True)
self.cd3 = L.ConvolutionND(ndim=3, in_channels=64 + 128, out_channels=64, ksize=3, pad=1, initialW=w)
self.bnd3 = L.BatchNormalization(64)
self.cd2 = L.ConvolutionND(ndim=3, in_channels=64, out_channels=64, ksize=3, pad=1, initialW=w)
self.bnd2 = L.BatchNormalization(64)
self.deconv1 = L.DeconvolutionND(ndim=3, in_channels=64, out_channels=64, ksize=2, stride=2, initialW=w,
nobias=True)
self.cd1 = L.ConvolutionND(ndim=3, in_channels=32 + 64, out_channels=32, ksize=3, pad=1, initialW=w)
self.bnd1 = L.BatchNormalization(32)
self.cd0 = L.ConvolutionND(ndim=3, in_channels=32, out_channels=32, ksize=3, pad=1, initialW=w)
self.bnd0 = L.BatchNormalization(32)
self.lcl = L.ConvolutionND(ndim=3, in_channels=32, out_channels=num_of_label, ksize=1, pad=0, initialW=w)
def __call__(self, x):
# encoder pass
e0 = F.relu(self.bne0(self.ce0(x)))
e1 = F.relu(self.bne1(self.ce1(e0)))
del e0
e2 = F.relu(self.bne2(self.ce2(F.max_pooling_nd(e1, ksize=2, stride=2))))
e3 = F.relu(self.bne3(self.ce3(e2)))
del e2
e4 = F.relu(self.bne4(self.ce4(F.max_pooling_nd(e3, ksize=2, stride=2))))
# decoder pass
d4 = F.relu(self.bnd4(self.cd4(e4)))
del e4
d3 = F.relu(self.bnd3(self.cd3(F.concat([self.deconv2(d4), e3]))))
del d4, e3
d2 = F.relu(self.bnd2(self.cd2(d3)))
del d3
d1 = F.relu(self.bnd1(self.cd1(F.concat([self.deconv1(d2), e1]))))
del d2, e1
d0 = F.relu(self.bnd0(self.cd0(d1)))
del d1
lcl = F.softmax(self.lcl(d0), axis=1)
return lcl # (batchsize, ch, z, y, x)
def cropping(self, input, ref):
'''
* @param input encoder feature map
* @param ref decoder feature map
'''
edgez = (input.shape[2] - ref.shape[2]) / 2
edgey = (input.shape[3] - ref.shape[3]) / 2
edgex = (input.shape[4] - ref.shape[4]) / 2
edgez = int(edgex)
edgey = int(edgey)
edgex = int(edgez)
X = F.split_axis(input, (edgex, int(input.shape[4] - edgex)), axis=4)
X = X[1]
X = F.split_axis(X, (edgey, int(X.shape[3] - edgey)), axis=3)
X = X[1]
X = F.split_axis(X, (edgez, int(X.shape[2] - edgez)), axis=2)
X = X[1]
return X
|