AlexBlck commited on
Commit
074ffb4
1 Parent(s): 9dd60e4

Upload model code

Browse files
Files changed (1) hide show
  1. model.py +60 -0
model.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from huggingface_hub import PyTorchModelHubMixin
4
+ from torch import nn
5
+ from torchvision import models
6
+
7
+
8
+ class ICN(nn.Module, PyTorchModelHubMixin):
9
+ def __init__(self):
10
+ super().__init__()
11
+
12
+ cnn = models.resnet50(pretrained=False)
13
+ self.cnn_head = nn.Sequential(
14
+ *list(cnn.children())[:4],
15
+ *list(list(list(cnn.children())[4].children())[0].children())[:4],
16
+ )
17
+ self.cnn_tail = nn.Sequential(
18
+ *list(list(cnn.children())[4].children()
19
+ )[1:], *list(cnn.children())[5:-2]
20
+ )
21
+
22
+ self.conv1 = nn.Conv2d(128, 256, 3, padding=1)
23
+ self.bn1 = nn.BatchNorm2d(num_features=256)
24
+
25
+ self.fc1 = nn.Linear(2048 * 7 * 7, 256)
26
+ self.fc2 = nn.Linear(256, 7 * 7)
27
+
28
+ self.cls_fc = nn.Linear(256, 3)
29
+
30
+ self.criterion = nn.CrossEntropyLoss()
31
+
32
+ def forward(self, x):
33
+ # Input: [-1, 6, 224, 224]
34
+ real = x[:, :3, :, :]
35
+ fake = x[:, 3:, :, :]
36
+
37
+ # Push both images through pretrained backbone
38
+ real_features = F.relu(self.cnn_head(real)) # [-1, 64, 56, 56]
39
+ fake_features = F.relu(self.cnn_head(fake)) # [-1, 64, 56, 56]
40
+
41
+ # [-1, 128, 56, 56]
42
+ combined = torch.cat((real_features, fake_features), 1)
43
+
44
+ x = self.conv1(combined) # [-1, 256, 56, 56]
45
+ x = self.bn1(x)
46
+ x = F.relu(x)
47
+
48
+ x = self.cnn_tail(x)
49
+ x = x.view(-1, 2048 * 7 * 7)
50
+
51
+ # Final feature [-1, 256]
52
+ d = F.relu(self.fc1(x))
53
+
54
+ # Heatmap [-1, 49]
55
+ grid = self.fc2(d)
56
+
57
+ # Classifier [-1, 1]
58
+ cl = self.cls_fc(d)
59
+
60
+ return grid, cl