RomanShnurov commited on
Commit
f3b2c5b
1 Parent(s): 295487b

add new synthetic detector

Browse files
Files changed (6) hide show
  1. .gitignore +160 -0
  2. app.py +32 -82
  3. model_classes.py +51 -0
  4. model_loader.py +59 -0
  5. model_transforms.py +25 -0
  6. models/synthetic_detector_v2.pt +3 -0
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
app.py CHANGED
@@ -1,93 +1,40 @@
1
- import os
2
- os.system("python -m pip install --upgrade pip")
3
- os.system("pip install git+https://github.com/rwightman/pytorch-image-models")
4
- os.system("pip install git+https://github.com/huggingface/huggingface_hub")
5
-
6
  import gradio as gr
7
- import timm
8
- import torch
9
- from torch import nn
10
  from torch.nn import functional as F
11
- import torchvision
12
-
13
-
14
- class Model200M(torch.nn.Module):
15
- def __init__(self):
16
- super().__init__()
17
- self.model = timm.create_model('convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384', pretrained=False,
18
- num_classes=0)
19
-
20
- self.clf = nn.Sequential(
21
- nn.Linear(1536, 128),
22
- nn.ReLU(inplace=True),
23
- nn.Linear(128, 2))
24
-
25
- def forward(self, image):
26
- image_features = self.model(image)
27
- return self.clf(image_features)
28
-
29
-
30
- class Model5M(torch.nn.Module):
31
- def __init__(self):
32
- super().__init__()
33
- self.model = timm.create_model('timm/tf_mobilenetv3_large_100.in1k', pretrained=False, num_classes=0)
34
-
35
- self.clf = nn.Sequential(
36
- nn.Linear(1280, 128),
37
- nn.ReLU(inplace=True),
38
- nn.Linear(128, 2))
39
-
40
- def forward(self, image):
41
- image_features = self.model(image)
42
- return self.clf(image_features)
43
-
44
- def load_model(name: str):
45
- model = Model200M() if "200M" in name else Model5M()
46
- ckpt = torch.load(name, map_location=torch.device('cpu'))
47
- model.load_state_dict(ckpt)
48
- model.eval()
49
- return model
50
 
51
- model_list = {
52
- 'midjourney_200M': load_model('models/midjourney200M.pt'),
53
- 'diffusions_200M': load_model('models/diffusions200M.pt'),
54
- 'midjourney_5M': load_model('models/midjourney5M.pt'),
55
- 'diffusions_5M': load_model('models/diffusions5M.pt')
56
- }
57
-
58
- tfm = torchvision.transforms.Compose([
59
- torchvision.transforms.Resize((640, 640)),
60
- torchvision.transforms.ToTensor(),
61
- torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
62
- std=[0.229, 0.224, 0.225]),
63
- ])
64
-
65
- tfm_small = torchvision.transforms.Compose([
66
- torchvision.transforms.Resize((224, 224)),
67
- torchvision.transforms.ToTensor(),
68
- torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
69
- std=[0.229, 0.224, 0.225]),
70
- ])
71
-
72
-
73
- def predict_from_model(model, img_1):
74
- y = model.forward(img_1[None, ...])
75
  y_1 = F.softmax(y, dim=1)[:, 1].cpu().detach().numpy()
76
  y_2 = F.softmax(y, dim=1)[:, 0].cpu().detach().numpy()
77
  return {'created by AI': y_1.tolist(),
78
  'created by human': y_2.tolist()}
79
 
 
 
 
 
80
 
81
  def predict(raw_image, model_name):
82
- img_1 = tfm(raw_image)
83
- img_2 = tfm_small(raw_image)
84
-
85
- if model_name not in model_list:
86
  return {'error': [0.]}
87
 
88
- model = model_list[model_name]
89
- img = img_1 if "200M" in model_name else img_2
90
- return predict_from_model(model, img)
 
 
 
 
 
 
91
 
