Spaces:
Paused
Paused
Upload 12 files
Browse files- README.md +1 -1
- app.py +39 -54
- fl2basepromptgen.py +6 -4
- fl2sd3longcap.py +4 -2
- promptenhancer.py +6 -3
- requirements.txt +1 -1
- tagger.py +62 -19
- utils.py +5 -0
README.md
CHANGED
|
@@ -4,7 +4,7 @@ emoji: ππ¦
|
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: yellow
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 4.
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
|
|
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: yellow
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.39.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
app.py
CHANGED
|
@@ -1,17 +1,9 @@
|
|
| 1 |
-
from PIL import Image
|
| 2 |
import gradio as gr
|
|
|
|
| 3 |
|
| 4 |
-
|
| 5 |
-
from v2 import (
|
| 6 |
-
V2_ALL_MODELS,
|
| 7 |
-
)
|
| 8 |
from utils import (
|
| 9 |
gradio_copy_text,
|
| 10 |
COPY_ACTION_JS,
|
| 11 |
-
V2_ASPECT_RATIO_OPTIONS,
|
| 12 |
-
V2_RATING_OPTIONS,
|
| 13 |
-
V2_LENGTH_OPTIONS,
|
| 14 |
-
V2_IDENTITY_OPTIONS
|
| 15 |
)
|
| 16 |
from tagger import (
|
| 17 |
predict_tags_wd,
|
|
@@ -20,21 +12,22 @@ from tagger import (
|
|
| 20 |
insert_recom_prompt,
|
| 21 |
compose_prompt_to_copy,
|
| 22 |
translate_prompt,
|
|
|
|
| 23 |
)
|
| 24 |
-
from fl2sd3longcap import
|
| 25 |
-
|
| 26 |
-
)
|
| 27 |
-
from fl2basepromptgen import (
|
| 28 |
-
predict_tags_fl2_base_prompt_gen,
|
| 29 |
-
)
|
| 30 |
from promptenhancer import prompt_enhancer
|
| 31 |
|
| 32 |
-
|
| 33 |
def description_ui():
|
| 34 |
gr.Markdown(
|
| 35 |
"""
|
| 36 |
## Prompt Enhancer with WD Tagger & SD3 Long Captioner
|
| 37 |
(Image =>) Prompt => Upsampled longer prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
- It's a mod. Original Spaces: p1atdev's [WD Tagger with π€ transformers](https://huggingface.co/spaces/p1atdev/wd-tagger-transformers),\
|
| 39 |
gokaygokay's [Prompt-Enhancer](https://huggingface.co/spaces/gokaygokay/Prompt-Enhancer) /\
|
| 40 |
[Florence-2-SD3-Captioner](https://huggingface.co/spaces/gokaygokay/Florence-2-SD3-Captioner).
|
|
@@ -46,62 +39,54 @@ def description_ui():
|
|
| 46 |
"""
|
| 47 |
)
|
| 48 |
|
| 49 |
-
|
| 50 |
def main():
|
| 51 |
-
|
| 52 |
with gr.Blocks() as ui:
|
| 53 |
description_ui()
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
with gr.
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
with gr.Group():
|
| 69 |
input_character = gr.Textbox(label="Character tags", placeholder="hatsune miku")
|
| 70 |
input_copyright = gr.Textbox(label="Copyright tags", placeholder="vocaloid")
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
input_identity = gr.Radio(label="Keep identity", info="How strictly to keep the identity of the character or subject. If you specify the detail of subject in the prompt, you should choose `strict`. Otherwise, choose `none` or `lax`. `none` is very creative but sometimes ignores the input prompt.", choices=list(V2_IDENTITY_OPTIONS), value="lax", visible=False)
|
| 83 |
-
input_ban_tags = gr.Textbox(label="Ban tags", info="Tags to ban from the output.", placeholder="alternate costumen, ...", value="censored", visible=False)
|
| 84 |
-
model_name = gr.Dropdown(label="Model", choices=list(V2_ALL_MODELS.keys()), value=list(V2_ALL_MODELS.keys())[0], visible=False)
|
| 85 |
-
dummy_np = gr.Textbox(label="Negative prompt", value="", visible=False)
|
| 86 |
-
recom_animagine = gr.Textbox(label="Animagine reccomended prompt", value="Animagine", visible=False)
|
| 87 |
-
recom_pony = gr.Textbox(label="Pony reccomended prompt", value="Pony", visible=False)
|
| 88 |
-
|
| 89 |
generate_btn = gr.Button(value="GENERATE TAGS", size="lg", variant="primary")
|
| 90 |
-
|
| 91 |
with gr.Group():
|
| 92 |
output_text = gr.TextArea(label="Output tags", interactive=False, show_copy_button=True)
|
| 93 |
copy_btn = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
|
| 94 |
-
elapsed_time_md = gr.Markdown(label="Elapsed time", value="", visible=False)
|
| 95 |
-
|
| 96 |
with gr.Group():
|
| 97 |
output_text_pony = gr.TextArea(label="Output tags (Pony e621 style)", interactive=False, show_copy_button=True)
|
| 98 |
copy_btn_pony = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
translate_input_prompt_button.click(translate_prompt, [input_general], [input_general], queue=False)
|
| 101 |
translate_input_prompt_button.click(translate_prompt, [input_character], [input_character], queue=False)
|
| 102 |
translate_input_prompt_button.click(translate_prompt, [input_copyright], [input_copyright], queue=False)
|
| 103 |
|
| 104 |
generate_from_image_btn.click(
|
|
|
|
|
|
|
| 105 |
predict_tags_wd,
|
| 106 |
[input_image, input_general, image_algorithms, general_threshold, character_threshold],
|
| 107 |
[input_copyright, input_character, input_general, copy_input_btn],
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import spaces
|
| 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from utils import (
|
| 5 |
gradio_copy_text,
|
| 6 |
COPY_ACTION_JS,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
)
|
| 8 |
from tagger import (
|
| 9 |
predict_tags_wd,
|
|
|
|
| 12 |
insert_recom_prompt,
|
| 13 |
compose_prompt_to_copy,
|
| 14 |
translate_prompt,
|
| 15 |
+
select_random_character,
|
| 16 |
)
|
| 17 |
+
from fl2sd3longcap import predict_tags_fl2_sd3
|
| 18 |
+
from fl2basepromptgen import predict_tags_fl2_base_prompt_gen
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
from promptenhancer import prompt_enhancer
|
| 20 |
|
|
|
|
| 21 |
def description_ui():
|
| 22 |
gr.Markdown(
|
| 23 |
"""
|
| 24 |
## Prompt Enhancer with WD Tagger & SD3 Long Captioner
|
| 25 |
(Image =>) Prompt => Upsampled longer prompt
|
| 26 |
+
"""
|
| 27 |
+
)
|
| 28 |
+
def description_ui2():
|
| 29 |
+
gr.Markdown(
|
| 30 |
+
"""
|
| 31 |
- It's a mod. Original Spaces: p1atdev's [WD Tagger with π€ transformers](https://huggingface.co/spaces/p1atdev/wd-tagger-transformers),\
|
| 32 |
gokaygokay's [Prompt-Enhancer](https://huggingface.co/spaces/gokaygokay/Prompt-Enhancer) /\
|
| 33 |
[Florence-2-SD3-Captioner](https://huggingface.co/spaces/gokaygokay/Florence-2-SD3-Captioner).
|
|
|
|
| 39 |
"""
|
| 40 |
)
|
| 41 |
|
|
|
|
| 42 |
def main():
|
|
|
|
| 43 |
with gr.Blocks() as ui:
|
| 44 |
description_ui()
|
| 45 |
+
with gr.Column():
|
| 46 |
+
with gr.Group():
|
| 47 |
+
input_image = gr.Image(label="Input image", type="pil", sources=["upload", "clipboard"], height=256)
|
| 48 |
+
with gr.Accordion(label="Advanced options", open=False):
|
| 49 |
+
general_threshold = gr.Slider(label="Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.01, interactive=True)
|
| 50 |
+
character_threshold = gr.Slider(label="Character threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.01, interactive=True)
|
| 51 |
+
input_tag_type = gr.Radio(label="Convert tags to", info="danbooru for Animagine, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru")
|
| 52 |
+
recom_prompt = gr.Radio(label="Insert reccomended prompt", choices=["None", "Animagine", "Pony"], value="None", interactive=True)
|
| 53 |
+
keep_tags = gr.Radio(label="Remove tags leaving only the following", choices=["body", "dress", "all"], value="all")
|
| 54 |
+
image_algorithms = gr.CheckboxGroup(["Use WD Tagger", "Use Florence-2-SD3-Long-Captioner", "Use Florence-2-base-PromptGen"], label="Algorithms", value=["Use WD Tagger", "Use Florence-2-SD3-Long-Captioner"])
|
| 55 |
+
generate_from_image_btn = gr.Button(value="GENERATE TAGS FROM IMAGE", size="lg", variant="primary")
|
| 56 |
+
with gr.Group():
|
| 57 |
+
with gr.Row():
|
|
|
|
|
|
|
| 58 |
input_character = gr.Textbox(label="Character tags", placeholder="hatsune miku")
|
| 59 |
input_copyright = gr.Textbox(label="Copyright tags", placeholder="vocaloid")
|
| 60 |
+
random_character = gr.Button(value="Random character π²", size="sm")
|
| 61 |
+
input_general = gr.TextArea(label="General tags", lines=4, placeholder="1girl, ...", value="")
|
| 62 |
+
input_tags_to_copy = gr.Textbox(value="", visible=False)
|
| 63 |
+
copy_input_btn = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
|
| 64 |
+
translate_input_prompt_button = gr.Button(value="Translate prompt to English", size="sm", variant="secondary")
|
| 65 |
+
prompt_enhancer_model = gr.Radio(["Medium", "Long"], label="Model Choice", value="Long", info="Enhance your prompts with Medium or Long answers")
|
| 66 |
+
with gr.Accordion(label="Advanced options", open=False, visible=False):
|
| 67 |
+
tag_type = gr.Radio(label="Output tag conversion", info="danbooru for Animagine, e621 for Pony.", choices=["danbooru", "e621"], value="e621", visible=False)
|
| 68 |
+
dummy_np = gr.Textbox(label="Negative prompt", value="", visible=False)
|
| 69 |
+
recom_animagine = gr.Textbox(label="Animagine reccomended prompt", value="Animagine", visible=False)
|
| 70 |
+
recom_pony = gr.Textbox(label="Pony reccomended prompt", value="Pony", visible=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
generate_btn = gr.Button(value="GENERATE TAGS", size="lg", variant="primary")
|
| 72 |
+
with gr.Row():
|
| 73 |
with gr.Group():
|
| 74 |
output_text = gr.TextArea(label="Output tags", interactive=False, show_copy_button=True)
|
| 75 |
copy_btn = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
|
|
|
|
|
|
|
| 76 |
with gr.Group():
|
| 77 |
output_text_pony = gr.TextArea(label="Output tags (Pony e621 style)", interactive=False, show_copy_button=True)
|
| 78 |
copy_btn_pony = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
|
| 79 |
+
description_ui2()
|
| 80 |
+
|
| 81 |
+
random_character.click(select_random_character, [input_copyright, input_character], [input_copyright, input_character], queue=False)
|
| 82 |
|
| 83 |
translate_input_prompt_button.click(translate_prompt, [input_general], [input_general], queue=False)
|
| 84 |
translate_input_prompt_button.click(translate_prompt, [input_character], [input_character], queue=False)
|
| 85 |
translate_input_prompt_button.click(translate_prompt, [input_copyright], [input_copyright], queue=False)
|
| 86 |
|
| 87 |
generate_from_image_btn.click(
|
| 88 |
+
lambda: ("", "", ""), None, [input_copyright, input_character, input_general], queue=False,
|
| 89 |
+
).success(
|
| 90 |
predict_tags_wd,
|
| 91 |
[input_image, input_general, image_algorithms, general_threshold, character_threshold],
|
| 92 |
[input_copyright, input_character, input_general, copy_input_btn],
|
fl2basepromptgen.py
CHANGED
|
@@ -1,11 +1,13 @@
|
|
| 1 |
from transformers import AutoProcessor, AutoModelForCausalLM
|
| 2 |
import spaces
|
| 3 |
from PIL import Image
|
|
|
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
| 7 |
|
| 8 |
-
|
|
|
|
| 9 |
fl_processor = AutoProcessor.from_pretrained('MiaoshouAI/Florence-2-base-PromptGen', trust_remote_code=True)
|
| 10 |
|
| 11 |
|
|
@@ -18,7 +20,7 @@ def fl_run(image):
|
|
| 18 |
if image.mode != "RGB":
|
| 19 |
image = image.convert("RGB")
|
| 20 |
|
| 21 |
-
inputs = fl_processor(text=prompt, images=image, return_tensors="pt")
|
| 22 |
generated_ids = fl_model.generate(
|
| 23 |
input_ids=inputs["input_ids"],
|
| 24 |
pixel_values=inputs["pixel_values"],
|
|
|
|
| 1 |
from transformers import AutoProcessor, AutoModelForCausalLM
|
| 2 |
import spaces
|
| 3 |
from PIL import Image
|
| 4 |
+
import torch
|
| 5 |
|
| 6 |
+
import subprocess
|
| 7 |
+
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
| 8 |
|
| 9 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 10 |
+
fl_model = AutoModelForCausalLM.from_pretrained('MiaoshouAI/Florence-2-base-PromptGen', trust_remote_code=True).to(device).eval()
|
| 11 |
fl_processor = AutoProcessor.from_pretrained('MiaoshouAI/Florence-2-base-PromptGen', trust_remote_code=True)
|
| 12 |
|
| 13 |
|
|
|
|
| 20 |
if image.mode != "RGB":
|
| 21 |
image = image.convert("RGB")
|
| 22 |
|
| 23 |
+
inputs = fl_processor(text=prompt, images=image, return_tensors="pt").to(device)
|
| 24 |
generated_ids = fl_model.generate(
|
| 25 |
input_ids=inputs["input_ids"],
|
| 26 |
pixel_values=inputs["pixel_values"],
|
fl2sd3longcap.py
CHANGED
|
@@ -2,11 +2,13 @@ from transformers import AutoProcessor, AutoModelForCausalLM
|
|
| 2 |
import spaces
|
| 3 |
import re
|
| 4 |
from PIL import Image
|
|
|
|
| 5 |
|
| 6 |
import subprocess
|
| 7 |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
| 8 |
|
| 9 |
-
|
|
|
|
| 10 |
fl_processor = AutoProcessor.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True)
|
| 11 |
|
| 12 |
|
|
@@ -48,7 +50,7 @@ def fl_run_example(image):
|
|
| 48 |
if image.mode != "RGB":
|
| 49 |
image = image.convert("RGB")
|
| 50 |
|
| 51 |
-
inputs = fl_processor(text=prompt, images=image, return_tensors="pt")
|
| 52 |
generated_ids = fl_model.generate(
|
| 53 |
input_ids=inputs["input_ids"],
|
| 54 |
pixel_values=inputs["pixel_values"],
|
|
|
|
| 2 |
import spaces
|
| 3 |
import re
|
| 4 |
from PIL import Image
|
| 5 |
+
import torch
|
| 6 |
|
| 7 |
import subprocess
|
| 8 |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
| 9 |
|
| 10 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 11 |
+
fl_model = AutoModelForCausalLM.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True).to(device).eval()
|
| 12 |
fl_processor = AutoProcessor.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True)
|
| 13 |
|
| 14 |
|
|
|
|
| 50 |
if image.mode != "RGB":
|
| 51 |
image = image.convert("RGB")
|
| 52 |
|
| 53 |
+
inputs = fl_processor(text=prompt, images=image, return_tensors="pt").to(device)
|
| 54 |
generated_ids = fl_model.generate(
|
| 55 |
input_ids=inputs["input_ids"],
|
| 56 |
pixel_values=inputs["pixel_values"],
|
promptenhancer.py
CHANGED
|
@@ -2,10 +2,13 @@ import spaces
|
|
| 2 |
import gradio as gr
|
| 3 |
from transformers import pipeline
|
| 4 |
import re
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
def load_models():
|
| 7 |
-
enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device=
|
| 8 |
-
enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=
|
| 9 |
return enhancer_medium, enhancer_long
|
| 10 |
|
| 11 |
enhancer_medium, enhancer_long = load_models()
|
|
@@ -39,4 +42,4 @@ def prompt_enhancer(character: str, series: str, general: str, model_choice: str
|
|
| 39 |
output = enhance_prompt(cprompt, model_choice)
|
| 40 |
prompt = cprompt + ", " + output
|
| 41 |
|
| 42 |
-
return prompt, gr.update(interactive=True), gr.update(interactive=True)
|
|
|
|
| 2 |
import gradio as gr
|
| 3 |
from transformers import pipeline
|
| 4 |
import re
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 8 |
|
| 9 |
def load_models():
|
| 10 |
+
enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device=device)
|
| 11 |
+
enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
|
| 12 |
return enhancer_medium, enhancer_long
|
| 13 |
|
| 14 |
enhancer_medium, enhancer_long = load_models()
|
|
|
|
| 42 |
output = enhance_prompt(cprompt, model_choice)
|
| 43 |
prompt = cprompt + ", " + output
|
| 44 |
|
| 45 |
+
return prompt, gr.update(interactive=True), gr.update(interactive=True)
|
requirements.txt
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
torch
|
| 2 |
torchvision
|
| 3 |
accelerate
|
| 4 |
transformers
|
|
|
|
| 1 |
+
torch==2.2.0
|
| 2 |
torchvision
|
| 3 |
accelerate
|
| 4 |
transformers
|
tagger.py
CHANGED
|
@@ -1,12 +1,13 @@
|
|
| 1 |
from PIL import Image
|
| 2 |
import torch
|
| 3 |
import gradio as gr
|
| 4 |
-
import spaces
|
| 5 |
-
|
| 6 |
from transformers import (
|
| 7 |
AutoImageProcessor,
|
| 8 |
AutoModelForImageClassification,
|
| 9 |
)
|
|
|
|
|
|
|
| 10 |
|
| 11 |
WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
|
| 12 |
WD_MODEL_NAME = WD_MODEL_NAMES[0]
|
|
@@ -30,12 +31,15 @@ PEOPLE_TAGS = (
|
|
| 30 |
|
| 31 |
|
| 32 |
RATING_MAP = {
|
|
|
|
| 33 |
"general": "safe",
|
| 34 |
"sensitive": "sensitive",
|
| 35 |
"questionable": "nsfw",
|
| 36 |
"explicit": "explicit, nsfw",
|
| 37 |
}
|
| 38 |
DANBOORU_TO_E621_RATING_MAP = {
|
|
|
|
|
|
|
| 39 |
"safe": "rating_safe",
|
| 40 |
"sensitive": "rating_safe",
|
| 41 |
"nsfw": "rating_explicit",
|
|
@@ -49,6 +53,34 @@ DANBOORU_TO_E621_RATING_MAP = {
|
|
| 49 |
}
|
| 50 |
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
def to_list(s):
|
| 53 |
return [x.strip() for x in s.split(",") if not s == ""]
|
| 54 |
|
|
@@ -62,9 +94,16 @@ def list_uniq(l):
|
|
| 62 |
|
| 63 |
|
| 64 |
def load_dict_from_csv(filename):
|
| 65 |
-
with open(filename, 'r', encoding="utf-8") as f:
|
| 66 |
-
lines = f.readlines()
|
| 67 |
dict = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
for line in lines:
|
| 69 |
parts = line.strip().split(',')
|
| 70 |
dict[parts[0]] = parts[1]
|
|
@@ -94,7 +133,8 @@ def character_list_to_series_list(character_list):
|
|
| 94 |
|
| 95 |
|
| 96 |
def select_random_character(series: str, character: str):
|
| 97 |
-
from random import randrange
|
|
|
|
| 98 |
character_list = list(anime_series_dict.keys())
|
| 99 |
character = character_list[randrange(len(character_list) - 1)]
|
| 100 |
series = anime_series_dict.get(character.split(",")[0].strip(), "")
|
|
@@ -104,7 +144,7 @@ def select_random_character(series: str, character: str):
|
|
| 104 |
def danbooru_to_e621(dtag, e621_dict):
|
| 105 |
def d_to_e(match, e621_dict):
|
| 106 |
dtag = match.group(0)
|
| 107 |
-
etag = e621_dict.get(dtag
|
| 108 |
if etag:
|
| 109 |
return etag
|
| 110 |
else:
|
|
@@ -112,7 +152,6 @@ def danbooru_to_e621(dtag, e621_dict):
|
|
| 112 |
|
| 113 |
import re
|
| 114 |
tag = re.sub(r'[\w ]+', lambda wrapper: d_to_e(wrapper, e621_dict), dtag, 2)
|
| 115 |
-
|
| 116 |
return tag
|
| 117 |
|
| 118 |
|
|
@@ -128,7 +167,7 @@ def convert_danbooru_to_e621_prompt(input_prompt: str = "", prompt_type: str = "
|
|
| 128 |
|
| 129 |
e621_dict = danbooru_to_e621_dict
|
| 130 |
for tag in tags:
|
| 131 |
-
tag = tag
|
| 132 |
tag = danbooru_to_e621(tag, e621_dict)
|
| 133 |
if tag in PEOPLE_TAGS:
|
| 134 |
people_tags.append(tag)
|
|
@@ -156,6 +195,7 @@ def translate_prompt(prompt: str = ""):
|
|
| 156 |
translated_prompt = translator.translate(prompt, src='auto', dest='en').text
|
| 157 |
return translated_prompt
|
| 158 |
except Exception as e:
|
|
|
|
| 159 |
return prompt
|
| 160 |
|
| 161 |
def is_japanese(s):
|
|
@@ -188,6 +228,7 @@ def translate_prompt_to_ja(prompt: str = ""):
|
|
| 188 |
translated_prompt = translator.translate(prompt, src='en', dest='ja').text
|
| 189 |
return translated_prompt
|
| 190 |
except Exception as e:
|
|
|
|
| 191 |
return prompt
|
| 192 |
|
| 193 |
def is_japanese(s):
|
|
@@ -213,7 +254,7 @@ def translate_prompt_to_ja(prompt: str = ""):
|
|
| 213 |
def tags_to_ja(itag, dict):
|
| 214 |
def t_to_j(match, dict):
|
| 215 |
tag = match.group(0)
|
| 216 |
-
ja = dict.get(tag
|
| 217 |
if ja:
|
| 218 |
return ja
|
| 219 |
else:
|
|
@@ -232,7 +273,7 @@ def convert_tags_to_ja(input_prompt: str = ""):
|
|
| 232 |
tags_to_ja_dict = load_dict_from_csv('all_tags_ja_ext.csv')
|
| 233 |
dict = tags_to_ja_dict
|
| 234 |
for tag in tags:
|
| 235 |
-
tag = tag
|
| 236 |
tag = tags_to_ja(tag, dict)
|
| 237 |
out_tags.append(tag)
|
| 238 |
|
|
@@ -242,13 +283,13 @@ def convert_tags_to_ja(input_prompt: str = ""):
|
|
| 242 |
enable_auto_recom_prompt = True
|
| 243 |
|
| 244 |
|
| 245 |
-
animagine_ps = to_list("
|
| 246 |
animagine_nps = to_list("lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
|
| 247 |
-
pony_ps = to_list("
|
| 248 |
-
pony_nps = to_list("source_pony,
|
| 249 |
other_ps = to_list("anime artwork, anime style, studio anime, highly detailed, cinematic photo, 35mm photograph, film, bokeh, professional, 4k, highly detailed")
|
| 250 |
other_nps = to_list("photo, deformed, black and white, realism, disfigured, low contrast, drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly")
|
| 251 |
-
default_ps = to_list("
|
| 252 |
default_nps = to_list("score_6, score_5, score_4, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
|
| 253 |
def insert_recom_prompt(prompt: str = "", neg_prompt: str = "", type: str = "None"):
|
| 254 |
global enable_auto_recom_prompt
|
|
@@ -281,6 +322,7 @@ def insert_recom_prompt(prompt: str = "", neg_prompt: str = "", type: str = "Non
|
|
| 281 |
def load_model_prompt_dict():
|
| 282 |
import json
|
| 283 |
dict = {}
|
|
|
|
| 284 |
try:
|
| 285 |
with open('model_dict.json', encoding='utf-8') as f:
|
| 286 |
dict = json.load(f)
|
|
@@ -359,7 +401,7 @@ def remove_specific_prompt(input_prompt: str = "", keep_tags: str = "all"):
|
|
| 359 |
|
| 360 |
group_dict = tag_group_dict
|
| 361 |
for tag in tags:
|
| 362 |
-
tag = tag
|
| 363 |
if tag in PEOPLE_TAGS:
|
| 364 |
people_tags.append(tag)
|
| 365 |
elif is_necessary(tag, keep_tags, group_dict):
|
|
@@ -387,7 +429,7 @@ def sort_taglist(tags: list[str]):
|
|
| 387 |
rating_set = set(DANBOORU_TO_E621_RATING_MAP.keys()) | set(DANBOORU_TO_E621_RATING_MAP.values())
|
| 388 |
|
| 389 |
for tag in tags:
|
| 390 |
-
tag = tag
|
| 391 |
if tag in PEOPLE_TAGS:
|
| 392 |
people_tags.append(tag)
|
| 393 |
elif tag in rating_set:
|
|
@@ -488,12 +530,13 @@ def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_t
|
|
| 488 |
output_series_tag = output_series_list[0]
|
| 489 |
else:
|
| 490 |
output_series_tag = ""
|
| 491 |
-
return output_series_tag, ", ".join(character.keys()), prompt, gr.update(interactive=True)
|
| 492 |
|
| 493 |
|
| 494 |
-
def predict_tags_wd(image: Image.Image, input_tags: str, algo: list[str], general_threshold: float = 0.3,
|
|
|
|
| 495 |
if not "Use WD Tagger" in algo and len(algo) != 0:
|
| 496 |
-
return
|
| 497 |
return predict_tags(image, general_threshold, character_threshold)
|
| 498 |
|
| 499 |
|
|
|
|
| 1 |
from PIL import Image
|
| 2 |
import torch
|
| 3 |
import gradio as gr
|
| 4 |
+
import spaces
|
|
|
|
| 5 |
from transformers import (
|
| 6 |
AutoImageProcessor,
|
| 7 |
AutoModelForImageClassification,
|
| 8 |
)
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
|
| 12 |
WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
|
| 13 |
WD_MODEL_NAME = WD_MODEL_NAMES[0]
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
RATING_MAP = {
|
| 34 |
+
"sfw": "safe",
|
| 35 |
"general": "safe",
|
| 36 |
"sensitive": "sensitive",
|
| 37 |
"questionable": "nsfw",
|
| 38 |
"explicit": "explicit, nsfw",
|
| 39 |
}
|
| 40 |
DANBOORU_TO_E621_RATING_MAP = {
|
| 41 |
+
"sfw": "rating_safe",
|
| 42 |
+
"general": "rating_safe",
|
| 43 |
"safe": "rating_safe",
|
| 44 |
"sensitive": "rating_safe",
|
| 45 |
"nsfw": "rating_explicit",
|
|
|
|
| 53 |
}
|
| 54 |
|
| 55 |
|
| 56 |
+
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
|
| 57 |
+
kaomojis = [
|
| 58 |
+
"0_0",
|
| 59 |
+
"(o)_(o)",
|
| 60 |
+
"+_+",
|
| 61 |
+
"+_-",
|
| 62 |
+
"._.",
|
| 63 |
+
"<o>_<o>",
|
| 64 |
+
"<|>_<|>",
|
| 65 |
+
"=_=",
|
| 66 |
+
">_<",
|
| 67 |
+
"3_3",
|
| 68 |
+
"6_9",
|
| 69 |
+
">_o",
|
| 70 |
+
"@_@",
|
| 71 |
+
"^_^",
|
| 72 |
+
"o_o",
|
| 73 |
+
"u_u",
|
| 74 |
+
"x_x",
|
| 75 |
+
"|_|",
|
| 76 |
+
"||_||",
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def replace_underline(x: str):
|
| 81 |
+
return x.strip().replace("_", " ") if x not in kaomojis else x.strip()
|
| 82 |
+
|
| 83 |
+
|
| 84 |
def to_list(s):
|
| 85 |
return [x.strip() for x in s.split(",") if not s == ""]
|
| 86 |
|
|
|
|
| 94 |
|
| 95 |
|
| 96 |
def load_dict_from_csv(filename):
|
|
|
|
|
|
|
| 97 |
dict = {}
|
| 98 |
+
if not Path(filename).exists():
|
| 99 |
+
if Path('./tagger/', filename).exists(): filename = str(Path('./tagger/', filename))
|
| 100 |
+
else: return dict
|
| 101 |
+
try:
|
| 102 |
+
with open(filename, 'r', encoding="utf-8") as f:
|
| 103 |
+
lines = f.readlines()
|
| 104 |
+
except Exception:
|
| 105 |
+
print(f"Failed to open dictionary file: {filename}")
|
| 106 |
+
return dict
|
| 107 |
for line in lines:
|
| 108 |
parts = line.strip().split(',')
|
| 109 |
dict[parts[0]] = parts[1]
|
|
|
|
| 133 |
|
| 134 |
|
| 135 |
def select_random_character(series: str, character: str):
|
| 136 |
+
from random import seed, randrange
|
| 137 |
+
seed()
|
| 138 |
character_list = list(anime_series_dict.keys())
|
| 139 |
character = character_list[randrange(len(character_list) - 1)]
|
| 140 |
series = anime_series_dict.get(character.split(",")[0].strip(), "")
|
|
|
|
| 144 |
def danbooru_to_e621(dtag, e621_dict):
|
| 145 |
def d_to_e(match, e621_dict):
|
| 146 |
dtag = match.group(0)
|
| 147 |
+
etag = e621_dict.get(replace_underline(dtag), "")
|
| 148 |
if etag:
|
| 149 |
return etag
|
| 150 |
else:
|
|
|
|
| 152 |
|
| 153 |
import re
|
| 154 |
tag = re.sub(r'[\w ]+', lambda wrapper: d_to_e(wrapper, e621_dict), dtag, 2)
|
|
|
|
| 155 |
return tag
|
| 156 |
|
| 157 |
|
|
|
|
| 167 |
|
| 168 |
e621_dict = danbooru_to_e621_dict
|
| 169 |
for tag in tags:
|
| 170 |
+
tag = replace_underline(tag)
|
| 171 |
tag = danbooru_to_e621(tag, e621_dict)
|
| 172 |
if tag in PEOPLE_TAGS:
|
| 173 |
people_tags.append(tag)
|
|
|
|
| 195 |
translated_prompt = translator.translate(prompt, src='auto', dest='en').text
|
| 196 |
return translated_prompt
|
| 197 |
except Exception as e:
|
| 198 |
+
print(e)
|
| 199 |
return prompt
|
| 200 |
|
| 201 |
def is_japanese(s):
|
|
|
|
| 228 |
translated_prompt = translator.translate(prompt, src='en', dest='ja').text
|
| 229 |
return translated_prompt
|
| 230 |
except Exception as e:
|
| 231 |
+
print(e)
|
| 232 |
return prompt
|
| 233 |
|
| 234 |
def is_japanese(s):
|
|
|
|
| 254 |
def tags_to_ja(itag, dict):
|
| 255 |
def t_to_j(match, dict):
|
| 256 |
tag = match.group(0)
|
| 257 |
+
ja = dict.get(replace_underline(tag), "")
|
| 258 |
if ja:
|
| 259 |
return ja
|
| 260 |
else:
|
|
|
|
| 273 |
tags_to_ja_dict = load_dict_from_csv('all_tags_ja_ext.csv')
|
| 274 |
dict = tags_to_ja_dict
|
| 275 |
for tag in tags:
|
| 276 |
+
tag = replace_underline(tag)
|
| 277 |
tag = tags_to_ja(tag, dict)
|
| 278 |
out_tags.append(tag)
|
| 279 |
|
|
|
|
| 283 |
enable_auto_recom_prompt = True
|
| 284 |
|
| 285 |
|
| 286 |
+
animagine_ps = to_list("masterpiece, best quality, very aesthetic, absurdres")
|
| 287 |
animagine_nps = to_list("lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
|
| 288 |
+
pony_ps = to_list("score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres")
|
| 289 |
+
pony_nps = to_list("source_pony, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends")
|
| 290 |
other_ps = to_list("anime artwork, anime style, studio anime, highly detailed, cinematic photo, 35mm photograph, film, bokeh, professional, 4k, highly detailed")
|
| 291 |
other_nps = to_list("photo, deformed, black and white, realism, disfigured, low contrast, drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly")
|
| 292 |
+
default_ps = to_list("highly detailed, masterpiece, best quality, very aesthetic, absurdres")
|
| 293 |
default_nps = to_list("score_6, score_5, score_4, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
|
| 294 |
def insert_recom_prompt(prompt: str = "", neg_prompt: str = "", type: str = "None"):
|
| 295 |
global enable_auto_recom_prompt
|
|
|
|
| 322 |
def load_model_prompt_dict():
|
| 323 |
import json
|
| 324 |
dict = {}
|
| 325 |
+
path = 'model_dict.json' if Path('model_dict.json').exists() else './tagger/model_dict.json'
|
| 326 |
try:
|
| 327 |
with open('model_dict.json', encoding='utf-8') as f:
|
| 328 |
dict = json.load(f)
|
|
|
|
| 401 |
|
| 402 |
group_dict = tag_group_dict
|
| 403 |
for tag in tags:
|
| 404 |
+
tag = replace_underline(tag)
|
| 405 |
if tag in PEOPLE_TAGS:
|
| 406 |
people_tags.append(tag)
|
| 407 |
elif is_necessary(tag, keep_tags, group_dict):
|
|
|
|
| 429 |
rating_set = set(DANBOORU_TO_E621_RATING_MAP.keys()) | set(DANBOORU_TO_E621_RATING_MAP.values())
|
| 430 |
|
| 431 |
for tag in tags:
|
| 432 |
+
tag = replace_underline(tag)
|
| 433 |
if tag in PEOPLE_TAGS:
|
| 434 |
people_tags.append(tag)
|
| 435 |
elif tag in rating_set:
|
|
|
|
| 530 |
output_series_tag = output_series_list[0]
|
| 531 |
else:
|
| 532 |
output_series_tag = ""
|
| 533 |
+
return output_series_tag, ", ".join(character.keys()), prompt, gr.update(interactive=True)
|
| 534 |
|
| 535 |
|
| 536 |
+
def predict_tags_wd(image: Image.Image, input_tags: str, algo: list[str], general_threshold: float = 0.3,
|
| 537 |
+
character_threshold: float = 0.8, input_series: str = "", input_character: str = ""):
|
| 538 |
if not "Use WD Tagger" in algo and len(algo) != 0:
|
| 539 |
+
return input_series, input_character, input_tags, gr.update(interactive=True)
|
| 540 |
return predict_tags(image, general_threshold, character_threshold)
|
| 541 |
|
| 542 |
|
utils.py
CHANGED
|
@@ -43,3 +43,8 @@ COPY_ACTION_JS = """\
|
|
| 43 |
navigator.clipboard.writeText(inputs);
|
| 44 |
}
|
| 45 |
}"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
navigator.clipboard.writeText(inputs);
|
| 44 |
}
|
| 45 |
}"""
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def gradio_copy_prompt(prompt: str):
|
| 49 |
+
gr.Info("Copied!")
|
| 50 |
+
return prompt
|