John6666 commited on
Commit
b47fcc1
1 Parent(s): 8865e09

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +60 -118
  2. output.py +0 -1
  3. tagger.py +47 -103
  4. utils.py +4 -35
  5. v2.py +65 -4
app.py CHANGED
@@ -323,33 +323,32 @@ logger.setLevel(logging.DEBUG)
323
 
324
  from v2 import (
325
  V2UI,
326
- ALL_MODELS,
 
327
  )
328
  from utils import (
329
  gradio_copy_text,
330
  COPY_ACTION_JS,
331
- ASPECT_RATIO_OPTIONS,
332
- RATING_OPTIONS,
333
- LENGTH_OPTIONS,
334
- IDENTITY_OPTIONS
335
  )
336
  from tagger import (
337
  predict_tags,
338
- parse_upsampling_output,
339
  convert_danbooru_to_e621_prompt,
 
340
  insert_recom_prompt,
 
341
  translate_prompt,
342
  )
343
  def description_ui():
344
  gr.Markdown(
345
  """
346
  ## Danbooru Tags Transformer V2 Demo with WD Tagger
347
- It’s a modification of [p1atdev's Danbooru Tags Transformer V2 Demo](https://huggingface.co/spaces/p1atdev/danbooru-tags-transformer-v2) and [p1atdev's WD Tagger with 🤗 transformers](https://huggingface.co/spaces/p1atdev/wd-tagger-transformers).
348
-
349
- Models:
350
- - [p1atdev/wd-swinv2-tagger-v3-hf](https://huggingface.co/p1atdev/wd-swinv2-tagger-v3-hf)
351
- - [dart-v2-moe-sft](https://huggingface.co/p1atdev/dart-v2-moe-sft) (Mixtral architecture)
352
- - [dart-v2-sft](https://huggingface.co/p1atdev/dart-v2-sft) (Mistral architecture)
353
  """
354
  )
355
 
@@ -704,11 +703,11 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
704
  character_dbt = gr.Textbox(lines=1, placeholder="kafuu chino, ...", label="Character names", scale=2)
705
  series_dbt = gr.Textbox(lines=1, placeholder="Is the order a rabbit?, ...", label="Series names", scale=2)
706
  generate_db_random_button = gr.Button(value="Generate random prompt from character", size="sm", variant="secondary")
707
- model_name_dbt = gr.Dropdown(label="Model", choices=list(ALL_MODELS.keys()), value=list(ALL_MODELS.keys())[0], visible=False)
708
- rating_dbt = gr.Radio(label="Rating", choices=list(RATING_OPTIONS), value="explicit", visible=False)
709
- aspect_ratio_dbt = gr.Radio(label="Aspect ratio", choices=list(ASPECT_RATIO_OPTIONS), value="square", visible=False)
710
- length_dbt = gr.Radio(label="Length", choices=list(LENGTH_OPTIONS), value="very_long", visible=False)
711
- identity_dbt = gr.Radio(label="Keep identity", choices=list(IDENTITY_OPTIONS), value="lax", visible=False)
712
  ban_tags_dbt = gr.Textbox(label="Ban tags", placeholder="alternate costumen, ...", value="futanari, censored, furry, furrification", visible=False)
713
  elapsed_time_dbt = gr.Markdown(label="Elapsed time", value="", visible=False)
714
  copy_button_dbt = gr.Button(value="Copy to clipboard", visible=False)
@@ -737,8 +736,8 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
737
  )
738
 
739
  with gr.Accordion("Generation settings", open=False, visible=True):
740
- steps_gui = gr.Slider(minimum=1, maximum=100, step=1, value=30, label="Steps")
741
- cfg_gui = gr.Slider(minimum=0, maximum=30, step=0.5, value=7.5, label="CFG")
742
  sampler_gui = gr.Dropdown(label="Sampler", choices=scheduler_names, value="Euler a")
743
  img_width_gui = gr.Slider(minimum=64, maximum=4096, step=8, value=1024, label="Img Width")
744
  img_height_gui = gr.Slider(minimum=64, maximum=4096, step=8, value=1024, label="Img Height")
@@ -963,7 +962,6 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
963
  length_dbt,
964
  identity_dbt,
965
  ban_tags_dbt,
966
- prompt_type_gui,
967
  ]
968
 
969
  insert_prompt_gui.change(
@@ -983,7 +981,7 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
983
  inputs=[
984
  *v2b.input_components,
985
  ],
986
- outputs=[prompt_gui, elapsed_time_dbt, copy_button_dbt],
987
  )
988
 
989
  translate_prompt_button.click(translate_prompt, inputs=[prompt_gui], outputs=[prompt_gui])
@@ -1094,102 +1092,41 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
1094
  with gr.Group():
1095
  input_image = gr.Image(label="Input image", type="pil", sources=["upload", "clipboard"])
1096
  with gr.Accordion(label="Advanced options", open=False):
1097
- general_threshold = gr.Slider(
1098
- label="Threshold",
1099
- minimum=0.0,
1100
- maximum=1.0,
1101
- value=0.3,
1102
- step=0.01,
1103
- interactive=True,
1104
- )
1105
- character_threshold = gr.Slider(
1106
- label="Character threshold",
1107
- minimum=0.0,
1108
- maximum=1.0,
1109
- value=0.8,
1110
- step=0.01,
1111
- interactive=True,
1112
- )
1113
- keep_tags = gr.Radio(
1114
- label="Remove tags leaving only the following",
1115
- choices=["body", "dress", "all"],
1116
- value="body",
1117
- )
1118
- generate_from_image_btn = gr.Button(value="Generate input tags from image", variant="primary")
1119
 
