SmilingWolf commited on
Commit
b682a57
·
verified ·
1 Parent(s): 68f9639

Upload 6 files

Browse files
Files changed (6) hide show
  1. .gitattributes +2 -10
  2. .gitignore +1 -0
  3. README.md +31 -5
  4. app.py +350 -0
  5. power.jpg +0 -0
  6. requirements.txt +5 -0
.gitattributes CHANGED
@@ -1,35 +1,27 @@
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
  *.bz2 filter=lfs diff=lfs merge=lfs -text
 
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
 
11
  *.model filter=lfs diff=lfs merge=lfs -text
12
  *.msgpack filter=lfs diff=lfs merge=lfs -text
 
 
13
  *.onnx filter=lfs diff=lfs merge=lfs -text
14
  *.ot filter=lfs diff=lfs merge=lfs -text
15
  *.parquet filter=lfs diff=lfs merge=lfs -text
16
  *.pb filter=lfs diff=lfs merge=lfs -text
 
 
17
  *.pt filter=lfs diff=lfs merge=lfs -text
18
  *.pth filter=lfs diff=lfs merge=lfs -text
19
  *.rar filter=lfs diff=lfs merge=lfs -text
 
20
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
22
  *.tflite filter=lfs diff=lfs merge=lfs -text
23
  *.tgz filter=lfs diff=lfs merge=lfs -text
 
24
  *.xz filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ images
README.md CHANGED
@@ -1,12 +1,38 @@
1
  ---
2
- title: Wd Tagger
3
- emoji: 🐨
4
- colorFrom: purple
5
- colorTo: blue
6
  sdk: gradio
