TharunSivamani commited on
Commit
a445351
·
verified ·
1 Parent(s): 6605ab3

resnet model code

Browse files
Files changed (1) hide show
  1. model.py +121 -0
model.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class Bottleneck(nn.Module):
5
+ expansion = 4
6
+ def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):
7
+ super(Bottleneck, self).__init__()
8
+
9
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
10
+ self.batch_norm1 = nn.BatchNorm2d(out_channels)
11
+
12
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)
13
+ self.batch_norm2 = nn.BatchNorm2d(out_channels)
14
+
15
+ self.conv3 = nn.Conv2d(out_channels, out_channels*self.expansion, kernel_size=1, stride=1, padding=0)
16
+ self.batch_norm3 = nn.BatchNorm2d(out_channels*self.expansion)
17
+
18
+ self.i_downsample = i_downsample
19
+ self.stride = stride
20
+ self.relu = nn.ReLU()
21
+
22
+ def forward(self, x):
23
+ identity = x.clone()
24
+ x = self.relu(self.batch_norm1(self.conv1(x)))
25
+
26
+ x = self.relu(self.batch_norm2(self.conv2(x)))
27
+
28
+ x = self.conv3(x)
29
+ x = self.batch_norm3(x)
30
+
31
+ #downsample if needed
32
+ if self.i_downsample is not None:
33
+ identity = self.i_downsample(identity)
34
+ #add identity
35
+ x+=identity
36
+ x=self.relu(x)
37
+
38
+ return x
39
+
40
+ class Block(nn.Module):
41
+ expansion = 1
42
+ def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):
43
+ super(Block, self).__init__()
44
+
45
+
46
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False)
47
+ self.batch_norm1 = nn.BatchNorm2d(out_channels)
48
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False)
49
+ self.batch_norm2 = nn.BatchNorm2d(out_channels)
50
+
51
+ self.i_downsample = i_downsample
52
+ self.stride = stride
53
+ self.relu = nn.ReLU()
54
+
55
+ def forward(self, x):
56
+ identity = x.clone()
57
+
58
+ x = self.relu(self.batch_norm2(self.conv1(x)))
59
+ x = self.batch_norm2(self.conv2(x))
60
+
61
+ if self.i_downsample is not None:
62
+ identity = self.i_downsample(identity)
63
+ print(x.shape)
64
+ print(identity.shape)
65
+ x += identity
66
+ x = self.relu(x)
67
+ return x
68
+
69
+ class ResNet(nn.Module):
70
+ def __init__(self, ResBlock, layer_list, num_classes, num_channels=3):
71
+ super(ResNet, self).__init__()
72
+ self.in_channels = 64
73
+
74
+ self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
75
+ self.batch_norm1 = nn.BatchNorm2d(64)
76
+ self.relu = nn.ReLU()
77
+ self.max_pool = nn.MaxPool2d(kernel_size = 3, stride=2, padding=1)
78
+
79
+ self.layer1 = self._make_layer(ResBlock, layer_list[0], planes=64)
80
+ self.layer2 = self._make_layer(ResBlock, layer_list[1], planes=128, stride=2)
81
+ self.layer3 = self._make_layer(ResBlock, layer_list[2], planes=256, stride=2)
82
+ self.layer4 = self._make_layer(ResBlock, layer_list[3], planes=512, stride=2)
83
+
84
+ self.avgpool = nn.AdaptiveAvgPool2d((1,1))
85
+ self.fc = nn.Linear(512*ResBlock.expansion, num_classes)
86
+
87
+ def forward(self, x):
88
+ x = self.relu(self.batch_norm1(self.conv1(x)))
89
+ x = self.max_pool(x)
90
+
91
+ x = self.layer1(x)
92
+ x = self.layer2(x)
93
+ x = self.layer3(x)
94
+ x = self.layer4(x)
95
+
96
+ x = self.avgpool(x)
97
+ x = x.reshape(x.shape[0], -1)
98
+ x = self.fc(x)
99
+
100
+ return x
101
+
102
+ def _make_layer(self, ResBlock, blocks, planes, stride=1):
103
+ ii_downsample = None
104
+ layers = []
105
+
106
+ if stride != 1 or self.in_channels != planes*ResBlock.expansion:
107
+ ii_downsample = nn.Sequential(
108
+ nn.Conv2d(self.in_channels, planes*ResBlock.expansion, kernel_size=1, stride=stride),
109
+ nn.BatchNorm2d(planes*ResBlock.expansion)
110
+ )
111
+
112
+ layers.append(ResBlock(self.in_channels, planes, i_downsample=ii_downsample, stride=stride))
113
+ self.in_channels = planes*ResBlock.expansion
114
+
115
+ for i in range(blocks-1):
116
+ layers.append(ResBlock(self.in_channels, planes))
117
+
118
+ return nn.Sequential(*layers)
119
+
120
+ def ResNet50(num_classes, channels=3):
121
+ return ResNet(Bottleneck, [3,4,6,3], num_classes, channels)