John6666 commited on
Commit
ab4fc1e
β€’
0 Parent(s):

Super-squash branch 'main' using huggingface_hub

Browse files
Files changed (14) hide show
  1. .gitattributes +35 -0
  2. README.md +14 -0
  3. app.py +164 -0
  4. character_series_dict.csv +0 -0
  5. danbooru_e621.csv +0 -0
  6. fl2sd3longcap.py +74 -0
  7. output.py +16 -0
  8. pre-requirements.txt +1 -0
  9. requirements.txt +13 -0
  10. tag_group.csv +0 -0
  11. tagger.py +450 -0
  12. utils.py +45 -0
  13. v2.py +214 -0
  14. z3de621conv.py +68 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Danbooru Tags Transformer V2 with WD Tagger & Florence 2 SD3 Captioner
3
+ emoji: πŸ“¦πŸƒ
4
+ colorFrom: yellow
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 4.37.2
8
+ header: mini
9
+ app_file: app.py
10
+ pinned: false
11
+ license: apache-2.0
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import gradio as gr
3
+
4
+ from v2 import (
5
+ V2UI,
6
+ parse_upsampling_output,
7
+ V2_ALL_MODELS,
8
+ )
9
+ from utils import (
10
+ gradio_copy_text,
11
+ COPY_ACTION_JS,
12
+ V2_ASPECT_RATIO_OPTIONS,
13
+ V2_RATING_OPTIONS,
14
+ V2_LENGTH_OPTIONS,
15
+ V2_IDENTITY_OPTIONS
16
+ )
17
+ from tagger import (
18
+ predict_tags_wd,
19
+ convert_danbooru_to_e621_prompt,
20
+ remove_specific_prompt,
21
+ insert_recom_prompt,
22
+ compose_prompt_to_copy,
23
+ translate_prompt,
24
+ sort_tags,
25
+ )
26
+ from z3de621conv import (
27
+ predict_tags_e621,
28
+ )
29
+ from fl2sd3longcap import (
30
+ predict_tags_fl2_sd3,
31
+ )
32
+
33
+ def description_ui():
34
+ gr.Markdown(
35
+ """
36
+ ## Danbooru Tags Transformer V2 Demo with WD Tagger
37
+ (Image =>) Prompt => Upsampled longer prompt
38
+ - Mod of p1atdev's [Danbooru Tags Transformer V2 Demo](https://huggingface.co/spaces/p1atdev/danbooru-tags-transformer-v2) and [WD Tagger with πŸ€— transformers](https://huggingface.co/spaces/p1atdev/wd-tagger-transformers).
39
+ - Models: p1atdev's [wd-swinv2-tagger-v3-hf](https://huggingface.co/p1atdev/wd-swinv2-tagger-v3-hf), [dart-v2-moe-sft](https://huggingface.co/p1atdev/dart-v2-moe-sft), [dart-v2-sft](https://huggingface.co/p1atdev/dart-v2-sft)\
40
+ , toynya's [Z3D-E621-Convnext](https://huggingface.co/toynya/Z3D-E621-Convnext), gokaygokay's [Florence-2-SD3-Captioner](https://huggingface.co/gokaygokay/Florence-2-SD3-Captioner)
41
+ """
42
+ )
43
+
44
+
45
+ def main():
46
+
47
+ v2 = V2UI()
48
+
49
+ with gr.Blocks() as ui:
50
+ description_ui()
51
+
52
+ with gr.Row():
53
+ with gr.Column(scale=2):
54
+ with gr.Group():
55
+ input_image = gr.Image(label="Input image", type="pil", sources=["upload", "clipboard"], height=256)
56
+ with gr.Accordion(label="Advanced options", open=False):
57
+ general_threshold = gr.Slider(label="Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.01, interactive=True)
58
+ character_threshold = gr.Slider(label="Character threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.01, interactive=True)
59
+ e621_threshold = gr.Slider(label="Threshold (Z3D-E621-Convnext)", minimum=0.0, maximum=1.0, value=0.5, step=0.01, interactive=True)
60
+ input_tag_type = gr.Radio(label="Convert tags to", info="danbooru for Animagine, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru")
61
+ recom_prompt = gr.Radio(label="Insert reccomended prompt", choices=["None", "Animagine", "Pony"], value="None", interactive=True)
62
+ image_algorithms = gr.CheckboxGroup(["Use WD Tagger", "Use Z3D-E621-Convnext", "Use Florence-2-SD3-Long-Captioner"], label="Algorithms", value=["Use WD Tagger"])
63
+ keep_tags = gr.Radio(label="Remove tags leaving only the following", choices=["body", "dress", "all"], value="all")
64
+ generate_from_image_btn = gr.Button(value="GENERATE TAGS FROM IMAGE", size="lg", variant="primary")
65
+
66
+ with gr.Group():
67
+ input_character = gr.Textbox(label="Character tags", placeholder="hatsune miku")
68
+ input_copyright = gr.Textbox(label="Copyright tags", placeholder="vocaloid")
69
+ input_general = gr.TextArea(label="General tags", lines=4, placeholder="1girl, ...", value="")
70
+ input_tags_to_copy = gr.Textbox(value="", visible=False)
71
+ copy_input_btn = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
72
+ translate_input_prompt_button = gr.Button(value="Translate prompt to English", size="sm", variant="secondary")
73
+ tag_type = gr.Radio(label="Output tag conversion", info="danbooru for Animagine, e621 for Pony.", choices=["danbooru", "e621"], value="e621", visible=False)
74
+ input_rating = gr.Radio(label="Rating", choices=list(V2_RATING_OPTIONS), value="explicit")
75
+ with gr.Accordion(label="Advanced options", open=False):
76
+ input_aspect_ratio = gr.Radio(label="Aspect ratio", info="The aspect ratio of the image.", choices=list(V2_ASPECT_RATIO_OPTIONS), value="square")
77
+ input_length = gr.Radio(label="Length", info="The total length of the tags.", choices=list(V2_LENGTH_OPTIONS), value="very_long")
78
+ input_identity = gr.Radio(label="Keep identity", info="How strictly to keep the identity of the character or subject. If you specify the detail of subject in the prompt, you should choose `strict`. Otherwise, choose `none` or `lax`. `none` is very creative but sometimes ignores the input prompt.", choices=list(V2_IDENTITY_OPTIONS), value="lax")
79
+ input_ban_tags = gr.Textbox(label="Ban tags", info="Tags to ban from the output.", placeholder="alternate costumen, ...", value="censored")
80
+ model_name = gr.Dropdown(label="Model", choices=list(V2_ALL_MODELS.keys()), value=list(V2_ALL_MODELS.keys())[0])
81
+ dummy_np = gr.Textbox(label="Negative prompt", value="", visible=False)
82
+ recom_animagine = gr.Textbox(label="Animagine reccomended prompt", value="Animagine", visible=False)
83
+ recom_pony = gr.Textbox(label="Pony reccomended prompt", value="Pony", visible=False)
84
+
85
+ generate_btn = gr.Button(value="GENERATE TAGS", size="lg", variant="primary")
86
+
87
+ with gr.Group():
88
+ output_text = gr.TextArea(label="Output tags", interactive=False, show_copy_button=True)
89
+ copy_btn = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
90
+ elapsed_time_md = gr.Markdown(label="Elapsed time", value="", visible=False)
91
+
92
+ with gr.Group():
93
+ output_text_pony = gr.TextArea(label="Output tags (Pony e621 style)", interactive=False, show_copy_button=True)
94
+ copy_btn_pony = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
95
+
96
+ v2.input_components = [
97
+ model_name,
98
+ input_copyright,
99
+ input_character,
100
+ input_general,
101
+ input_rating,
102
+ input_aspect_ratio,
103
+ input_length,
104
+ input_identity,
105
+ input_ban_tags,
106
+ ]
107
+
108
+ translate_input_prompt_button.click(translate_prompt, inputs=[input_general], outputs=[input_general])
109
+ translate_input_prompt_button.click(translate_prompt, inputs=[input_character], outputs=[input_character])
110
+ translate_input_prompt_button.click(translate_prompt, inputs=[input_copyright], outputs=[input_copyright])
111
+
112
+ generate_from_image_btn.click(
113
+ predict_tags_wd,
114
+ inputs=[input_image, input_general, image_algorithms, general_threshold, character_threshold],
115
+ outputs=[
116
+ input_copyright,
117
+ input_character,
118
+ input_general,
119
+ copy_input_btn,
120
+ ],
121
+ ).then(
122
+ predict_tags_e621,
123
+ inputs=[input_image, input_general, image_algorithms, e621_threshold],
124
+ outputs=[input_general],
125
+ ).then(
126
+ predict_tags_fl2_sd3,
127
+ inputs=[input_image, input_general, image_algorithms],
128
+ outputs=[input_general],
129
+ ).then(
130
+ remove_specific_prompt, inputs=[input_general, keep_tags], outputs=[input_general],
131
+ ).then(
132
+ convert_danbooru_to_e621_prompt, inputs=[input_general, input_tag_type], outputs=[input_general],
133
+ ).then(
134
+ sort_tags, inputs=[input_general], outputs=[input_general],
135
+ ).then(
136
+ insert_recom_prompt, inputs=[input_general, dummy_np, recom_prompt], outputs=[input_general, dummy_np],
137
+ )
138
+ copy_input_btn.click(compose_prompt_to_copy, inputs=[input_character, input_copyright, input_general], outputs=[input_tags_to_copy]).then(
139
+ gradio_copy_text, inputs=[input_tags_to_copy], js=COPY_ACTION_JS,
140
+ )
141
+
142
+ generate_btn.click(
143
+ parse_upsampling_output(v2.on_generate),
144
+ inputs=[
145
+ *v2.input_components,
146
+ ],
147
+ outputs=[output_text, elapsed_time_md, copy_btn, copy_btn_pony],
148
+ ).then(
149
+ sort_tags, inputs=[output_text], outputs=[output_text],
150
+ ).then(
151
+ convert_danbooru_to_e621_prompt, inputs=[output_text, tag_type], outputs=[output_text_pony],
152
+ ).then(
153
+ insert_recom_prompt, inputs=[output_text, dummy_np, recom_animagine], outputs=[output_text, dummy_np],
154
+ ).then(
155
+ insert_recom_prompt, inputs=[output_text_pony, dummy_np, recom_pony], outputs=[output_text_pony, dummy_np],
156
+ )
157
+ copy_btn.click(gradio_copy_text, inputs=[output_text], js=COPY_ACTION_JS)
158
+ copy_btn_pony.click(gradio_copy_text, inputs=[output_text_pony], js=COPY_ACTION_JS)
159
+
160
+ ui.launch()
161
+
162
+
163
+ if __name__ == "__main__":
164
+ main()
character_series_dict.csv ADDED
The diff for this file is too large to render. See raw diff
 
danbooru_e621.csv ADDED
The diff for this file is too large to render. See raw diff
 
fl2sd3longcap.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoProcessor, AutoModelForCausalLM
2
+ import spaces
3
+ import re
4
+ from PIL import Image
5
+
6
+ import subprocess
7
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
8
+
9
+ fl_model = AutoModelForCausalLM.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True).eval()
10
+ fl_processor = AutoProcessor.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True)
11
+
12
+
13
+ def fl_modify_caption(caption: str) -> str:
14
+ """
15
+ Removes specific prefixes from captions if present, otherwise returns the original caption.
16
+ Args:
17
+ caption (str): A string containing a caption.
18
+ Returns:
19
+ str: The caption with the prefix removed if it was present, or the original caption.
20
+ """
21
+ # Define the prefixes to remove
22
+ prefix_substrings = [
23
+ ('captured from ', ''),
24
+ ('captured at ', '')
25
+ ]
26
+
27
+ # Create a regex pattern to match any of the prefixes
28
+ pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings])
29
+ replacers = {opening.lower(): replacer for opening, replacer in prefix_substrings}
30
+
31
+ # Function to replace matched prefix with its corresponding replacement
32
+ def replace_fn(match):
33
+ return replacers[match.group(0).lower()]
34
+
35
+ # Apply the regex to the caption
36
+ modified_caption = re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE)
37
+
38
+ # If the caption was modified, return the modified version; otherwise, return the original
39
+ return modified_caption if modified_caption != caption else caption
40
+
41
+
42
+ @spaces.GPU
43
+ def fl_run_example(image):
44
+ task_prompt = "<DESCRIPTION>"
45
+ prompt = task_prompt + "Describe this image in great detail."
46
+
47
+ # Ensure the image is in RGB mode
48
+ if image.mode != "RGB":
49
+ image = image.convert("RGB")
50
+
51
+ inputs = fl_processor(text=prompt, images=image, return_tensors="pt")
52
+ generated_ids = fl_model.generate(
53
+ input_ids=inputs["input_ids"],
54
+ pixel_values=inputs["pixel_values"],
55
+ max_new_tokens=1024,
56
+ num_beams=3
57
+ )
58
+ generated_text = fl_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
59
+ parsed_answer = fl_processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
60
+ return fl_modify_caption(parsed_answer["<DESCRIPTION>"])
61
+
62
+
63
+ def predict_tags_fl2_sd3(image: Image.Image, input_tags: str, algo: list[str]):
64
+ def to_list(s):
65
+ return [x.strip() for x in s.split(",") if not s == ""]
66
+
67
+ def list_uniq(l):
68
+ return sorted(set(l), key=l.index)
69
+
70
+ if not "Use Florence-2-SD3-Long-Captioner" in algo:
71
+ return input_tags
72
+ tag_list = list_uniq(to_list(input_tags) + to_list(fl_run_example(image) + ", "))
73
+ tag_list.remove("")
74
+ return ", ".join(tag_list)
output.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class UpsamplingOutput:
6
+ upsampled_tags: str
7
+
8
+ copyright_tags: str
9
+ character_tags: str
10
+ general_tags: str
11
+ rating_tag: str
12
+ aspect_ratio_tag: str
13
+ length_tag: str
14
+ identity_tag: str
15
+
16
+ elapsed_time: float = 0.0
pre-requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ pip>=23.0.0
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ accelerate
4
+ transformers
5
+ optimum[onnxruntime]
6
+ spaces
7
+ dartrs
8
+ httpx==0.13.3
9
+ httpcore
10
+ googletrans==4.0.0rc1
11
+ numpy
12
+ onnxruntime-gpu
13
+ timm
tag_group.csv ADDED
The diff for this file is too large to render. See raw diff
 
