sanjanatule commited on
Commit
5c09ac9
·
1 Parent(s): 422d720

Create models/custom_resnet.py

Browse files
Files changed (1) hide show
  1. models/custom_resnet.py +85 -0
models/custom_resnet.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.optim as optim
6
+ from torchvision import datasets, transforms
7
+ from tqdm import tqdm
8
+
9
+ class Net(nn.Module):
10
+ def __init__(self):
11
+ super(Net, self).__init__()
12
+
13
+
14
+ self.prep_layer = nn.Sequential(
15
+ nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 3), padding=1, stride=1, bias=False),
16
+ nn.BatchNorm2d(64),
17
+ nn.ReLU(),
18
+ ) # output_size = 64,32,32
19
+
20
+ self.layer1_x = nn.Sequential(
21
+ nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), padding=1, stride=1, bias=False),
22
+ nn.MaxPool2d(2, 2),
23
+ nn.BatchNorm2d(128),
24
+ nn.ReLU(),
25
+ ) # output_size = 128,16,16
26
+
27
+ self.layer1_r = nn.Sequential(
28
+ nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), padding=1, stride=1, bias=False),
29
+ nn.BatchNorm2d(128),
30
+ nn.ReLU(),
31
+ nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), padding=1, stride=1, bias=False),
32
+ nn.BatchNorm2d(128),
33
+ nn.ReLU(),
34
+ ) #output_size = 128,16,16
35
+
36
+ self.layer2 = nn.Sequential(
37
+ nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), padding=1, stride=1, bias=False),
38
+ nn.MaxPool2d(2, 2),
39
+ nn.BatchNorm2d(256),
40
+ nn.ReLU(),
41
+ ) # output_size = 256,8,8
42
+
43
+
44
+ self.layer3_x = nn.Sequential(
45
+ nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), padding=1, stride=1, bias=False),
46
+ nn.MaxPool2d(2, 2),
47
+ nn.BatchNorm2d(512),
48
+ nn.ReLU(),
49
+ ) # output_size = 512,4,4
50
+
51
+ self.layer3_r = nn.Sequential(
52
+ nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), padding=1, stride=1, bias=False),
53
+ nn.BatchNorm2d(512),
54
+ nn.ReLU(),
55
+ nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), padding=1, stride=1, bias=False),
56
+ nn.BatchNorm2d(512),
57
+ nn.ReLU(),
58
+ ) #output_size = 512,4,4
59
+
60
+ self.last_maxpool = nn.Sequential(
61
+ nn.MaxPool2d(4, 4), #512
62
+
63
+ )
64
+ self.last_fc = nn.Sequential(
65
+ nn.Linear(512,10,bias=False)
66
+ )
67
+
68
+ def forward(self, x):
69
+ x = self.prep_layer(x)
70
+ x = self.layer1_x(x)
71
+ x_layer1_identity = x.clone()
72
+ x = self.layer1_r(x)
73
+ #x = F.relu(x + x_layer1_identity)
74
+ x = x + x_layer1_identity
75
+ x = self.layer2(x)
76
+ x = self.layer3_x(x)
77
+ x_layer3_identity = x.clone()
78
+ x = self.layer3_r(x)
79
+ #x = F.relu(x + x_layer3_identity)
80
+ x = x + x_layer3_identity
81
+ x = self.last_maxpool(x)
82
+ x = x.view(-1,512)
83
+ x = self.last_fc(x)
84
+ x = x.view(-1, 10)
85
+ return F.log_softmax(x, dim=-1)