Upload 6 files
Browse files- wdv3-timm-main/.gitignore +253 -0
- wdv3-timm-main/.vscode/settings.json +94 -0
- wdv3-timm-main/README.md +84 -0
- wdv3-timm-main/requirements.txt +11 -0
- wdv3-timm-main/setup.sh +24 -0
- wdv3-timm-main/wdv3_timm.py +203 -0
wdv3-timm-main/.gitignore
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Created by https://www.toptal.com/developers/gitignore/api/linux,windows,macos,visualstudiocode,python
|
2 |
+
# Edit at https://www.toptal.com/developers/gitignore?templates=linux,windows,macos,visualstudiocode,python
|
3 |
+
|
4 |
+
### Linux ###
|
5 |
+
*~
|
6 |
+
|
7 |
+
# temporary files which can be created if a process still has a handle open of a deleted file
|
8 |
+
.fuse_hidden*
|
9 |
+
|
10 |
+
# KDE directory preferences
|
11 |
+
.directory
|
12 |
+
|
13 |
+
# Linux trash folder which might appear on any partition or disk
|
14 |
+
.Trash-*
|
15 |
+
|
16 |
+
# .nfs files are created when an open file is removed but is still being accessed
|
17 |
+
.nfs*
|
18 |
+
|
19 |
+
### macOS ###
|
20 |
+
# General
|
21 |
+
.DS_Store
|
22 |
+
.AppleDouble
|
23 |
+
.LSOverride
|
24 |
+
|
25 |
+
# Icon must end with two \r
|
26 |
+
Icon
|
27 |
+
|
28 |
+
|
29 |
+
# Thumbnails
|
30 |
+
._*
|
31 |
+
|
32 |
+
# Files that might appear in the root of a volume
|
33 |
+
.DocumentRevisions-V100
|
34 |
+
.fseventsd
|
35 |
+
.Spotlight-V100
|
36 |
+
.TemporaryItems
|
37 |
+
.Trashes
|
38 |
+
.VolumeIcon.icns
|
39 |
+
.com.apple.timemachine.donotpresent
|
40 |
+
|
41 |
+
# Directories potentially created on remote AFP share
|
42 |
+
.AppleDB
|
43 |
+
.AppleDesktop
|
44 |
+
Network Trash Folder
|
45 |
+
Temporary Items
|
46 |
+
.apdisk
|
47 |
+
|
48 |
+
### Python ###
|
49 |
+
# Byte-compiled / optimized / DLL files
|
50 |
+
__pycache__/
|
51 |
+
*.py[cod]
|
52 |
+
*$py.class
|
53 |
+
|
54 |
+
# C extensions
|
55 |
+
*.so
|
56 |
+
|
57 |
+
# Distribution / packaging
|
58 |
+
.Python
|
59 |
+
build/
|
60 |
+
develop-eggs/
|
61 |
+
dist/
|
62 |
+
downloads/
|
63 |
+
eggs/
|
64 |
+
.eggs/
|
65 |
+
lib/
|
66 |
+
lib64/
|
67 |
+
parts/
|
68 |
+
sdist/
|
69 |
+
var/
|
70 |
+
wheels/
|
71 |
+
share/python-wheels/
|
72 |
+
*.egg-info/
|
73 |
+
.installed.cfg
|
74 |
+
*.egg
|
75 |
+
MANIFEST
|
76 |
+
|
77 |
+
# PyInstaller
|
78 |
+
# Usually these files are written by a python script from a template
|
79 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
80 |
+
*.manifest
|
81 |
+
*.spec
|
82 |
+
|
83 |
+
# Installer logs
|
84 |
+
pip-log.txt
|
85 |
+
pip-delete-this-directory.txt
|
86 |
+
|
87 |
+
# Unit test / coverage reports
|
88 |
+
htmlcov/
|
89 |
+
.tox/
|
90 |
+
.nox/
|
91 |
+
.coverage
|
92 |
+
.coverage.*
|
93 |
+
.cache
|
94 |
+
nosetests.xml
|
95 |
+
coverage.xml
|
96 |
+
*.cover
|
97 |
+
*.py,cover
|
98 |
+
.hypothesis/
|
99 |
+
.pytest_cache/
|
100 |
+
cover/
|
101 |
+
|
102 |
+
# Translations
|
103 |
+
*.mo
|
104 |
+
*.pot
|
105 |
+
|
106 |
+
# Django stuff:
|
107 |
+
*.log
|
108 |
+
local_settings.py
|
109 |
+
db.sqlite3
|
110 |
+
db.sqlite3-journal
|
111 |
+
|
112 |
+
# Flask stuff:
|
113 |
+
instance/
|
114 |
+
.webassets-cache
|
115 |
+
|
116 |
+
# Scrapy stuff:
|
117 |
+
.scrapy
|
118 |
+
|
119 |
+
# Sphinx documentation
|
120 |
+
docs/_build/
|
121 |
+
|
122 |
+
# PyBuilder
|
123 |
+
.pybuilder/
|
124 |
+
target/
|
125 |
+
|
126 |
+
# Jupyter Notebook
|
127 |
+
.ipynb_checkpoints
|
128 |
+
|
129 |
+
# IPython
|
130 |
+
profile_default/
|
131 |
+
ipython_config.py
|
132 |
+
|
133 |
+
# pyenv
|
134 |
+
# For a library or package, you might want to ignore these files since the code is
|
135 |
+
# intended to run in multiple environments; otherwise, check them in:
|
136 |
+
# .python-version
|
137 |
+
|
138 |
+
# pipenv
|
139 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
140 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
141 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
142 |
+
# install all needed dependencies.
|
143 |
+
#Pipfile.lock
|
144 |
+
|
145 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
146 |
+
__pypackages__/
|
147 |
+
|
148 |
+
# Celery stuff
|
149 |
+
celerybeat-schedule
|
150 |
+
celerybeat.pid
|
151 |
+
|
152 |
+
# SageMath parsed files
|
153 |
+
*.sage.py
|
154 |
+
|
155 |
+
# Environments
|
156 |
+
.env
|
157 |
+
.venv
|
158 |
+
env/
|
159 |
+
venv/
|
160 |
+
ENV/
|
161 |
+
env.bak/
|
162 |
+
venv.bak/
|
163 |
+
|
164 |
+
# Spyder project settings
|
165 |
+
.spyderproject
|
166 |
+
.spyproject
|
167 |
+
|
168 |
+
# Rope project settings
|
169 |
+
.ropeproject
|
170 |
+
|
171 |
+
# mkdocs documentation
|
172 |
+
/site
|
173 |
+
|
174 |
+
# mypy
|
175 |
+
.mypy_cache/
|
176 |
+
.dmypy.json
|
177 |
+
dmypy.json
|
178 |
+
|
179 |
+
# Pyre type checker
|
180 |
+
.pyre/
|
181 |
+
|
182 |
+
# pytype static type analyzer
|
183 |
+
.pytype/
|
184 |
+
|
185 |
+
# Cython debug symbols
|
186 |
+
cython_debug/
|
187 |
+
|
188 |
+
### VisualStudioCode ###
|
189 |
+
.vscode/*
|
190 |
+
!.vscode/settings.json
|
191 |
+
!.vscode/tasks.json
|
192 |
+
!.vscode/launch.json
|
193 |
+
!.vscode/extensions.json
|
194 |
+
*.code-workspace
|
195 |
+
|
196 |
+
# Local History for Visual Studio Code
|
197 |
+
.history/
|
198 |
+
|
199 |
+
### VisualStudioCode Patch ###
|
200 |
+
# Ignore all local history of files
|
201 |
+
.history
|
202 |
+
.ionide
|
203 |
+
|
204 |
+
### Windows ###
|
205 |
+
# Windows thumbnail cache files
|
206 |
+
Thumbs.db
|
207 |
+
Thumbs.db:encryptable
|
208 |
+
ehthumbs.db
|
209 |
+
ehthumbs_vista.db
|
210 |
+
|
211 |
+
# Dump file
|
212 |
+
*.stackdump
|
213 |
+
|
214 |
+
# Folder config file
|
215 |
+
[Dd]esktop.ini
|
216 |
+
|
217 |
+
# Recycle Bin used on file shares
|
218 |
+
$RECYCLE.BIN/
|
219 |
+
|
220 |
+
# Windows Installer files
|
221 |
+
*.cab
|
222 |
+
*.msi
|
223 |
+
*.msix
|
224 |
+
*.msm
|
225 |
+
*.msp
|
226 |
+
|
227 |
+
# Windows shortcuts
|
228 |
+
*.lnk
|
229 |
+
|
230 |
+
# End of https://www.toptal.com/developers/gitignore/api/linux,windows,macos,visualstudiocode,python
|
231 |
+
|
232 |
+
# temp and misc
|
233 |
+
/misc/
|
234 |
+
/temp/
|
235 |
+
|
236 |
+
# direnv
|
237 |
+
.envrc
|
238 |
+
.envrc.*
|
239 |
+
|
240 |
+
# dotenv
|
241 |
+
.env
|
242 |
+
.env.*
|
243 |
+
|
244 |
+
# temp files
|
245 |
+
**/tmp_*.*
|
246 |
+
**/*.tmp.*
|
247 |
+
|
248 |
+
# but keep examples
|
249 |
+
!*.example
|
250 |
+
|
251 |
+
# input images and heatmap outputs
|
252 |
+
/images/
|
253 |
+
/heatmaps/
|
wdv3-timm-main/.vscode/settings.json
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"editor.insertSpaces": true,
|
3 |
+
"editor.tabSize": 4,
|
4 |
+
"files.trimTrailingWhitespace": true,
|
5 |
+
"editor.rulers": [100, 120],
|
6 |
+
|
7 |
+
"files.associations": {
|
8 |
+
"*.yaml": "yaml"
|
9 |
+
},
|
10 |
+
"files.exclude": {
|
11 |
+
"**/.git": true,
|
12 |
+
"**/.svn": true,
|
13 |
+
"**/.hg": true,
|
14 |
+
"**/CVS": true,
|
15 |
+
"**/.DS_Store": true,
|
16 |
+
"**/Thumbs.db": true,
|
17 |
+
"**/.ruff_cache": true,
|
18 |
+
"**/__pycache__": true,
|
19 |
+
"**/*.egg-info": true
|
20 |
+
},
|
21 |
+
|
22 |
+
"[shellscript]": {
|
23 |
+
"files.eol": "\n",
|
24 |
+
"editor.tabSize": 4,
|
25 |
+
"editor.detectIndentation": false
|
26 |
+
},
|
27 |
+
|
28 |
+
"[python]": {
|
29 |
+
"editor.wordBasedSuggestions": "off",
|
30 |
+
"editor.formatOnSave": true,
|
31 |
+
"editor.defaultFormatter": "charliermarsh.ruff",
|
32 |
+
"editor.codeActionsOnSave": {
|
33 |
+
"source.organizeImports": "always"
|
34 |
+
}
|
35 |
+
},
|
36 |
+
"python.analysis.include": ["./src", "./scripts", "./tests"],
|
37 |
+
|
38 |
+
"[json]": {
|
39 |
+
"editor.defaultFormatter": "esbenp.prettier-vscode",
|
40 |
+
"editor.detectIndentation": false,
|
41 |
+
"editor.formatOnSaveMode": "file",
|
42 |
+
"editor.formatOnSave": true,
|
43 |
+
"editor.tabSize": 2
|
44 |
+
},
|
45 |
+
"[jsonc]": {
|
46 |
+
"editor.defaultFormatter": "esbenp.prettier-vscode",
|
47 |
+
"editor.detectIndentation": false,
|
48 |
+
"editor.formatOnSaveMode": "file",
|
49 |
+
"editor.formatOnSave": true,
|
50 |
+
"editor.tabSize": 2
|
51 |
+
},
|
52 |
+
|
53 |
+
"[toml]": {
|
54 |
+
"editor.tabSize": 2,
|
55 |
+
"editor.detectIndentation": false,
|
56 |
+
"editor.formatOnSave": true,
|
57 |
+
"editor.formatOnSaveMode": "file",
|
58 |
+
"editor.defaultFormatter": "tamasfe.even-better-toml",
|
59 |
+
"editor.rulers": [80, 100]
|
60 |
+
},
|
61 |
+
"evenBetterToml.formatter.columnWidth": 88,
|
62 |
+
|
63 |
+
"[yaml]": {
|
64 |
+
"editor.detectIndentation": false,
|
65 |
+
"editor.tabSize": 2,
|
66 |
+
"editor.formatOnSave": true,
|
67 |
+
"editor.formatOnSaveMode": "file",
|
68 |
+
"diffEditor.ignoreTrimWhitespace": false,
|
69 |
+
"editor.defaultFormatter": "redhat.vscode-yaml"
|
70 |
+
},
|
71 |
+
"yaml.format.bracketSpacing": true,
|
72 |
+
"yaml.format.proseWrap": "preserve",
|
73 |
+
"yaml.format.singleQuote": false,
|
74 |
+
"yaml.format.printWidth": 110,
|
75 |
+
|
76 |
+
"[hcl]": {
|
77 |
+
"editor.detectIndentation": false,
|
78 |
+
"editor.formatOnSave": true,
|
79 |
+
"editor.formatOnSaveMode": "file",
|
80 |
+
"editor.defaultFormatter": "fredwangwang.vscode-hcl-format"
|
81 |
+
},
|
82 |
+
|
83 |
+
"[markdown]": {
|
84 |
+
"files.trimTrailingWhitespace": false
|
85 |
+
},
|
86 |
+
|
87 |
+
"css.lint.validProperties": ["dock", "content-align", "content-justify"],
|
88 |
+
"[css]": {
|
89 |
+
"editor.formatOnSave": true
|
90 |
+
},
|
91 |
+
|
92 |
+
"remote.autoForwardPorts": false,
|
93 |
+
"remote.autoForwardPortsSource": "process"
|
94 |
+
}
|
wdv3-timm-main/README.md
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# wdv3-timm
|
2 |
+
|
3 |
+
small example thing showing how to use `timm` to run the WD Tagger V3 models.
|
4 |
+
|
5 |
+
## How To Use
|
6 |
+
|
7 |
+
1. clone the repository and enter the directory:
|
8 |
+
```sh
|
9 |
+
git clone https://github.com/neggles/wdv3-timm.git
|
10 |
+
cd wd3-timm
|
11 |
+
```
|
12 |
+
|
13 |
+
2. Create a virtual environment and install the Python requirements.
|
14 |
+
|
15 |
+
If you're using Linux, you can use the provided script:
|
16 |
+
```sh
|
17 |
+
bash setup.sh
|
18 |
+
```
|
19 |
+
|
20 |
+
Or if you're on Windows (or just want to do it manually), you can do the following:
|
21 |
+
```sh
|
22 |
+
# Create virtual environment
|
23 |
+
python3.10 -m venv .venv
|
24 |
+
# Activate it
|
25 |
+
source .venv/bin/activate
|
26 |
+
# Upgrade pip/setuptools/wheel
|
27 |
+
python -m pip install -U pip setuptools wheel
|
28 |
+
# At this point, optionally you can install PyTorch manually (e.g. if you are not using an nVidia GPU)
|
29 |
+
python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
30 |
+
# Install requirements
|
31 |
+
python -m pip install -r requirements.txt
|
32 |
+
```
|
33 |
+
|
34 |
+
3. Run the example script, picking one of the 3 models to use:
|
35 |
+
```sh
|
36 |
+
python wdv3_timm.py <swinv2|convnext|vit> path/to/image.png
|
37 |
+
```
|
38 |
+
|
39 |
+
Example output from `python wdv3_timm.py vit a_picture_of_ganyu.png`:
|
40 |
+
```sh
|
41 |
+
Loading model 'vit' from 'SmilingWolf/wd-vit-tagger-v3'...
|
42 |
+
Loading tag list...
|
43 |
+
Creating data transform...
|
44 |
+
Loading image and preprocessing...
|
45 |
+
Running inference...
|
46 |
+
Processing results...
|
47 |
+
--------
|
48 |
+
Caption: 1girl, horns, solo, bell, ahoge, colored_skin, blue_skin, neck_bell, looking_at_viewer, purple_eyes, upper_body, blonde_hair, long_hair, goat_horns, blue_hair, off_shoulder, sidelocks, bare_shoulders, alternate_costume, shirt, black_shirt, cowbell, ganyu_(genshin_impact)
|
49 |
+
--------
|
50 |
+
Tags: 1girl, horns, solo, bell, ahoge, colored skin, blue skin, neck bell, looking at viewer, purple eyes, upper body, blonde hair, long hair, goat horns, blue hair, off shoulder, sidelocks, bare shoulders, alternate costume, shirt, black shirt, cowbell, ganyu \(genshin impact\)
|
51 |
+
--------
|
52 |
+
Ratings:
|
53 |
+
general: 0.827
|
54 |
+
sensitive: 0.199
|
55 |
+
questionable: 0.001
|
56 |
+
explicit: 0.001
|
57 |
+
--------
|
58 |
+
Character tags (threshold=0.75):
|
59 |
+
ganyu_(genshin_impact): 0.991
|
60 |
+
--------
|
61 |
+
General tags (threshold=0.35):
|
62 |
+
1girl: 0.996
|
63 |
+
horns: 0.950
|
64 |
+
solo: 0.947
|
65 |
+
bell: 0.918
|
66 |
+
ahoge: 0.897
|
67 |
+
colored_skin: 0.881
|
68 |
+
blue_skin: 0.872
|
69 |
+
neck_bell: 0.854
|
70 |
+
looking_at_viewer: 0.817
|
71 |
+
purple_eyes: 0.734
|
72 |
+
upper_body: 0.615
|
73 |
+
blonde_hair: 0.609
|
74 |
+
long_hair: 0.607
|
75 |
+
goat_horns: 0.524
|
76 |
+
blue_hair: 0.496
|
77 |
+
off_shoulder: 0.472
|
78 |
+
sidelocks: 0.470
|
79 |
+
bare_shoulders: 0.464
|
80 |
+
alternate_costume: 0.437
|
81 |
+
shirt: 0.427
|
82 |
+
black_shirt: 0.417
|
83 |
+
cowbell: 0.415
|
84 |
+
```
|
wdv3-timm-main/requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diffusers
|
2 |
+
huggingface-hub
|
3 |
+
numpy
|
4 |
+
pandas
|
5 |
+
pillow >= 9.5.0
|
6 |
+
simple-parsing >= 0.1.5
|
7 |
+
timm @ git+https://github.com/huggingface/pytorch-image-models@main#egg=timm
|
8 |
+
tokenizers
|
9 |
+
torch >= 2.0.0
|
10 |
+
torchvision
|
11 |
+
transformers
|
wdv3-timm-main/setup.sh
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
set -euo pipefail
|
3 |
+
|
4 |
+
# get the folder this script is in and make sure we're in it
|
5 |
+
script_dir=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd -P)
|
6 |
+
cd "${script_dir}"
|
7 |
+
|
8 |
+
# make venv if not exist
|
9 |
+
if [[ ! -d .venv ]]; then
|
10 |
+
echo "Creating virtual environment..."
|
11 |
+
python3.10 -m venv .venv
|
12 |
+
fi
|
13 |
+
|
14 |
+
# activate the venv
|
15 |
+
source .venv/bin/activate
|
16 |
+
|
17 |
+
# upgrade pip
|
18 |
+
python -m pip install -U pip setuptools wheel
|
19 |
+
|
20 |
+
# install requirements
|
21 |
+
python -m pip install -r requirements.txt
|
22 |
+
|
23 |
+
echo "Setup complete. Run 'source .venv/bin/activate' to enter the virtual environment."
|
24 |
+
exit 0
|
wdv3-timm-main/wdv3_timm.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import timm
|
8 |
+
import torch
|
9 |
+
from huggingface_hub import hf_hub_download
|
10 |
+
from huggingface_hub.utils import HfHubHTTPError
|
11 |
+
from PIL import Image
|
12 |
+
from simple_parsing import field, parse_known_args
|
13 |
+
from timm.data import create_transform, resolve_data_config
|
14 |
+
from torch import Tensor, nn
|
15 |
+
from torch.nn import functional as F
|
16 |
+
|
17 |
+
import json
|
18 |
+
|
19 |
+
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
+
MODEL_REPO_MAP = {
|
21 |
+
"vit": "SmilingWolf/wd-vit-tagger-v3",
|
22 |
+
"swinv2": "SmilingWolf/wd-swinv2-tagger-v3",
|
23 |
+
"convnext": "SmilingWolf/wd-convnext-tagger-v3",
|
24 |
+
}
|
25 |
+
|
26 |
+
def pil_ensure_rgb(image: Image.Image) -> Image.Image:
|
27 |
+
# convert to RGB/RGBA if not already (deals with palette images etc.)
|
28 |
+
if image.mode not in ["RGB", "RGBA"]:
|
29 |
+
image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
|
30 |
+
# convert RGBA to RGB with white background
|
31 |
+
if image.mode == "RGBA":
|
32 |
+
canvas = Image.new("RGBA", image.size, (255, 255, 255))
|
33 |
+
canvas.alpha_composite(image)
|
34 |
+
image = canvas.convert("RGB")
|
35 |
+
return image
|
36 |
+
|
37 |
+
def pil_pad_square(image: Image.Image) -> Image.Image:
|
38 |
+
w, h = image.size
|
39 |
+
# get the largest dimension so we can pad to a square
|
40 |
+
px = max(image.size)
|
41 |
+
# pad to square with white background
|
42 |
+
canvas = Image.new("RGB", (px, px), (255, 255, 255))
|
43 |
+
canvas.paste(image, ((px - w) // 2, (px - h) // 2))
|
44 |
+
return canvas
|
45 |
+
|
46 |
+
@dataclass
|
47 |
+
class LabelData:
|
48 |
+
names: list[str]
|
49 |
+
rating: list[np.int64]
|
50 |
+
general: list[np.int64]
|
51 |
+
character: list[np.int64]
|
52 |
+
|
53 |
+
def load_labels_hf(
|
54 |
+
repo_id: str,
|
55 |
+
revision: Optional[str] = None,
|
56 |
+
token: Optional[str] = None,
|
57 |
+
) -> LabelData:
|
58 |
+
try:
|
59 |
+
csv_path = hf_hub_download(
|
60 |
+
repo_id=repo_id, filename="selected_tags.csv", revision=revision, token=token
|
61 |
+
)
|
62 |
+
csv_path = Path(csv_path).resolve()
|
63 |
+
except HfHubHTTPError as e:
|
64 |
+
raise FileNotFoundError(f"selected_tags.csv failed to download from {repo_id}") from e
|
65 |
+
|
66 |
+
df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
|
67 |
+
tag_data = LabelData(
|
68 |
+
names=df["name"].tolist(),
|
69 |
+
rating=list(np.where(df["category"] == 9)[0]),
|
70 |
+
general=list(np.where(df["category"] == 0)[0]),
|
71 |
+
character=list(np.where(df["category"] == 4)[0]),
|
72 |
+
)
|
73 |
+
|
74 |
+
return tag_data
|
75 |
+
|
76 |
+
def get_tags(
|
77 |
+
probs: Tensor,
|
78 |
+
labels: LabelData,
|
79 |
+
gen_threshold: float,
|
80 |
+
char_threshold: float,
|
81 |
+
):
|
82 |
+
# Convert indices+probs to labels
|
83 |
+
probs = list(zip(labels.names, probs.numpy()))
|
84 |
+
|
85 |
+
# First 4 labels are actually ratings
|
86 |
+
rating_labels = dict([probs[i] for i in labels.rating])
|
87 |
+
|
88 |
+
# General labels, pick any where prediction confidence > threshold
|
89 |
+
gen_labels = [probs[i] for i in labels.general]
|
90 |
+
gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold])
|
91 |
+
gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True))
|
92 |
+
|
93 |
+
# Character labels, pick any where prediction confidence > threshold
|
94 |
+
char_labels = [probs[i] for i in labels.character]
|
95 |
+
char_labels = dict([x for x in char_labels if x[1] > char_threshold])
|
96 |
+
char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True))
|
97 |
+
|
98 |
+
# Combine general and character labels, sort by confidence
|
99 |
+
combined_names = [x for x in gen_labels]
|
100 |
+
combined_names.extend([x for x in char_labels])
|
101 |
+
|
102 |
+
# Convert to a string suitable for use as a training caption
|
103 |
+
caption = ", ".join(combined_names)
|
104 |
+
taglist = caption.replace("_", " ").replace("(", "\(").replace(")", "\)")
|
105 |
+
|
106 |
+
return caption, taglist, rating_labels, char_labels, gen_labels
|
107 |
+
|
108 |
+
@dataclass
|
109 |
+
class ScriptOptions:
|
110 |
+
image_file: Path = field(positional=True)
|
111 |
+
model: str = field(default="vit")
|
112 |
+
gen_threshold: float = field(default=0.35)
|
113 |
+
char_threshold: float = field(default=0.75)
|
114 |
+
|
115 |
+
def main(opts: ScriptOptions):
|
116 |
+
repo_id = MODEL_REPO_MAP.get(opts.model)
|
117 |
+
image_path = Path(opts.image_file).resolve()
|
118 |
+
if not image_path.is_file():
|
119 |
+
raise FileNotFoundError(f"Image file not found: {image_path}")
|
120 |
+
|
121 |
+
print(f"Loading model '{opts.model}' from '{repo_id}'...")
|
122 |
+
model: nn.Module = timm.create_model("hf-hub:" + repo_id).eval()
|
123 |
+
state_dict = timm.models.load_state_dict_from_hf(repo_id)
|
124 |
+
model.load_state_dict(state_dict)
|
125 |
+
|
126 |
+
print("Loading tag list...")
|
127 |
+
labels: LabelData = load_labels_hf(repo_id=repo_id)
|
128 |
+
|
129 |
+
print("Creating data transform...")
|
130 |
+
transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
|
131 |
+
|
132 |
+
print("Loading image and preprocessing...")
|
133 |
+
# get image
|
134 |
+
img_input: Image.Image = Image.open(image_path)
|
135 |
+
# ensure image is RGB
|
136 |
+
img_input = pil_ensure_rgb(img_input)
|
137 |
+
# pad to square with white background
|
138 |
+
img_input = pil_pad_square(img_input)
|
139 |
+
# run the model's input transform to convert to tensor and rescale
|
140 |
+
inputs: Tensor = transform(img_input).unsqueeze(0)
|
141 |
+
# NCHW image RGB to BGR
|
142 |
+
inputs = inputs[:, [2, 1, 0]]
|
143 |
+
|
144 |
+
print("Running inference...")
|
145 |
+
with torch.inference_mode():
|
146 |
+
# move model to GPU, if available
|
147 |
+
if torch_device.type != "cpu":
|
148 |
+
model = model.to(torch_device)
|
149 |
+
inputs = inputs.to(torch_device)
|
150 |
+
# run the model
|
151 |
+
outputs = model.forward(inputs)
|
152 |
+
# apply the final activation function (timm doesn't support doing this internally)
|
153 |
+
outputs = F.sigmoid(outputs)
|
154 |
+
# move inputs, outputs, và model về CPU nếu đang ở trên GPU
|
155 |
+
if torch_device.type != "cpu":
|
156 |
+
inputs = inputs.to("cpu")
|
157 |
+
outputs = outputs.to("cpu")
|
158 |
+
model = model.to("cpu")
|
159 |
+
|
160 |
+
print("Processing results...")
|
161 |
+
# Đọc giá trị từ config.json
|
162 |
+
with open('config.json', 'r') as config_file:
|
163 |
+
config_data = json.load(config_file)
|
164 |
+
|
165 |
+
gen_threshold = config_data.get('general_threshold', 0.35)
|
166 |
+
char_threshold = config_data.get('character_threshold', 0.75)
|
167 |
+
|
168 |
+
caption, taglist, ratings, character, general = get_tags(
|
169 |
+
probs=outputs.squeeze(0),
|
170 |
+
labels=labels,
|
171 |
+
gen_threshold=gen_threshold,
|
172 |
+
char_threshold=char_threshold,
|
173 |
+
)
|
174 |
+
|
175 |
+
print("--------")
|
176 |
+
print(f"Caption: {caption}")
|
177 |
+
print("--------")
|
178 |
+
print(f"Tags: {taglist}")
|
179 |
+
|
180 |
+
print("--------")
|
181 |
+
print("Ratings:")
|
182 |
+
for k, v in ratings.items():
|
183 |
+
print(f" {k}: {v:.3f}")
|
184 |
+
|
185 |
+
print("--------")
|
186 |
+
print(f"Character tags (threshold={char_threshold}):")
|
187 |
+
for k, v in character.items():
|
188 |
+
print(f" {k}: {v:.3f}")
|
189 |
+
|
190 |
+
print("--------")
|
191 |
+
print(f"General tags (threshold={gen_threshold}):")
|
192 |
+
for k, v in general.items():
|
193 |
+
print(f" {k}: {v:.3f}")
|
194 |
+
|
195 |
+
print("Done!")
|
196 |
+
|
197 |
+
|
198 |
+
if __name__ == "__main__":
|
199 |
+
opts, _ = parse_known_args(ScriptOptions)
|
200 |
+
if opts.model not in MODEL_REPO_MAP:
|
201 |
+
print(f"Available models: {list(MODEL_REPO_MAP.keys())}")
|
202 |
+
raise ValueError(f"Unknown model name '{opts.model}'")
|
203 |
+
main(opts)
|