tagger.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ import gradio as gr
4
+ import spaces # ZERO GPU
5
+
6
+ from transformers import (
7
+ AutoImageProcessor,
8
+ AutoModelForImageClassification,
9
+ )
10
+
11
+ WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
12
+ WD_MODEL_NAME = WD_MODEL_NAMES[0]
13
+
14
+ wd_model = AutoModelForImageClassification.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
15
+ wd_model.to("cuda" if torch.cuda.is_available() else "cpu")
16
+ wd_processor = AutoImageProcessor.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
17
+
18
+
19
+ def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
20
+ return (
21
+ [f"1{noun}"]
22
+ + [f"{num}{noun}s" for num in range(minimum + 1, maximum + 1)]
23
+ + [f"{maximum+1}+{noun}s"]
24
+ )
25
+
26
+
27
+ PEOPLE_TAGS = (
28
+ _people_tag("girl") + _people_tag("boy") + _people_tag("other") + ["no humans"]
29
+ )
30
+
31
+
32
+ RATING_MAP = {
33
+ "general": "safe",
34
+ "sensitive": "sensitive",
35
+ "questionable": "nsfw",
36
+ "explicit": "explicit, nsfw",
37
+ }
38
+ DANBOORU_TO_E621_RATING_MAP = {
39
+ "safe": "rating_safe",
40
+ "sensitive": "rating_safe",
41
+ "nsfw": "rating_explicit",
42
+ "explicit, nsfw": "rating_explicit",
43
+ "explicit": "rating_explicit",
44
+ "rating:safe": "rating_safe",
45
+ "rating:general": "rating_safe",
46
+ "rating:sensitive": "rating_safe",
47
+ "rating:questionable, nsfw": "rating_explicit",
48
+ "rating:explicit, nsfw": "rating_explicit",
49
+ }
50
+
51
+
52
+ def load_dict_from_csv(filename):
53
+ with open(filename, 'r', encoding="utf-8") as f:
54
+ lines = f.readlines()
55
+ dict = {}
56
+ for line in lines:
57
+ parts = line.strip().split(',')
58
+ dict[parts[0]] = parts[1]
59
+ return dict
60
+
61
+
62
+ anime_series_dict = load_dict_from_csv('character_series_dict.csv')
63
+
64
+
65
+ def character_list_to_series_list(character_list):
66
+ output_series_tag = []
67
+ series_tag = ""
68
+ series_dict = anime_series_dict
69
+ for tag in character_list:
70
+ series_tag = series_dict.get(tag, "")
71
+ if tag.endswith(")"):
72
+ tags = tag.split("(")
73
+ character_tag = "(".join(tags[:-1])
74
+ if character_tag.endswith(" "):
75
+ character_tag = character_tag[:-1]
76
+ series_tag = tags[-1].replace(")", "")
77
+
78
+ if series_tag:
79
+ output_series_tag.append(series_tag)
80
+
81
+ return output_series_tag
82
+
83
+
84
+ def danbooru_to_e621(dtag, e621_dict):
85
+ def d_to_e(match, e621_dict):
86
+ dtag = match.group(0)
87
+ etag = e621_dict.get(dtag.strip().replace("_", " "), "")
88
+ if etag:
89
+ return etag
90
+ else:
91
+ return dtag
92
+
93
+ import re
94
+ tag = re.sub(r'[\w ]+', lambda wrapper: d_to_e(wrapper, e621_dict), dtag, 2)
95
+
96
+ return tag
97
+
98
+
99
+ danbooru_to_e621_dict = load_dict_from_csv('danbooru_e621.csv')
100
+
101
+
102
+ def convert_danbooru_to_e621_prompt(input_prompt: str = "", prompt_type: str = "danbooru"):
103
+ if prompt_type == "danbooru": return input_prompt
104
+ tags = input_prompt.split(",") if input_prompt else []
105
+ people_tags: list[str] = []
106
+ other_tags: list[str] = []
107
+ rating_tags: list[str] = []
108
+
109
+ e621_dict = danbooru_to_e621_dict
110
+ for tag in tags:
111
+ tag = tag.strip().replace("_", " ")
112
+ tag = danbooru_to_e621(tag, e621_dict)
113
+ if tag in PEOPLE_TAGS:
114
+ people_tags.append(tag)
115
+ elif tag in DANBOORU_TO_E621_RATING_MAP.keys():
116
+ rating_tags.append(DANBOORU_TO_E621_RATING_MAP.get(tag.replace(" ",""), ""))
117
+ else:
118
+ other_tags.append(tag)
119
+
120
+ rating_tags = sorted(set(rating_tags), key=rating_tags.index)
121
+ rating_tags = [rating_tags[0]] if rating_tags else []
122
+ rating_tags = ["explicit, nsfw"] if rating_tags and rating_tags[0] == "explicit" else rating_tags
123
+
124
+ output_prompt = ", ".join(people_tags + other_tags + rating_tags)
125
+
126
+ return output_prompt
127
+
128
+
129
+ def translate_prompt(prompt: str = ""):
130
+ def translate_to_english(prompt):
131
+ import httpcore
132
+ setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy')
133
+ from googletrans import Translator
134
+ translator = Translator()
135
+ try:
136
+ translated_prompt = translator.translate(prompt, src='auto', dest='en').text
137
+ return translated_prompt
138
+ except Exception as e:
139
+ return prompt
140
+
141
+ def is_japanese(s):
142
+ import unicodedata
143
+ for ch in s:
144
+ name = unicodedata.name(ch, "")
145
+ if "CJK UNIFIED" in name or "HIRAGANA" in name or "KATAKANA" in name:
146
+ return True
147
+ return False
148
+
149
+ def to_list(s):
150
+ return [x.strip() for x in s.split(",")]
151
+
152
+ prompts = to_list(prompt)
153
+ outputs = []
154
+ for p in prompts:
155
+ p = translate_to_english(p) if is_japanese(p) else p
156
+ outputs.append(p)
157
+
158
+ return ", ".join(outputs)
159
+
160
+
161
+ def translate_prompt_to_ja(prompt: str = ""):
162
+ def translate_to_japanese(prompt):
163
+ import httpcore
164
+ setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy')
165
+ from googletrans import Translator
166
+ translator = Translator()
167
+ try:
168
+ translated_prompt = translator.translate(prompt, src='en', dest='ja').text
169
+ return translated_prompt
170
+ except Exception as e:
171
+ return prompt
172
+
173
+ def is_japanese(s):
174
+ import unicodedata
175
+ for ch in s:
176
+ name = unicodedata.name(ch, "")
177
+ if "CJK UNIFIED" in name or "HIRAGANA" in name or "KATAKANA" in name:
178
+ return True
179
+ return False
180
+
181
+ def to_list(s):
182
+ return [x.strip() for x in s.split(",")]
183
+
184
+ prompts = to_list(prompt)
185
+ outputs = []
186
+ for p in prompts:
187
+ p = translate_to_japanese(p) if not is_japanese(p) else p
188
+ outputs.append(p)
189
+
190
+ return ", ".join(outputs)
191
+
192
+
193
+ def tags_to_ja(itag, dict):
194
+ def t_to_j(match, dict):
195
+ tag = match.group(0)
196
+ ja = dict.get(tag.strip().replace("_", " "), "")
197
+ if ja:
198
+ return ja
199
+ else:
200
+ return tag
201
+
202
+ import re
203
+ tag = re.sub(r'[\w ]+', lambda wrapper: t_to_j(wrapper, dict), itag, 2)
204
+
205
+ return tag
206
+
207
+
208
+ def convert_tags_to_ja(input_prompt: str = ""):
209
+ tags = input_prompt.split(",") if input_prompt else []
210
+ out_tags = []
211
+
212
+ tags_to_ja_dict = load_dict_from_csv('all_tags_ja_ext.csv')
213
+ dict = tags_to_ja_dict
214
+ for tag in tags:
215
+ tag = tag.strip().replace("_", " ")
216
+ tag = tags_to_ja(tag, dict)
217
+ out_tags.append(tag)
218
+
219
+ return ", ".join(out_tags)
220
+
221
+
222
+ def insert_recom_prompt(prompt: str = "", neg_prompt: str = "", type: str = "None"):
223
+ def to_list(s):
224
+ return [x.strip() for x in s.split(",") if not s == ""]
225
+
226
+ def list_sub(a, b):
227
+ return [e for e in a if e not in b]
228
+
229
+ def list_uniq(l):
230
+ return sorted(set(l), key=l.index)
231
+
232
+ animagine_ps = to_list("anime artwork, anime style, key visual, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres")
233
+ animagine_nps = to_list("lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
234
+ pony_ps = to_list("source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres")
235
+ pony_nps = to_list("source_pony, source_furry, source_cartoon, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends")
236
+ prompts = to_list(prompt)
237
+ neg_prompts = to_list(neg_prompt)
238
+
239
+ prompts = list_sub(prompts, animagine_ps + pony_ps)
240
+ neg_prompts = list_sub(neg_prompts, animagine_nps + pony_nps)
241
+
242
+ last_empty_p = [""] if not prompts and type != "None" else []
243
+ last_empty_np = [""] if not neg_prompts and type != "None" else []
244
+
245
+ if type == "Animagine":
246
+ prompts = prompts + animagine_ps
247
+ neg_prompts = neg_prompts + animagine_nps
248
+ elif type == "Pony":
249
+ prompts = prompts + pony_ps
250
+ neg_prompts = neg_prompts + pony_nps
251
+
252
+ prompt = ", ".join(list_uniq(prompts) + last_empty_p)
253
+ neg_prompt = ", ".join(list_uniq(neg_prompts) + last_empty_np)
254
+
255
+ return prompt, neg_prompt
256
+
257
+
258
+ tag_group_dict = load_dict_from_csv('tag_group.csv')
259
+
260
+
261
+ def remove_specific_prompt(input_prompt: str = "", keep_tags: str = "all"):
262
+ def is_dressed(tag):
263
+ import re
264
+ p = re.compile(r'dress|cloth|uniform|costume|vest|sweater|coat|shirt|jacket|blazer|apron|leotard|hood|sleeve|skirt|shorts|pant|loafer|ribbon|necktie|bow|collar|glove|sock|shoe|boots|wear|emblem')
265
+ return p.search(tag)
266
+
267
+ def is_background(tag):
268
+ import re
269
+ p = re.compile(r'background|outline|light|sky|build|day|screen|tree|city')
270
+ return p.search(tag)
271
+
272
+ un_tags = ['solo']
273
+ group_list = ['groups', 'body_parts', 'attire', 'posture', 'objects', 'creatures', 'locations', 'disambiguation_pages', 'commonly_misused_tags', 'phrases', 'verbs_and_gerunds', 'subjective', 'nudity', 'sex_objects', 'sex', 'sex_acts', 'image_composition', 'artistic_license', 'text', 'year_tags', 'metatags']
274
+ keep_group_dict = {
275
+ "body": ['groups', 'body_parts'],
276
+ "dress": ['groups', 'body_parts', 'attire'],
277
+ "all": group_list,
278
+ }
279
+
280
+ def is_necessary(tag, keep_tags, group_dict):
281
+ if keep_tags == "all":
282
+ return True
283
+ elif tag in un_tags or group_dict.get(tag, "") in explicit_group:
284
+ return False
285
+ elif keep_tags == "body" and is_dressed(tag):
286
+ return False
287
+ elif is_background(tag):
288
+ return False
289
+ else:
290
+ return True
291
+
292
+ if keep_tags == "all": return input_prompt
293
+ keep_group = keep_group_dict.get(keep_tags, keep_group_dict["body"])
294
+ explicit_group = list(set(group_list) ^ set(keep_group))
295
+
296
+ tags = input_prompt.split(",") if input_prompt else []
297
+ people_tags: list[str] = []
298
+ other_tags: list[str] = []
299
+
300
+ group_dict = tag_group_dict
301
+ for tag in tags:
302
+ tag = tag.strip().replace("_", " ")
303
+ if tag in PEOPLE_TAGS:
304
+ people_tags.append(tag)
305
+ elif is_necessary(tag, keep_tags, group_dict):
306
+ other_tags.append(tag)
307
+
308
+ output_prompt = ", ".join(people_tags + other_tags)
309
+
310
+ return output_prompt
311
+
312
+
313
+ def sort_taglist(tags: list[str]):
314
+ if not tags: return []
315
+ character_tags: list[str] = []
316
+ series_tags: list[str] = []
317
+ people_tags: list[str] = []
318
+ group_list = ['groups', 'body_parts', 'attire', 'posture', 'objects', 'creatures', 'locations', 'disambiguation_pages', 'commonly_misused_tags', 'phrases', 'verbs_and_gerunds', 'subjective', 'nudity', 'sex_objects', 'sex', 'sex_acts', 'image_composition', 'artistic_license', 'text', 'year_tags', 'metatags']
319
+ group_tags = {}
320
+ other_tags: list[str] = []
321
+ rating_tags: list[str] = []
322
+
323
+ group_dict = tag_group_dict
324
+ group_set = set(group_dict.keys())
325
+ character_set = set(anime_series_dict.keys())
326
+ series_set = set(anime_series_dict.values())
327
+ rating_set = set(DANBOORU_TO_E621_RATING_MAP.keys()) | set(DANBOORU_TO_E621_RATING_MAP.values())
328
+
329
+ for tag in tags:
330
+ tag = tag.strip().replace("_", " ")
331
+ if tag in PEOPLE_TAGS:
332
+ people_tags.append(tag)
333
+ elif tag in rating_set:
334
+ rating_tags.append(tag)
335
+ elif tag in group_set:
336
+ elem = group_dict[tag]
337
+ group_tags[elem] = group_tags[elem] + [tag] if elem in group_tags else [tag]
338
+ elif tag in character_set:
339
+ character_tags.append(tag)
340
+ elif tag in series_set:
341
+ series_tags.append(tag)
342
+ else:
343
+ other_tags.append(tag)
344
+
345
+ output_group_tags: list[str] = []
346
+ for k in group_list:
347
+ output_group_tags.extend(group_tags.get(k, []))
348
+
349
+ rating_tags = [rating_tags[0]] if rating_tags else []
350
+ rating_tags = ["explicit, nsfw"] if rating_tags and rating_tags[0] == "explicit" else rating_tags
351
+
352
+ output_tags = character_tags + series_tags + people_tags + output_group_tags + other_tags + rating_tags
353
+
354
+ return output_tags
355
+
356
+
357
+ def sort_tags(tags: str):
358
+ if not tags: return ""
359
+ taglist: list[str] = []
360
+ for tag in tags.split(","):
361
+ taglist.append(tag.strip())
362
+ taglist = list(filter(lambda x: x != "", taglist))
363
+ return ", ".join(sort_taglist(taglist))
364
+
365
+
366
+ def postprocess_results(results: dict[str, float], general_threshold: float, character_threshold: float):
367
+ results = {
368
+ k: v for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)
369
+ }
370
+
371
+ rating = {}
372
+ character = {}
373
+ general = {}
374
+
375
+ for k, v in results.items():
376
+ if k.startswith("rating:"):
377
+ rating[k.replace("rating:", "")] = v
378
+ continue
379
+ elif k.startswith("character:"):
380
+ character[k.replace("character:", "")] = v
381
+ continue
382
+
383
+ general[k] = v
384
+
385
+ character = {k: v for k, v in character.items() if v >= character_threshold}
386
+ general = {k: v for k, v in general.items() if v >= general_threshold}
387
+
388
+ return rating, character, general
389
+
390
+
391
+ def gen_prompt(rating: list[str], character: list[str], general: list[str]):
392
+ people_tags: list[str] = []
393
+ other_tags: list[str] = []
394
+ rating_tag = RATING_MAP[rating[0]]
395
+
396
+ for tag in general:
397
+ if tag in PEOPLE_TAGS:
398
+ people_tags.append(tag)
399
+ else:
400
+ other_tags.append(tag)
401
+
402
+ all_tags = people_tags + other_tags
403
+
404
+ return ", ".join(all_tags)
405
+
406
+
407
+ @spaces.GPU()
408
+ def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8):
409
+ inputs = wd_processor.preprocess(image, return_tensors="pt")
410
+
411
+ outputs = wd_model(**inputs.to(wd_model.device, wd_model.dtype))
412
+ logits = torch.sigmoid(outputs.logits[0]) # take the first logits
413
+
414
+ # get probabilities
415
+ results = {
416
+ wd_model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits)
417
+ }
418
+
419
+ # rating, character, general
420
+ rating, character, general = postprocess_results(
421
+ results, general_threshold, character_threshold
422
+ )
423
+
424
+ prompt = gen_prompt(
425
+ list(rating.keys()), list(character.keys()), list(general.keys())
426
+ )
427
+
428
+ output_series_tag = ""
429
+ output_series_list = character_list_to_series_list(character.keys())
430
+ if output_series_list:
431
+ output_series_tag = output_series_list[0]
432
+ else:
433
+ output_series_tag = ""
434
+
435
+ return output_series_tag, ", ".join(character.keys()), prompt, gr.update(interactive=True),
436
+
437
+
438
+ def predict_tags_wd(image: Image.Image, input_tags: str, algo: list[str], general_threshold: float = 0.3, character_threshold: float = 0.8):
439
+ if not "Use WD Tagger" in algo and len(algo) != 0:
440
+ return "", "", input_tags, gr.update(interactive=True),
441
+ return predict_tags(image, general_threshold, character_threshold)
442
+
443
+
444
+ def compose_prompt_to_copy(character: str, series: str, general: str):
445
+ characters = character.split(",") if character else []
446
+ serieses = series.split(",") if series else []
447
+ generals = general.split(",") if general else []
448
+ tags = characters + serieses + generals
449
+ cprompt = ",".join(tags) if tags else ""
450
+ return cprompt
utils.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from dartrs.v2 import AspectRatioTag, LengthTag, RatingTag, IdentityTag
3
+
4
+
5
+ V2_ASPECT_RATIO_OPTIONS: list[AspectRatioTag] = [
6
+ "ultra_wide",
7
+ "wide",
8
+ "square",
9
+ "tall",
10
+ "ultra_tall",
11
+ ]
12
+ V2_RATING_OPTIONS: list[RatingTag] = [
13
+ "sfw",
14
+ "general",
15
+ "sensitive",
16
+ "nsfw",
17
+ "questionable",
18
+ "explicit",
19
+ ]
20
+ V2_LENGTH_OPTIONS: list[LengthTag] = [
21
+ "very_short",
22
+ "short",
23
+ "medium",
24
+ "long",
25
+ "very_long",
26
+ ]
27
+ V2_IDENTITY_OPTIONS: list[IdentityTag] = [
28
+ "none",
29
+ "lax",
30
+ "strict",
31
+ ]
32
+
33
+
34
+ # ref: https://qiita.com/tregu148/items/fccccbbc47d966dd2fc2
35
+ def gradio_copy_text(_text: None):
36
+ gr.Info("Copied!")
37
+
38
+
39
+ COPY_ACTION_JS = """\
40
+ (inputs, _outputs) => {
41
+ // inputs is the string value of the input_text
42
+ if (inputs.trim() !== "") {
43
+ navigator.clipboard.writeText(inputs);
44
+ }
45
+ }"""
v2.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import os
3
+ import torch
4
+ from typing import Callable
5
+
6
+ from dartrs.v2 import (
7
+ V2Model,
8
+ MixtralModel,
9
+ MistralModel,
10
+ compose_prompt,
11
+ LengthTag,
12
+ AspectRatioTag,
13
+ RatingTag,
14
+ IdentityTag,
15
+ )
16
+ from dartrs.dartrs import DartTokenizer
17
+ from dartrs.utils import get_generation_config
18
+
19
+
20
+ import gradio as gr
21
+ from gradio.components import Component
22
+
23
+ try:
24
+ import spaces
25
+ except ImportError:
26
+
27
+ class spaces:
28
+ def GPU(*args, **kwargs):
29
+ return lambda x: x
30
+
31
+
32
+ from output import UpsamplingOutput
33
+
34
+
35
+ HF_TOKEN = os.getenv("HF_TOKEN", None)
36
+
37
+ V2_ALL_MODELS = {
38
+ "dart-v2-moe-sft": {
39
+ "repo": "p1atdev/dart-v2-moe-sft",
40
+ "type": "sft",
41
+ "class": MixtralModel,
42
+ },
43
+ "dart-v2-sft": {
44
+ "repo": "p1atdev/dart-v2-sft",
45
+ "type": "sft",
46
+ "class": MistralModel,
47
+ },
48
+ }
49
+
50
+
51
+ def prepare_models(model_config: dict):
52
+ model_name = model_config["repo"]
53
+ tokenizer = DartTokenizer.from_pretrained(model_name, auth_token=HF_TOKEN)
54
+ model = model_config["class"].from_pretrained(model_name, auth_token=HF_TOKEN)
55
+
56
+ return {
57
+ "tokenizer": tokenizer,
58
+ "model": model,
59
+ }
60
+
61
+
62
+ def normalize_tags(tokenizer: DartTokenizer, tags: str):
63
+ """Just remove unk tokens."""
64
+ return ", ".join([tag for tag in tokenizer.tokenize(tags) if tag != "<|unk|>"])
65
+
66
+
67
+ @torch.no_grad()
68
+ def generate_tags(
69
+ model: V2Model,
70
+ tokenizer: DartTokenizer,
71
+ prompt: str,
72
+ ban_token_ids: list[int],
73
+ ):
74
+ output = model.generate(
75
+ get_generation_config(
76
+ prompt,
77
+ tokenizer=tokenizer,
78
+ temperature=1,
79
+ top_p=0.9,
80
+ top_k=100,
81
+ max_new_tokens=256,
82
+ ban_token_ids=ban_token_ids,
83
+ ),
84
+ )
85
+
86
+ return output
87
+
88
+
89
+ def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
90
+ return (
91
+ [f"1{noun}"]
92
+ + [f"{num}{noun}s" for num in range(minimum + 1, maximum + 1)]
93
+ + [f"{maximum+1}+{noun}s"]
94
+ )
95
+
96
+
97
+ PEOPLE_TAGS = (
98
+ _people_tag("girl") + _people_tag("boy") + _people_tag("other") + ["no humans"]
99
+ )
100
+
101
+
102
+ def gen_prompt_text(output: UpsamplingOutput):
103
+ # separate people tags (e.g. 1girl)
104
+ people_tags = []
105
+ other_general_tags = []
106
+
107
+ for tag in output.general_tags.split(","):
108
+ tag = tag.strip()
109
+ if tag in PEOPLE_TAGS:
110
+ people_tags.append(tag)
111
+ else:
112
+ other_general_tags.append(tag)
113
+
114
+ return ", ".join(
115
+ [
116
+ part.strip()
117
+ for part in [
118
+ *people_tags,
119
+ output.character_tags,
120
+ output.copyright_tags,
121
+ *other_general_tags,
122
+ output.upsampled_tags,
123
+ output.rating_tag,
124
+ ]
125
+ if part.strip() != ""
126
+ ]
127
+ )
128
+
129
+
130
+ def elapsed_time_format(elapsed_time: float) -> str:
131
+ return f"Elapsed: {elapsed_time:.2f} seconds"
132
+
133
+
134
+ def parse_upsampling_output(
135
+ upsampler: Callable[..., UpsamplingOutput],
136
+ ):
137
+ def _parse_upsampling_output(*args) -> tuple[str, str, dict]:
138
+ output = upsampler(*args)
139
+
140
+ return (
141
+ gen_prompt_text(output),
142
+ elapsed_time_format(output.elapsed_time),
143
+ gr.update(interactive=True),
144
+ gr.update(interactive=True),
145
+ )
146
+
147
+ return _parse_upsampling_output
148
+
149
+
150
+ class V2UI:
151
+ model_name: str | None = None
152
+ model: V2Model
153
+ tokenizer: DartTokenizer
154
+
155
+ input_components: list[Component] = []
156
+ generate_btn: gr.Button
157
+
158
+ def on_generate(
159
+ self,
160
+ model_name: str,
161
+ copyright_tags: str,
162
+ character_tags: str,
163
+ general_tags: str,
164
+ rating_tag: RatingTag,
165
+ aspect_ratio_tag: AspectRatioTag,
166
+ length_tag: LengthTag,
167
+ identity_tag: IdentityTag,
168
+ ban_tags: str,
169
+ *args,
170
+ ) -> UpsamplingOutput:
171
+ if self.model_name is None or self.model_name != model_name:
172
+ models = prepare_models(V2_ALL_MODELS[model_name])
173
+ self.model = models["model"]
174
+ self.tokenizer = models["tokenizer"]
175
+ self.model_name = model_name
176
+
177
+ # normalize tags
178
+ # copyright_tags = normalize_tags(self.tokenizer, copyright_tags)
179
+ # character_tags = normalize_tags(self.tokenizer, character_tags)
180
+ # general_tags = normalize_tags(self.tokenizer, general_tags)
181
+
182
+ ban_token_ids = self.tokenizer.encode(ban_tags.strip())
183
+
184
+ prompt = compose_prompt(
185
+ prompt=general_tags,
186
+ copyright=copyright_tags,
187
+ character=character_tags,
188
+ rating=rating_tag,
189
+ aspect_ratio=aspect_ratio_tag,
190
+ length=length_tag,
191
+ identity=identity_tag,
192
+ )
193
+
194
+ start = time.time()
195
+ upsampled_tags = generate_tags(
196
+ self.model,
197
+ self.tokenizer,
198
+ prompt,
199
+ ban_token_ids,
200
+ )
201
+ elapsed_time = time.time() - start
202
+
203
+ return UpsamplingOutput(
204
+ upsampled_tags=upsampled_tags,
205
+ copyright_tags=copyright_tags,
206
+ character_tags=character_tags,
207
+ general_tags=general_tags,
208
+ rating_tag=rating_tag,
209
+ aspect_ratio_tag=aspect_ratio_tag,
210
+ length_tag=length_tag,
211
+ identity_tag=identity_tag,
212
+ elapsed_time=elapsed_time,
213
+ )
214
+
z3de621conv.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import huggingface_hub
2
+ from PIL import Image
3
+ from pathlib import Path
4
+ import csv
5
+ import spaces
6
+
7
+ import onnxruntime as rt
8
+ e621_model_path = Path(huggingface_hub.snapshot_download('toynya/Z3D-E621-Convnext'))
9
+ e621_model_session = rt.InferenceSession(e621_model_path / 'model.onnx', providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
10
+ with open(e621_model_path / 'tags-selected.csv', mode='r', encoding='utf-8') as file:
11
+ csv_reader = csv.DictReader(file)
12
+ e621_model_tags = [row['name'].strip() for row in csv_reader]
13
+
14
+
15
+ def prepare_image_e621(image: Image.Image, target_size: int):
16
+ import numpy as np
17
+ # Pad image to square
18
+ image_shape = image.size
19
+ max_dim = max(image_shape)
20
+ pad_left = (max_dim - image_shape[0]) // 2
21
+ pad_top = (max_dim - image_shape[1]) // 2
22
+
23
+ padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
24
+ padded_image.paste(image, (pad_left, pad_top))
25
+
26
+ # Resize
27
+ if max_dim != target_size:
28
+ padded_image = padded_image.resize((target_size, target_size), Image.BICUBIC)
29
+
30
+ # Convert to numpy array
31
+ # Based on the ONNX graph, the model appears to expect inputs in the range of 0-255
32
+ image_array = np.asarray(padded_image, dtype=np.float32)
33
+
34
+ # Convert PIL-native RGB to BGR
35
+ image_array = image_array[:, :, ::-1]
36
+
37
+ return np.expand_dims(image_array, axis=0)
38
+
39
+
40
+ @spaces.GPU
41
+ def predict_e621(image: Image.Image, threshold: float = 0.3):
42
+ image_array = prepare_image_e621(image, 448)
43
+
44
+ image_array = prepare_image_e621(image, 448)
45
+ input_name = 'input_1:0'
46
+ output_name = 'predictions_sigmoid'
47
+
48
+ result = e621_model_session.run([output_name], {input_name: image_array})
49
+ result = result[0][0]
50
+
51
+ scores = {e621_model_tags[i]: result[i] for i in range(len(result))}
52
+ predicted_tags = [tag for tag, score in scores.items() if score > threshold]
53
+ tag_string = ', '.join(predicted_tags).replace("_", " ")
54
+
55
+ return tag_string
56
+
57
+
58
+ def predict_tags_e621(image: Image.Image, input_tags: str, algo: list[str], threshold: float = 0.3):
59
+ def to_list(s):
60
+ return [x.strip() for x in s.split(",") if not s == ""]
61
+
62
+ def list_uniq(l):
63
+ return sorted(set(l), key=l.index)
64
+
65
+ if not "Use Z3D-E621-Convnext" in algo:
66
+ return input_tags
67
+ tag_list = list_uniq(to_list(input_tags) + to_list(predict_e621(image)))
68
+ return ", ".join(tag_list)