JohnJoelMota commited on
Commit
96a848f
·
verified ·
1 Parent(s): e245366

Created ResNet50_for_CC.py file.

Browse files
Files changed (1) hide show
  1. ResNet50_for_CC.py +56 -0
ResNet50_for_CC.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.models as models
4
+
5
+ class ResClassifier(nn.Module):
6
+ def __init__(self, class_num=14):
7
+ super(ResClassifier, self).__init__()
8
+ self.fc1 = nn.Sequential(
9
+ nn.Linear(128, 64),
10
+ nn.BatchNorm1d(64, affine=True),
11
+ nn.ReLU(inplace=True),
12
+ nn.Dropout()
13
+ )
14
+ self.fc2 = nn.Sequential(
15
+ nn.Linear(64, 64),
16
+ nn.BatchNorm1d(64, affine=True),
17
+ nn.ReLU(inplace=True),
18
+ nn.Dropout()
19
+ )
20
+ self.fc3 = nn.Linear(64, class_num)
21
+
22
+ def forward(self, x):
23
+ fc1_emb = self.fc1(x)
24
+ fc2_emb = self.fc2(fc1_emb)
25
+ logit = self.fc3(fc2_emb)
26
+ return logit
27
+
28
+ class CC_model(nn.Module):
29
+ def __init__(self, num_classes1=14, num_classes2=None):
30
+ if num_classes2 is None:
31
+ num_classes2 = num_classes1
32
+
33
+ super(CC_model, self).__init__()
34
+ assert num_classes1 == num_classes2
35
+ self.num_classes = num_classes1
36
+ self.model_resnet = models.resnet50(weights='ResNet50_Weights.DEFAULT')
37
+ num_ftrs = self.model_resnet.fc.in_features
38
+ self.model_resnet.fc = nn.Identity()
39
+ self.classification_fc = nn.Linear(num_ftrs, num_classes1)
40
+ self.dr = nn.Linear(num_ftrs, 128)
41
+ self.fc1 = ResClassifier(num_classes1)
42
+ self.fc2 = ResClassifier(num_classes1)
43
+
44
+ def forward(self, x, detach_feature=False):
45
+ with torch.no_grad():
46
+ feature = self.model_resnet(x)
47
+ res_out = self.classification_fc(feature)
48
+ if detach_feature:
49
+ feature = feature.detach()
50
+ dr_feature = self.dr(feature)
51
+ out1 = self.fc1(dr_feature)
52
+ out2 = self.fc2(dr_feature)
53
+ output_mean = (out1 + out2) / 2
54
+ return dr_feature, output_mean
55
+
56
+ return dr_feature