Spaces:
Running
on
Zero
Running
on
Zero
Upload 5 files
Browse files
app.py
CHANGED
@@ -323,33 +323,32 @@ logger.setLevel(logging.DEBUG)
|
|
323 |
|
324 |
from v2 import (
|
325 |
V2UI,
|
326 |
-
|
|
|
327 |
)
|
328 |
from utils import (
|
329 |
gradio_copy_text,
|
330 |
COPY_ACTION_JS,
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
)
|
336 |
from tagger import (
|
337 |
predict_tags,
|
338 |
-
parse_upsampling_output,
|
339 |
convert_danbooru_to_e621_prompt,
|
|
|
340 |
insert_recom_prompt,
|
|
|
341 |
translate_prompt,
|
342 |
)
|
343 |
def description_ui():
|
344 |
gr.Markdown(
|
345 |
"""
|
346 |
## Danbooru Tags Transformer V2 Demo with WD Tagger
|
347 |
-
|
348 |
-
|
349 |
-
Models:
|
350 |
-
- [p1atdev/wd-swinv2-tagger-v3-hf](https://huggingface.co/p1atdev/wd-swinv2-tagger-v3-hf)
|
351 |
-
- [dart-v2-moe-sft](https://huggingface.co/p1atdev/dart-v2-moe-sft) (Mixtral architecture)
|
352 |
-
- [dart-v2-sft](https://huggingface.co/p1atdev/dart-v2-sft) (Mistral architecture)
|
353 |
"""
|
354 |
)
|
355 |
|
@@ -704,11 +703,11 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
|
|
704 |
character_dbt = gr.Textbox(lines=1, placeholder="kafuu chino, ...", label="Character names", scale=2)
|
705 |
series_dbt = gr.Textbox(lines=1, placeholder="Is the order a rabbit?, ...", label="Series names", scale=2)
|
706 |
generate_db_random_button = gr.Button(value="Generate random prompt from character", size="sm", variant="secondary")
|
707 |
-
model_name_dbt = gr.Dropdown(label="Model", choices=list(
|
708 |
-
rating_dbt = gr.Radio(label="Rating", choices=list(
|
709 |
-
aspect_ratio_dbt = gr.Radio(label="Aspect ratio", choices=list(
|
710 |
-
length_dbt = gr.Radio(label="Length", choices=list(
|
711 |
-
identity_dbt = gr.Radio(label="Keep identity", choices=list(
|
712 |
ban_tags_dbt = gr.Textbox(label="Ban tags", placeholder="alternate costumen, ...", value="futanari, censored, furry, furrification", visible=False)
|
713 |
elapsed_time_dbt = gr.Markdown(label="Elapsed time", value="", visible=False)
|
714 |
copy_button_dbt = gr.Button(value="Copy to clipboard", visible=False)
|
@@ -737,8 +736,8 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
|
|
737 |
)
|
738 |
|
739 |
with gr.Accordion("Generation settings", open=False, visible=True):
|
740 |
-
steps_gui = gr.Slider(minimum=1, maximum=100, step=1, value=
|
741 |
-
cfg_gui = gr.Slider(minimum=0, maximum=30, step=0.5, value=7.
|
742 |
sampler_gui = gr.Dropdown(label="Sampler", choices=scheduler_names, value="Euler a")
|
743 |
img_width_gui = gr.Slider(minimum=64, maximum=4096, step=8, value=1024, label="Img Width")
|
744 |
img_height_gui = gr.Slider(minimum=64, maximum=4096, step=8, value=1024, label="Img Height")
|
@@ -963,7 +962,6 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
|
|
963 |
length_dbt,
|
964 |
identity_dbt,
|
965 |
ban_tags_dbt,
|
966 |
-
prompt_type_gui,
|
967 |
]
|
968 |
|
969 |
insert_prompt_gui.change(
|
@@ -983,7 +981,7 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
|
|
983 |
inputs=[
|
984 |
*v2b.input_components,
|
985 |
],
|
986 |
-
outputs=[prompt_gui, elapsed_time_dbt, copy_button_dbt],
|
987 |
)
|
988 |
|
989 |
translate_prompt_button.click(translate_prompt, inputs=[prompt_gui], outputs=[prompt_gui])
|
@@ -1094,102 +1092,41 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
|
|
1094 |
with gr.Group():
|
1095 |
input_image = gr.Image(label="Input image", type="pil", sources=["upload", "clipboard"])
|
1096 |
with gr.Accordion(label="Advanced options", open=False):
|
1097 |
-
general_threshold = gr.Slider(
|
1098 |
-
|
1099 |
-
|
1100 |
-
|
1101 |
-
value=0.3,
|
1102 |
-
step=0.01,
|
1103 |
-
interactive=True,
|
1104 |
-
)
|
1105 |
-
character_threshold = gr.Slider(
|
1106 |
-
label="Character threshold",
|
1107 |
-
minimum=0.0,
|
1108 |
-
maximum=1.0,
|
1109 |
-
value=0.8,
|
1110 |
-
step=0.01,
|
1111 |
-
interactive=True,
|
1112 |
-
)
|
1113 |
-
keep_tags = gr.Radio(
|
1114 |
-
label="Remove tags leaving only the following",
|
1115 |
-
choices=["body", "dress", "all"],
|
1116 |
-
value="body",
|
1117 |
-
)
|
1118 |
-
generate_from_image_btn = gr.Button(value="Generate input tags from image", variant="primary")
|
1119 |
|
1120 |
with gr.Group():
|
1121 |
-
input_character = gr.Textbox(
|
1122 |
-
|
1123 |
-
|
1124 |
-
)
|
1125 |
-
input_copyright = gr.Textbox(
|
1126 |
-
label="Copyright tags",
|
1127 |
-
placeholder="vocaloid",
|
1128 |
-
)
|
1129 |
-
input_general = gr.TextArea(
|
1130 |
-
label="General tags",
|
1131 |
-
lines=4,
|
1132 |
-
placeholder="1girl, solo, ...",
|
1133 |
-
value="",
|
1134 |
-
)
|
1135 |
input_tags_to_copy = gr.Textbox(value="", visible=False)
|
1136 |
-
copy_input_btn = gr.Button(
|
1137 |
-
|
1138 |
-
|
1139 |
-
)
|
1140 |
-
translate_input_prompt_button = gr.Button(value="Translate prompt to English", variant="secondary")
|
1141 |
-
tag_type = gr.Radio(
|
1142 |
-
label="Output tag conversion",
|
1143 |
-
info="danbooru for Animagine, e621 for Pony.",
|
1144 |
-
choices=["danbooru", "e621"],
|
1145 |
-
value="danbooru",
|
1146 |
-
)
|
1147 |
-
input_rating = gr.Radio(
|
1148 |
-
label="Rating",
|
1149 |
-
choices=list(RATING_OPTIONS),
|
1150 |
-
value="explicit",
|
1151 |
-
)
|
1152 |
with gr.Accordion(label="Advanced options", open=False):
|
1153 |
-
input_aspect_ratio = gr.Radio(
|
1154 |
-
|
1155 |
-
|
1156 |
-
|
1157 |
-
|
1158 |
-
)
|
1159 |
-
|
1160 |
-
|
1161 |
-
|
1162 |
-
|
1163 |
-
value="very_long",
|
1164 |
-
)
|
1165 |
-
input_identity = gr.Radio(
|
1166 |
-
label="Keep identity",
|
1167 |
-
info="How strictly to keep the identity of the character or subject. If you specify the detail of subject in the prompt, you should choose `strict`. Otherwise, choose `none` or `lax`. `none` is very creative but sometimes ignores the input prompt.",
|
1168 |
-
choices=list(IDENTITY_OPTIONS),
|
1169 |
-
value="lax",
|
1170 |
-
)
|
1171 |
-
input_ban_tags = gr.Textbox(
|
1172 |
-
label="Ban tags",
|
1173 |
-
info="Tags to ban from the output.",
|
1174 |
-
placeholder="alternate costumen, ...",
|
1175 |
-
value="futanari, censored, furry, furrification"
|
1176 |
-
)
|
1177 |
-
model_name = gr.Dropdown(
|
1178 |
-
label="Model",
|
1179 |
-
choices=list(ALL_MODELS.keys()),
|
1180 |
-
value=list(ALL_MODELS.keys())[0],
|
1181 |
-
)
|
1182 |
-
|
1183 |
-
generate_btn = gr.Button(value="GENERATE TAGS", variant="primary")
|
1184 |
|
1185 |
with gr.Group():
|
1186 |
-
output_text = gr.TextArea(label="Output tags", interactive=False)
|
1187 |
-
copy_btn = gr.Button(
|
1188 |
-
|
1189 |
-
|
1190 |
-
|
1191 |
-
|
1192 |
-
|
|
|
1193 |
description_ui()
|
1194 |
|
1195 |
v2.input_components = [
|
@@ -1202,7 +1139,6 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
|
|
1202 |
input_length,
|
1203 |
input_identity,
|
1204 |
input_ban_tags,
|
1205 |
-
tag_type,
|
1206 |
]
|
1207 |
|
1208 |
translate_input_prompt_button.click(translate_prompt, inputs=[input_general], outputs=[input_general])
|
@@ -1211,27 +1147,33 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
|
|
1211 |
|
1212 |
generate_from_image_btn.click(
|
1213 |
predict_tags,
|
1214 |
-
inputs=[input_image, general_threshold, character_threshold
|
1215 |
outputs=[
|
1216 |
input_copyright,
|
1217 |
input_character,
|
1218 |
input_general,
|
1219 |
-
input_tags_to_copy,
|
1220 |
copy_input_btn,
|
1221 |
],
|
|
|
|
|
|
|
1222 |
)
|
1223 |
|
1224 |
-
copy_input_btn.click(gradio_copy_text, inputs=[input_tags_to_copy], js=COPY_ACTION_JS)
|
1225 |
-
|
1226 |
generate_btn.click(
|
1227 |
parse_upsampling_output(v2.on_generate),
|
1228 |
inputs=[
|
1229 |
*v2.input_components,
|
1230 |
],
|
1231 |
-
outputs=[output_text, elapsed_time_md, copy_btn],
|
|
|
|
|
|
|
|
|
|
|
|
|
1232 |
)
|
1233 |
-
|
1234 |
copy_btn.click(gradio_copy_text, inputs=[output_text], js=COPY_ACTION_JS)
|
|
|
1235 |
|
1236 |
with gr.Accordion("Examples", open=True, visible=True):
|
1237 |
gr.Examples(
|
|
|
323 |
|
324 |
from v2 import (
|
325 |
V2UI,
|
326 |
+
parse_upsampling_output,
|
327 |
+
V2_ALL_MODELS,
|
328 |
)
|
329 |
from utils import (
|
330 |
gradio_copy_text,
|
331 |
COPY_ACTION_JS,
|
332 |
+
V2_ASPECT_RATIO_OPTIONS,
|
333 |
+
V2_RATING_OPTIONS,
|
334 |
+
V2_LENGTH_OPTIONS,
|
335 |
+
V2_IDENTITY_OPTIONS
|
336 |
)
|
337 |
from tagger import (
|
338 |
predict_tags,
|
|
|
339 |
convert_danbooru_to_e621_prompt,
|
340 |
+
remove_specific_prompt,
|
341 |
insert_recom_prompt,
|
342 |
+
compose_prompt_to_copy,
|
343 |
translate_prompt,
|
344 |
)
|
345 |
def description_ui():
|
346 |
gr.Markdown(
|
347 |
"""
|
348 |
## Danbooru Tags Transformer V2 Demo with WD Tagger
|
349 |
+
Image => Prompt => Upsampled longer prompt
|
350 |
+
- Mod of p1atdev's [Danbooru Tags Transformer V2 Demo](https://huggingface.co/spaces/p1atdev/danbooru-tags-transformer-v2) and [WD Tagger with 🤗 transformers](https://huggingface.co/spaces/p1atdev/wd-tagger-transformers).
|
351 |
+
- Models: p1atdev's [wd-swinv2-tagger-v3-hf](https://huggingface.co/p1atdev/wd-swinv2-tagger-v3-hf), [dart-v2-moe-sft](https://huggingface.co/p1atdev/dart-v2-moe-sft), [dart-v2-sft](https://huggingface.co/p1atdev/dart-v2-sft)
|
|
|
|
|
|
|
352 |
"""
|
353 |
)
|
354 |
|
|
|
703 |
character_dbt = gr.Textbox(lines=1, placeholder="kafuu chino, ...", label="Character names", scale=2)
|
704 |
series_dbt = gr.Textbox(lines=1, placeholder="Is the order a rabbit?, ...", label="Series names", scale=2)
|
705 |
generate_db_random_button = gr.Button(value="Generate random prompt from character", size="sm", variant="secondary")
|
706 |
+
model_name_dbt = gr.Dropdown(label="Model", choices=list(V2_ALL_MODELS.keys()), value=list(V2_ALL_MODELS.keys())[0], visible=False)
|
707 |
+
rating_dbt = gr.Radio(label="Rating", choices=list(V2_RATING_OPTIONS), value="explicit", visible=False)
|
708 |
+
aspect_ratio_dbt = gr.Radio(label="Aspect ratio", choices=list(V2_ASPECT_RATIO_OPTIONS), value="square", visible=False)
|
709 |
+
length_dbt = gr.Radio(label="Length", choices=list(V2_LENGTH_OPTIONS), value="very_long", visible=False)
|
710 |
+
identity_dbt = gr.Radio(label="Keep identity", choices=list(V2_IDENTITY_OPTIONS), value="lax", visible=False)
|
711 |
ban_tags_dbt = gr.Textbox(label="Ban tags", placeholder="alternate costumen, ...", value="futanari, censored, furry, furrification", visible=False)
|
712 |
elapsed_time_dbt = gr.Markdown(label="Elapsed time", value="", visible=False)
|
713 |
copy_button_dbt = gr.Button(value="Copy to clipboard", visible=False)
|
|
|
736 |
)
|
737 |
|
738 |
with gr.Accordion("Generation settings", open=False, visible=True):
|
739 |
+
steps_gui = gr.Slider(minimum=1, maximum=100, step=1, value=28, label="Steps")
|
740 |
+
cfg_gui = gr.Slider(minimum=0, maximum=30, step=0.5, value=7.0, label="CFG")
|
741 |
sampler_gui = gr.Dropdown(label="Sampler", choices=scheduler_names, value="Euler a")
|
742 |
img_width_gui = gr.Slider(minimum=64, maximum=4096, step=8, value=1024, label="Img Width")
|
743 |
img_height_gui = gr.Slider(minimum=64, maximum=4096, step=8, value=1024, label="Img Height")
|
|
|
962 |
length_dbt,
|
963 |
identity_dbt,
|
964 |
ban_tags_dbt,
|
|
|
965 |
]
|
966 |
|
967 |
insert_prompt_gui.change(
|
|
|
981 |
inputs=[
|
982 |
*v2b.input_components,
|
983 |
],
|
984 |
+
outputs=[prompt_gui, elapsed_time_dbt, copy_button_dbt, copy_button_dbt],
|
985 |
)
|
986 |
|
987 |
translate_prompt_button.click(translate_prompt, inputs=[prompt_gui], outputs=[prompt_gui])
|
|
|
1092 |
with gr.Group():
|
1093 |
input_image = gr.Image(label="Input image", type="pil", sources=["upload", "clipboard"])
|
1094 |
with gr.Accordion(label="Advanced options", open=False):
|
1095 |
+
general_threshold = gr.Slider(label="Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.01, interactive=True)
|
1096 |
+
character_threshold = gr.Slider(label="Character threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.01, interactive=True)
|
1097 |
+
keep_tags = gr.Radio(label="Remove tags leaving only the following", choices=["body", "dress", "all"], value="body")
|
1098 |
+
generate_from_image_btn = gr.Button(value="GENERATE TAGS FROM IMAGE", size="lg", variant="primary")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1099 |
|
1100 |
with gr.Group():
|
1101 |
+
input_character = gr.Textbox(label="Character tags", placeholder="hatsune miku")
|
1102 |
+
input_copyright = gr.Textbox(label="Copyright tags", placeholder="vocaloid")
|
1103 |
+
input_general = gr.TextArea(label="General tags", lines=4, placeholder="1girl, ...", value="")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1104 |
input_tags_to_copy = gr.Textbox(value="", visible=False)
|
1105 |
+
copy_input_btn = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
|
1106 |
+
translate_input_prompt_button = gr.Button(value="Translate prompt to English", size="sm", variant="secondary")
|
1107 |
+
tag_type = gr.Radio(label="Output tag conversion", info="danbooru for Animagine, e621 for Pony.", choices=["danbooru", "e621"], value="e621", visible=False)
|
1108 |
+
input_rating = gr.Radio(label="Rating", choices=list(V2_RATING_OPTIONS), value="explicit")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1109 |
with gr.Accordion(label="Advanced options", open=False):
|
1110 |
+
input_aspect_ratio = gr.Radio(label="Aspect ratio", info="The aspect ratio of the image.", choices=list(V2_ASPECT_RATIO_OPTIONS), value="square")
|
1111 |
+
input_length = gr.Radio(label="Length", info="The total length of the tags.", choices=list(V2_LENGTH_OPTIONS), value="very_long")
|
1112 |
+
input_identity = gr.Radio(label="Keep identity", info="How strictly to keep the identity of the character or subject. If you specify the detail of subject in the prompt, you should choose `strict`. Otherwise, choose `none` or `lax`. `none` is very creative but sometimes ignores the input prompt.", choices=list(V2_IDENTITY_OPTIONS), value="lax")
|
1113 |
+
input_ban_tags = gr.Textbox(label="Ban tags", info="Tags to ban from the output.", placeholder="alternate costumen, ...", value="censored")
|
1114 |
+
model_name = gr.Dropdown(label="Model", choices=list(V2_ALL_MODELS.keys()), value=list(V2_ALL_MODELS.keys())[0])
|
1115 |
+
dummy_np = gr.Textbox(label="Negative prompt", value="", visible=False)
|
1116 |
+
recom_animagine = gr.Textbox(label="Animagine reccomended prompt", value="Animagine", visible=False)
|
1117 |
+
recom_pony = gr.Textbox(label="Pony reccomended prompt", value="Pony", visible=False)
|
1118 |
+
|
1119 |
+
generate_btn = gr.Button(value="GENERATE TAGS", size="lg", variant="primary")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1120 |
|
1121 |
with gr.Group():
|
1122 |
+
output_text = gr.TextArea(label="Output tags", interactive=False, show_copy_button=True)
|
1123 |
+
copy_btn = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
|
1124 |
+
elapsed_time_md = gr.Markdown(label="Elapsed time", value="", visible=False)
|
1125 |
+
|
1126 |
+
with gr.Group():
|
1127 |
+
output_text_pony = gr.TextArea(label="Output tags (Pony e621 style)", interactive=False, show_copy_button=True)
|
1128 |
+
copy_btn_pony = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
|
1129 |
+
|
1130 |
description_ui()
|
1131 |
|
1132 |
v2.input_components = [
|
|
|
1139 |
input_length,
|
1140 |
input_identity,
|
1141 |
input_ban_tags,
|
|
|
1142 |
]
|
1143 |
|
1144 |
translate_input_prompt_button.click(translate_prompt, inputs=[input_general], outputs=[input_general])
|
|
|
1147 |
|
1148 |
generate_from_image_btn.click(
|
1149 |
predict_tags,
|
1150 |
+
inputs=[input_image, general_threshold, character_threshold],
|
1151 |
outputs=[
|
1152 |
input_copyright,
|
1153 |
input_character,
|
1154 |
input_general,
|
|
|
1155 |
copy_input_btn,
|
1156 |
],
|
1157 |
+
).then(remove_specific_prompt, inputs=[input_general, keep_tags], outputs=[input_general])
|
1158 |
+
copy_input_btn.click(compose_prompt_to_copy, inputs=[input_character, input_copyright, input_general], outputs=[input_tags_to_copy]).then(
|
1159 |
+
gradio_copy_text, inputs=[input_tags_to_copy], js=COPY_ACTION_JS,
|
1160 |
)
|
1161 |
|
|
|
|
|
1162 |
generate_btn.click(
|
1163 |
parse_upsampling_output(v2.on_generate),
|
1164 |
inputs=[
|
1165 |
*v2.input_components,
|
1166 |
],
|
1167 |
+
outputs=[output_text, elapsed_time_md, copy_btn, copy_btn_pony],
|
1168 |
+
).then(
|
1169 |
+
convert_danbooru_to_e621_prompt, inputs=[output_text, tag_type], outputs=[output_text_pony],
|
1170 |
+
).then(
|
1171 |
+
insert_recom_prompt, inputs=[output_text, dummy_np, recom_animagine], outputs=[output_text, dummy_np],
|
1172 |
+
).then(
|
1173 |
+
insert_recom_prompt, inputs=[output_text_pony, dummy_np, recom_pony], outputs=[output_text_pony, dummy_np],
|
1174 |
)
|
|
|
1175 |
copy_btn.click(gradio_copy_text, inputs=[output_text], js=COPY_ACTION_JS)
|
1176 |
+
copy_btn_pony.click(gradio_copy_text, inputs=[output_text_pony], js=COPY_ACTION_JS)
|
1177 |
|
1178 |
with gr.Accordion("Examples", open=True, visible=True):
|
1179 |
gr.Examples(
|
output.py
CHANGED
@@ -12,6 +12,5 @@ class UpsamplingOutput:
|
|
12 |
aspect_ratio_tag: str
|
13 |
length_tag: str
|
14 |
identity_tag: str
|
15 |
-
tag_type: str
|
16 |
|
17 |
elapsed_time: float = 0.0
|
|
|
12 |
aspect_ratio_tag: str
|
13 |
length_tag: str
|
14 |
identity_tag: str
|
|
|
15 |
|
16 |
elapsed_time: float = 0.0
|
tagger.py
CHANGED
@@ -1,6 +1,4 @@
|
|
1 |
from PIL import Image
|
2 |
-
from typing import Callable
|
3 |
-
|
4 |
import torch
|
5 |
|
6 |
from transformers import (
|
@@ -11,15 +9,10 @@ from transformers import (
|
|
11 |
import gradio as gr
|
12 |
import spaces # ZERO GPU
|
13 |
|
14 |
-
from output import UpsamplingOutput
|
15 |
-
|
16 |
-
|
17 |
MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
|
18 |
MODEL_NAME = MODEL_NAMES[0]
|
19 |
|
20 |
-
model = AutoModelForImageClassification.from_pretrained(
|
21 |
-
MODEL_NAME,
|
22 |
-
)
|
23 |
model.to("cuda" if torch.cuda.is_available() else "cpu")
|
24 |
processor = AutoImageProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
25 |
|
@@ -43,22 +36,6 @@ RATING_MAP = {
|
|
43 |
"questionable": "nsfw",
|
44 |
"explicit": "explicit, nsfw",
|
45 |
}
|
46 |
-
NORMALIZE_RATING_TAG = {
|
47 |
-
"sfw": "",
|
48 |
-
"general": "",
|
49 |
-
"sensitive": "sensitive",
|
50 |
-
"nsfw": "nsfw",
|
51 |
-
"questionable": "nsfw",
|
52 |
-
"explicit": "nsfw, explicit",
|
53 |
-
}
|
54 |
-
NORMALIZE_RATING_TAG_E621 = {
|
55 |
-
"sfw": "rating_safe",
|
56 |
-
"general": "rating_safe",
|
57 |
-
"sensitive": "sensitive",
|
58 |
-
"nsfw": "nsfw, rating_explicit",
|
59 |
-
"questionable": "rating_questionable",
|
60 |
-
"explicit": "rating_explicit",
|
61 |
-
}
|
62 |
DANBOORU_TO_E621_RATING_MAP = {
|
63 |
"safe": "rating_safe",
|
64 |
"sensitive": "rating_safe",
|
@@ -160,6 +137,7 @@ def convert_danbooru_to_e621_prompt(input_prompt: str = "", prompt_type: str = "
|
|
160 |
|
161 |
rating_tags = sorted(set(rating_tags), key=rating_tags.index)
|
162 |
rating_tags = [rating_tags[0]] if rating_tags else []
|
|
|
163 |
|
164 |
output_prompt = ", ".join(people_tags + other_tags + rating_tags)
|
165 |
|
@@ -246,7 +224,7 @@ def get_tag_group_dict():
|
|
246 |
return tag_group_dict
|
247 |
|
248 |
|
249 |
-
def
|
250 |
def is_dressed(tag):
|
251 |
import re
|
252 |
p = re.compile(r'dress|cloth|uniform|costume|vest|sweater|coat|shirt|jacket|blazer|apron|leotard|hood|sleeve|skirt|shorts|pant|loafer|ribbon|necktie|bow|collar|glove|sock|shoe|boots|wear|emblem')
|
@@ -257,27 +235,48 @@ def is_necessary(tag, keep_tags, group_dict):
|
|
257 |
p = re.compile(r'background|outline|light|sky|build|day|screen|tree|city')
|
258 |
return p.search(tag)
|
259 |
|
|
|
260 |
group_list = ['people', 'age', 'pattern', 'place', 'hair', 'modifier', 'screen', 'animal', 'effect', 'situation', 'status', 'lighting', 'accesory', 'body', 'nsfw', 'camera', 'option', 'taste', 'other', 'detail', 'action', 'dress', 'character', 'face', 'costume', 'attribute', 'weather', 'temporary', 'gender', 'favorite', 'food', 'object', 'quality', 'expression', 'life', 'background']
|
261 |
keep_group_dict = {
|
262 |
"body": ['people', 'age', 'hair', 'body', 'character', 'face', 'gender'],
|
263 |
"dress": ['people', 'age', 'hair', 'accesory', 'body', 'dress', 'character', 'face', 'costume', 'gender'],
|
264 |
"all": ['people', 'age', 'pattern', 'place', 'hair', 'modifier', 'screen', 'animal', 'effect', 'situation', 'status', 'lighting', 'accesory', 'body', 'nsfw', 'camera', 'option', 'taste', 'other', 'detail', 'action', 'dress', 'character', 'face', 'costume', 'attribute', 'weather', 'temporary', 'gender', 'favorite', 'food', 'object', 'quality', 'expression', 'life', 'background']
|
265 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
keep_group = keep_group_dict.get(keep_tags, ['people', 'age', 'hair', 'body', 'character', 'face', 'gender'])
|
267 |
explicit_group = list(set(group_list) ^ set(keep_group))
|
268 |
-
if group_dict.get(tag.strip().replace("_", " "), "") in explicit_group:
|
269 |
-
return False
|
270 |
-
elif keep_tags == "body" and is_dressed(tag):
|
271 |
-
return False
|
272 |
-
elif is_background(tag):
|
273 |
-
return False
|
274 |
-
else:
|
275 |
-
return True
|
276 |
|
|
|
|
|
|
|
277 |
|
278 |
-
|
279 |
-
|
280 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
results = {
|
282 |
k: v for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)
|
283 |
}
|
@@ -302,16 +301,15 @@ def postprocess_results(
|
|
302 |
return rating, character, general
|
303 |
|
304 |
|
305 |
-
def gen_prompt(rating: list[str], character: list[str], general: list[str]
|
306 |
people_tags: list[str] = []
|
307 |
other_tags: list[str] = []
|
308 |
rating_tag = RATING_MAP[rating[0]]
|
309 |
|
310 |
-
group_dict = get_tag_group_dict()
|
311 |
for tag in general:
|
312 |
if tag in PEOPLE_TAGS:
|
313 |
people_tags.append(tag)
|
314 |
-
|
315 |
other_tags.append(tag)
|
316 |
|
317 |
all_tags = people_tags + other_tags
|
@@ -320,9 +318,7 @@ def gen_prompt(rating: list[str], character: list[str], general: list[str], keep
|
|
320 |
|
321 |
|
322 |
@spaces.GPU()
|
323 |
-
def predict_tags(
|
324 |
-
image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8, keep_tags = "all",
|
325 |
-
):
|
326 |
inputs = processor.preprocess(image, return_tensors="pt")
|
327 |
|
328 |
outputs = model(**inputs.to(model.device, model.dtype))
|
@@ -339,7 +335,7 @@ def predict_tags(
|
|
339 |
)
|
340 |
|
341 |
prompt = gen_prompt(
|
342 |
-
list(rating.keys()), list(character.keys()), list(general.keys())
|
343 |
)
|
344 |
|
345 |
output_series_tag = ""
|
@@ -349,65 +345,13 @@ def predict_tags(
|
|
349 |
else:
|
350 |
output_series_tag = ""
|
351 |
|
352 |
-
|
353 |
-
cprompt = cprompt + ", " + output_series_tag if output_series_tag else cprompt
|
354 |
-
cprompt = cprompt + ", " + prompt if prompt else cprompt
|
355 |
-
|
356 |
-
return output_series_tag, ", ".join(character.keys()), prompt, cprompt, gr.update(interactive=True),
|
357 |
|
358 |
|
359 |
-
def
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
tag = tag.strip()
|
367 |
-
|
368 |
-
if tag in PEOPLE_TAGS:
|
369 |
-
if output.tag_type == "e621":
|
370 |
-
tag = danbooru_to_e621(tag, e621_dict)
|
371 |
-
people_tags.append(tag)
|
372 |
-
else:
|
373 |
-
if output.tag_type == "e621":
|
374 |
-
tag = danbooru_to_e621(tag, e621_dict)
|
375 |
-
other_general_tags.append(tag)
|
376 |
-
|
377 |
-
return ", ".join(
|
378 |
-
[
|
379 |
-
part.strip()
|
380 |
-
for part in [
|
381 |
-
*people_tags,
|
382 |
-
output.character_tags,
|
383 |
-
output.copyright_tags,
|
384 |
-
*other_general_tags,
|
385 |
-
output.upsampled_tags,
|
386 |
-
NORMALIZE_RATING_TAG_E621[output.rating_tag] if output.tag_type == "e621" else NORMALIZE_RATING_TAG[output.rating_tag],
|
387 |
-
]
|
388 |
-
if part.strip() != ""
|
389 |
-
]
|
390 |
-
)
|
391 |
-
|
392 |
-
|
393 |
-
def elapsed_time_format(elapsed_time: float) -> str:
|
394 |
-
return f"Elapsed: {elapsed_time:.2f} seconds"
|
395 |
-
|
396 |
-
|
397 |
-
def parse_upsampling_output(
|
398 |
-
upsampler: Callable[..., UpsamplingOutput],
|
399 |
-
):
|
400 |
-
def _parse_upsampling_output(*args) -> tuple[str, str, dict]:
|
401 |
-
output = upsampler(*args)
|
402 |
-
|
403 |
-
print(output)
|
404 |
-
|
405 |
-
return (
|
406 |
-
gen_prompt_text(output),
|
407 |
-
elapsed_time_format(output.elapsed_time),
|
408 |
-
gr.update(
|
409 |
-
interactive=True,
|
410 |
-
),
|
411 |
-
)
|
412 |
-
|
413 |
-
return _parse_upsampling_output
|
|
|
1 |
from PIL import Image
|
|
|
|
|
2 |
import torch
|
3 |
|
4 |
from transformers import (
|
|
|
9 |
import gradio as gr
|
10 |
import spaces # ZERO GPU
|
11 |
|
|
|
|
|
|
|
12 |
MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
|
13 |
MODEL_NAME = MODEL_NAMES[0]
|
14 |
|
15 |
+
model = AutoModelForImageClassification.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
|
|
|
|
16 |
model.to("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
processor = AutoImageProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
18 |
|
|
|
36 |
"questionable": "nsfw",
|
37 |
"explicit": "explicit, nsfw",
|
38 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
DANBOORU_TO_E621_RATING_MAP = {
|
40 |
"safe": "rating_safe",
|
41 |
"sensitive": "rating_safe",
|
|
|
137 |
|
138 |
rating_tags = sorted(set(rating_tags), key=rating_tags.index)
|
139 |
rating_tags = [rating_tags[0]] if rating_tags else []
|
140 |
+
rating_tags = ["explicit, nsfw"] if rating_tags[0] == "explicit" else rating_tags
|
141 |
|
142 |
output_prompt = ", ".join(people_tags + other_tags + rating_tags)
|
143 |
|
|
|
224 |
return tag_group_dict
|
225 |
|
226 |
|
227 |
+
def remove_specific_prompt(input_prompt: str = "", keep_tags: str = "all"):
|
228 |
def is_dressed(tag):
|
229 |
import re
|
230 |
p = re.compile(r'dress|cloth|uniform|costume|vest|sweater|coat|shirt|jacket|blazer|apron|leotard|hood|sleeve|skirt|shorts|pant|loafer|ribbon|necktie|bow|collar|glove|sock|shoe|boots|wear|emblem')
|
|
|
235 |
p = re.compile(r'background|outline|light|sky|build|day|screen|tree|city')
|
236 |
return p.search(tag)
|
237 |
|
238 |
+
un_tags = ['solo']
|
239 |
group_list = ['people', 'age', 'pattern', 'place', 'hair', 'modifier', 'screen', 'animal', 'effect', 'situation', 'status', 'lighting', 'accesory', 'body', 'nsfw', 'camera', 'option', 'taste', 'other', 'detail', 'action', 'dress', 'character', 'face', 'costume', 'attribute', 'weather', 'temporary', 'gender', 'favorite', 'food', 'object', 'quality', 'expression', 'life', 'background']
|
240 |
keep_group_dict = {
|
241 |
"body": ['people', 'age', 'hair', 'body', 'character', 'face', 'gender'],
|
242 |
"dress": ['people', 'age', 'hair', 'accesory', 'body', 'dress', 'character', 'face', 'costume', 'gender'],
|
243 |
"all": ['people', 'age', 'pattern', 'place', 'hair', 'modifier', 'screen', 'animal', 'effect', 'situation', 'status', 'lighting', 'accesory', 'body', 'nsfw', 'camera', 'option', 'taste', 'other', 'detail', 'action', 'dress', 'character', 'face', 'costume', 'attribute', 'weather', 'temporary', 'gender', 'favorite', 'food', 'object', 'quality', 'expression', 'life', 'background']
|
244 |
}
|
245 |
+
|
246 |
+
def is_necessary(tag, keep_tags, group_dict):
|
247 |
+
if keep_tags == "all":
|
248 |
+
return True
|
249 |
+
elif tag in un_tags or group_dict.get(tag, "") in explicit_group:
|
250 |
+
return False
|
251 |
+
elif keep_tags == "body" and is_dressed(tag):
|
252 |
+
return False
|
253 |
+
elif is_background(tag):
|
254 |
+
return False
|
255 |
+
else:
|
256 |
+
return True
|
257 |
+
|
258 |
+
if keep_tags == "all": return input_prompt
|
259 |
keep_group = keep_group_dict.get(keep_tags, ['people', 'age', 'hair', 'body', 'character', 'face', 'gender'])
|
260 |
explicit_group = list(set(group_list) ^ set(keep_group))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
|
262 |
+
tags = input_prompt.split(",") if input_prompt else []
|
263 |
+
people_tags: list[str] = []
|
264 |
+
other_tags: list[str] = []
|
265 |
|
266 |
+
group_dict = get_tag_group_dict()
|
267 |
+
for tag in tags:
|
268 |
+
tag = tag.strip().replace("_", " ")
|
269 |
+
if tag in PEOPLE_TAGS:
|
270 |
+
people_tags.append(tag)
|
271 |
+
elif is_necessary(tag, keep_tags, group_dict):
|
272 |
+
other_tags.append(tag)
|
273 |
+
|
274 |
+
output_prompt = ", ".join(people_tags + other_tags)
|
275 |
+
|
276 |
+
return output_prompt
|
277 |
+
|
278 |
+
|
279 |
+
def postprocess_results(results: dict[str, float], general_threshold: float, character_threshold: float):
|
280 |
results = {
|
281 |
k: v for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)
|
282 |
}
|
|
|
301 |
return rating, character, general
|
302 |
|
303 |
|
304 |
+
def gen_prompt(rating: list[str], character: list[str], general: list[str]):
|
305 |
people_tags: list[str] = []
|
306 |
other_tags: list[str] = []
|
307 |
rating_tag = RATING_MAP[rating[0]]
|
308 |
|
|
|
309 |
for tag in general:
|
310 |
if tag in PEOPLE_TAGS:
|
311 |
people_tags.append(tag)
|
312 |
+
else:
|
313 |
other_tags.append(tag)
|
314 |
|
315 |
all_tags = people_tags + other_tags
|
|
|
318 |
|
319 |
|
320 |
@spaces.GPU()
|
321 |
+
def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8):
|
|
|
|
|
322 |
inputs = processor.preprocess(image, return_tensors="pt")
|
323 |
|
324 |
outputs = model(**inputs.to(model.device, model.dtype))
|
|
|
335 |
)
|
336 |
|
337 |
prompt = gen_prompt(
|
338 |
+
list(rating.keys()), list(character.keys()), list(general.keys())
|
339 |
)
|
340 |
|
341 |
output_series_tag = ""
|
|
|
345 |
else:
|
346 |
output_series_tag = ""
|
347 |
|
348 |
+
return output_series_tag, ", ".join(character.keys()), prompt, gr.update(interactive=True),
|
|
|
|
|
|
|
|
|
349 |
|
350 |
|
351 |
+
def compose_prompt_to_copy(character: str, series: str, general: str):
|
352 |
+
characters = character.split(",") if character else []
|
353 |
+
serieses = series.split(",") if series else []
|
354 |
+
generals = general.split(",") if general else []
|
355 |
+
tags = characters + serieses + generals
|
356 |
+
cprompt = ",".join(tags) if tags else ""
|
357 |
+
return cprompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils.py
CHANGED
@@ -1,38 +1,15 @@
|
|
1 |
import gradio as gr
|
2 |
from dartrs.v2 import AspectRatioTag, LengthTag, RatingTag, IdentityTag
|
3 |
|
4 |
-
# from https://huggingface.co/spaces/cagliostrolab/animagine-xl-3.1/blob/main/config.py
|
5 |
-
QUALITY_TAGS = {
|
6 |
-
"default": "(masterpiece), best quality, very aesthetic, perfect face",
|
7 |
-
}
|
8 |
-
NEGATIVE_PROMPT = {
|
9 |
-
"default": "nsfw, (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn",
|
10 |
-
}
|
11 |
|
12 |
-
|
13 |
-
IMAGE_SIZE_OPTIONS = {
|
14 |
-
"1536x640": "<|aspect_ratio:ultra_wide|>",
|
15 |
-
"1344x768": "<|aspect_ratio:wide|>",
|
16 |
-
"1024x1024": "<|aspect_ratio:square|>",
|
17 |
-
"768x1344": "<|aspect_ratio:tall|>",
|
18 |
-
"640x1536": "<|aspect_ratio:ultra_tall|>",
|
19 |
-
}
|
20 |
-
IMAGE_SIZES = {
|
21 |
-
"1536x640": (1536, 640),
|
22 |
-
"1344x768": (1344, 768),
|
23 |
-
"1024x1024": (1024, 1024),
|
24 |
-
"768x1344": (768, 1344),
|
25 |
-
"640x1536": (640, 1536),
|
26 |
-
}
|
27 |
-
|
28 |
-
ASPECT_RATIO_OPTIONS: list[AspectRatioTag] = [
|
29 |
"ultra_wide",
|
30 |
"wide",
|
31 |
"square",
|
32 |
"tall",
|
33 |
"ultra_tall",
|
34 |
]
|
35 |
-
|
36 |
"sfw",
|
37 |
"general",
|
38 |
"sensitive",
|
@@ -40,28 +17,20 @@ RATING_OPTIONS: list[RatingTag] = [
|
|
40 |
"questionable",
|
41 |
"explicit",
|
42 |
]
|
43 |
-
|
44 |
"very_short",
|
45 |
"short",
|
46 |
"medium",
|
47 |
"long",
|
48 |
"very_long",
|
49 |
]
|
50 |
-
|
51 |
"none",
|
52 |
"lax",
|
53 |
"strict",
|
54 |
]
|
55 |
|
56 |
|
57 |
-
PEOPLE_TAGS = [
|
58 |
-
*[f"1{x}" for x in ["girl", "boy", "other"]],
|
59 |
-
*[f"{i}girls" for i in range(2, 6)],
|
60 |
-
*[f"6+{x}s" for x in ["girl", "boy", "other"]],
|
61 |
-
"no humans",
|
62 |
-
]
|
63 |
-
|
64 |
-
|
65 |
# ref: https://qiita.com/tregu148/items/fccccbbc47d966dd2fc2
|
66 |
def gradio_copy_text(_text: None):
|
67 |
gr.Info("Copied!")
|
|
|
1 |
import gradio as gr
|
2 |
from dartrs.v2 import AspectRatioTag, LengthTag, RatingTag, IdentityTag
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
+
V2_ASPECT_RATIO_OPTIONS: list[AspectRatioTag] = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
"ultra_wide",
|
7 |
"wide",
|
8 |
"square",
|
9 |
"tall",
|
10 |
"ultra_tall",
|
11 |
]
|
12 |
+
V2_RATING_OPTIONS: list[RatingTag] = [
|
13 |
"sfw",
|
14 |
"general",
|
15 |
"sensitive",
|
|
|
17 |
"questionable",
|
18 |
"explicit",
|
19 |
]
|
20 |
+
V2_LENGTH_OPTIONS: list[LengthTag] = [
|
21 |
"very_short",
|
22 |
"short",
|
23 |
"medium",
|
24 |
"long",
|
25 |
"very_long",
|
26 |
]
|
27 |
+
V2_IDENTITY_OPTIONS: list[IdentityTag] = [
|
28 |
"none",
|
29 |
"lax",
|
30 |
"strict",
|
31 |
]
|
32 |
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
# ref: https://qiita.com/tregu148/items/fccccbbc47d966dd2fc2
|
35 |
def gradio_copy_text(_text: None):
|
36 |
gr.Info("Copied!")
|
v2.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import time
|
2 |
import os
|
3 |
import torch
|
|
|
4 |
|
5 |
from dartrs.v2 import (
|
6 |
V2Model,
|
@@ -33,7 +34,7 @@ from output import UpsamplingOutput
|
|
33 |
|
34 |
HF_TOKEN = os.getenv("HF_TOKEN", None)
|
35 |
|
36 |
-
|
37 |
"dart-v2-moe-sft": {
|
38 |
"repo": "p1atdev/dart-v2-moe-sft",
|
39 |
"type": "sft",
|
@@ -85,6 +86,67 @@ def generate_tags(
|
|
85 |
return output
|
86 |
|
87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
class V2UI:
|
89 |
model_name: str | None = None
|
90 |
model: V2Model
|
@@ -104,11 +166,10 @@ class V2UI:
|
|
104 |
length_tag: LengthTag,
|
105 |
identity_tag: IdentityTag,
|
106 |
ban_tags: str,
|
107 |
-
tag_type: str,
|
108 |
*args,
|
109 |
) -> UpsamplingOutput:
|
110 |
if self.model_name is None or self.model_name != model_name:
|
111 |
-
models = prepare_models(
|
112 |
self.model = models["model"]
|
113 |
self.tokenizer = models["tokenizer"]
|
114 |
self.model_name = model_name
|
@@ -149,5 +210,5 @@ class V2UI:
|
|
149 |
length_tag=length_tag,
|
150 |
identity_tag=identity_tag,
|
151 |
elapsed_time=elapsed_time,
|
152 |
-
tag_type=tag_type,
|
153 |
)
|
|
|
|
1 |
import time
|
2 |
import os
|
3 |
import torch
|
4 |
+
from typing import Callable
|
5 |
|
6 |
from dartrs.v2 import (
|
7 |
V2Model,
|
|
|
34 |
|
35 |
HF_TOKEN = os.getenv("HF_TOKEN", None)
|
36 |
|
37 |
+
V2_ALL_MODELS = {
|
38 |
"dart-v2-moe-sft": {
|
39 |
"repo": "p1atdev/dart-v2-moe-sft",
|
40 |
"type": "sft",
|
|
|
86 |
return output
|
87 |
|
88 |
|
89 |
+
def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
|
90 |
+
return (
|
91 |
+
[f"1{noun}"]
|
92 |
+
+ [f"{num}{noun}s" for num in range(minimum + 1, maximum + 1)]
|
93 |
+
+ [f"{maximum+1}+{noun}s"]
|
94 |
+
)
|
95 |
+
|
96 |
+
|
97 |
+
PEOPLE_TAGS = (
|
98 |
+
_people_tag("girl") + _people_tag("boy") + _people_tag("other") + ["no humans"]
|
99 |
+
)
|
100 |
+
|
101 |
+
|
102 |
+
def gen_prompt_text(output: UpsamplingOutput):
|
103 |
+
# separate people tags (e.g. 1girl)
|
104 |
+
people_tags = []
|
105 |
+
other_general_tags = []
|
106 |
+
|
107 |
+
for tag in output.general_tags.split(","):
|
108 |
+
tag = tag.strip()
|
109 |
+
if tag in PEOPLE_TAGS:
|
110 |
+
people_tags.append(tag)
|
111 |
+
else:
|
112 |
+
other_general_tags.append(tag)
|
113 |
+
|
114 |
+
return ", ".join(
|
115 |
+
[
|
116 |
+
part.strip()
|
117 |
+
for part in [
|
118 |
+
*people_tags,
|
119 |
+
output.character_tags,
|
120 |
+
output.copyright_tags,
|
121 |
+
*other_general_tags,
|
122 |
+
output.upsampled_tags,
|
123 |
+
output.rating_tag,
|
124 |
+
]
|
125 |
+
if part.strip() != ""
|
126 |
+
]
|
127 |
+
)
|
128 |
+
|
129 |
+
|
130 |
+
def elapsed_time_format(elapsed_time: float) -> str:
|
131 |
+
return f"Elapsed: {elapsed_time:.2f} seconds"
|
132 |
+
|
133 |
+
|
134 |
+
def parse_upsampling_output(
|
135 |
+
upsampler: Callable[..., UpsamplingOutput],
|
136 |
+
):
|
137 |
+
def _parse_upsampling_output(*args) -> tuple[str, str, dict]:
|
138 |
+
output = upsampler(*args)
|
139 |
+
|
140 |
+
return (
|
141 |
+
gen_prompt_text(output),
|
142 |
+
elapsed_time_format(output.elapsed_time),
|
143 |
+
gr.update(interactive=True),
|
144 |
+
gr.update(interactive=True),
|
145 |
+
)
|
146 |
+
|
147 |
+
return _parse_upsampling_output
|
148 |
+
|
149 |
+
|
150 |
class V2UI:
|
151 |
model_name: str | None = None
|
152 |
model: V2Model
|
|
|
166 |
length_tag: LengthTag,
|
167 |
identity_tag: IdentityTag,
|
168 |
ban_tags: str,
|
|
|
169 |
*args,
|
170 |
) -> UpsamplingOutput:
|
171 |
if self.model_name is None or self.model_name != model_name:
|
172 |
+
models = prepare_models(V2_ALL_MODELS[model_name])
|
173 |
self.model = models["model"]
|
174 |
self.tokenizer = models["tokenizer"]
|
175 |
self.model_name = model_name
|
|
|
210 |
length_tag=length_tag,
|
211 |
identity_tag=identity_tag,
|
212 |
elapsed_time=elapsed_time,
|
|
|
213 |
)
|
214 |
+
|