92
  general_examples = [
93
  ["images/general/img_1.jpg"],
@@ -125,8 +72,9 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
125
  <a href='https://huggingface.co/Sumsub/Sumsub-ffs-synthetic-1.0_mj_200'>midjourney200M</a>,
126
  <a href='https://huggingface.co/Sumsub/Sumsub-ffs-synthetic-1.0_mj_5'>midjourney5M</a>,
127
  <a href='https://huggingface.co/Sumsub/Sumsub-ffs-synthetic-1.0_sd_200'>diffusions200M</a>,
128
- <a href='https://huggingface.co/Sumsub/Sumsub-ffs-synthetic-1.0_sd_5'>diffusions5M</a>.<br>
129
- We provide several detectors for images generated by popular tools, such as Midjourney and Stable Diffusion.<br>
 
130
  Please refer to model cards for evaluation metrics and limitations.
131
  """
132
  )
@@ -134,7 +82,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
134
  with gr.Row():
135
  with gr.Column():
136
  image_input = gr.Image(type="pil")
137
- drop_down = gr.Dropdown(model_list.keys(), type="value", label="Model", value="diffusions_200M")
138
  with gr.Row():
139
  gr.ClearButton(components=[image_input])
140
  submit_button = gr.Button("Submit", variant="primary")
@@ -154,12 +102,14 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
154
  <h3>Models</h3>
155
  <p><code>*_200M</code> models are based on <code>convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384</code> with image size <code>640x640</code></p>
156
  <p><code>*_5M</code> models are based on <code>tf_mobilenetv3_large_100.in1k</code> with image size <code>224x224</code></p>
 
157
 
158
  <h3>Details</h3>
159
  <li>Model cards: <a href='https://huggingface.co/Sumsub/Sumsub-ffs-synthetic-1.0_mj_200'>midjourney200M</a>,
160
  <a href='https://huggingface.co/Sumsub/Sumsub-ffs-synthetic-1.0_mj_5'>midjourney5M</a>,
161
  <a href='https://huggingface.co/Sumsub/Sumsub-ffs-synthetic-1.0_sd_200'>diffusions200M</a>,
162
- <a href='https://huggingface.co/Sumsub/Sumsub-ffs-synthetic-1.0_sd_5'>diffusions5M</a>.
 
163
  </li>
164
  <li>License: CC-By-SA-3.0</li>
165
  """
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
2
  from torch.nn import functional as F
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ from model_loader import ModelType, type_to_transforms, type_to_loaded_model
5
+
6
+ def predict_from_model(model_type, raw_image):
7
+ tfm = type_to_transforms[model_type]
8
+ model = type_to_loaded_model[model_type]
9
+ img = tfm(raw_image)
10
+ y = None
11
+ if model_type == ModelType.SYNTHETIC_DETECTOR_V2:
12
+ y = model.forward(img.unsqueeze(0).to("cpu"))
13
+ else:
14
+ y = model.forward(img[None, ...])
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  y_1 = F.softmax(y, dim=1)[:, 1].cpu().detach().numpy()
16
  y_2 = F.softmax(y, dim=1)[:, 0].cpu().detach().numpy()
17
  return {'created by AI': y_1.tolist(),
18
  'created by human': y_2.tolist()}
19
 
20
+ def get_y(model_type, model, image):
21
+ if model_type == ModelType.SYNTHETIC_DETECTOR_V2:
22
+ return model.forward(image.unsqueeze(0).to("cpu"))
23
+ return model.forward(image[None, ...])
24
 
25
  def predict(raw_image, model_name):
26
+ if model_name not in ModelType.get_list():
 
 
 
27
  return {'error': [0.]}
28
 
29
+ model_type = ModelType[str(model_name).upper()].value
30
+ model = type_to_loaded_model[model_type]
31
+ tfm = type_to_transforms[model_type]
32
+ image = tfm(raw_image)
33
+ y = get_y(model_type, model, image)
34
+ y_1 = F.softmax(y, dim=1)[:, 1].cpu().detach().numpy()
35
+ y_2 = F.softmax(y, dim=1)[:, 0].cpu().detach().numpy()
36
+ return {'created by AI': y_1.tolist(),
37
+ 'created by human': y_2.tolist()}
38
 
39
  general_examples = [
40
  ["images/general/img_1.jpg"],
 
72
  <a href='https://huggingface.co/Sumsub/Sumsub-ffs-synthetic-1.0_mj_200'>midjourney200M</a>,
73
  <a href='https://huggingface.co/Sumsub/Sumsub-ffs-synthetic-1.0_mj_5'>midjourney5M</a>,
74
  <a href='https://huggingface.co/Sumsub/Sumsub-ffs-synthetic-1.0_sd_200'>diffusions200M</a>,
75
+ <a href='https://huggingface.co/Sumsub/Sumsub-ffs-synthetic-1.0_sd_5'>diffusions5M</a>,
76
+ <a href=''>synthetic_detector_v2</a>.
77
+ <br>We provide several detectors for images generated by popular tools, such as Midjourney and Stable Diffusion.<br>
78
  Please refer to model cards for evaluation metrics and limitations.
79
  """
80
  )
 
82
  with gr.Row():
83
  with gr.Column():
84
  image_input = gr.Image(type="pil")
85
+ drop_down = gr.Dropdown(ModelType.get_list(), type="value", label="Model", value=ModelType.SYNTHETIC_DETECTOR_V2)
86
  with gr.Row():
87
  gr.ClearButton(components=[image_input])
88
  submit_button = gr.Button("Submit", variant="primary")
 
102
  <h3>Models</h3>
103
  <p><code>*_200M</code> models are based on <code>convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384</code> with image size <code>640x640</code></p>
104
  <p><code>*_5M</code> models are based on <code>tf_mobilenetv3_large_100.in1k</code> with image size <code>224x224</code></p>
105
+ <p><code>synthetic_detector_2.0</code> models are based on <code>convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384</code> with image size <code>384x384</code></p>
106
 
107
  <h3>Details</h3>
108
  <li>Model cards: <a href='https://huggingface.co/Sumsub/Sumsub-ffs-synthetic-1.0_mj_200'>midjourney200M</a>,
109
  <a href='https://huggingface.co/Sumsub/Sumsub-ffs-synthetic-1.0_mj_5'>midjourney5M</a>,
110
  <a href='https://huggingface.co/Sumsub/Sumsub-ffs-synthetic-1.0_sd_200'>diffusions200M</a>,
111
+ <a href='https://huggingface.co/Sumsub/Sumsub-ffs-synthetic-1.0_sd_5'>diffusions5M</a>,
112
+ <a href=''>synthetic_detector_v2</a>.
113
  </li>
114
  <li>License: CC-By-SA-3.0</li>
115
  """
model_classes.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm
2
+ import torch
3
+ from torch import nn
4
+ import pytorch_lightning as pl
5
+ from pytorch_lightning.core.mixins import HyperparametersMixin
6
+
7
+ class Model200M(torch.nn.Module):
8
+ def __init__(self):
9
+ super().__init__()
10
+ self.model = timm.create_model('convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384', pretrained=False,
11
+ num_classes=0)
12
+
13
+ self.clf = nn.Sequential(
14
+ nn.Linear(1536, 128),
15
+ nn.ReLU(inplace=True),
16
+ nn.Linear(128, 2))
17
+
18
+ def forward(self, image):
19
+ image_features = self.model(image)
20
+ return self.clf(image_features)
21
+
22
+
23
+ class Model5M(torch.nn.Module):
24
+ def __init__(self):
25
+ super().__init__()
26
+ self.model = timm.create_model('timm/tf_mobilenetv3_large_100.in1k', pretrained=False, num_classes=0)
27
+
28
+ self.clf = nn.Sequential(
29
+ nn.Linear(1280, 128),
30
+ nn.ReLU(inplace=True),
31
+ nn.Linear(128, 2))
32
+
33
+ def forward(self, image):
34
+ image_features = self.model(image)
35
+ return self.clf(image_features)
36
+
37
+
38
+ class SyntheticV2(pl.LightningModule, HyperparametersMixin):
39
+ def __init__(self):
40
+ super().__init__()
41
+ self.model = timm.create_model('convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384', pretrained=False,
42
+ num_classes=0)
43
+
44
+ self.clf = nn.Sequential(
45
+ nn.Linear(1536, 128),
46
+ nn.ReLU(inplace=True),
47
+ nn.Linear(128, 2))
48
+
49
+ def forward(self, image):
50
+ image_features = self.model(image)
51
+ return self.clf(image_features)
model_loader.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ import torch
3
+
4
+ from model_classes import Model200M, Model5M, SyntheticV2
5
+ from model_transforms import transform_200M, transform_5M, transform_synthetic
6
+
7
+ class ModelType(str, Enum):
8
+ MIDJOURNEY_200M = "midjourney_200M"
9
+ DIFFUSIONS_200M = "diffusions_200M"
10
+ MIDJOURNEY_5M = "midjourney_5M"
11
+ DIFFUSIONS_5M = "diffusions_5M"
12
+ SYNTHETIC_DETECTOR_V2 = "synthetic_detector_v2"
13
+
14
+ def __str__(self):
15
+ return str(self.value)
16
+
17
+ @staticmethod
18
+ def get_list():
19
+ return [model_type.value for model_type in ModelType]
20
+
21
+ def load_model(value: ModelType):
22
+ model = type_to_class[value]
23
+ path = type_to_path[value]
24
+ ckpt = torch.load(path, map_location=torch.device('cpu'))
25
+ model.load_state_dict(ckpt)
26
+ model.eval()
27
+ return model
28
+
29
+ type_to_class = {
30
+ ModelType.MIDJOURNEY_200M : Model200M(),
31
+ ModelType.DIFFUSIONS_200M : Model200M(),
32
+ ModelType.MIDJOURNEY_5M : Model5M(),
33
+ ModelType.DIFFUSIONS_5M : Model5M(),
34
+ ModelType.SYNTHETIC_DETECTOR_V2 : SyntheticV2(),
35
+ }
36
+
37
+ type_to_path = {
38
+ ModelType.MIDJOURNEY_200M : 'models/midjourney200M.pt',
39
+ ModelType.DIFFUSIONS_200M : 'models/diffusions200M.pt',
40
+ ModelType.MIDJOURNEY_5M : 'models/midjourney5M.pt',
41
+ ModelType.DIFFUSIONS_5M : 'models/diffusions5M.pt',
42
+ ModelType.SYNTHETIC_DETECTOR_V2 : 'models/synthetic_detector_v2.pt',
43
+ }
44
+
45
+ type_to_loaded_model = {
46
+ ModelType.MIDJOURNEY_200M: load_model(ModelType.MIDJOURNEY_200M),
47
+ ModelType.DIFFUSIONS_200M: load_model(ModelType.DIFFUSIONS_200M),
48
+ ModelType.MIDJOURNEY_5M: load_model(ModelType.MIDJOURNEY_5M),
49
+ ModelType.DIFFUSIONS_5M: load_model(ModelType.DIFFUSIONS_5M),
50
+ ModelType.SYNTHETIC_DETECTOR_V2: load_model(ModelType.SYNTHETIC_DETECTOR_V2)
51
+ }
52
+
53
+ type_to_transforms = {
54
+ ModelType.MIDJOURNEY_200M: transform_200M,
55
+ ModelType.DIFFUSIONS_200M: transform_200M,
56
+ ModelType.MIDJOURNEY_5M: transform_5M,
57
+ ModelType.DIFFUSIONS_5M: transform_5M,
58
+ ModelType.SYNTHETIC_DETECTOR_V2: transform_synthetic
59
+ }
model_transforms.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm
2
+ import torchvision
3
+
4
+ data_config = {'input_size': (3, 384, 384),
5
+ 'interpolation': 'bicubic',
6
+ 'mean': (0.48145466, 0.4578275, 0.40821073),
7
+ 'std': (0.26862954, 0.26130258, 0.27577711),
8
+ 'crop_pct': 1.0,
9
+ 'crop_mode': 'squash'}
10
+
11
+ transform_synthetic = timm.data.create_transform(**data_config, is_training=False)
12
+
13
+ transform_200M = torchvision.transforms.Compose([
14
+ torchvision.transforms.Resize((640, 640)),
15
+ torchvision.transforms.ToTensor(),
16
+ torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
17
+ std=[0.229, 0.224, 0.225]),
18
+ ])
19
+
20
+ transform_5M = torchvision.transforms.Compose([
21
+ torchvision.transforms.Resize((224, 224)),
22
+ torchvision.transforms.ToTensor(),
23
+ torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
24
+ std=[0.229, 0.224, 0.225]),
25
+ ])
models/synthetic_detector_v2.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:89a955ec54bddab759228757e437d300b6b86bbba9f45cfd5ecd0e3d7dec83a2
3
+ size 795263437