Vrk commited on
Commit
d8455aa
1 Parent(s): c5d6aa2

Upload Models.py

Browse files
Files changed (1) hide show
  1. Models.py +76 -0
Models.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import torch.nn.functional as F
5
+ import timm
6
+
7
+ # ResNet50 Model
8
+ class ResNet(nn.Module):
9
+ def __init__(self, num_classes, is_freeze=True):
10
+ super(ResNet, self).__init__()
11
+
12
+ self.num_classes = num_classes
13
+ self.is_freeze = is_freeze
14
+ self.base_model = timm.create_model('resnet50', pretrained=True)
15
+
16
+ if self.is_freeze:
17
+ for param in self.base_model.parameters():
18
+ param.requires_grad = False
19
+
20
+ self.base_model.fc = nn.Linear(2048, self.num_classes)
21
+
22
+ def forward(self, x):
23
+ x = self.base_model(x)
24
+ return x
25
+
26
+ # EfficientNet Model
27
+ class EfficientNet(nn.Module):
28
+ def __init__(self, num_classes):
29
+ super(EfficientNet, self).__init__()
30
+
31
+ self.num_classes = num_classes
32
+ self.base_model = timm.create_model('efficientnet_b0', pretrained=True)
33
+ self.base_model.classifier = nn.Linear(1280, self.num_classes)
34
+
35
+ def forward(self, x):
36
+ x = self.base_model(x)
37
+ return x
38
+
39
+ # BaseLine Model
40
+ class BaseLine(nn.Module):
41
+ def __init__(self, num_classes):
42
+ super(BaseLine, self).__init__()
43
+
44
+ self.Conv1 = nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=1)
45
+ self.Conv2 = nn.Conv2d(96, 256, kernel_size=5, padding=2)
46
+ self.Conv3 = nn.Conv2d(256, 384, kernel_size=3, padding=1)
47
+ self.Conv4 = nn.Conv2d(384, 256, kernel_size=3, padding=1)
48
+
49
+ self.Linear1 = nn.Linear(2304, 512)
50
+ self.Linear3 = nn.Linear(512, num_classes)
51
+
52
+ self.relu = nn.ReLU()
53
+ self.dropout = nn.Dropout(p=0.5)
54
+
55
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
56
+ self.flatten = nn.Flatten()
57
+
58
+ def forward(self, x):
59
+ x = self.Conv1(x)
60
+ x = self.relu(x)
61
+ x = self.maxpool(x)
62
+
63
+ x = self.Conv2(x)
64
+ x = self.maxpool(x)
65
+
66
+ x = self.Conv3(x)
67
+ x = self.Conv4(x)
68
+ x = self.maxpool(x)
69
+
70
+ x = self.flatten(x)
71
+ x = self.Linear1(x)
72
+ x = self.relu(x)
73
+ x = self.dropout(x)
74
+
75
+ x = self.Linear3(x)
76
+ return x