1120
  with gr.Group():
1121
- input_character = gr.Textbox(
1122
- label="Character tags",
1123
- placeholder="hatsune miku",
1124
- )
1125
- input_copyright = gr.Textbox(
1126
- label="Copyright tags",
1127
- placeholder="vocaloid",
1128
- )
1129
- input_general = gr.TextArea(
1130
- label="General tags",
1131
- lines=4,
1132
- placeholder="1girl, solo, ...",
1133
- value="",
1134
- )
1135
  input_tags_to_copy = gr.Textbox(value="", visible=False)
1136
- copy_input_btn = gr.Button(
1137
- value="Copy to clipboard",
1138
- interactive=False,
1139
- )
1140
- translate_input_prompt_button = gr.Button(value="Translate prompt to English", variant="secondary")
1141
- tag_type = gr.Radio(
1142
- label="Output tag conversion",
1143
- info="danbooru for Animagine, e621 for Pony.",
1144
- choices=["danbooru", "e621"],
1145
- value="danbooru",
1146
- )
1147
- input_rating = gr.Radio(
1148
- label="Rating",
1149
- choices=list(RATING_OPTIONS),
1150
- value="explicit",
1151
- )
1152
  with gr.Accordion(label="Advanced options", open=False):
1153
- input_aspect_ratio = gr.Radio(
1154
- label="Aspect ratio",
1155
- info="The aspect ratio of the image.",
1156
- choices=list(ASPECT_RATIO_OPTIONS),
1157
- value="square",
1158
- )
1159
- input_length = gr.Radio(
1160
- label="Length",
1161
- info="The total length of the tags.",
1162
- choices=list(LENGTH_OPTIONS),
1163
- value="very_long",
1164
- )
1165
- input_identity = gr.Radio(
1166
- label="Keep identity",
1167
- info="How strictly to keep the identity of the character or subject. If you specify the detail of subject in the prompt, you should choose `strict`. Otherwise, choose `none` or `lax`. `none` is very creative but sometimes ignores the input prompt.",
1168
- choices=list(IDENTITY_OPTIONS),
1169
- value="lax",
1170
- )
1171
- input_ban_tags = gr.Textbox(
1172
- label="Ban tags",
1173
- info="Tags to ban from the output.",
1174
- placeholder="alternate costumen, ...",
1175
- value="futanari, censored, furry, furrification"
1176
- )
1177
- model_name = gr.Dropdown(
1178
- label="Model",
1179
- choices=list(ALL_MODELS.keys()),
1180
- value=list(ALL_MODELS.keys())[0],
1181
- )
1182
-
1183
- generate_btn = gr.Button(value="GENERATE TAGS", variant="primary")
1184
 
1185
  with gr.Group():
1186
- output_text = gr.TextArea(label="Output tags", interactive=False)
1187
- copy_btn = gr.Button(
1188
- value="Copy to clipboard",
1189
- interactive=False,
1190
- )
1191
- elapsed_time_md = gr.Markdown(label="Elapsed time", value="")
1192
-
 
1193
  description_ui()
1194
 
1195
  v2.input_components = [
@@ -1202,7 +1139,6 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
1202
  input_length,
1203
  input_identity,
1204
  input_ban_tags,
1205
- tag_type,
1206
  ]
1207
 
1208
  translate_input_prompt_button.click(translate_prompt, inputs=[input_general], outputs=[input_general])
@@ -1211,27 +1147,33 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
1211
 
1212
  generate_from_image_btn.click(
1213
  predict_tags,
1214
- inputs=[input_image, general_threshold, character_threshold, keep_tags],
1215
  outputs=[
1216
  input_copyright,
1217
  input_character,
1218
  input_general,
1219
- input_tags_to_copy,
1220
  copy_input_btn,
1221
  ],
 
 
 
1222
  )
1223
 
1224
- copy_input_btn.click(gradio_copy_text, inputs=[input_tags_to_copy], js=COPY_ACTION_JS)
1225
-
1226
  generate_btn.click(
1227
  parse_upsampling_output(v2.on_generate),
1228
  inputs=[
1229
  *v2.input_components,
1230
  ],
1231
- outputs=[output_text, elapsed_time_md, copy_btn],
 
 
 
 
 
 
1232
  )
1233
-
1234
  copy_btn.click(gradio_copy_text, inputs=[output_text], js=COPY_ACTION_JS)
 
1235
 
1236
  with gr.Accordion("Examples", open=True, visible=True):
