RomanShnurov commited on
Commit
fcfe2af
โ€ข
1 Parent(s): 03bdc9c

inital commit

Browse files
README.md CHANGED
@@ -1,13 +1,14 @@
1
  ---
2
- title: Sumsub Ffs Demo
3
- emoji: ๐ŸŒ
4
- colorFrom: purple
5
- colorTo: indigo
 
 
6
  sdk: gradio
7
  sdk_version: 3.39.0
8
  app_file: app.py
9
  pinned: false
10
- license: cc-by-sa-3.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: >-
3
+ For Fake's Sake: a set of models for detecting deepfakes, generated images and
4
+ synthetic images
5
+ emoji: ๐Ÿ 
6
+ colorFrom: green
7
+ colorTo: blue
8
  sdk: gradio
9
  sdk_version: 3.39.0
10
  app_file: app.py
11
  pinned: false
 
12
  ---
13
 
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system("python -m pip install --upgrade pip")
3
+ os.system("pip install git+https://github.com/rwightman/pytorch-image-models")
4
+ os.system("pip install git+https://github.com/huggingface/huggingface_hub")
5
+
6
+ import gradio as gr
7
+ import timm
8
+ import torch
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+ import torchvision
12
+
13
+
14
+ class Model200M(torch.nn.Module):
15
+ def __init__(self):
16
+ super().__init__()
17
+ self.model = timm.create_model('convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384', pretrained=False,
18
+ num_classes=0)
19
+
20
+ self.clf = nn.Sequential(
21
+ nn.Linear(1536, 128),
22
+ nn.ReLU(inplace=True),
23
+ nn.Linear(128, 2))
24
+
25
+ def forward(self, image):
26
+ image_features = self.model(image)
27
+ return self.clf(image_features)
28
+
29
+
30
+ class Model5M(torch.nn.Module):
31
+ def __init__(self):
32
+ super().__init__()
33
+ self.model = timm.create_model('timm/tf_mobilenetv3_large_100.in1k', pretrained=False, num_classes=0)
34
+
35
+ self.clf = nn.Sequential(
36
+ nn.Linear(1280, 128),
37
+ nn.ReLU(inplace=True),
38
+ nn.Linear(128, 2))
39
+
40
+ def forward(self, image):
41
+ image_features = self.model(image)
42
+ return self.clf(image_features)
43
+
44
+ def load_model(name: str):
45
+ model = Model200M() if "200M" in name else Model5M()
46
+ ckpt = torch.load(name, map_location=torch.device('cpu'))
47
+ model.load_state_dict(ckpt)
48
+ model.eval()
49
+ return model
50
+
51
+ model_list = {
52
+ 'midjourney_200M': load_model('models/midjourney200M.pt'),
53
+ 'diffusions_200M': load_model('models/diffusions200M.pt'),
54
+ 'midjourney_5M': load_model('models/midjourney5M.pt'),
55
+ 'diffusions_5M': load_model('models/diffusions5M.pt')
56
+ }
57
+
58
+ tfm = torchvision.transforms.Compose([
59
+ torchvision.transforms.Resize((640, 640)),
60
+ torchvision.transforms.ToTensor(),
61
+ torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
62
+ std=[0.229, 0.224, 0.225]),
63
+ ])
64
+
65
+ tfm_small = torchvision.transforms.Compose([
66
+ torchvision.transforms.Resize((224, 224)),
67
+ torchvision.transforms.ToTensor(),
68
+ torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
69
+ std=[0.229, 0.224, 0.225]),
70
+ ])
71
+
72
+
73
+ def predict_from_model(model, img_1):
74
+ y = model.forward(img_1[None, ...])
75
+ y_1 = F.softmax(y, dim=1)[:, 1].cpu().detach().numpy()
76
+ y_2 = F.softmax(y, dim=1)[:, 0].cpu().detach().numpy()
77
+ return {'created by AI': y_1.tolist(),
78
+ 'created by human': y_2.tolist()}
79
+
80
+
81
+ def predict(raw_image, model_name):
82
+ img_1 = tfm(raw_image)
83
+ img_2 = tfm_small(raw_image)
84
+
85
+ if model_name not in model_list:
86
+ return {'error': [0.]}
87
+
88
+ model = model_list[model_name]
89
+ img = img_1 if "200M" in model_name else img_2
90
+ return predict_from_model(model, img)
91
+
92
+ general_examples = [
93
+ ["images/general/img_1.jpg"],
94
+ ["images/general/img_2.jpg"],
95
+ ["images/general/img_3.jpg"],
96
+ ["images/general/img_4.jpg"],
97
+ ["images/general/img_5.jpg"],
98
+ ["images/general/img_6.jpg"],
99
+ ["images/general/img_7.jpg"],
100
+ ["images/general/img_8.jpg"],
101
+ ["images/general/img_9.jpg"],
102
+ ["images/general/img_10.jpg"],
103
+ ]
104
+
105
+ optic_examples = [
106
+ ["images/optic/img_1.jpg"],
107
+ ["images/optic/img_2.jpg"],
108
+ ["images/optic/img_3.jpg"],
109
+ ["images/optic/img_4.jpg"],
110
+ ["images/optic/img_5.jpg"],
111
+ ]
112
+
113
+ famous_deepfake_examples = [
114
+ ["images/famous_deepfakes/img_1.jpg"],
115
+ ["images/famous_deepfakes/img_2.jpg"],
116
+ ["images/famous_deepfakes/img_3.jpg"],
117
+ ["images/famous_deepfakes/img_4.webp"],
118
+ ]
119
+
120
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
121
+ gr.Markdown(
122
+ """
123
+ <h1 style="text-align: center;">For Fake's Sake: a set of models for detecting generated and synthetic images</h3>
124
+ This is a demo space for <a href='https://huggingface.co/Sumsub/aiornot'>synthetic image detectors</a>.
125
+
126
+ We provide several detectors for images generated by popular tools, such as Midjourney and Stable Diffusion.
127
+ """
128
+ )
129
+
130
+ with gr.Row():
131
+ with gr.Column():
132
+ image_input = gr.Image(type="pil")
133
+ drop_down = gr.Dropdown(model_list.keys(), type="value", label="Model", value="diffusions_200M")
134
+ with gr.Row():
135
+ gr.ClearButton(components=[image_input])
136
+ submit_button = gr.Button("Submit", variant="primary")
137
+ with gr.Column():
138
+ result_score = gr.Label(label='result', num_top_classes=2)
139
+ with gr.Tab("Examples"):
140
+ gr.Examples(examples=general_examples, inputs=image_input)
141
+ # with gr.Tab("More examples"):
142
+ # gr.Examples(examples=optic_examples, inputs=image_input)
143
+ with gr.Tab("Widely known deepfakes"):
144
+ gr.Examples(examples=famous_deepfake_examples, inputs=image_input)
145
+
146
+ submit_button.click(predict, inputs=[image_input, drop_down], outputs=result_score)
147
+
148
+ gr.Markdown(
149
+ """
150
+ <h3>Models</h3>
151
+ <p><code>*_200M</code> models are based on <code>convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384</code> with image size <code>640x640</code></p>
152
+ <p><code>*_5M</code> models are based on <code>tf_mobilenetv3_large_100.in1k</code> with image size <code>224x224</code></p>
153
+
154
+ <h3>Details</h3>
155
+ <li>Model card: <a href='https://huggingface.co/Sumsub/aiornot'>aiornot</a></li>
156
+ <li>License: CC-By-SA-3.0</li>
157
+ """
158
+ )
159
+
160
+ demo.launch()
images/famous_deepfakes/img_1.jpg ADDED
images/famous_deepfakes/img_2.jpg ADDED
images/famous_deepfakes/img_3.jpg ADDED
images/famous_deepfakes/img_4.webp ADDED
images/general/img_1.jpg ADDED
images/general/img_10.jpg ADDED
images/general/img_2.jpg ADDED
images/general/img_3.jpg ADDED
images/general/img_4.jpg ADDED
images/general/img_5.jpg ADDED
images/general/img_6.jpg ADDED
images/general/img_7.jpg ADDED
images/general/img_8.jpg ADDED
images/general/img_9.jpg ADDED
images/optic/img_1.jpg ADDED
images/optic/img_2.jpg ADDED
images/optic/img_3.jpg ADDED
images/optic/img_4.jpg ADDED
images/optic/img_5.jpg ADDED
models/diffusions200M.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24e79ec2325e489dd83ca60f52e04aa214fdee81463b47e150ff09072c627169
3
+ size 794891785
models/diffusions5M.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d2ad16a05ba72f6a5fadae6a7788f2e0c370f02e24e7fc629f3fa526059c435b
3
+ size 17339597
models/midjourney200M.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a689a102e0f54967d2e670a0652bd1d440bc7874ae7cdce77e2a1ccc7ad3f0f
3
+ size 794891785
models/midjourney5M.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ece116e4c8e42c707d46e86574524d9eb1ac4f0e775cc377729a0bb64201e81
3
+ size 17339597
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ timm==0.9.5
2
+ torch==2.0.1
3
+ torchvision==0.15.2