image-comparator / model.py
AlexBlck's picture
Streamlit upload
bd0a3d5
import torch
import torch.nn.functional as F
from huggingface_hub import PyTorchModelHubMixin
from torch import nn
from torchvision import models
class ICN(nn.Module, PyTorchModelHubMixin):
def __init__(self):
super().__init__()
cnn = models.resnet50(pretrained=False)
self.cnn_head = nn.Sequential(
*list(cnn.children())[:4],
*list(list(list(cnn.children())[4].children())[0].children())[:4],
)
self.cnn_tail = nn.Sequential(
*list(list(cnn.children())[4].children()
)[1:], *list(cnn.children())[5:-2]
)
self.conv1 = nn.Conv2d(128, 256, 3, padding=1)
self.bn1 = nn.BatchNorm2d(num_features=256)
self.fc1 = nn.Linear(2048 * 7 * 7, 256)
self.fc2 = nn.Linear(256, 7 * 7)
self.cls_fc = nn.Linear(256, 3)
self.criterion = nn.CrossEntropyLoss()
def forward(self, x):
# Input: [-1, 6, 224, 224]
real = x[:, :3, :, :]
fake = x[:, 3:, :, :]
# Push both images through pretrained backbone
real_features = F.relu(self.cnn_head(real)) # [-1, 64, 56, 56]
fake_features = F.relu(self.cnn_head(fake)) # [-1, 64, 56, 56]
# [-1, 128, 56, 56]
combined = torch.cat((real_features, fake_features), 1)
x = self.conv1(combined) # [-1, 256, 56, 56]
x = self.bn1(x)
x = F.relu(x)
x = self.cnn_tail(x)
x = x.view(-1, 2048 * 7 * 7)
# Final feature [-1, 256]
d = F.relu(self.fc1(x))
# Heatmap [-1, 49]
grid = self.fc2(d)
# Classifier [-1, 1]
cl = self.cls_fc(d)
return grid, cl