Spaces:
Running
Running
Super-squash branch 'main' using huggingface_hub
Browse filesCo-authored-by: p1atdev <p1atdev@users.noreply.huggingface.co>
- .gitattributes +35 -0
- README.md +13 -0
- app.py +276 -0
- characterfull.txt +0 -0
- danbooru_e621.csv +0 -0
- requirements.txt +5 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Wd Tagger Transformers
|
3 |
+
emoji: 😻
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.36.1
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from transformers import (
|
6 |
+
AutoImageProcessor,
|
7 |
+
AutoModelForImageClassification,
|
8 |
+
)
|
9 |
+
|
10 |
+
import gradio as gr
|
11 |
+
import spaces # ZERO GPU
|
12 |
+
|
13 |
+
MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
|
14 |
+
MODEL_NAME = MODEL_NAMES[0]
|
15 |
+
|
16 |
+
model = AutoModelForImageClassification.from_pretrained(
|
17 |
+
MODEL_NAME,
|
18 |
+
)
|
19 |
+
model.to("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
+
processor = AutoImageProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
21 |
+
|
22 |
+
# ref: https://qiita.com/tregu148/items/fccccbbc47d966dd2fc2
|
23 |
+
def gradio_copy_text(_text: None):
|
24 |
+
gr.Info("Copied!")
|
25 |
+
|
26 |
+
COPY_ACTION_JS = """\
|
27 |
+
(inputs, _outputs) => {
|
28 |
+
// inputs is the string value of the input_text
|
29 |
+
if (inputs.trim() !== "") {
|
30 |
+
navigator.clipboard.writeText(inputs);
|
31 |
+
}
|
32 |
+
}"""
|
33 |
+
|
34 |
+
def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
|
35 |
+
return (
|
36 |
+
[f"1{noun}"]
|
37 |
+
+ [f"{num}{noun}s" for num in range(minimum + 1, maximum + 1)]
|
38 |
+
+ [f"{maximum+1}+{noun}s"]
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
PEOPLE_TAGS = (
|
43 |
+
_people_tag("girl") + _people_tag("boy") + _people_tag("other") + ["no humans"]
|
44 |
+
)
|
45 |
+
RATING_MAP = {
|
46 |
+
"general": "safe",
|
47 |
+
"sensitive": "sensitive",
|
48 |
+
"questionable": "nsfw",
|
49 |
+
"explicit": "explicit, nsfw",
|
50 |
+
}
|
51 |
+
RATING_MAP_E621 = {
|
52 |
+
"general": "rating_safe",
|
53 |
+
"sensitive": "rating_safe",
|
54 |
+
"questionable": "rating_questionable",
|
55 |
+
"explicit": "rating_explicit",
|
56 |
+
}
|
57 |
+
|
58 |
+
DESCRIPTION_MD = """
|
59 |
+
# WD Tagger with 🤗 transformers
|
60 |
+
Currently supports the following model(s):
|
61 |
+
- [p1atdev/wd-swinv2-tagger-v3-hf](https://huggingface.co/p1atdev/wd-swinv2-tagger-v3-hf)
|
62 |
+
|
63 |
+
""".strip()
|
64 |
+
|
65 |
+
|
66 |
+
def character_list_to_series_list(character_list):
|
67 |
+
def get_series_dict():
|
68 |
+
import re
|
69 |
+
|
70 |
+
with open('characterfull.txt', 'r') as f:
|
71 |
+
lines = f.readlines()
|
72 |
+
|
73 |
+
series_dict = {}
|
74 |
+
for line in lines:
|
75 |
+
parts = line.strip().split(', ')
|
76 |
+
if len(parts) >= 3:
|
77 |
+
name = parts[-2].replace("\\", "")
|
78 |
+
if name.endswith(")"):
|
79 |
+
names = name.split("(")
|
80 |
+
character_name = "(".join(names[:-1])
|
81 |
+
if character_name.endswith(" "):
|
82 |
+
name = character_name[:-1]
|
83 |
+
series = re.sub(r'\\[()]', '', parts[-1])
|
84 |
+
series_dict[name] = series
|
85 |
+
|
86 |
+
return series_dict
|
87 |
+
|
88 |
+
output_series_tag = []
|
89 |
+
series_tag = ""
|
90 |
+
series_dict = get_series_dict()
|
91 |
+
for tag in character_list:
|
92 |
+
series_tag = series_dict.get(tag, "")
|
93 |
+
if tag.endswith(")"):
|
94 |
+
tags = tag.split("(")
|
95 |
+
character_tag = "(".join(tags[:-1])
|
96 |
+
if character_tag.endswith(" "):
|
97 |
+
character_tag = character_tag[:-1]
|
98 |
+
series_tag = tags[-1].replace(")", "")
|
99 |
+
|
100 |
+
if series_tag:
|
101 |
+
output_series_tag.append(series_tag)
|
102 |
+
|
103 |
+
return output_series_tag
|
104 |
+
|
105 |
+
|
106 |
+
def get_e621_dict():
|
107 |
+
with open('danbooru_e621.csv', 'r', encoding="utf-8") as f:
|
108 |
+
lines = f.readlines()
|
109 |
+
|
110 |
+
e621_dict = {}
|
111 |
+
for line in lines:
|
112 |
+
parts = line.strip().split(',')
|
113 |
+
e621_dict[parts[0]] = parts[1]
|
114 |
+
|
115 |
+
return e621_dict
|
116 |
+
|
117 |
+
|
118 |
+
def danbooru_to_e621(dtag, e621_dict):
|
119 |
+
def d_to_e(match, e621_dict):
|
120 |
+
dtag = match.group(0)
|
121 |
+
etag = e621_dict.get(dtag.strip().replace("_", " "), "")
|
122 |
+
if etag:
|
123 |
+
return etag
|
124 |
+
else:
|
125 |
+
return dtag
|
126 |
+
|
127 |
+
import re
|
128 |
+
tag = re.sub(r'[\w ]+', lambda wrapper: d_to_e(wrapper, e621_dict), dtag, 2)
|
129 |
+
|
130 |
+
return tag
|
131 |
+
|
132 |
+
def postprocess_results(
|
133 |
+
results: dict[str, float], general_threshold: float, character_threshold: float
|
134 |
+
):
|
135 |
+
results = {
|
136 |
+
k: v for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)
|
137 |
+
}
|
138 |
+
|
139 |
+
rating = {}
|
140 |
+
character = {}
|
141 |
+
general = {}
|
142 |
+
|
143 |
+
for k, v in results.items():
|
144 |
+
if k.startswith("rating:"):
|
145 |
+
rating[k.replace("rating:", "")] = v
|
146 |
+
continue
|
147 |
+
elif k.startswith("character:"):
|
148 |
+
character[k.replace("character:", "")] = v
|
149 |
+
continue
|
150 |
+
|
151 |
+
general[k] = v
|
152 |
+
|
153 |
+
character = {k: v for k, v in character.items() if v >= character_threshold}
|
154 |
+
general = {k: v for k, v in general.items() if v >= general_threshold}
|
155 |
+
|
156 |
+
return rating, character, general
|
157 |
+
|
158 |
+
|
159 |
+
def animagine_prompt(rating: list[str], character: list[str], general: list[str], tag_type):
|
160 |
+
people_tags: list[str] = []
|
161 |
+
other_tags: list[str] = []
|
162 |
+
if tag_type == "e621":
|
163 |
+
rating_tag = RATING_MAP_E621[rating[0]]
|
164 |
+
else:
|
165 |
+
rating_tag = RATING_MAP[rating[0]]
|
166 |
+
|
167 |
+
e621_dict = get_e621_dict()
|
168 |
+
for tag in general:
|
169 |
+
if tag_type == "e621":
|
170 |
+
tag = danbooru_to_e621(tag, e621_dict)
|
171 |
+
if tag in PEOPLE_TAGS:
|
172 |
+
people_tags.append(tag)
|
173 |
+
else:
|
174 |
+
other_tags.append(tag)
|
175 |
+
|
176 |
+
output_series_tag = character_list_to_series_list(character)
|
177 |
+
|
178 |
+
all_tags = people_tags + character + output_series_tag + other_tags + [rating_tag]
|
179 |
+
|
180 |
+
return ", ".join(all_tags)
|
181 |
+
|
182 |
+
|
183 |
+
@spaces.GPU(enable_queue=True)
|
184 |
+
def predict_tags(
|
185 |
+
image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8, tag_type = "danbooru"
|
186 |
+
):
|
187 |
+
inputs = processor.preprocess(image, return_tensors="pt")
|
188 |
+
|
189 |
+
outputs = model(**inputs.to(model.device, model.dtype))
|
190 |
+
logits = torch.sigmoid(outputs.logits[0]) # take the first logits
|
191 |
+
|
192 |
+
# get probabilities
|
193 |
+
results = {
|
194 |
+
model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits)
|
195 |
+
}
|
196 |
+
|
197 |
+
# rating, character, general
|
198 |
+
rating, character, general = postprocess_results(
|
199 |
+
results, general_threshold, character_threshold
|
200 |
+
)
|
201 |
+
|
202 |
+
prompt = animagine_prompt(
|
203 |
+
list(rating.keys()), list(character.keys()), list(general.keys()), tag_type
|
204 |
+
)
|
205 |
+
|
206 |
+
return rating, character, general, prompt, gr.update(interactive=True,)
|
207 |
+
|
208 |
+
|
209 |
+
def demo():
|
210 |
+
with gr.Blocks() as ui:
|
211 |
+
gr.Markdown(DESCRIPTION_MD)
|
212 |
+
|
213 |
+
with gr.Row():
|
214 |
+
with gr.Column():
|
215 |
+
input_image = gr.Image(label="Input image", type="pil")
|
216 |
+
|
217 |
+
with gr.Group():
|
218 |
+
general_threshold = gr.Slider(
|
219 |
+
label="Threshold",
|
220 |
+
minimum=0.0,
|
221 |
+
maximum=1.0,
|
222 |
+
value=0.3,
|
223 |
+
step=0.01,
|
224 |
+
interactive=True,
|
225 |
+
)
|
226 |
+
character_threshold = gr.Slider(
|
227 |
+
label="Character threshold",
|
228 |
+
minimum=0.0,
|
229 |
+
maximum=1.0,
|
230 |
+
value=0.8,
|
231 |
+
step=0.01,
|
232 |
+
interactive=True,
|
233 |
+
)
|
234 |
+
tag_type = gr.Radio(
|
235 |
+
label="Output tag conversion",
|
236 |
+
info="danbooru for Animagine, e621 for Pony.",
|
237 |
+
choices=["danbooru", "e621"],
|
238 |
+
value="danbooru",
|
239 |
+
)
|
240 |
+
|
241 |
+
_model_radio = gr.Dropdown(
|
242 |
+
choices=MODEL_NAMES,
|
243 |
+
label="Model",
|
244 |
+
value=MODEL_NAMES[0],
|
245 |
+
interactive=True,
|
246 |
+
)
|
247 |
+
|
248 |
+
start_btn = gr.Button(value="Start", variant="primary")
|
249 |
+
|
250 |
+
with gr.Column():
|
251 |
+
|
252 |
+
with gr.Group():
|
253 |
+
prompt_text = gr.TextArea(label="Prompt", interactive=False)
|
254 |
+
copy_btn = gr.Button(value="Copy to clipboard", interactive=False)
|
255 |
+
|
256 |
+
rating_tags_label = gr.Label(label="Rating tags")
|
257 |
+
character_tags_label = gr.Label(label="Character tags")
|
258 |
+
general_tags_label = gr.Label(label="General tags")
|
259 |
+
|
260 |
+
start_btn.click(
|
261 |
+
predict_tags,
|
262 |
+
inputs=[input_image, general_threshold, character_threshold, tag_type],
|
263 |
+
outputs=[
|
264 |
+
rating_tags_label,
|
265 |
+
character_tags_label,
|
266 |
+
general_tags_label,
|
267 |
+
prompt_text,
|
268 |
+
copy_btn,
|
269 |
+
],
|
270 |
+
)
|
271 |
+
copy_btn.click(gradio_copy_text, inputs=[prompt_text], js=COPY_ACTION_JS)
|
272 |
+
|
273 |
+
return ui
|
274 |
+
|
275 |
+
if __name__ == "__main__":
|
276 |
+
demo().queue().launch()
|
characterfull.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
danbooru_e621.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
accelerate
|
4 |
+
transformers
|
5 |
+
spaces
|