1237
  gr.Examples(
 
323
 
324
  from v2 import (
325
  V2UI,
326
+ parse_upsampling_output,
327
+ V2_ALL_MODELS,
328
  )
329
  from utils import (
330
  gradio_copy_text,
331
  COPY_ACTION_JS,
332
+ V2_ASPECT_RATIO_OPTIONS,
333
+ V2_RATING_OPTIONS,
334
+ V2_LENGTH_OPTIONS,
335
+ V2_IDENTITY_OPTIONS
336
  )
337
  from tagger import (
338
  predict_tags,
 
339
  convert_danbooru_to_e621_prompt,
340
+ remove_specific_prompt,
341
  insert_recom_prompt,
342
+ compose_prompt_to_copy,
343
  translate_prompt,
344
  )
345
  def description_ui():
346
  gr.Markdown(
347
  """
348
  ## Danbooru Tags Transformer V2 Demo with WD Tagger
349
+ Image => Prompt => Upsampled longer prompt
350
+ - Mod of p1atdev's [Danbooru Tags Transformer V2 Demo](https://huggingface.co/spaces/p1atdev/danbooru-tags-transformer-v2) and [WD Tagger with 🤗 transformers](https://huggingface.co/spaces/p1atdev/wd-tagger-transformers).
351
+ - Models: p1atdev's [wd-swinv2-tagger-v3-hf](https://huggingface.co/p1atdev/wd-swinv2-tagger-v3-hf), [dart-v2-moe-sft](https://huggingface.co/p1atdev/dart-v2-moe-sft), [dart-v2-sft](https://huggingface.co/p1atdev/dart-v2-sft)
 
 
 
352
  """
353
  )
354
 
 
703
  character_dbt = gr.Textbox(lines=1, placeholder="kafuu chino, ...", label="Character names", scale=2)
704
  series_dbt = gr.Textbox(lines=1, placeholder="Is the order a rabbit?, ...", label="Series names", scale=2)
705
  generate_db_random_button = gr.Button(value="Generate random prompt from character", size="sm", variant="secondary")
706
+ model_name_dbt = gr.Dropdown(label="Model", choices=list(V2_ALL_MODELS.keys()), value=list(V2_ALL_MODELS.keys())[0], visible=False)
707
+ rating_dbt = gr.Radio(label="Rating", choices=list(V2_RATING_OPTIONS), value="explicit", visible=False)
708
+ aspect_ratio_dbt = gr.Radio(label="Aspect ratio", choices=list(V2_ASPECT_RATIO_OPTIONS), value="square", visible=False)
709
+ length_dbt = gr.Radio(label="Length", choices=list(V2_LENGTH_OPTIONS), value="very_long", visible=False)
710
+ identity_dbt = gr.Radio(label="Keep identity", choices=list(V2_IDENTITY_OPTIONS), value="lax", visible=False)
711
  ban_tags_dbt = gr.Textbox(label="Ban tags", placeholder="alternate costumen, ...", value="futanari, censored, furry, furrification", visible=False)
712
  elapsed_time_dbt = gr.Markdown(label="Elapsed time", value="", visible=False)
713
  copy_button_dbt = gr.Button(value="Copy to clipboard", visible=False)
 
736
  )
737
 
738
  with gr.Accordion("Generation settings", open=False, visible=True):
739
+ steps_gui = gr.Slider(minimum=1, maximum=100, step=1, value=28, label="Steps")
740
+ cfg_gui = gr.Slider(minimum=0, maximum=30, step=0.5, value=7.0, label="CFG")
741
  sampler_gui = gr.Dropdown(label="Sampler", choices=scheduler_names, value="Euler a")
742
  img_width_gui = gr.Slider(minimum=64, maximum=4096, step=8, value=1024, label="Img Width")
743
  img_height_gui = gr.Slider(minimum=64, maximum=4096, step=8, value=1024, label="Img Height")
 
962
  length_dbt,
963
  identity_dbt,
964
  ban_tags_dbt,
 
965
  ]
966
 
967
  insert_prompt_gui.change(
 
981
  inputs=[
982
  *v2b.input_components,
983
  ],
984
+ outputs=[prompt_gui, elapsed_time_dbt, copy_button_dbt, copy_button_dbt],
985
  )
986
 
987
  translate_prompt_button.click(translate_prompt, inputs=[prompt_gui], outputs=[prompt_gui])
 
1092
  with gr.Group():
1093
  input_image = gr.Image(label="Input image", type="pil", sources=["upload", "clipboard"])
1094
  with gr.Accordion(label="Advanced options", open=False):
