pengdaqian commited on
Commit
d39fc00
1 Parent(s): 92bbee3
Files changed (9) hide show
  1. .gitignore +479 -0
  2. Dockerfile +10 -0
  3. dbimutils.py +69 -0
  4. img_label.py +236 -0
  5. img_nsfw.py +83 -0
  6. main.py +73 -0
  7. model.py +48 -0
  8. requirements.txt +15 -0
  9. test.jpeg +0 -0
.gitignore ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Ignore Visual Studio temporary files, build results, and
2
+ ## files generated by popular Visual Studio add-ons.
3
+ ##
4
+ ## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore
5
+
6
+ # User-specific files
7
+ *.rsuser
8
+ *.suo
9
+ *.user
10
+ *.userosscache
11
+ *.sln.docstates
12
+
13
+ # User-specific files (MonoDevelop/Xamarin Studio)
14
+ *.userprefs
15
+
16
+ # Mono auto generated files
17
+ mono_crash.*
18
+
19
+ # Build results
20
+ [Dd]ebug/
21
+ [Dd]ebugPublic/
22
+ [Rr]elease/
23
+ [Rr]eleases/
24
+ x64/
25
+ x86/
26
+ [Ww][Ii][Nn]32/
27
+ [Aa][Rr][Mm]/
28
+ [Aa][Rr][Mm]64/
29
+ bld/
30
+ [Bb]in/
31
+ [Oo]bj/
32
+ [Ll]og/
33
+ [Ll]ogs/
34
+
35
+ # Visual Studio 2015/2017 cache/options directory
36
+ .vs/
37
+ # Uncomment if you have tasks that create the project's static files in wwwroot
38
+ #wwwroot/
39
+
40
+ # Visual Studio 2017 auto generated files
41
+ Generated\ Files/
42
+
43
+ # MSTest test Results
44
+ [Tt]est[Rr]esult*/
45
+ [Bb]uild[Ll]og.*
46
+
47
+ # NUnit
48
+ *.VisualState.xml
49
+ TestResult.xml
50
+ nunit-*.xml
51
+
52
+ # Build Results of an ATL Project
53
+ [Dd]ebugPS/
54
+ [Rr]eleasePS/
55
+ dlldata.c
56
+
57
+ # Benchmark Results
58
+ BenchmarkDotNet.Artifacts/
59
+
60
+ # .NET
61
+ project.lock.json
62
+ project.fragment.lock.json
63
+ artifacts/
64
+
65
+ # Tye
66
+ .tye/
67
+
68
+ # ASP.NET Scaffolding
69
+ ScaffoldingReadMe.txt
70
+
71
+ # StyleCop
72
+ StyleCopReport.xml
73
+
74
+ # Files built by Visual Studio
75
+ *_i.c
76
+ *_p.c
77
+ *_h.h
78
+ *.ilk
79
+ *.meta
80
+ *.obj
81
+ *.iobj
82
+ *.pch
83
+ *.pdb
84
+ *.ipdb
85
+ *.pgc
86
+ *.pgd
87
+ *.rsp
88
+ *.sbr
89
+ *.tlb
90
+ *.tli
91
+ *.tlh
92
+ *.tmp
93
+ *.tmp_proj
94
+ *_wpftmp.csproj
95
+ *.log
96
+ *.tlog
97
+ *.vspscc
98
+ *.vssscc
99
+ .builds
100
+ *.pidb
101
+ *.svclog
102
+ *.scc
103
+
104
+ # Chutzpah Test files
105
+ _Chutzpah*
106
+
107
+ # Visual C++ cache files
108
+ ipch/
109
+ *.aps
110
+ *.ncb
111
+ *.opendb
112
+ *.opensdf
113
+ *.sdf
114
+ *.cachefile
115
+ *.VC.db
116
+ *.VC.VC.opendb
117
+
118
+ # Visual Studio profiler
119
+ *.psess
120
+ *.vsp
121
+ *.vspx
122
+ *.sap
123
+
124
+ # Visual Studio Trace Files
125
+ *.e2e
126
+
127
+ # TFS 2012 Local Workspace
128
+ $tf/
129
+
130
+ # Guidance Automation Toolkit
131
+ *.gpState
132
+
133
+ # ReSharper is a .NET coding add-in
134
+ _ReSharper*/
135
+ *.[Rr]e[Ss]harper
136
+ *.DotSettings.user
137
+
138
+ # TeamCity is a build add-in
139
+ _TeamCity*
140
+
141
+ # DotCover is a Code Coverage Tool
142
+ *.dotCover
143
+
144
+ # AxoCover is a Code Coverage Tool
145
+ .axoCover/*
146
+ !.axoCover/settings.json
147
+
148
+ # Coverlet is a free, cross platform Code Coverage Tool
149
+ coverage*.json
150
+ coverage*.xml
151
+ coverage*.info
152
+
153
+ # Visual Studio code coverage results
154
+ *.coverage
155
+ *.coveragexml
156
+
157
+ # NCrunch
158
+ _NCrunch_*
159
+ .*crunch*.local.xml
160
+ nCrunchTemp_*
161
+
162
+ # MightyMoose
163
+ *.mm.*
164
+ AutoTest.Net/
165
+
166
+ # Web workbench (sass)
167
+ .sass-cache/
168
+
169
+ # Installshield output folder
170
+ [Ee]xpress/
171
+
172
+ # DocProject is a documentation generator add-in
173
+ DocProject/buildhelp/
174
+ DocProject/Help/*.HxT
175
+ DocProject/Help/*.HxC
176
+ DocProject/Help/*.hhc
177
+ DocProject/Help/*.hhk
178
+ DocProject/Help/*.hhp
179
+ DocProject/Help/Html2
180
+ DocProject/Help/html
181
+
182
+ # Click-Once directory
183
+ publish/
184
+
185
+ # Publish Web Output
186
+ *.[Pp]ublish.xml
187
+ *.azurePubxml
188
+ # Note: Comment the next line if you want to checkin your web deploy settings,
189
+ # but database connection strings (with potential passwords) will be unencrypted
190
+ *.pubxml
191
+ *.publishproj
192
+
193
+ # Microsoft Azure Web App publish settings. Comment the next line if you want to
194
+ # checkin your Azure Web App publish settings, but sensitive information contained
195
+ # in these scripts will be unencrypted
196
+ PublishScripts/
197
+
198
+ # NuGet Packages
199
+ *.nupkg
200
+ # NuGet Symbol Packages
201
+ *.snupkg
202
+ # The packages folder can be ignored because of Package Restore
203
+ **/[Pp]ackages/*
204
+ # except build/, which is used as an MSBuild target.
205
+ !**/[Pp]ackages/build/
206
+ # Uncomment if necessary however generally it will be regenerated when needed
207
+ #!**/[Pp]ackages/repositories.config
208
+ # NuGet v3's project.json files produces more ignorable files
209
+ *.nuget.props
210
+ *.nuget.targets
211
+
212
+ # Microsoft Azure Build Output
213
+ csx/
214
+ *.build.csdef
215
+
216
+ # Microsoft Azure Emulator
217
+ ecf/
218
+ rcf/
219
+
220
+ # Windows Store app package directories and files
221
+ AppPackages/
222
+ BundleArtifacts/
223
+ Package.StoreAssociation.xml
224
+ _pkginfo.txt
225
+ *.appx
226
+ *.appxbundle
227
+ *.appxupload
228
+
229
+ # Visual Studio cache files
230
+ # files ending in .cache can be ignored
231
+ *.[Cc]ache
232
+ # but keep track of directories ending in .cache
233
+ !?*.[Cc]ache/
234
+
235
+ # Others
236
+ ClientBin/
237
+ ~$*
238
+ *~
239
+ *.dbmdl
240
+ *.dbproj.schemaview
241
+ *.jfm
242
+ *.pfx
243
+ *.publishsettings
244
+ orleans.codegen.cs
245
+
246
+ # Including strong name files can present a security risk
247
+ # (https://github.com/github/gitignore/pull/2483#issue-259490424)
248
+ #*.snk
249
+
250
+ # Since there are multiple workflows, uncomment next line to ignore bower_components
251
+ # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
252
+ #bower_components/
253
+
254
+ # RIA/Silverlight projects
255
+ Generated_Code/
256
+
257
+ # Backup & report files from converting an old project file
258
+ # to a newer Visual Studio version. Backup files are not needed,
259
+ # because we have git ;-)
260
+ _UpgradeReport_Files/
261
+ Backup*/
262
+ UpgradeLog*.XML
263
+ UpgradeLog*.htm
264
+ ServiceFabricBackup/
265
+ *.rptproj.bak
266
+
267
+ # SQL Server files
268
+ *.mdf
269
+ *.ldf
270
+ *.ndf
271
+
272
+ # Business Intelligence projects
273
+ *.rdl.data
274
+ *.bim.layout
275
+ *.bim_*.settings
276
+ *.rptproj.rsuser
277
+ *- [Bb]ackup.rdl
278
+ *- [Bb]ackup ([0-9]).rdl
279
+ *- [Bb]ackup ([0-9][0-9]).rdl
280
+
281
+ # Microsoft Fakes
282
+ FakesAssemblies/
283
+
284
+ # GhostDoc plugin setting file
285
+ *.GhostDoc.xml
286
+
287
+ # Node.js Tools for Visual Studio
288
+ .ntvs_analysis.dat
289
+ node_modules/
290
+
291
+ # Visual Studio 6 build log
292
+ *.plg
293
+
294
+ # Visual Studio 6 workspace options file
295
+ *.opt
296
+
297
+ # Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
298
+ *.vbw
299
+
300
+ # Visual Studio 6 auto-generated project file (contains which files were open etc.)
301
+ *.vbp
302
+
303
+ # Visual Studio 6 workspace and project file (working project files containing files to include in project)
304
+ *.dsw
305
+ *.dsp
306
+
307
+ # Visual Studio 6 technical files
308
+ *.ncb
309
+ *.aps
310
+
311
+ # Visual Studio LightSwitch build output
312
+ **/*.HTMLClient/GeneratedArtifacts
313
+ **/*.DesktopClient/GeneratedArtifacts
314
+ **/*.DesktopClient/ModelManifest.xml
315
+ **/*.Server/GeneratedArtifacts
316
+ **/*.Server/ModelManifest.xml
317
+ _Pvt_Extensions
318
+
319
+ # Paket dependency manager
320
+ .paket/paket.exe
321
+ paket-files/
322
+
323
+ # FAKE - F# Make
324
+ .fake/
325
+
326
+ # CodeRush personal settings
327
+ .cr/personal
328
+
329
+ # Python Tools for Visual Studio (PTVS)
330
+ __pycache__/
331
+ *.pyc
332
+
333
+ # Cake - Uncomment if you are using it
334
+ # tools/**
335
+ # !tools/packages.config
336
+
337
+ # Tabs Studio
338
+ *.tss
339
+
340
+ # Telerik's JustMock configuration file
341
+ *.jmconfig
342
+
343
+ # BizTalk build output
344
+ *.btp.cs
345
+ *.btm.cs
346
+ *.odx.cs
347
+ *.xsd.cs
348
+
349
+ # OpenCover UI analysis results
350
+ OpenCover/
351
+
352
+ # Azure Stream Analytics local run output
353
+ ASALocalRun/
354
+
355
+ # MSBuild Binary and Structured Log
356
+ *.binlog
357
+
358
+ # NVidia Nsight GPU debugger configuration file
359
+ *.nvuser
360
+
361
+ # MFractors (Xamarin productivity tool) working folder
362
+ .mfractor/
363
+
364
+ # Local History for Visual Studio
365
+ .localhistory/
366
+
367
+ # Visual Studio History (VSHistory) files
368
+ .vshistory/
369
+
370
+ # BeatPulse healthcheck temp database
371
+ healthchecksdb
372
+
373
+ # Backup folder for Package Reference Convert tool in Visual Studio 2017
374
+ MigrationBackup/
375
+
376
+ # Ionide (cross platform F# VS Code tools) working folder
377
+ .ionide/
378
+
379
+ # Fody - auto-generated XML schema
380
+ FodyWeavers.xsd
381
+
382
+ # VS Code files for those working on multiple tools
383
+ .vscode/*
384
+ !.vscode/settings.json
385
+ !.vscode/tasks.json
386
+ !.vscode/launch.json
387
+ !.vscode/extensions.json
388
+ *.code-workspace
389
+
390
+ # Local History for Visual Studio Code
391
+ .history/
392
+
393
+ # Windows Installer files from build outputs
394
+ *.cab
395
+ *.msi
396
+ *.msix
397
+ *.msm
398
+ *.msp
399
+
400
+ # JetBrains Rider
401
+ *.sln.iml
402
+
403
+ ##
404
+ ## Visual studio for Mac
405
+ ##
406
+
407
+
408
+ # globs
409
+ Makefile.in
410
+ *.userprefs
411
+ *.usertasks
412
+ config.make
413
+ config.status
414
+ aclocal.m4
415
+ install-sh
416
+ autom4te.cache/
417
+ *.tar.gz
418
+ tarballs/
419
+ test-results/
420
+
421
+ # Mac bundle stuff
422
+ *.dmg
423
+ *.app
424
+
425
+ # content below from: https://github.com/github/gitignore/blob/master/Global/macOS.gitignore
426
+ # General
427
+ .DS_Store
428
+ .AppleDouble
429
+ .LSOverride
430
+
431
+ # Icon must end with two \r
432
+ Icon
433
+
434
+
435
+ # Thumbnails
436
+ ._*
437
+
438
+ # Files that might appear in the root of a volume
439
+ .DocumentRevisions-V100
440
+ .fseventsd
441
+ .Spotlight-V100
442
+ .TemporaryItems
443
+ .Trashes
444
+ .VolumeIcon.icns
445
+ .com.apple.timemachine.donotpresent
446
+
447
+ # Directories potentially created on remote AFP share
448
+ .AppleDB
449
+ .AppleDesktop
450
+ Network Trash Folder
451
+ Temporary Items
452
+ .apdisk
453
+
454
+ # content below from: https://github.com/github/gitignore/blob/master/Global/Windows.gitignore
455
+ # Windows thumbnail cache files
456
+ Thumbs.db
457
+ ehthumbs.db
458
+ ehthumbs_vista.db
459
+
460
+ # Dump file
461
+ *.stackdump
462
+
463
+ # Folder config file
464
+ [Dd]esktop.ini
465
+
466
+ # Recycle Bin used on file shares
467
+ $RECYCLE.BIN/
468
+
469
+ # Windows Installer files
470
+ *.cab
471
+ *.msi
472
+ *.msix
473
+ *.msm
474
+ *.msp
475
+
476
+ # Windows shortcuts
477
+ *.lnk
478
+ *.db
479
+ .idea/
Dockerfile ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.8
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt requirements.txt
6
+ RUN pip3 install -r requirements.txt
7
+
8
+ COPY . .
9
+
10
+ CMD ["sh", "-c", "python3 main.py"]
dbimutils.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DanBooru IMage Utility functions
2
+
3
+ import cv2
4
+ import numpy as np
5
+ from PIL import Image
6
+ import requests
7
+ import io
8
+
9
+ def smart_imread(img, flag=cv2.IMREAD_UNCHANGED):
10
+ if img.endswith(".gif"):
11
+ img = Image.open(img)
12
+ img = img.convert("RGB")
13
+ img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
14
+ else:
15
+ img = cv2.imread(img, flag)
16
+ return img
17
+
18
+
19
+ def smart_24bit(img):
20
+ if img.dtype is np.dtype(np.uint16):
21
+ img = (img / 257).astype(np.uint8)
22
+
23
+ if len(img.shape) == 2:
24
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
25
+ elif img.shape[2] == 4:
26
+ trans_mask = img[:, :, 3] == 0
27
+ img[trans_mask] = [255, 255, 255, 255]
28
+ img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
29
+ return img
30
+
31
+
32
+ def make_square(img, target_size):
33
+ old_size = img.shape[:2]
34
+ desired_size = max(old_size)
35
+ desired_size = max(desired_size, target_size)
36
+
37
+ delta_w = desired_size - old_size[1]
38
+ delta_h = desired_size - old_size[0]
39
+ top, bottom = delta_h // 2, delta_h - (delta_h // 2)
40
+ left, right = delta_w // 2, delta_w - (delta_w // 2)
41
+
42
+ color = [255, 255, 255]
43
+ new_im = cv2.copyMakeBorder(
44
+ img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
45
+ )
46
+ return new_im
47
+
48
+
49
+ def smart_resize(img, size):
50
+ # Assumes the image has already gone through make_square
51
+ if img.shape[0] > size:
52
+ img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
53
+ elif img.shape[0] < size:
54
+ img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
55
+ return img
56
+
57
+ headers = {
58
+ "User-Agent": "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:47.0) Gecko/20100101 Firefox/47.0"
59
+ }
60
+
61
+ def read_img_from_url(image):
62
+ if isinstance(image, str):
63
+ if image.startswith("http"):
64
+ res = requests.get(image, headers=headers)
65
+ rawimage = Image.open(io.BytesIO(res.content))
66
+ return rawimage
67
+ else:
68
+ raise Exception("Invalid image type")
69
+ return image
img_label.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ import io
5
+ import urllib
6
+ from typing import Tuple, List, Any
7
+
8
+ import huggingface_hub
9
+ import onnxruntime as rt
10
+ import pandas as pd
11
+ import numpy as np
12
+ import PIL.Image
13
+ import requests
14
+
15
+ import dbimutils
16
+ import piexif
17
+ import piexif.helper
18
+ from urllib.request import urlopen
19
+
20
+ import model
21
+
22
+ HF_TOKEN = ""
23
+ SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
24
+ CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
25
+ CONV2_MODEL_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
26
+ VIT_MODEL_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
27
+ MODEL_FILENAME = "model.onnx"
28
+ LABEL_FILENAME = "selected_tags.csv"
29
+
30
+
31
+ def change_model(model_name):
32
+ global loaded_models
33
+
34
+ if model_name == "SwinV2":
35
+ model = load_model(SWIN_MODEL_REPO, MODEL_FILENAME)
36
+ elif model_name == "ConvNext":
37
+ model = load_model(CONV_MODEL_REPO, MODEL_FILENAME)
38
+ elif model_name == "ConvNextV2":
39
+ model = load_model(CONV2_MODEL_REPO, MODEL_FILENAME)
40
+ elif model_name == "ViT":
41
+ model = load_model(VIT_MODEL_REPO, MODEL_FILENAME)
42
+
43
+ loaded_models[model_name] = model
44
+ return loaded_models[model_name]
45
+
46
+
47
+ def load_model(model_repo: str, model_filename: str) -> rt.InferenceSession:
48
+ path = huggingface_hub.hf_hub_download(
49
+ model_repo, model_filename, use_auth_token=HF_TOKEN
50
+ )
51
+ model = rt.InferenceSession(path)
52
+ return model
53
+
54
+
55
+ def load_labels() -> tuple[list[Any], list[Any], list[Any], list[Any]]:
56
+ path = huggingface_hub.hf_hub_download(
57
+ CONV2_MODEL_REPO, LABEL_FILENAME, use_auth_token=HF_TOKEN
58
+ )
59
+ df = pd.read_csv(path)
60
+
61
+ tag_names = df["name"].tolist()
62
+ rating_indexes = list(np.where(df["category"] == 9)[0])
63
+ general_indexes = list(np.where(df["category"] == 0)[0])
64
+ character_indexes = list(np.where(df["category"] == 4)[0])
65
+ return tag_names, rating_indexes, general_indexes, character_indexes
66
+
67
+
68
+ def predict(
69
+ image: PIL.Image.Image,
70
+ model_name: str,
71
+ general_threshold: float,
72
+ character_threshold: float,
73
+ tag_names: list[str],
74
+ rating_indexes: list[np.int64],
75
+ general_indexes: list[np.int64],
76
+ character_indexes: list[np.int64],
77
+ ):
78
+ global loaded_models
79
+
80
+ if isinstance(image, str):
81
+ rawimage = dbimutils.read_img_from_url(image)
82
+ elif isinstance(image, PIL.Image.Image):
83
+ rawimage = image
84
+ else:
85
+ raise Exception("Invalid image type")
86
+
87
+ image = rawimage
88
+
89
+ model = loaded_models[model_name]
90
+ if model is None:
91
+ model = change_model(model_name)
92
+
93
+ _, height, width, _ = model.get_inputs()[0].shape
94
+
95
+ # Alpha to white
96
+ image = image.convert("RGBA")
97
+ new_image = PIL.Image.new("RGBA", image.size, "WHITE")
98
+ new_image.paste(image, mask=image)
99
+ image = new_image.convert("RGB")
100
+ image = np.asarray(image)
101
+
102
+ # PIL RGB to OpenCV BGR
103
+ image = image[:, :, ::-1]
104
+
105
+ image = dbimutils.make_square(image, height)
106
+ image = dbimutils.smart_resize(image, height)
107
+ image = image.astype(np.float32)
108
+ image = np.expand_dims(image, 0)
109
+
110
+ input_name = model.get_inputs()[0].name
111
+ label_name = model.get_outputs()[0].name
112
+ probs = model.run([label_name], {input_name: image})[0]
113
+
114
+ labels = list(zip(tag_names, probs[0].astype(float)))
115
+
116
+ # First 4 labels are actually ratings: pick one with argmax
117
+ ratings_names = [labels[i] for i in rating_indexes]
118
+ rating = dict(ratings_names)
119
+
120
+ # Then we have general tags: pick any where prediction confidence > threshold
121
+ general_names = [labels[i] for i in general_indexes]
122
+ general_res = [x for x in general_names if x[1] > general_threshold]
123
+ general_res = dict(general_res)
124
+
125
+ # Everything else is characters: pick any where prediction confidence > threshold
126
+ character_names = [labels[i] for i in character_indexes]
127
+ character_res = [x for x in character_names if x[1] > character_threshold]
128
+ character_res = dict(character_res)
129
+
130
+ b = dict(sorted(general_res.items(), key=lambda item: item[1], reverse=True))
131
+ a = (
132
+ ", ".join(list(b.keys()))
133
+ .replace("_", " ")
134
+ .replace("(", "\(")
135
+ .replace(")", "\)")
136
+ )
137
+ c = ", ".join(list(b.keys()))
138
+
139
+ items = rawimage.info
140
+ geninfo = ""
141
+
142
+ if "exif" in rawimage.info:
143
+ exif = piexif.load(rawimage.info["exif"])
144
+ exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b"")
145
+ try:
146
+ exif_comment = piexif.helper.UserComment.load(exif_comment)
147
+ except ValueError:
148
+ exif_comment = exif_comment.decode("utf8", errors="ignore")
149
+
150
+ items["exif comment"] = exif_comment
151
+ geninfo = exif_comment
152
+
153
+ for field in [
154
+ "jfif",
155
+ "jfif_version",
156
+ "jfif_unit",
157
+ "jfif_density",
158
+ "dpi",
159
+ "exif",
160
+ "loop",
161
+ "background",
162
+ "timestamp",
163
+ "duration",
164
+ ]:
165
+ items.pop(field, None)
166
+
167
+ geninfo = items.get("parameters", geninfo)
168
+
169
+ for key, text in items.items():
170
+ print(key)
171
+ print(text)
172
+
173
+ print("geninfo", geninfo)
174
+ print("a", a)
175
+ print("c", c)
176
+ print("rating", rating)
177
+ print("character_res", character_res)
178
+ print("general_res", general_res)
179
+
180
+ character_res = list(filter(lambda x: x['confidence'] > 0.4, [{'tag': tag, 'confidence': score}
181
+ for tag, score in character_res.items()]))
182
+
183
+ general_res = list(filter(lambda x: x['confidence'] > 0.4, [{'tag': tag, 'confidence': score}
184
+ for tag, score in general_res.items()]))
185
+
186
+ return {'a': a, 'c': c, 'rating': rating, 'character_res': character_res, 'general_res': general_res}
187
+
188
+
189
+ def label_img(
190
+ image: PIL.Image.Image | str,
191
+ model: str,
192
+ # model: (["SwinV2", "ConvNext", "ConvNextV2", "ViT"], value="ConvNextV2", label="Model"),
193
+ l_score_general_threshold: float,
194
+ l_score_character_threshold: float,
195
+ ):
196
+ if isinstance(image, str) and image.startswith("http"):
197
+ image = dbimutils.read_img_from_url(image)
198
+
199
+ global loaded_models
200
+ loaded_models = {"SwinV2": None, "ConvNext": None, "ConvNextV2": None, "ViT": None}
201
+
202
+ change_model("ConvNextV2")
203
+
204
+ tag_names, rating_indexes, general_indexes, character_indexes = load_labels()
205
+
206
+ func = functools.partial(
207
+ predict,
208
+ tag_names=tag_names,
209
+ rating_indexes=rating_indexes,
210
+ general_indexes=general_indexes,
211
+ character_indexes=character_indexes,
212
+ )
213
+
214
+ return func(
215
+ image=image, model_name=model,
216
+ general_threshold=l_score_general_threshold,
217
+ character_threshold=l_score_character_threshold,
218
+ )
219
+
220
+
221
+ def write_image_tag(img_id: int, is_valid: bool, tags: List[model.ImageTag], callback_url: str):
222
+ model.ImageScanCallbackRequest(img_id=img_id, is_valid=is_valid, tags=tags)
223
+
224
+
225
+ if __name__ == "__main__":
226
+ score_slider_step = 0.05
227
+ score_general_threshold = 0.35
228
+ score_character_threshold = 0.85
229
+
230
+ ret = label_img(
231
+ image='https://pub-9747017e9ec54620bfbe2385f14fe4d7.r2.dev/cnGirlYcy_v10_people_network_nannansleep/cnGirlYcy_v10_people_network_nannansleep_r_1679670778_0.png',
232
+ model="SwinV2",
233
+ l_score_general_threshold=score_general_threshold,
234
+ l_score_character_threshold=score_character_threshold,
235
+ )
236
+ print(ret)
img_nsfw.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ concepts = ['sexual', 'nude', 'sex', '18+', 'naked', 'nsfw', 'porn', 'dick', 'vagina', 'naked person (approximation)',
2
+ 'explicit content', 'uncensored', 'fuck', 'nipples', 'nipples (approximation)', 'naked breasts', 'areola']
3
+ special_concepts = ["small girl (approximation)", "young child", "young girl"]
4
+
5
+ import dbimutils
6
+
7
+
8
+ def init_nsfw_pipe():
9
+ import torch
10
+ from diffusers import StableDiffusionPipeline
11
+ from torch import nn
12
+
13
+ # make sure you're logged in with `huggingface-cli login`
14
+ pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16",
15
+ torch_dtype=torch.float16)
16
+ pipe = pipe.to('cuda')
17
+
18
+ def cosine_distance(image_embeds, text_embeds):
19
+ normalized_image_embeds = nn.functional.normalize(image_embeds)
20
+ normalized_text_embeds = nn.functional.normalize(text_embeds)
21
+ return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
22
+
23
+ @torch.no_grad()
24
+ def forward_ours(self, clip_input, images):
25
+ pooled_output = self.vision_model(clip_input)[1] # pooled_output
26
+ image_embeds = self.visual_projection(pooled_output)
27
+
28
+ special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().numpy()
29
+ cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().numpy()
30
+
31
+ result = []
32
+ batch_size = image_embeds.shape[0]
33
+ for i in range(batch_size):
34
+ result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}
35
+
36
+ # increase this value to create a stronger `nfsw` filter
37
+ # at the cost of increasing the possibility of filtering benign images
38
+ adjustment = 0.0
39
+
40
+ for concet_idx in range(len(special_cos_dist[0])):
41
+ concept_cos = special_cos_dist[i][concet_idx]
42
+ concept_threshold = self.special_care_embeds_weights[concet_idx].item()
43
+ result_img["special_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3)
44
+ if result_img["special_scores"][concet_idx] > 0:
45
+ result_img["special_care"].append({"tag": special_concepts[concet_idx],
46
+ "confidence": result_img["special_scores"][concet_idx]})
47
+ adjustment = 0.01
48
+ print("Special concept matched:", special_concepts[concet_idx])
49
+
50
+ for concet_idx in range(len(cos_dist[0])):
51
+ concept_cos = cos_dist[i][concet_idx]
52
+ concept_threshold = self.concept_embeds_weights[concet_idx].item()
53
+ result_img["concept_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3)
54
+ # print("no-special", concet_idx, concepts[concet_idx], concept_threshold, round(concept_cos - concept_threshold + adjustment, 3))
55
+ if result_img["concept_scores"][concet_idx] > 0:
56
+ result_img["bad_concepts"].append({"tag": concepts[concet_idx],
57
+ "confidence": result_img["concept_scores"][concet_idx]})
58
+ print("NSFW concept found:", concepts[concet_idx])
59
+
60
+ special_tags = list(filter(lambda x: x['confidence'] > 0.4, result_img['special_care']))
61
+ bad_tags = list(filter(lambda x: x['confidence'] > 0.4, result_img['bad_concepts']))
62
+
63
+ result.append({"special_tags": special_tags,
64
+ "bad_tags": bad_tags, })
65
+
66
+ return images, result
67
+
68
+ from functools import partial
69
+ pipe.safety_checker.forward = partial(forward_ours, self=pipe.safety_checker)
70
+
71
+ return pipe
72
+
73
+
74
+ def check_nsfw(img, pipe):
75
+ if isinstance(img, str):
76
+ img = dbimutils.read_img_from_url(img)
77
+ safety_checker_input = pipe.feature_extractor(images=img, return_tensors="pt").to("cuda")
78
+ else:
79
+ safety_checker_input = pipe.feature_extractor(images=img, return_tensors="pt").to("cuda")
80
+ from torch.cuda.amp import autocast
81
+ with autocast():
82
+ _, nsfw_tags = pipe.safety_checker.forward(clip_input=safety_checker_input.pixel_values, images=img)
83
+ return nsfw_tags
main.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, List
2
+
3
+ import requests
4
+ import uvicorn
5
+ from fastapi import BackgroundTasks, FastAPI
6
+ import img_label
7
+ from img_nsfw import init_nsfw_pipe, check_nsfw
8
+ import model
9
+
10
+ app = FastAPI()
11
+
12
+
13
+ def write_scan_img_result(image_id: int, scans: List[int], img: str, callback: str):
14
+ score_general_threshold = 0.35
15
+ score_character_threshold = 0.85
16
+
17
+ nsfw_tags = []
18
+ img_tags = []
19
+ if 0 in scans:
20
+ nsfw_tags = check_nsfw(img, pipe)
21
+ if 1 in scans:
22
+ img_tags = img_label.label_img(
23
+ image=img,
24
+ model="SwinV2",
25
+ l_score_general_threshold=score_general_threshold,
26
+ l_score_character_threshold=score_character_threshold,
27
+ )['general_res']
28
+ print(nsfw_tags)
29
+ print(img_tags)
30
+ img_tags = list(map(lambda x: model.ImageTag(tag=x['tag'], confidence=x['confidence']), img_tags))
31
+
32
+ callBackReq = model.ImageScanCallbackRequest(id=image_id, isValid=True, tags=img_tags)
33
+ try:
34
+
35
+ requests.post(callback, json=callBackReq.dict())
36
+ except Exception as ex:
37
+ print(ex)
38
+
39
+ nsfw_tags = map(lambda x: model.ImageScanTag(type="Moderation", confidence=x['confidence']), nsfw_tags)
40
+
41
+ ret = model.ImageScanResponse(ok=True, error="", deleted=False, blockedFor=[], tags=nsfw_tags)
42
+ return ret
43
+
44
+
45
+ def write_scan_model_result(model_name: str, callback: str):
46
+ pass
47
+
48
+
49
+ # @app.post("/model-scan")
50
+ # async def send_notification(email: str, background_tasks: BackgroundTasks):
51
+ # background_tasks.add_task(write_scan_model_result, email, callback="")
52
+ # return {"message": "Notification sent in the background"}
53
+
54
+
55
+ @app.post("/image-scan")
56
+ async def image_scan_handler(req: model.ImageScanRequest, background_tasks: BackgroundTasks):
57
+ if not req.wait:
58
+ background_tasks.add_task(write_scan_img_result,
59
+ image_id=req.imageId,
60
+ scans=req.scans,
61
+ img=req.url, callback=req.callbackUrl)
62
+ return model.ImageScanResponse(ok=True, error="", deleted=False, blockedFor=[], tags=[])
63
+ else:
64
+ ret = write_scan_img_result(image_id=req.imageId,
65
+ scans=req.scans,
66
+ img=req.url, callback=req.callbackUrl)
67
+ return ret
68
+
69
+
70
+ if __name__ == "__main__":
71
+ global pipe
72
+ pipe = init_nsfw_pipe()
73
+ uvicorn.run(app, host="0.0.0.0", port=6006)
model.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class ImageScanRequest(BaseModel):
7
+ imageId: int
8
+ url: str
9
+ wait: bool
10
+ scans: List[int]
11
+ callbackUrl: str
12
+
13
+
14
+ class ImageScanTag(BaseModel):
15
+ type: str
16
+ name: str
17
+
18
+
19
+ class ImageScanResponse(BaseModel):
20
+ ok: bool
21
+ error: str
22
+ deleted: bool
23
+ blockedFor: List[str]
24
+ tags: List[ImageScanTag]
25
+
26
+
27
+ class ImageTag(BaseModel):
28
+ tag: str
29
+ id: Optional[int]
30
+ confidence: int
31
+
32
+
33
+ class ImageScanCallbackRequest(BaseModel):
34
+ id: int
35
+ isValid: bool
36
+ tags: List[ImageTag]
37
+
38
+
39
+ class ModelScanRequest(BaseModel):
40
+ callbackUrl: str
41
+ fileUrl: str
42
+ lowPriority: bool
43
+ tasks: List[str]
44
+
45
+
46
+ class ModelScanResponse(BaseModel):
47
+ ok: bool
48
+ error: str
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ uvicorn[standard]
2
+ fastapi
3
+ huggingface_hub
4
+ pandas
5
+ Pillow
6
+ onnxruntime
7
+ numpy
8
+ piexif
9
+ opencv-python-headless
10
+ diffusers
11
+ transformers
12
+ scipy
13
+ ftfy
14
+ torch
15
+ requests
test.jpeg ADDED