AlexBlck commited on
Commit
bd0a3d5
1 Parent(s): b704fa2

Streamlit upload

Browse files
Files changed (2) hide show
  1. app.py +111 -0
  2. model.py +60 -0
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import streamlit as st
4
+ import torch
5
+ from huggingface_hub import PyTorchModelHubMixin
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+ from torchvision.transforms.functional import to_pil_image
9
+
10
+ from model import ICN
11
+
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+
15
+ def mask_processing(x):
16
+ if x > 90:
17
+ return 140
18
+ elif x < 80:
19
+ return 0
20
+ else:
21
+ return 255
22
+
23
+
24
+ def grid_to_heatmap(grid, size=1024):
25
+ mask = to_pil_image(grid.view(7, 7))
26
+ mask = mask.resize((size, size), Image.BICUBIC)
27
+ mask = Image.eval(mask, mask_processing)
28
+
29
+ colormap = plt.get_cmap("Wistia")
30
+ heatmap = np.array(colormap(mask))
31
+ heatmap = (heatmap * 255).astype(np.uint8)
32
+ heatmap = Image.fromarray(heatmap)
33
+
34
+ return heatmap, mask
35
+
36
+
37
+ def summary_image(img, fake, prediction):
38
+ prediction -= prediction.min()
39
+ prediction = prediction / prediction.max()
40
+
41
+ size = 1024
42
+
43
+ img1 = img.resize((size, size))
44
+ img2 = fake.resize((size, size))
45
+
46
+ heatmap, mask = grid_to_heatmap(prediction)
47
+ img1.paste(heatmap, (0, 0), mask)
48
+ img2.paste(heatmap, (0, 0), mask)
49
+
50
+ return img1, img2
51
+
52
+
53
+ @st.cache_resource
54
+ def load_model():
55
+ model = torch.jit.load("traced_model.pt")
56
+ model.eval().to(device)
57
+ return model
58
+
59
+
60
+ model = ICN.from_pretrained("AlexBlck/image-comparator").eval().to(device)
61
+
62
+ # model = load_model()
63
+
64
+ st.title("Image Comparator Network")
65
+
66
+ st.write("## Upload a pair of images")
67
+ cols = st.columns(2)
68
+ with cols[0]:
69
+ im1 = st.file_uploader("Image 1", type=["jpg", "png"])
70
+ with cols[1]:
71
+ im2 = st.file_uploader("Image 2", type=["jpg", "png"])
72
+
73
+ if not (im1 and im2):
74
+ st.stop()
75
+
76
+ btn = st.button("Run")
77
+ if not btn:
78
+ st.stop()
79
+
80
+ im1 = Image.open(im1).convert("RGB")
81
+ im2 = Image.open(im2).convert("RGB")
82
+
83
+ tr = transforms.Compose(
84
+ [
85
+ transforms.Resize(size=(224, 224)),
86
+ transforms.ToTensor(),
87
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
88
+ ]
89
+ )
90
+
91
+ img = torch.vstack((tr(im1), tr(im2))).unsqueeze(0)
92
+
93
+ heatmap, cl = model(img.to(device))
94
+ confs = torch.softmax(cl, dim=1)
95
+ pred = torch.argmax(confs, dim=1).item()
96
+
97
+ if pred == 0:
98
+ st.success("No Manipulation Detected")
99
+ heatmap *= 0
100
+ elif pred == 1:
101
+ st.warning("Manipulation Detected!")
102
+ else:
103
+ st.error("Images are not related.")
104
+ heatmap *= 0
105
+
106
+ img1, img2 = summary_image(im1, im2, heatmap[0])
107
+ cols = st.columns(2)
108
+ with cols[0]:
109
+ st.image(img1)
110
+ with cols[1]:
111
+ st.image(img2)
model.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from huggingface_hub import PyTorchModelHubMixin
4
+ from torch import nn
5
+ from torchvision import models
6
+
7
+
8
+ class ICN(nn.Module, PyTorchModelHubMixin):
9
+ def __init__(self):
10
+ super().__init__()
11
+
12
+ cnn = models.resnet50(pretrained=False)
13
+ self.cnn_head = nn.Sequential(
14
+ *list(cnn.children())[:4],
15
+ *list(list(list(cnn.children())[4].children())[0].children())[:4],
16
+ )
17
+ self.cnn_tail = nn.Sequential(
18
+ *list(list(cnn.children())[4].children()
19
+ )[1:], *list(cnn.children())[5:-2]
20
+ )
21
+
22
+ self.conv1 = nn.Conv2d(128, 256, 3, padding=1)
23
+ self.bn1 = nn.BatchNorm2d(num_features=256)
24
+
25
+ self.fc1 = nn.Linear(2048 * 7 * 7, 256)
26
+ self.fc2 = nn.Linear(256, 7 * 7)
27
+
28
+ self.cls_fc = nn.Linear(256, 3)
29
+
30
+ self.criterion = nn.CrossEntropyLoss()
31
+
32
+ def forward(self, x):
33
+ # Input: [-1, 6, 224, 224]
34
+ real = x[:, :3, :, :]
35
+ fake = x[:, 3:, :, :]
36
+
37
+ # Push both images through pretrained backbone
38
+ real_features = F.relu(self.cnn_head(real)) # [-1, 64, 56, 56]
39
+ fake_features = F.relu(self.cnn_head(fake)) # [-1, 64, 56, 56]
40
+
41
+ # [-1, 128, 56, 56]
42
+ combined = torch.cat((real_features, fake_features), 1)
43
+
44
+ x = self.conv1(combined) # [-1, 256, 56, 56]
45
+ x = self.bn1(x)
46
+ x = F.relu(x)
47
+
48
+ x = self.cnn_tail(x)
49
+ x = x.view(-1, 2048 * 7 * 7)
50
+
51
+ # Final feature [-1, 256]
52
+ d = F.relu(self.fc1(x))
53
+
54
+ # Heatmap [-1, 49]
55
+ grid = self.fc2(d)
56
+
57
+ # Classifier [-1, 1]
58
+ cl = self.cls_fc(d)
59
+
60
+ return grid, cl