import torch import torch.nn.functional as F from huggingface_hub import PyTorchModelHubMixin from torch import nn from torchvision import models class ICN(nn.Module, PyTorchModelHubMixin): def __init__(self): super().__init__() cnn = models.resnet50(pretrained=False) self.cnn_head = nn.Sequential( *list(cnn.children())[:4], *list(list(list(cnn.children())[4].children())[0].children())[:4], ) self.cnn_tail = nn.Sequential( *list(list(cnn.children())[4].children() )[1:], *list(cnn.children())[5:-2] ) self.conv1 = nn.Conv2d(128, 256, 3, padding=1) self.bn1 = nn.BatchNorm2d(num_features=256) self.fc1 = nn.Linear(2048 * 7 * 7, 256) self.fc2 = nn.Linear(256, 7 * 7) self.cls_fc = nn.Linear(256, 3) self.criterion = nn.CrossEntropyLoss() def forward(self, x): # Input: [-1, 6, 224, 224] real = x[:, :3, :, :] fake = x[:, 3:, :, :] # Push both images through pretrained backbone real_features = F.relu(self.cnn_head(real)) # [-1, 64, 56, 56] fake_features = F.relu(self.cnn_head(fake)) # [-1, 64, 56, 56] # [-1, 128, 56, 56] combined = torch.cat((real_features, fake_features), 1) x = self.conv1(combined) # [-1, 256, 56, 56] x = self.bn1(x) x = F.relu(x) x = self.cnn_tail(x) x = x.view(-1, 2048 * 7 * 7) # Final feature [-1, 256] d = F.relu(self.fc1(x)) # Heatmap [-1, 49] grid = self.fc2(d) # Classifier [-1, 1] cl = self.cls_fc(d) return grid, cl