John6666 p1atdev commited on
Commit
ae039af
0 Parent(s):

Super-squash branch 'main' using huggingface_hub

Browse files

Co-authored-by: p1atdev <p1atdev@users.noreply.huggingface.co>

Files changed (6) hide show
  1. .gitattributes +35 -0
  2. README.md +13 -0
  3. app.py +276 -0
  4. characterfull.txt +0 -0
  5. danbooru_e621.csv +0 -0
  6. requirements.txt +5 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Wd Tagger Transformers
3
+ emoji: 😻
4
+ colorFrom: pink
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 4.36.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+
3
+ import torch
4
+
5
+ from transformers import (
6
+ AutoImageProcessor,
7
+ AutoModelForImageClassification,
8
+ )
9
+
10
+ import gradio as gr
11
+ import spaces # ZERO GPU
12
+
13
+ MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
14
+ MODEL_NAME = MODEL_NAMES[0]
15
+
16
+ model = AutoModelForImageClassification.from_pretrained(
17
+ MODEL_NAME,
18
+ )
19
+ model.to("cuda" if torch.cuda.is_available() else "cpu")
20
+ processor = AutoImageProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
21
+
22
+ # ref: https://qiita.com/tregu148/items/fccccbbc47d966dd2fc2
23
+ def gradio_copy_text(_text: None):
24
+ gr.Info("Copied!")
25
+
26
+ COPY_ACTION_JS = """\
27
+ (inputs, _outputs) => {
28
+ // inputs is the string value of the input_text
29
+ if (inputs.trim() !== "") {
30
+ navigator.clipboard.writeText(inputs);
31
+ }
32
+ }"""
33
+
34
+ def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
35
+ return (
36
+ [f"1{noun}"]
37
+ + [f"{num}{noun}s" for num in range(minimum + 1, maximum + 1)]
38
+ + [f"{maximum+1}+{noun}s"]
39
+ )
40
+
41
+
42
+ PEOPLE_TAGS = (
43
+ _people_tag("girl") + _people_tag("boy") + _people_tag("other") + ["no humans"]
44
+ )
45
+ RATING_MAP = {
46
+ "general": "safe",
47
+ "sensitive": "sensitive",
48
+ "questionable": "nsfw",
49
+ "explicit": "explicit, nsfw",
50
+ }
51
+ RATING_MAP_E621 = {
52
+ "general": "rating_safe",
53
+ "sensitive": "rating_safe",
54
+ "questionable": "rating_questionable",
55
+ "explicit": "rating_explicit",
56
+ }
57
+
58
+ DESCRIPTION_MD = """
59
+ # WD Tagger with 🤗 transformers
60
+ Currently supports the following model(s):
61
+ - [p1atdev/wd-swinv2-tagger-v3-hf](https://huggingface.co/p1atdev/wd-swinv2-tagger-v3-hf)
62
+
63
+ """.strip()
64
+
65
+
66
+ def character_list_to_series_list(character_list):
67
+ def get_series_dict():
68
+ import re
69
+
70
+ with open('characterfull.txt', 'r') as f:
71
+ lines = f.readlines()
72
+
73
+ series_dict = {}
74
+ for line in lines:
75
+ parts = line.strip().split(', ')
76
+ if len(parts) >= 3:
77
+ name = parts[-2].replace("\\", "")
78
+ if name.endswith(")"):
79
+ names = name.split("(")
80
+ character_name = "(".join(names[:-1])
81
+ if character_name.endswith(" "):
82
+ name = character_name[:-1]
83
+ series = re.sub(r'\\[()]', '', parts[-1])
84
+ series_dict[name] = series
85
+
86
+ return series_dict
87
+
88
+ output_series_tag = []
89
+ series_tag = ""
90
+ series_dict = get_series_dict()
91
+ for tag in character_list:
92
+ series_tag = series_dict.get(tag, "")
93
+ if tag.endswith(")"):
94
+ tags = tag.split("(")
95
+ character_tag = "(".join(tags[:-1])
96
+ if character_tag.endswith(" "):
97
+ character_tag = character_tag[:-1]
98
+ series_tag = tags[-1].replace(")", "")
99
+
100
+ if series_tag:
101
+ output_series_tag.append(series_tag)
102
+
103
+ return output_series_tag
104
+
105
+
106
+ def get_e621_dict():
107
+ with open('danbooru_e621.csv', 'r', encoding="utf-8") as f:
108
+ lines = f.readlines()
109
+
110
+ e621_dict = {}
111
+ for line in lines:
112
+ parts = line.strip().split(',')
113
+ e621_dict[parts[0]] = parts[1]
114
+
115
+ return e621_dict
116
+
117
+
118
+ def danbooru_to_e621(dtag, e621_dict):
119
+ def d_to_e(match, e621_dict):
120
+ dtag = match.group(0)
121
+ etag = e621_dict.get(dtag.strip().replace("_", " "), "")
122
+ if etag:
123
+ return etag
124
+ else:
125
+ return dtag
126
+
127
+ import re
128
+ tag = re.sub(r'[\w ]+', lambda wrapper: d_to_e(wrapper, e621_dict), dtag, 2)
129
+
130
+ return tag
131
+
132
+ def postprocess_results(
133
+ results: dict[str, float], general_threshold: float, character_threshold: float
134
+ ):
135
+ results = {
136
+ k: v for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)
137
+ }
138
+
139
+ rating = {}
140
+ character = {}
141
+ general = {}
142
+
143
+ for k, v in results.items():
144
+ if k.startswith("rating:"):
145
+ rating[k.replace("rating:", "")] = v
146
+ continue
147
+ elif k.startswith("character:"):
148
+ character[k.replace("character:", "")] = v
149
+ continue
150
+
151
+ general[k] = v
152
+
153
+ character = {k: v for k, v in character.items() if v >= character_threshold}
154
+ general = {k: v for k, v in general.items() if v >= general_threshold}
155
+
156
+ return rating, character, general
157
+
158
+
159
+ def animagine_prompt(rating: list[str], character: list[str], general: list[str], tag_type):
160
+ people_tags: list[str] = []
161
+ other_tags: list[str] = []
162
+ if tag_type == "e621":
163
+ rating_tag = RATING_MAP_E621[rating[0]]
164
+ else:
165
+ rating_tag = RATING_MAP[rating[0]]
166
+
167
+ e621_dict = get_e621_dict()
168
+ for tag in general:
169
+ if tag_type == "e621":
170
+ tag = danbooru_to_e621(tag, e621_dict)
171
+ if tag in PEOPLE_TAGS:
172
+ people_tags.append(tag)
173
+ else:
174
+ other_tags.append(tag)
175
+
176
+ output_series_tag = character_list_to_series_list(character)
177
+
178
+ all_tags = people_tags + character + output_series_tag + other_tags + [rating_tag]
179
+
180
+ return ", ".join(all_tags)
181
+
182
+
183
+ @spaces.GPU(enable_queue=True)
184
+ def predict_tags(
185
+ image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8, tag_type = "danbooru"
186
+ ):
187
+ inputs = processor.preprocess(image, return_tensors="pt")
188
+
189
+ outputs = model(**inputs.to(model.device, model.dtype))
190
+ logits = torch.sigmoid(outputs.logits[0]) # take the first logits
191
+
192
+ # get probabilities
193
+ results = {
194
+ model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits)
195
+ }
196
+
197
+ # rating, character, general
198
+ rating, character, general = postprocess_results(
199
+ results, general_threshold, character_threshold
200
+ )
201
+
202
+ prompt = animagine_prompt(
203
+ list(rating.keys()), list(character.keys()), list(general.keys()), tag_type
204
+ )
205
+
206
+ return rating, character, general, prompt, gr.update(interactive=True,)
207
+
208
+
209
+ def demo():
210
+ with gr.Blocks() as ui:
211
+ gr.Markdown(DESCRIPTION_MD)
212
+
213
+ with gr.Row():
214
+ with gr.Column():
215
+ input_image = gr.Image(label="Input image", type="pil")
216
+
217
+ with gr.Group():
218
+ general_threshold = gr.Slider(
219
+ label="Threshold",
220
+ minimum=0.0,
221
+ maximum=1.0,
222
+ value=0.3,
223
+ step=0.01,
224
+ interactive=True,
225
+ )
226
+ character_threshold = gr.Slider(
227
+ label="Character threshold",
228
+ minimum=0.0,
229
+ maximum=1.0,
230
+ value=0.8,
231
+ step=0.01,
232
+ interactive=True,
233
+ )
234
+ tag_type = gr.Radio(
235
+ label="Output tag conversion",
236
+ info="danbooru for Animagine, e621 for Pony.",
237
+ choices=["danbooru", "e621"],
238
+ value="danbooru",
239
+ )
240
+
241
+ _model_radio = gr.Dropdown(
242
+ choices=MODEL_NAMES,
243
+ label="Model",
244
+ value=MODEL_NAMES[0],
245
+ interactive=True,
246
+ )
247
+
248
+ start_btn = gr.Button(value="Start", variant="primary")
249
+
250
+ with gr.Column():
251
+
252
+ with gr.Group():
253
+ prompt_text = gr.TextArea(label="Prompt", interactive=False)
254
+ copy_btn = gr.Button(value="Copy to clipboard", interactive=False)
255
+
256
+ rating_tags_label = gr.Label(label="Rating tags")
257
+ character_tags_label = gr.Label(label="Character tags")
258
+ general_tags_label = gr.Label(label="General tags")
259
+
260
+ start_btn.click(
261
+ predict_tags,
262
+ inputs=[input_image, general_threshold, character_threshold, tag_type],
263
+ outputs=[
264
+ rating_tags_label,
265
+ character_tags_label,
266
+ general_tags_label,
267
+ prompt_text,
268
+ copy_btn,
269
+ ],
270
+ )
271
+ copy_btn.click(gradio_copy_text, inputs=[prompt_text], js=COPY_ACTION_JS)
272
+
273
+ return ui
274
+
275
+ if __name__ == "__main__":
276
+ demo().queue().launch()
characterfull.txt ADDED
The diff for this file is too large to render. See raw diff
 
danbooru_e621.csv ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ accelerate
4
+ transformers
5
+ spaces