File size: 1,689 Bytes
bd0a3d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torch.nn.functional as F
from huggingface_hub import PyTorchModelHubMixin
from torch import nn
from torchvision import models


class ICN(nn.Module, PyTorchModelHubMixin):
    def __init__(self):
        super().__init__()

        cnn = models.resnet50(pretrained=False)
        self.cnn_head = nn.Sequential(
            *list(cnn.children())[:4],
            *list(list(list(cnn.children())[4].children())[0].children())[:4],
        )
        self.cnn_tail = nn.Sequential(
            *list(list(cnn.children())[4].children()
                  )[1:], *list(cnn.children())[5:-2]
        )

        self.conv1 = nn.Conv2d(128, 256, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(num_features=256)

        self.fc1 = nn.Linear(2048 * 7 * 7, 256)
        self.fc2 = nn.Linear(256, 7 * 7)

        self.cls_fc = nn.Linear(256, 3)

        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        # Input: [-1, 6, 224, 224]
        real = x[:, :3, :, :]
        fake = x[:, 3:, :, :]

        # Push both images through pretrained backbone
        real_features = F.relu(self.cnn_head(real))  # [-1, 64, 56, 56]
        fake_features = F.relu(self.cnn_head(fake))  # [-1, 64, 56, 56]

        # [-1, 128, 56, 56]
        combined = torch.cat((real_features, fake_features), 1)

        x = self.conv1(combined)  # [-1, 256, 56, 56]
        x = self.bn1(x)
        x = F.relu(x)

        x = self.cnn_tail(x)
        x = x.view(-1, 2048 * 7 * 7)

        # Final feature [-1, 256]
        d = F.relu(self.fc1(x))

        # Heatmap [-1, 49]
        grid = self.fc2(d)

        # Classifier [-1, 1]
        cl = self.cls_fc(d)

        return grid, cl