Upload 6 files
Browse files- wdv3-timm-main/.gitignore +253 -0
- wdv3-timm-main/.vscode/settings.json +94 -0
- wdv3-timm-main/README.md +84 -0
- wdv3-timm-main/requirements.txt +11 -0
- wdv3-timm-main/setup.sh +24 -0
- wdv3-timm-main/wdv3_timm.py +203 -0
@@ -0,0 +1,253 @@
1 |
# Created by https://www.toptal.com/developers/gitignore/api/linux,windows,macos,visualstudiocode,python
2 |
# Edit at https://www.toptal.com/developers/gitignore?templates=linux,windows,macos,visualstudiocode,python
3 |
4 |
### Linux ###
5 |
6 |
7 |
# temporary files which can be created if a process still has a handle open of a deleted file
8 |
9 |
10 |
# KDE directory preferences
11 |
12 |
13 |
# Linux trash folder which might appear on any partition or disk
14 |
15 |
16 |
# .nfs files are created when an open file is removed but is still being accessed
17 |
18 |
19 |
### macOS ###
20 |
# General
21 |
22 |
23 |
24 |
25 |
# Icon must end with two \r
26 |
27 |
28 |
29 |
# Thumbnails
30 |
31 |
32 |
# Files that might appear in the root of a volume
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
# Directories potentially created on remote AFP share
42 |
43 |
44 |
Network Trash Folder
45 |
Temporary Items
46 |
47 |
48 |
### Python ###
49 |
# Byte-compiled / optimized / DLL files
50 |
51 |
52 |
53 |
54 |
# C extensions
55 |
56 |
57 |
# Distribution / packaging
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
# PyInstaller
78 |
# Usually these files are written by a python script from a template
79 |
# before PyInstaller builds the exe, so as to inject date/other infos into it.
80 |
81 |
82 |
83 |
# Installer logs
84 |
85 |
86 |
87 |
# Unit test / coverage reports
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
# Translations
103 |
104 |
105 |
106 |
# Django stuff:
107 |
108 |
109 |
110 |
111 |
112 |
# Flask stuff:
113 |
114 |
115 |
116 |
# Scrapy stuff:
117 |
118 |
119 |
# Sphinx documentation
120 |
121 |
122 |
# PyBuilder
123 |
124 |
125 |
126 |
# Jupyter Notebook
127 |
128 |
129 |
# IPython
130 |
131 |
132 |
133 |
# pyenv
134 |
# For a library or package, you might want to ignore these files since the code is
135 |
# intended to run in multiple environments; otherwise, check them in:
136 |
# .python-version
137 |
138 |
# pipenv
139 |
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
140 |
# However, in case of collaboration, if having platform-specific dependencies or dependencies
141 |
# having no cross-platform support, pipenv may install dependencies that don't work, or not
142 |
# install all needed dependencies.
143 |
144 |
145 |
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
146 |
147 |
148 |
# Celery stuff
149 |
150 |
151 |
152 |
# SageMath parsed files
153 |
154 |
155 |
# Environments
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
# Spyder project settings
165 |
166 |
167 |
168 |
# Rope project settings
169 |
170 |
171 |
# mkdocs documentation
172 |
173 |
174 |
# mypy
175 |
176 |
177 |
178 |
179 |
# Pyre type checker
180 |
181 |
182 |
# pytype static type analyzer
183 |
184 |
185 |
# Cython debug symbols
186 |
187 |
188 |
### VisualStudioCode ###
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
# Local History for Visual Studio Code
197 |
198 |
199 |
### VisualStudioCode Patch ###
200 |
# Ignore all local history of files
201 |
202 |
203 |
204 |
### Windows ###
205 |
# Windows thumbnail cache files
206 |
207 |
208 |
209 |
210 |
211 |
# Dump file
212 |
213 |
214 |
# Folder config file
215 |
216 |
217 |
# Recycle Bin used on file shares
218 |
219 |
220 |
# Windows Installer files
221 |
222 |
223 |
224 |
225 |
226 |
227 |
# Windows shortcuts
228 |
229 |
230 |
# End of https://www.toptal.com/developers/gitignore/api/linux,windows,macos,visualstudiocode,python
231 |
232 |
# temp and misc
233 |
234 |
235 |
236 |
# direnv
237 |
238 |
239 |
240 |
# dotenv
241 |
242 |
243 |
244 |
# temp files
245 |
246 |
247 |
248 |
# but keep examples
249 |
250 |
251 |
# input images and heatmap outputs
252 |
253 |
@@ -0,0 +1,94 @@
1 |
2 |
"editor.insertSpaces": true,
3 |
"editor.tabSize": 4,
4 |
"files.trimTrailingWhitespace": true,
5 |
"editor.rulers": [100, 120],
6 |
7 |
"files.associations": {
8 |
"*.yaml": "yaml"
9 |
10 |
"files.exclude": {
11 |
"**/.git": true,
12 |
"**/.svn": true,
13 |
"**/.hg": true,
14 |
"**/CVS": true,
15 |
"**/.DS_Store": true,
16 |
"**/Thumbs.db": true,
17 |
"**/.ruff_cache": true,
18 |
"**/__pycache__": true,
19 |
"**/*.egg-info": true
20 |
21 |
22 |
"[shellscript]": {
23 |
"files.eol": "\n",
24 |
"editor.tabSize": 4,
25 |
"editor.detectIndentation": false
26 |
27 |
28 |
"[python]": {
29 |
"editor.wordBasedSuggestions": "off",
30 |
"editor.formatOnSave": true,
31 |
"editor.defaultFormatter": "charliermarsh.ruff",
32 |
"editor.codeActionsOnSave": {
33 |
"source.organizeImports": "always"
34 |
35 |
36 |
"python.analysis.include": ["./src", "./scripts", "./tests"],
37 |
38 |
"[json]": {
39 |
"editor.defaultFormatter": "esbenp.prettier-vscode",
40 |
"editor.detectIndentation": false,
41 |
"editor.formatOnSaveMode": "file",
42 |
"editor.formatOnSave": true,
43 |
"editor.tabSize": 2
44 |
45 |
"[jsonc]": {
46 |
"editor.defaultFormatter": "esbenp.prettier-vscode",
47 |
"editor.detectIndentation": false,
48 |
"editor.formatOnSaveMode": "file",
49 |
"editor.formatOnSave": true,
50 |
"editor.tabSize": 2
51 |
52 |
53 |
"[toml]": {
54 |
"editor.tabSize": 2,
55 |
"editor.detectIndentation": false,
56 |
"editor.formatOnSave": true,
57 |
"editor.formatOnSaveMode": "file",
58 |
"editor.defaultFormatter": "tamasfe.even-better-toml",
59 |
"editor.rulers": [80, 100]
60 |
61 |
"evenBetterToml.formatter.columnWidth": 88,
62 |
63 |
"[yaml]": {
64 |
"editor.detectIndentation": false,
65 |
"editor.tabSize": 2,
66 |
"editor.formatOnSave": true,
67 |
"editor.formatOnSaveMode": "file",
68 |
"diffEditor.ignoreTrimWhitespace": false,
69 |
"editor.defaultFormatter": "redhat.vscode-yaml"
70 |
71 |
"yaml.format.bracketSpacing": true,
72 |
"yaml.format.proseWrap": "preserve",
73 |
"yaml.format.singleQuote": false,
74 |
"yaml.format.printWidth": 110,
75 |
76 |
"[hcl]": {
77 |
"editor.detectIndentation": false,
78 |
"editor.formatOnSave": true,
79 |
"editor.formatOnSaveMode": "file",
80 |
"editor.defaultFormatter": "fredwangwang.vscode-hcl-format"
81 |
82 |
83 |
"[markdown]": {
84 |
"files.trimTrailingWhitespace": false
85 |
86 |
87 |
"css.lint.validProperties": ["dock", "content-align", "content-justify"],
88 |
"[css]": {
89 |
"editor.formatOnSave": true
90 |
91 |
92 |
"remote.autoForwardPorts": false,
93 |
"remote.autoForwardPortsSource": "process"
94 |
@@ -0,0 +1,84 @@
1 |
# wdv3-timm
2 |
3 |
small example thing showing how to use `timm` to run the WD Tagger V3 models.
4 |
5 |
## How To Use
6 |
7 |
1. clone the repository and enter the directory:
8 |
9 |
git clone https://github.com/neggles/wdv3-timm.git
10 |
cd wd3-timm
11 |
12 |
13 |
2. Create a virtual environment and install the Python requirements.
14 |
15 |
If you're using Linux, you can use the provided script:
16 |
17 |
bash setup.sh
18 |
19 |
20 |
Or if you're on Windows (or just want to do it manually), you can do the following:
21 |
22 |
# Create virtual environment
23 |
python3.10 -m venv .venv
24 |
# Activate it
25 |
source .venv/bin/activate
26 |
# Upgrade pip/setuptools/wheel
27 |
python -m pip install -U pip setuptools wheel
28 |
# At this point, optionally you can install PyTorch manually (e.g. if you are not using an nVidia GPU)
29 |
python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
30 |
# Install requirements
31 |
python -m pip install -r requirements.txt
32 |
33 |
34 |
3. Run the example script, picking one of the 3 models to use:
35 |
36 |
python wdv3_timm.py <swinv2|convnext|vit> path/to/image.png
37 |
38 |
39 |
Example output from `python wdv3_timm.py vit a_picture_of_ganyu.png`:
40 |
41 |
Loading model 'vit' from 'SmilingWolf/wd-vit-tagger-v3'...
42 |
Loading tag list...
43 |
Creating data transform...
44 |
Loading image and preprocessing...
45 |
Running inference...
46 |
Processing results...
47 |
48 |
Caption: 1girl, horns, solo, bell, ahoge, colored_skin, blue_skin, neck_bell, looking_at_viewer, purple_eyes, upper_body, blonde_hair, long_hair, goat_horns, blue_hair, off_shoulder, sidelocks, bare_shoulders, alternate_costume, shirt, black_shirt, cowbell, ganyu_(genshin_impact)
49 |
50 |
Tags: 1girl, horns, solo, bell, ahoge, colored skin, blue skin, neck bell, looking at viewer, purple eyes, upper body, blonde hair, long hair, goat horns, blue hair, off shoulder, sidelocks, bare shoulders, alternate costume, shirt, black shirt, cowbell, ganyu \(genshin impact\)
51 |
52 |
53 |
general: 0.827
54 |
sensitive: 0.199
55 |
questionable: 0.001
56 |
explicit: 0.001
57 |
58 |
Character tags (threshold=0.75):
59 |
ganyu_(genshin_impact): 0.991
60 |
61 |
General tags (threshold=0.35):
62 |
1girl: 0.996
63 |
horns: 0.950
64 |
solo: 0.947
65 |
bell: 0.918
66 |
ahoge: 0.897
67 |
colored_skin: 0.881
68 |
blue_skin: 0.872
69 |
neck_bell: 0.854
70 |
looking_at_viewer: 0.817
71 |
purple_eyes: 0.734
72 |
upper_body: 0.615
73 |
blonde_hair: 0.609
74 |
long_hair: 0.607
75 |
goat_horns: 0.524
76 |
blue_hair: 0.496
77 |
off_shoulder: 0.472
78 |
sidelocks: 0.470
79 |
bare_shoulders: 0.464
80 |
alternate_costume: 0.437
81 |
shirt: 0.427
82 |
black_shirt: 0.417
83 |
cowbell: 0.415
84 |
@@ -0,0 +1,11 @@
1 |
2 |
3 |
4 |
5 |
pillow >= 9.5.0
6 |
simple-parsing >= 0.1.5
7 |
timm @ git+https://github.com/huggingface/pytorch-image-models@main#egg=timm
8 |
9 |
torch >= 2.0.0
10 |
11 |
@@ -0,0 +1,24 @@
1 |
#!/usr/bin/env bash
2 |
set -euo pipefail
3 |
4 |
# get the folder this script is in and make sure we're in it
5 |
script_dir=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd -P)
6 |
cd "${script_dir}"
7 |
8 |
# make venv if not exist
9 |
if [[ ! -d .venv ]]; then
10 |
echo "Creating virtual environment..."
11 |
python3.10 -m venv .venv
12 |
13 |
14 |
# activate the venv
15 |
source .venv/bin/activate
16 |
17 |
# upgrade pip
18 |
python -m pip install -U pip setuptools wheel
19 |
20 |
# install requirements
21 |
python -m pip install -r requirements.txt
22 |
23 |
echo "Setup complete. Run 'source .venv/bin/activate' to enter the virtual environment."
24 |
exit 0
@@ -0,0 +1,203 @@
1 |
from dataclasses import dataclass
2 |
from pathlib import Path
3 |
from typing import Optional
4 |
5 |
import numpy as np
6 |
import pandas as pd
7 |
import timm
8 |
import torch
9 |
from huggingface_hub import hf_hub_download
10 |
from huggingface_hub.utils import HfHubHTTPError
11 |
from PIL import Image
12 |
from simple_parsing import field, parse_known_args
13 |
from timm.data import create_transform, resolve_data_config
14 |
from torch import Tensor, nn
15 |
from torch.nn import functional as F
16 |
17 |
import json
18 |
19 |
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20 |
21 |
"vit": "SmilingWolf/wd-vit-tagger-v3",
22 |
"swinv2": "SmilingWolf/wd-swinv2-tagger-v3",
23 |
"convnext": "SmilingWolf/wd-convnext-tagger-v3",
24 |
25 |
26 |
def pil_ensure_rgb(image: Image.Image) -> Image.Image:
27 |
# convert to RGB/RGBA if not already (deals with palette images etc.)
28 |
if image.mode not in ["RGB", "RGBA"]:
29 |
image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
30 |
# convert RGBA to RGB with white background
31 |
if image.mode == "RGBA":
32 |
canvas = Image.new("RGBA", image.size, (255, 255, 255))
33 |
34 |
image = canvas.convert("RGB")
35 |
return image
36 |
37 |
def pil_pad_square(image: Image.Image) -> Image.Image:
38 |
w, h = image.size
39 |
# get the largest dimension so we can pad to a square
40 |
px = max(image.size)
41 |
# pad to square with white background
42 |
canvas = Image.new("RGB", (px, px), (255, 255, 255))
43 |
canvas.paste(image, ((px - w) // 2, (px - h) // 2))
44 |
return canvas
45 |
46 |
47 |
class LabelData:
48 |
names: list[str]
49 |
rating: list[np.int64]
50 |
general: list[np.int64]
51 |
character: list[np.int64]
52 |
53 |
def load_labels_hf(
54 |
repo_id: str,
55 |
revision: Optional[str] = None,
56 |
token: Optional[str] = None,
57 |
) -> LabelData:
58 |
59 |
csv_path = hf_hub_download(
60 |
repo_id=repo_id, filename="selected_tags.csv", revision=revision, token=token
61 |
62 |
csv_path = Path(csv_path).resolve()
63 |
except HfHubHTTPError as e:
64 |
raise FileNotFoundError(f"selected_tags.csv failed to download from {repo_id}") from e
65 |
66 |
df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
67 |
tag_data = LabelData(
68 |
69 |
rating=list(np.where(df["category"] == 9)[0]),
70 |
general=list(np.where(df["category"] == 0)[0]),
71 |
character=list(np.where(df["category"] == 4)[0]),
72 |
73 |
74 |
return tag_data
75 |
76 |
def get_tags(
77 |
probs: Tensor,
78 |
labels: LabelData,
79 |
gen_threshold: float,
80 |
char_threshold: float,
81 |
82 |
# Convert indices+probs to labels
83 |
probs = list(zip(labels.names, probs.numpy()))
84 |
85 |
# First 4 labels are actually ratings
86 |
rating_labels = dict([probs[i] for i in labels.rating])
87 |
88 |
# General labels, pick any where prediction confidence > threshold
89 |
gen_labels = [probs[i] for i in labels.general]
90 |
gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold])
91 |
gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True))
92 |
93 |
# Character labels, pick any where prediction confidence > threshold
94 |
char_labels = [probs[i] for i in labels.character]
95 |
char_labels = dict([x for x in char_labels if x[1] > char_threshold])
96 |
char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True))
97 |
98 |
# Combine general and character labels, sort by confidence
99 |
combined_names = [x for x in gen_labels]
100 |
combined_names.extend([x for x in char_labels])
101 |
102 |
# Convert to a string suitable for use as a training caption
103 |
caption = ", ".join(combined_names)
104 |
taglist = caption.replace("_", " ").replace("(", "\(").replace(")", "\)")
105 |
106 |
return caption, taglist, rating_labels, char_labels, gen_labels
107 |
108 |
109 |
class ScriptOptions:
110 |
image_file: Path = field(positional=True)
111 |
model: str = field(default="vit")
112 |
gen_threshold: float = field(default=0.35)
113 |
char_threshold: float = field(default=0.75)
114 |
115 |
def main(opts: ScriptOptions):
116 |
repo_id = MODEL_REPO_MAP.get(opts.model)
117 |
image_path = Path(opts.image_file).resolve()
118 |
if not image_path.is_file():
119 |
raise FileNotFoundError(f"Image file not found: {image_path}")
120 |
121 |
print(f"Loading model '{opts.model}' from '{repo_id}'...")
122 |
model: nn.Module = timm.create_model("hf-hub:" + repo_id).eval()
123 |
state_dict = timm.models.load_state_dict_from_hf(repo_id)
124 |
125 |
126 |
print("Loading tag list...")
127 |
labels: LabelData = load_labels_hf(repo_id=repo_id)
128 |
129 |
print("Creating data transform...")
130 |
transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
131 |
132 |
print("Loading image and preprocessing...")
133 |
# get image
134 |
img_input: Image.Image = Image.open(image_path)
135 |
# ensure image is RGB
136 |
img_input = pil_ensure_rgb(img_input)
137 |
# pad to square with white background
138 |
img_input = pil_pad_square(img_input)
139 |
# run the model's input transform to convert to tensor and rescale
140 |
inputs: Tensor = transform(img_input).unsqueeze(0)
141 |
# NCHW image RGB to BGR
142 |
inputs = inputs[:, [2, 1, 0]]
143 |
144 |
print("Running inference...")
145 |
with torch.inference_mode():
146 |
# move model to GPU, if available
147 |
if torch_device.type != "cpu":
148 |
model = model.to(torch_device)
149 |
inputs = inputs.to(torch_device)
150 |
# run the model
151 |
outputs = model.forward(inputs)
152 |
# apply the final activation function (timm doesn't support doing this internally)
153 |
outputs = F.sigmoid(outputs)
154 |
# move inputs, outputs, và model về CPU nếu đang ở trên GPU
155 |
if torch_device.type != "cpu":
156 |
inputs = inputs.to("cpu")
157 |
outputs = outputs.to("cpu")
158 |
model = model.to("cpu")
159 |
160 |
print("Processing results...")
161 |
# Đọc giá trị từ config.json
162 |
with open('config.json', 'r') as config_file:
163 |
config_data = json.load(config_file)
164 |
165 |
gen_threshold = config_data.get('general_threshold', 0.35)
166 |
char_threshold = config_data.get('character_threshold', 0.75)
167 |
168 |
caption, taglist, ratings, character, general = get_tags(
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
print(f"Caption: {caption}")
177 |
178 |
print(f"Tags: {taglist}")
179 |
180 |
181 |
182 |
for k, v in ratings.items():
183 |
print(f" {k}: {v:.3f}")
184 |
185 |
186 |
print(f"Character tags (threshold={char_threshold}):")
187 |
for k, v in character.items():
188 |
print(f" {k}: {v:.3f}")
189 |
190 |
191 |
print(f"General tags (threshold={gen_threshold}):")
192 |
for k, v in general.items():
193 |
print(f" {k}: {v:.3f}")
194 |
195 |
196 |
197 |
198 |
if __name__ == "__main__":
199 |
opts, _ = parse_known_args(ScriptOptions)
200 |
if opts.model not in MODEL_REPO_MAP:
201 |
print(f"Available models: {list(MODEL_REPO_MAP.keys())}")
202 |
raise ValueError(f"Unknown model name '{opts.model}'")
203 |