Spaces:
Running
Running
File size: 4,435 Bytes
4ba6fde |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
'''RetinaFPN in PyTorch.'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.expansion*planes)
self.downsample = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.downsample(x)
out = F.relu(out)
return out
class FPN(nn.Module):
def __init__(self, block, num_blocks):
super(FPN, self).__init__()
self.in_planes = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
# Bottom-up layers
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.conv6 = nn.Conv2d(2048, 256, kernel_size=3, stride=2, padding=1)
self.conv7 = nn.Conv2d( 256, 256, kernel_size=3, stride=2, padding=1)
# Lateral layers
self.latlayer1 = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=0)
self.latlayer2 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
self.latlayer3 = nn.Conv2d( 512, 256, kernel_size=1, stride=1, padding=0)
# Top-down layers
self.toplayer1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.toplayer2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def _upsample_add(self, x, y):
'''Upsample and add two feature maps.
Args:
x: (Variable) top feature map to be upsampled.
y: (Variable) lateral feature map.
Returns:
(Variable) added feature map.
Note in PyTorch, when input size is odd, the upsampled feature map
with `F.upsample(..., scale_factor=2, mode='nearest')`
maybe not equal to the lateral feature map size.
e.g.
original input size: [N,_,15,15] ->
conv2d feature map size: [N,_,8,8] ->
upsampled feature map size: [N,_,16,16]
So we choose bilinear upsample which supports arbitrary output sizes.
'''
_,_,H,W = y.size()
return F.upsample(x, size=(H,W), mode='bilinear') + y
def forward(self, x):
# Bottom-up
c1 = F.relu(self.bn1(self.conv1(x)))
c1 = F.max_pool2d(c1, kernel_size=3, stride=2, padding=1)
c2 = self.layer1(c1)
c3 = self.layer2(c2)
c4 = self.layer3(c3)
c5 = self.layer4(c4)
p6 = self.conv6(c5)
p7 = self.conv7(F.relu(p6))
# Top-down
p5 = self.latlayer1(c5)
p4 = self._upsample_add(p5, self.latlayer2(c4))
p4 = self.toplayer1(p4)
p3 = self._upsample_add(p4, self.latlayer3(c3))
p3 = self.toplayer2(p3)
return p3, p4, p5, p6 , p7
def FPN50():
print ('load fpn50')
return FPN(Bottleneck, [3,4,6,3])
def FPN101():
return FPN(Bottleneck, [3,4,23,3])
def test():
fpn = FPN101()
dd = fpn.state_dict()
print (len(dd))
for d in dd:
print (d)
print('*'*100)
#test()
|