7
  sdk_version: 5.17.1
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: WaifuDiffusion Tagger
3
+ emoji: 💬
4
+ colorFrom: blue
5
+ colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.17.1
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ # Configuration
13
+
14
+ `title`: _string_
15
+ Display title for the Space
16
+
17
+ `emoji`: _string_
18
+ Space emoji (emoji-only character allowed)
19
+
20
+ `colorFrom`: _string_
21
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
22
+
23
+ `colorTo`: _string_
24
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
25
+
26
+ `sdk`: _string_
27
+ Can be either `gradio`, `streamlit`, or `static`
28
+
29
+ `sdk_version` : _string_
30
+ Only applicable for `streamlit` SDK.
31
+ See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
32
+
33
+ `app_file`: _string_
34
+ Path to your main application file (which contains either `gradio` or `streamlit` Python code, or `static` html code).
35
+ Path is relative to the root of the repository.
36
+
37
+ `pinned`: _boolean_
38
+ Whether the Space stays on top of your list.
app.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import gradio as gr
5
+ import huggingface_hub
6
+ import numpy as np
7
+ import onnxruntime as rt
8
+ import pandas as pd
9
+ from PIL import Image
10
+
11
+ TITLE = "WaifuDiffusion Tagger"
12
+ DESCRIPTION = """
13
+ Demo for the WaifuDiffusion tagger models
14
+
15
+ Example image by [ほし☆☆☆](https://www.pixiv.net/en/users/43565085)
16
+ """
17
+
18
+ HF_TOKEN = os.environ.get("HF_TOKEN", "")
19
+
20
+ # Dataset v3 series of models:
21
+ SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
22
+ CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
23
+ VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
24
+ VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
25
+ EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
26
+
27
+ # Dataset v2 series of models:
28
+ MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
29
+ SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
30
+ CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
31
+ CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
32
+ VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
33
+
34
+ # IdolSankaku series of models:
35
+ EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1"
36
+ SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1"
37
+
38
+ # Files to download from the repos
39
+ MODEL_FILENAME = "model.onnx"
40
+ LABEL_FILENAME = "selected_tags.csv"
41
+
42
+ # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
43
+ kaomojis = [
44
+ "0_0",
45
+ "(o)_(o)",
46
+ "+_+",
47
+ "+_-",
48
+ "._.",
49
+ "<o>_<o>",
50
+ "<|>_<|>",
51
+ "=_=",
52
+ ">_<",
53
+ "3_3",
54
+ "6_9",
55
+ ">_o",
56
+ "@_@",
57
+ "^_^",
58
+ "o_o",
59
+ "u_u",
60
+ "x_x",
61
+ "|_|",
62
+ "||_||",
63
+ ]
64
+
65
+
66
+ def parse_args() -> argparse.Namespace:
67
+ parser = argparse.ArgumentParser()
68
+ parser.add_argument("--score-slider-step", type=float, default=0.05)
69
+ parser.add_argument("--score-general-threshold", type=float, default=0.35)
70
+ parser.add_argument("--score-character-threshold", type=float, default=0.85)
71
+ return parser.parse_args()
72
+
73
+
74
+ def load_labels(dataframe) -> list[str]:
75
+ name_series = dataframe["name"]
76
+ name_series = name_series.map(
77
+ lambda x: x.replace("_", " ") if x not in kaomojis else x
78
+ )
79
+ tag_names = name_series.tolist()
80
+
81
+ rating_indexes = list(np.where(dataframe["category"] == 9)[0])
82
+ general_indexes = list(np.where(dataframe["category"] == 0)[0])
83
+ character_indexes = list(np.where(dataframe["category"] == 4)[0])
84
+ return tag_names, rating_indexes, general_indexes, character_indexes
85
+
86
+
87
+ def mcut_threshold(probs):
88
+ """
89
+ Maximum Cut Thresholding (MCut)
90
+ Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy
91
+ for Multi-label Classification. In 11th International Symposium, IDA 2012
92
+ (pp. 172-183).
93
+ """
94
+ sorted_probs = probs[probs.argsort()[::-1]]
95
+ difs = sorted_probs[:-1] - sorted_probs[1:]
96
+ t = difs.argmax()
97
+ thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
98
+ return thresh
99
+
100
+
101
+ class Predictor:
102
+ def __init__(self):
103
+ self.model_target_size = None
104
+ self.last_loaded_repo = None
105
+
106
+ def download_model(self, model_repo):
107
+ csv_path = huggingface_hub.hf_hub_download(
108
+ model_repo,
109
+ LABEL_FILENAME,
110
+ use_auth_token=HF_TOKEN,
111
+ )
112
+ model_path = huggingface_hub.hf_hub_download(
113
+ model_repo,
114
+ MODEL_FILENAME,
115
+ use_auth_token=HF_TOKEN,
116
+ )
117
+ return csv_path, model_path
118
+
119
+ def load_model(self, model_repo):
120
+ if model_repo == self.last_loaded_repo:
121
+ return
122
+
123
+ csv_path, model_path = self.download_model(model_repo)
124
+
125
+ tags_df = pd.read_csv(csv_path)
126
+ sep_tags = load_labels(tags_df)
127
+
128
+ self.tag_names = sep_tags[0]
129
+ self.rating_indexes = sep_tags[1]
130
+ self.general_indexes = sep_tags[2]
131
+ self.character_indexes = sep_tags[3]
132
+
133
+ model = rt.InferenceSession(model_path)
134
+ _, height, width, _ = model.get_inputs()[0].shape
135
+ self.model_target_size = height
136
+
137
+ self.last_loaded_repo = model_repo
138
+ self.model = model
139
+
140
+ def prepare_image(self, image):
141
+ target_size = self.model_target_size
142
+
143
+ canvas = Image.new("RGBA", image.size, (255, 255, 255))
144
+ canvas.alpha_composite(image)
145
+ image = canvas.convert("RGB")
146
+
147
+ # Pad image to square
148
+ image_shape = image.size
149
+ max_dim = max(image_shape)
150
+ pad_left = (max_dim - image_shape[0]) // 2
151
+ pad_top = (max_dim - image_shape[1]) // 2
152
+
153
+ padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
154
+ padded_image.paste(image, (pad_left, pad_top))
155
+
156
+ # Resize
157
+ if max_dim != target_size:
158
+ padded_image = padded_image.resize(
159
+ (target_size, target_size),
160
+ Image.BICUBIC,
161
+ )
162
+
163
+ # Convert to numpy array
164
+ image_array = np.asarray(padded_image, dtype=np.float32)
165
+
166
+ # Convert PIL-native RGB to BGR
167
+ image_array = image_array[:, :, ::-1]
168
+
169
+ return np.expand_dims(image_array, axis=0)
170
+
171
+ def predict(
172
+ self,
173
+ image,
174
+ model_repo,
175
+ general_thresh,
176
+ general_mcut_enabled,
177
+ character_thresh,
178
+ character_mcut_enabled,
179
+ ):
180
+ self.load_model(model_repo)
181
+
182
+ image = self.prepare_image(image)
183
+
184
+ input_name = self.model.get_inputs()[0].name
185
+ label_name = self.model.get_outputs()[0].name
186
+ preds = self.model.run([label_name], {input_name: image})[0]
187
+
188
+ labels = list(zip(self.tag_names, preds[0].astype(float)))
189
+
190
+ # First 4 labels are actually ratings: pick one with argmax
191
+ ratings_names = [labels[i] for i in self.rating_indexes]
192
+ rating = dict(ratings_names)
193
+
194
+ # Then we have general tags: pick any where prediction confidence > threshold
195
+ general_names = [labels[i] for i in self.general_indexes]
196
+
197
+ if general_mcut_enabled:
198
+ general_probs = np.array([x[1] for x in general_names])
199
+ general_thresh = mcut_threshold(general_probs)
200
+
201
+ general_res = [x for x in general_names if x[1] > general_thresh]
202
+ general_res = dict(general_res)
203
+
204
+ # Everything else is characters: pick any where prediction confidence > threshold
205
+ character_names = [labels[i] for i in self.character_indexes]
206
+
207
+ if character_mcut_enabled:
208
+ character_probs = np.array([x[1] for x in character_names])
209
+ character_thresh = mcut_threshold(character_probs)
210
+ character_thresh = max(0.15, character_thresh)
211
+
212
+ character_res = [x for x in character_names if x[1] > character_thresh]
213
+ character_res = dict(character_res)
214
+
215
+ sorted_general_strings = sorted(
216
+ general_res.items(),
217
+ key=lambda x: x[1],
218
+ reverse=True,
219
+ )
220
+ sorted_general_strings = [x[0] for x in sorted_general_strings]
221
+ sorted_general_strings = (
222
+ ", ".join(sorted_general_strings).replace("(", r"\(").replace(")", r"\)")
223
+ )
224
+
225
+ return sorted_general_strings, rating, character_res, general_res
226
+
227
+
228
+ def main():
229
+ args = parse_args()
230
+
231
+ predictor = Predictor()
232
+
233
+ dropdown_list = [
234
+ SWINV2_MODEL_DSV3_REPO,
235
+ CONV_MODEL_DSV3_REPO,
236
+ VIT_MODEL_DSV3_REPO,
237
+ VIT_LARGE_MODEL_DSV3_REPO,
238
+ EVA02_LARGE_MODEL_DSV3_REPO,
239
+ # ---
240
+ MOAT_MODEL_DSV2_REPO,
241
+ SWIN_MODEL_DSV2_REPO,
242
+ CONV_MODEL_DSV2_REPO,
243
+ CONV2_MODEL_DSV2_REPO,
244
+ VIT_MODEL_DSV2_REPO,
245
+ # ---
246
+ SWINV2_MODEL_IS_DSV1_REPO,
247
+ EVA02_LARGE_MODEL_IS_DSV1_REPO,
248
+ ]
249
+
250
+ with gr.Blocks(title=TITLE) as demo:
251
+ with gr.Column():
252
+ gr.Markdown(
253
+ value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
254
+ )
255
+ gr.Markdown(value=DESCRIPTION)
256
+ with gr.Row():
257
+ with gr.Column(variant="panel"):
258
+ image = gr.Image(type="pil", image_mode="RGBA", label="Input")
259
+ model_repo = gr.Dropdown(
260
+ dropdown_list,
261
+ value=SWINV2_MODEL_DSV3_REPO,
262
+ label="Model",
263
+ )
264
+ with gr.Row():
265
+ general_thresh = gr.Slider(
266
+ 0,
267
+ 1,
268
+ step=args.score_slider_step,
269
+ value=args.score_general_threshold,
270
+ label="General Tags Threshold",
271
+ scale=3,
272
+ )
273
+ general_mcut_enabled = gr.Checkbox(
274
+ value=False,
275
+ label="Use MCut threshold",
276
+ scale=1,
277
+ )
278
+ with gr.Row():
279
+ character_thresh = gr.Slider(
280
+ 0,
281
+ 1,
282
+ step=args.score_slider_step,
283
+ value=args.score_character_threshold,
284
+ label="Character Tags Threshold",
285
+ scale=3,
286
+ )
287
+ character_mcut_enabled = gr.Checkbox(
288
+ value=False,
289
+ label="Use MCut threshold",
290
+ scale=1,
291
+ )
292
+ with gr.Row():
293
+ clear = gr.ClearButton(
294
+ components=[
295
+ image,
296
+ model_repo,
297
+ general_thresh,
298
+ general_mcut_enabled,
299
+ character_thresh,
300
+ character_mcut_enabled,
301
+ ],
302
+ variant="secondary",
303
+ size="lg",
304
+ )
305
+ submit = gr.Button(value="Submit", variant="primary", size="lg")
306
+ with gr.Column(variant="panel"):
307
+ sorted_general_strings = gr.Textbox(label="Output (string)")
308
+ rating = gr.Label(label="Rating")
309
+ character_res = gr.Label(label="Output (characters)")
310
+ general_res = gr.Label(label="Output (tags)")
311
+ clear.add(
312
+ [
313
+ sorted_general_strings,
314
+ rating,
315
+ character_res,
316
+ general_res,
317
+ ]
318
+ )
319
+
320
+ submit.click(
321
+ predictor.predict,
322
+ inputs=[
323
+ image,
324
+ model_repo,
325
+ general_thresh,
326
+ general_mcut_enabled,
327
+ character_thresh,
328
+ character_mcut_enabled,
329
+ ],
330
+ outputs=[sorted_general_strings, rating, character_res, general_res],
331
+ )
332
+
333
+ gr.Examples(
334
+ [["power.jpg", SWINV2_MODEL_DSV3_REPO, 0.35, False, 0.85, False]],
335
+ inputs=[
336
+ image,
337
+ model_repo,
338
+ general_thresh,
339
+ general_mcut_enabled,
340
+ character_thresh,
341
+ character_mcut_enabled,
342
+ ],
343
+ )
344
+
345
+ demo.queue(max_size=10)
346
+ demo.launch()
347
+
348
+
349
+ if __name__ == "__main__":
350
+ main()
power.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ pillow
2
+ onnxruntime
3
+ huggingface-hub
4
+ pandas
5
+ numpy