Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn.functional as F | |
from huggingface_hub import PyTorchModelHubMixin | |
from torch import nn | |
from torchvision import models | |
class ICN(nn.Module, PyTorchModelHubMixin): | |
def __init__(self): | |
super().__init__() | |
cnn = models.resnet50(pretrained=False) | |
self.cnn_head = nn.Sequential( | |
*list(cnn.children())[:4], | |
*list(list(list(cnn.children())[4].children())[0].children())[:4], | |
) | |
self.cnn_tail = nn.Sequential( | |
*list(list(cnn.children())[4].children() | |
)[1:], *list(cnn.children())[5:-2] | |
) | |
self.conv1 = nn.Conv2d(128, 256, 3, padding=1) | |
self.bn1 = nn.BatchNorm2d(num_features=256) | |
self.fc1 = nn.Linear(2048 * 7 * 7, 256) | |
self.fc2 = nn.Linear(256, 7 * 7) | |
self.cls_fc = nn.Linear(256, 3) | |
self.criterion = nn.CrossEntropyLoss() | |
def forward(self, x): | |
# Input: [-1, 6, 224, 224] | |
real = x[:, :3, :, :] | |
fake = x[:, 3:, :, :] | |
# Push both images through pretrained backbone | |
real_features = F.relu(self.cnn_head(real)) # [-1, 64, 56, 56] | |
fake_features = F.relu(self.cnn_head(fake)) # [-1, 64, 56, 56] | |
# [-1, 128, 56, 56] | |
combined = torch.cat((real_features, fake_features), 1) | |
x = self.conv1(combined) # [-1, 256, 56, 56] | |
x = self.bn1(x) | |
x = F.relu(x) | |
x = self.cnn_tail(x) | |
x = x.view(-1, 2048 * 7 * 7) | |
# Final feature [-1, 256] | |
d = F.relu(self.fc1(x)) | |
# Heatmap [-1, 49] | |
grid = self.fc2(d) | |
# Classifier [-1, 1] | |
cl = self.cls_fc(d) | |
return grid, cl | |