1095
+ general_threshold = gr.Slider(label="Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.01, interactive=True)
1096
+ character_threshold = gr.Slider(label="Character threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.01, interactive=True)
1097
+ keep_tags = gr.Radio(label="Remove tags leaving only the following", choices=["body", "dress", "all"], value="body")
1098
+ generate_from_image_btn = gr.Button(value="GENERATE TAGS FROM IMAGE", size="lg", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1099
 
1100
  with gr.Group():
1101
+ input_character = gr.Textbox(label="Character tags", placeholder="hatsune miku")
1102
+ input_copyright = gr.Textbox(label="Copyright tags", placeholder="vocaloid")
1103
+ input_general = gr.TextArea(label="General tags", lines=4, placeholder="1girl, ...", value="")
 
 
 
 
 
 
 
 
 
 
 
1104
  input_tags_to_copy = gr.Textbox(value="", visible=False)
1105
+ copy_input_btn = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
1106
+ translate_input_prompt_button = gr.Button(value="Translate prompt to English", size="sm", variant="secondary")
1107
+ tag_type = gr.Radio(label="Output tag conversion", info="danbooru for Animagine, e621 for Pony.", choices=["danbooru", "e621"], value="e621", visible=False)
1108
+ input_rating = gr.Radio(label="Rating", choices=list(V2_RATING_OPTIONS), value="explicit")
 
 
 
 
 
 
 
 
 
 
 
 
1109
  with gr.Accordion(label="Advanced options", open=False):
1110
+ input_aspect_ratio = gr.Radio(label="Aspect ratio", info="The aspect ratio of the image.", choices=list(V2_ASPECT_RATIO_OPTIONS), value="square")
1111
+ input_length = gr.Radio(label="Length", info="The total length of the tags.", choices=list(V2_LENGTH_OPTIONS), value="very_long")
1112
+ input_identity = gr.Radio(label="Keep identity", info="How strictly to keep the identity of the character or subject. If you specify the detail of subject in the prompt, you should choose `strict`. Otherwise, choose `none` or `lax`. `none` is very creative but sometimes ignores the input prompt.", choices=list(V2_IDENTITY_OPTIONS), value="lax")
1113
+ input_ban_tags = gr.Textbox(label="Ban tags", info="Tags to ban from the output.", placeholder="alternate costumen, ...", value="censored")
1114
+ model_name = gr.Dropdown(label="Model", choices=list(V2_ALL_MODELS.keys()), value=list(V2_ALL_MODELS.keys())[0])
1115
+ dummy_np = gr.Textbox(label="Negative prompt", value="", visible=False)
1116
+ recom_animagine = gr.Textbox(label="Animagine reccomended prompt", value="Animagine", visible=False)
1117
+ recom_pony = gr.Textbox(label="Pony reccomended prompt", value="Pony", visible=False)
1118
+
1119
+ generate_btn = gr.Button(value="GENERATE TAGS", size="lg", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1120
 
1121
  with gr.Group():
1122
+ output_text = gr.TextArea(label="Output tags", interactive=False, show_copy_button=True)
1123
+ copy_btn = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
1124
+ elapsed_time_md = gr.Markdown(label="Elapsed time", value="", visible=False)
1125
+
1126
+ with gr.Group():
1127
+ output_text_pony = gr.TextArea(label="Output tags (Pony e621 style)", interactive=False, show_copy_button=True)
1128
+ copy_btn_pony = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
1129
+
1130
  description_ui()
1131
 
1132
  v2.input_components = [
 
1139
  input_length,
1140
  input_identity,
1141
  input_ban_tags,
 
1142
  ]
1143
 
1144
  translate_input_prompt_button.click(translate_prompt, inputs=[input_general], outputs=[input_general])
 
1147
 
1148
  generate_from_image_btn.click(
1149
  predict_tags,
1150
+ inputs=[input_image, general_threshold, character_threshold],
1151
  outputs=[
1152
  input_copyright,
1153
  input_character,
1154
  input_general,
 
1155
  copy_input_btn,
1156
  ],
1157
+ ).then(remove_specific_prompt, inputs=[input_general, keep_tags], outputs=[input_general])
1158
+ copy_input_btn.click(compose_prompt_to_copy, inputs=[input_character, input_copyright, input_general], outputs=[input_tags_to_copy]).then(
1159
+ gradio_copy_text, inputs=[input_tags_to_copy], js=COPY_ACTION_JS,
1160
  )
1161
 
 
 
1162
  generate_btn.click(
1163
  parse_upsampling_output(v2.on_generate),
1164
  inputs=[
1165
  *v2.input_components,
1166
  ],
1167
+ outputs=[output_text, elapsed_time_md, copy_btn, copy_btn_pony],
1168
+ ).then(
1169
+ convert_danbooru_to_e621_prompt, inputs=[output_text, tag_type], outputs=[output_text_pony],
1170
+ ).then(
1171
+ insert_recom_prompt, inputs=[output_text, dummy_np, recom_animagine], outputs=[output_text, dummy_np],
1172
+ ).then(
1173
+ insert_recom_prompt, inputs=[output_text_pony, dummy_np, recom_pony], outputs=[output_text_pony, dummy_np],
1174
  )
 
1175
  copy_btn.click(gradio_copy_text, inputs=[output_text], js=COPY_ACTION_JS)
1176
+ copy_btn_pony.click(gradio_copy_text, inputs=[output_text_pony], js=COPY_ACTION_JS)
1177
 
1178
  with gr.Accordion("Examples", open=True, visible=True):
1179
  gr.Examples(
output.py CHANGED
@@ -12,6 +12,5 @@ class UpsamplingOutput:
12
  aspect_ratio_tag: str
13
  length_tag: str
14
  identity_tag: str
15
- tag_type: str
16
 
17
  elapsed_time: float = 0.0
 
12
  aspect_ratio_tag: str
13
  length_tag: str
14
  identity_tag: str
 
15
 
16
  elapsed_time: float = 0.0
tagger.py CHANGED
@@ -1,6 +1,4 @@
1
  from PIL import Image
2
- from typing import Callable
3
-
4
  import torch
5
 
6
  from transformers import (
@@ -11,15 +9,10 @@ from transformers import (
11
  import gradio as gr
12
  import spaces # ZERO GPU
13
 
14
- from output import UpsamplingOutput
15
-
16
-
17
  MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
18
  MODEL_NAME = MODEL_NAMES[0]
19
 
20
- model = AutoModelForImageClassification.from_pretrained(
21
- MODEL_NAME,
22
- )
23
  model.to("cuda" if torch.cuda.is_available() else "cpu")
24
  processor = AutoImageProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
25
 
@@ -43,22 +36,6 @@ RATING_MAP = {
43
  "questionable": "nsfw",
44
  "explicit": "explicit, nsfw",
45
  }
46
- NORMALIZE_RATING_TAG = {
47
- "sfw": "",
48
- "general": "",
49
- "sensitive": "sensitive",
50
- "nsfw": "nsfw",
51
- "questionable": "nsfw",
52
- "explicit": "nsfw, explicit",
53
- }
54
- NORMALIZE_RATING_TAG_E621 = {
55
- "sfw": "rating_safe",
56
- "general": "rating_safe",
57
- "sensitive": "sensitive",
58
- "nsfw": "nsfw, rating_explicit",
59
- "questionable": "rating_questionable",
60
- "explicit": "rating_explicit",
61
- }
62
  DANBOORU_TO_E621_RATING_MAP = {
63
  "safe": "rating_safe",
64
  "sensitive": "rating_safe",
@@ -160,6 +137,7 @@ def convert_danbooru_to_e621_prompt(input_prompt: str = "", prompt_type: str = "
160
 
161
  rating_tags = sorted(set(rating_tags), key=rating_tags.index)
162
  rating_tags = [rating_tags[0]] if rating_tags else []
 
163
 
164
  output_prompt = ", ".join(people_tags + other_tags + rating_tags)
165
 
@@ -246,7 +224,7 @@ def get_tag_group_dict():
246
  return tag_group_dict
247
 
248
 
249
- def is_necessary(tag, keep_tags, group_dict):
250
  def is_dressed(tag):
251
  import re
252
  p = re.compile(r'dress|cloth|uniform|costume|vest|sweater|coat|shirt|jacket|blazer|apron|leotard|hood|sleeve|skirt|shorts|pant|loafer|ribbon|necktie|bow|collar|glove|sock|shoe|boots|wear|emblem')
@@ -257,27 +235,48 @@ def is_necessary(tag, keep_tags, group_dict):
257
  p = re.compile(r'background|outline|light|sky|build|day|screen|tree|city')
258
  return p.search(tag)
259
 
 
260
  group_list = ['people', 'age', 'pattern', 'place', 'hair', 'modifier', 'screen', 'animal', 'effect', 'situation', 'status', 'lighting', 'accesory', 'body', 'nsfw', 'camera', 'option', 'taste', 'other', 'detail', 'action', 'dress', 'character', 'face', 'costume', 'attribute', 'weather', 'temporary', 'gender', 'favorite', 'food', 'object', 'quality', 'expression', 'life', 'background']
261
  keep_group_dict = {
262
  "body": ['people', 'age', 'hair', 'body', 'character', 'face', 'gender'],
263
  "dress": ['people', 'age', 'hair', 'accesory', 'body', 'dress', 'character', 'face', 'costume', 'gender'],
264
  "all": ['people', 'age', 'pattern', 'place', 'hair', 'modifier', 'screen', 'animal', 'effect', 'situation', 'status', 'lighting', 'accesory', 'body', 'nsfw', 'camera', 'option', 'taste', 'other', 'detail', 'action', 'dress', 'character', 'face', 'costume', 'attribute', 'weather', 'temporary', 'gender', 'favorite', 'food', 'object', 'quality', 'expression', 'life', 'background']
265
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  keep_group = keep_group_dict.get(keep_tags, ['people', 'age', 'hair', 'body', 'character', 'face', 'gender'])
267
  explicit_group = list(set(group_list) ^ set(keep_group))
268
- if group_dict.get(tag.strip().replace("_", " "), "") in explicit_group:
269
- return False
270
- elif keep_tags == "body" and is_dressed(tag):
271
- return False
272
- elif is_background(tag):
273
- return False
274
- else:
275
- return True
276
 
 
 
 
277
 
278
- def postprocess_results(
279
- results: dict[str, float], general_threshold: float, character_threshold: float
280
- ):
 
 
 
 
 
 
 
 
 
 
 
281
  results = {
282
  k: v for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)
283
  }
@@ -302,16 +301,15 @@ def postprocess_results(
302
  return rating, character, general
303
 
304
 
305
- def gen_prompt(rating: list[str], character: list[str], general: list[str], keep_tags):
306
  people_tags: list[str] = []
307
  other_tags: list[str] = []
308
  rating_tag = RATING_MAP[rating[0]]
309
 
310
- group_dict = get_tag_group_dict()
311
  for tag in general:
312
  if tag in PEOPLE_TAGS:
313
  people_tags.append(tag)
314
- elif is_necessary(tag, keep_tags, group_dict):
315
  other_tags.append(tag)
316
 
317
  all_tags = people_tags + other_tags
@@ -320,9 +318,7 @@ def gen_prompt(rating: list[str], character: list[str], general: list[str], keep
320
 
321
 
322
  @spaces.GPU()
323
- def predict_tags(
324
- image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8, keep_tags = "all",
325
- ):
326
  inputs = processor.preprocess(image, return_tensors="pt")
327
 
328
  outputs = model(**inputs.to(model.device, model.dtype))
@@ -339,7 +335,7 @@ def predict_tags(
339
  )
340
 
341
  prompt = gen_prompt(
342
- list(rating.keys()), list(character.keys()), list(general.keys()), keep_tags
343
  )
344
 
345
  output_series_tag = ""
@@ -349,65 +345,13 @@ def predict_tags(
349
  else:
350
  output_series_tag = ""
351
 
352
- cprompt = ", ".join(character.keys())
353
- cprompt = cprompt + ", " + output_series_tag if output_series_tag else cprompt
354
- cprompt = cprompt + ", " + prompt if prompt else cprompt
355
-
356
- return output_series_tag, ", ".join(character.keys()), prompt, cprompt, gr.update(interactive=True),
357
 
358
 
359
- def gen_prompt_text(output: UpsamplingOutput):
360
- # separate people tags (e.g. 1girl)
361
- people_tags = []
362
- other_general_tags = []
363
-
364
- e621_dict = get_e621_dict() if output.tag_type == "e621" else {}
365
- for tag in output.general_tags.split(","):
366
- tag = tag.strip()
367
-
368
- if tag in PEOPLE_TAGS:
369
- if output.tag_type == "e621":
370
- tag = danbooru_to_e621(tag, e621_dict)
371
- people_tags.append(tag)
372
- else:
373
- if output.tag_type == "e621":
374
- tag = danbooru_to_e621(tag, e621_dict)
375
- other_general_tags.append(tag)
376
-
377
- return ", ".join(
378
- [
379
- part.strip()
380
- for part in [
381
- *people_tags,
382
- output.character_tags,
383
- output.copyright_tags,
384
- *other_general_tags,
385
- output.upsampled_tags,
386
- NORMALIZE_RATING_TAG_E621[output.rating_tag] if output.tag_type == "e621" else NORMALIZE_RATING_TAG[output.rating_tag],
387
- ]
388
- if part.strip() != ""
389
- ]
390
- )
391
-
392
-
393
- def elapsed_time_format(elapsed_time: float) -> str:
394
- return f"Elapsed: {elapsed_time:.2f} seconds"
395
-
396
-
397
- def parse_upsampling_output(
398
- upsampler: Callable[..., UpsamplingOutput],
399
- ):
400
- def _parse_upsampling_output(*args) -> tuple[str, str, dict]:
401
- output = upsampler(*args)
402
-
403
- print(output)
404
-
405
- return (
406
- gen_prompt_text(output),
407
- elapsed_time_format(output.elapsed_time),
408
- gr.update(
409
- interactive=True,
410
- ),
411
- )
412
-
413
- return _parse_upsampling_output
 
1
  from PIL import Image
 
 
2
  import torch
3
 
4
  from transformers import (
 
9
  import gradio as gr
10
  import spaces # ZERO GPU
11
 
 
 
 
12
  MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
13
  MODEL_NAME = MODEL_NAMES[0]
14
 
15
+ model = AutoModelForImageClassification.from_pretrained(MODEL_NAME, trust_remote_code=True)
 
 
16
  model.to("cuda" if torch.cuda.is_available() else "cpu")
17
  processor = AutoImageProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
18
 
 
36
  "questionable": "nsfw",
37
  "explicit": "explicit, nsfw",
38
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  DANBOORU_TO_E621_RATING_MAP = {
40
  "safe": "rating_safe",
41
  "sensitive": "rating_safe",
 
137
 
138
  rating_tags = sorted(set(rating_tags), key=rating_tags.index)
139
  rating_tags = [rating_tags[0]] if rating_tags else []
140
+ rating_tags = ["explicit, nsfw"] if rating_tags[0] == "explicit" else rating_tags
141
 
142
  output_prompt = ", ".join(people_tags + other_tags + rating_tags)
143
 
 
224
  return tag_group_dict
225
 
226
 
227
+ def remove_specific_prompt(input_prompt: str = "", keep_tags: str = "all"):
228
  def is_dressed(tag):
229
  import re
230
  p = re.compile(r'dress|cloth|uniform|costume|vest|sweater|coat|shirt|jacket|blazer|apron|leotard|hood|sleeve|skirt|shorts|pant|loafer|ribbon|necktie|bow|collar|glove|sock|shoe|boots|wear|emblem')
 
235
  p = re.compile(r'background|outline|light|sky|build|day|screen|tree|city')
236
  return p.search(tag)
237
 
238
+ un_tags = ['solo']
239
  group_list = ['people', 'age', 'pattern', 'place', 'hair', 'modifier', 'screen', 'animal', 'effect', 'situation', 'status', 'lighting', 'accesory', 'body', 'nsfw', 'camera', 'option', 'taste', 'other', 'detail', 'action', 'dress', 'character', 'face', 'costume', 'attribute', 'weather', 'temporary', 'gender', 'favorite', 'food', 'object', 'quality', 'expression', 'life', 'background']
240
  keep_group_dict = {
241
  "body": ['people', 'age', 'hair', 'body', 'character', 'face', 'gender'],
242
  "dress": ['people', 'age', 'hair', 'accesory', 'body', 'dress', 'character', 'face', 'costume', 'gender'],
243
  "all": ['people', 'age', 'pattern', 'place', 'hair', 'modifier', 'screen', 'animal', 'effect', 'situation', 'status', 'lighting', 'accesory', 'body', 'nsfw', 'camera', 'option', 'taste', 'other', 'detail', 'action', 'dress', 'character', 'face', 'costume', 'attribute', 'weather', 'temporary', 'gender', 'favorite', 'food', 'object', 'quality', 'expression', 'life', 'background']
244
  }
245
+
246
+ def is_necessary(tag, keep_tags, group_dict):
247
+ if keep_tags == "all":
248
+ return True
249
+ elif tag in un_tags or group_dict.get(tag, "") in explicit_group:
250
+ return False
251
+ elif keep_tags == "body" and is_dressed(tag):
252
+ return False
253
+ elif is_background(tag):
254
+ return False
255
+ else:
256
+ return True
257
+
258
+ if keep_tags == "all": return input_prompt
259
  keep_group = keep_group_dict.get(keep_tags, ['people', 'age', 'hair', 'body', 'character', 'face', 'gender'])
260
  explicit_group = list(set(group_list) ^ set(keep_group))
 
 
 
 
 
 
 
 
261
 
262
+ tags = input_prompt.split(",") if input_prompt else []
263
+ people_tags: list[str] = []
264
+ other_tags: list[str] = []
265
 
266
+ group_dict = get_tag_group_dict()
267
+ for tag in tags:
268
+ tag = tag.strip().replace("_", " ")
269
+ if tag in PEOPLE_TAGS:
270
+ people_tags.append(tag)
271
+ elif is_necessary(tag, keep_tags, group_dict):
272
+ other_tags.append(tag)
273
+
274
+ output_prompt = ", ".join(people_tags + other_tags)
275
+
276
+ return output_prompt
277
+
278
+
279
+ def postprocess_results(results: dict[str, float], general_threshold: float, character_threshold: float):
280
  results = {
281
  k: v for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)
282
  }
 
301
  return rating, character, general
302
 
303
 
304
+ def gen_prompt(rating: list[str], character: list[str], general: list[str]):
305
  people_tags: list[str] = []
306
  other_tags: list[str] = []
307
  rating_tag = RATING_MAP[rating[0]]
308
 
 
309
  for tag in general:
310
  if tag in PEOPLE_TAGS:
311
  people_tags.append(tag)
312
+ else:
313
  other_tags.append(tag)
314
 
315
  all_tags = people_tags + other_tags
 
318
 
319
 
320
  @spaces.GPU()
321
+ def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8):
 
 
322
  inputs = processor.preprocess(image, return_tensors="pt")
323
 
324
  outputs = model(**inputs.to(model.device, model.dtype))
 
335
  )
336
 
337
  prompt = gen_prompt(
338
+ list(rating.keys()), list(character.keys()), list(general.keys())
339
  )
340
 
341
  output_series_tag = ""
 
345
  else:
346
  output_series_tag = ""
347
 
348
+ return output_series_tag, ", ".join(character.keys()), prompt, gr.update(interactive=True),
 
 
 
 
349
 
350
 
351
+ def compose_prompt_to_copy(character: str, series: str, general: str):
352
+ characters = character.split(",") if character else []
353
+ serieses = series.split(",") if series else []
354
+ generals = general.split(",") if general else []
355
+ tags = characters + serieses + generals
356
+ cprompt = ",".join(tags) if tags else ""
357
+ return cprompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils.py CHANGED
@@ -1,38 +1,15 @@
1
  import gradio as gr
2
  from dartrs.v2 import AspectRatioTag, LengthTag, RatingTag, IdentityTag
3
 
4
- # from https://huggingface.co/spaces/cagliostrolab/animagine-xl-3.1/blob/main/config.py
5
- QUALITY_TAGS = {
6
- "default": "(masterpiece), best quality, very aesthetic, perfect face",
7
- }
8
- NEGATIVE_PROMPT = {
9
- "default": "nsfw, (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn",
10
- }
11
 
12
-
13
- IMAGE_SIZE_OPTIONS = {
14
- "1536x640": "<|aspect_ratio:ultra_wide|>",
15
- "1344x768": "<|aspect_ratio:wide|>",
16
- "1024x1024": "<|aspect_ratio:square|>",
17
- "768x1344": "<|aspect_ratio:tall|>",
18
- "640x1536": "<|aspect_ratio:ultra_tall|>",
19
- }
20
- IMAGE_SIZES = {
21
- "1536x640": (1536, 640),
22
- "1344x768": (1344, 768),
23
- "1024x1024": (1024, 1024),
24
- "768x1344": (768, 1344),
25
- "640x1536": (640, 1536),
26
- }
27
-
28
- ASPECT_RATIO_OPTIONS: list[AspectRatioTag] = [
29
  "ultra_wide",
30
  "wide",
31
  "square",
32
  "tall",
33
  "ultra_tall",
34
  ]
35
- RATING_OPTIONS: list[RatingTag] = [
36
  "sfw",
37
  "general",
38
  "sensitive",
@@ -40,28 +17,20 @@ RATING_OPTIONS: list[RatingTag] = [
40
  "questionable",
41
  "explicit",
42
  ]
43
- LENGTH_OPTIONS: list[LengthTag] = [
44
  "very_short",
45
  "short",
46
  "medium",
47
  "long",
48
  "very_long",
49
  ]
50
- IDENTITY_OPTIONS: list[IdentityTag] = [
51
  "none",
52
  "lax",
53
  "strict",
54
  ]
55
 
56
 
57
- PEOPLE_TAGS = [
58
- *[f"1{x}" for x in ["girl", "boy", "other"]],
59
- *[f"{i}girls" for i in range(2, 6)],
60
- *[f"6+{x}s" for x in ["girl", "boy", "other"]],
61
- "no humans",
62
- ]
63
-
64
-
65
  # ref: https://qiita.com/tregu148/items/fccccbbc47d966dd2fc2
66
  def gradio_copy_text(_text: None):
67
  gr.Info("Copied!")
 
1
  import gradio as gr
2
  from dartrs.v2 import AspectRatioTag, LengthTag, RatingTag, IdentityTag
3
 
 
 
 
 
 
 
 
4
 
5
+ V2_ASPECT_RATIO_OPTIONS: list[AspectRatioTag] = [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  "ultra_wide",
7
  "wide",
8
  "square",
9
  "tall",
10
  "ultra_tall",
11
  ]
12
+ V2_RATING_OPTIONS: list[RatingTag] = [
13
  "sfw",
14
  "general",
15
  "sensitive",
 
17
  "questionable",
18
  "explicit",
19
  ]
20
+ V2_LENGTH_OPTIONS: list[LengthTag] = [
21
  "very_short",
22
  "short",
23
  "medium",
24
  "long",
25
  "very_long",
26
  ]
27
+ V2_IDENTITY_OPTIONS: list[IdentityTag] = [
28
  "none",
29
  "lax",
30
  "strict",
31
  ]
32
 
33
 
 
 
 
 
 
 
 
 
34
  # ref: https://qiita.com/tregu148/items/fccccbbc47d966dd2fc2
35
  def gradio_copy_text(_text: None):
36
  gr.Info("Copied!")
v2.py CHANGED
@@ -1,6 +1,7 @@
1
  import time
2
  import os
3
  import torch
 
4
 
5
  from dartrs.v2 import (
6
  V2Model,
@@ -33,7 +34,7 @@ from output import UpsamplingOutput
33
 
34
  HF_TOKEN = os.getenv("HF_TOKEN", None)
35
 
36
- ALL_MODELS = {
37
  "dart-v2-moe-sft": {
38
  "repo": "p1atdev/dart-v2-moe-sft",
39
  "type": "sft",
@@ -85,6 +86,67 @@ def generate_tags(
85
  return output
86
 
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  class V2UI:
89
  model_name: str | None = None
90
  model: V2Model
@@ -104,11 +166,10 @@ class V2UI:
104
  length_tag: LengthTag,
105
  identity_tag: IdentityTag,
106
  ban_tags: str,
107
- tag_type: str,
108
  *args,
109
  ) -> UpsamplingOutput:
110
  if self.model_name is None or self.model_name != model_name:
111
- models = prepare_models(ALL_MODELS[model_name])
112
  self.model = models["model"]
113
  self.tokenizer = models["tokenizer"]
114
  self.model_name = model_name
@@ -149,5 +210,5 @@ class V2UI:
149
  length_tag=length_tag,
150
  identity_tag=identity_tag,
151
  elapsed_time=elapsed_time,
152
- tag_type=tag_type,
153
  )
 
 
1
  import time
2
  import os
3
  import torch
4
+ from typing import Callable
5
 
6
  from dartrs.v2 import (
7
  V2Model,
 
34
 
35
  HF_TOKEN = os.getenv("HF_TOKEN", None)
36
 
37
+ V2_ALL_MODELS = {
38
  "dart-v2-moe-sft": {
39
  "repo": "p1atdev/dart-v2-moe-sft",
40
  "type": "sft",
 
86
  return output
87
 
88
 
89
+ def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
90
+ return (
91
+ [f"1{noun}"]
92
+ + [f"{num}{noun}s" for num in range(minimum + 1, maximum + 1)]
93
+ + [f"{maximum+1}+{noun}s"]
94
+ )
95
+
96
+
97
+ PEOPLE_TAGS = (
98
+ _people_tag("girl") + _people_tag("boy") + _people_tag("other") + ["no humans"]
99
+ )
100
+
101
+
102
+ def gen_prompt_text(output: UpsamplingOutput):
103
+ # separate people tags (e.g. 1girl)
104
+ people_tags = []
105
+ other_general_tags = []
106
+
107
+ for tag in output.general_tags.split(","):
108
+ tag = tag.strip()
109
+ if tag in PEOPLE_TAGS:
110
+ people_tags.append(tag)
111
+ else:
112
+ other_general_tags.append(tag)
113
+
114
+ return ", ".join(
115
+ [
116
+ part.strip()
117
+ for part in [
118
+ *people_tags,
119
+ output.character_tags,
120
+ output.copyright_tags,
121
+ *other_general_tags,
122
+ output.upsampled_tags,
123
+ output.rating_tag,
124
+ ]
125
+ if part.strip() != ""
126
+ ]
127
+ )
128
+
129
+
130
+ def elapsed_time_format(elapsed_time: float) -> str:
131
+ return f"Elapsed: {elapsed_time:.2f} seconds"
132
+
133
+
134
+ def parse_upsampling_output(
135
+ upsampler: Callable[..., UpsamplingOutput],
136
+ ):
137
+ def _parse_upsampling_output(*args) -> tuple[str, str, dict]:
138
+ output = upsampler(*args)
139
+
140
+ return (
141
+ gen_prompt_text(output),
142
+ elapsed_time_format(output.elapsed_time),
143
+ gr.update(interactive=True),
144
+ gr.update(interactive=True),
145
+ )
146
+
147
+ return _parse_upsampling_output
148
+
149
+
150
  class V2UI:
151
  model_name: str | None = None
152
  model: V2Model
 
166
  length_tag: LengthTag,
167
  identity_tag: IdentityTag,
168
  ban_tags: str,
 
169
  *args,
170
  ) -> UpsamplingOutput:
171
  if self.model_name is None or self.model_name != model_name:
172
+ models = prepare_models(V2_ALL_MODELS[model_name])
173
  self.model = models["model"]
174
  self.tokenizer = models["tokenizer"]
175
  self.model_name = model_name
 
210
  length_tag=length_tag,
211
  identity_tag=identity_tag,
212
  elapsed_time=elapsed_time,
 
213
  )
214
+