John6666 commited on
Commit
862d3ae
Β·
verified Β·
1 Parent(s): d9301f2

Upload 12 files

Browse files
Files changed (8) hide show
  1. README.md +1 -1
  2. app.py +39 -54
  3. fl2basepromptgen.py +6 -4
  4. fl2sd3longcap.py +4 -2
  5. promptenhancer.py +6 -3
  6. requirements.txt +1 -1
  7. tagger.py +62 -19
  8. utils.py +5 -0
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: πŸƒπŸ“¦
4
  colorFrom: blue
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 4.37.2
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
4
  colorFrom: blue
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 4.39.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
app.py CHANGED
@@ -1,17 +1,9 @@
1
- from PIL import Image
2
  import gradio as gr
 
3
 
4
-
5
- from v2 import (
6
- V2_ALL_MODELS,
7
- )
8
  from utils import (
9
  gradio_copy_text,
10
  COPY_ACTION_JS,
11
- V2_ASPECT_RATIO_OPTIONS,
12
- V2_RATING_OPTIONS,
13
- V2_LENGTH_OPTIONS,
14
- V2_IDENTITY_OPTIONS
15
  )
16
  from tagger import (
17
  predict_tags_wd,
@@ -20,21 +12,22 @@ from tagger import (
20
  insert_recom_prompt,
21
  compose_prompt_to_copy,
22
  translate_prompt,
 
23
  )
24
- from fl2sd3longcap import (
25
- predict_tags_fl2_sd3,
26
- )
27
- from fl2basepromptgen import (
28
- predict_tags_fl2_base_prompt_gen,
29
- )
30
  from promptenhancer import prompt_enhancer
31
 
32
-
33
  def description_ui():
34
  gr.Markdown(
35
  """
36
  ## Prompt Enhancer with WD Tagger & SD3 Long Captioner
37
  (Image =>) Prompt => Upsampled longer prompt
 
 
 
 
 
38
  - It's a mod. Original Spaces: p1atdev's [WD Tagger with πŸ€— transformers](https://huggingface.co/spaces/p1atdev/wd-tagger-transformers),\
39
  gokaygokay's [Prompt-Enhancer](https://huggingface.co/spaces/gokaygokay/Prompt-Enhancer) /\
40
  [Florence-2-SD3-Captioner](https://huggingface.co/spaces/gokaygokay/Florence-2-SD3-Captioner).
@@ -46,62 +39,54 @@ def description_ui():
46
  """
47
  )
48
 
49
-
50
  def main():
51
-
52
  with gr.Blocks() as ui:
53
  description_ui()
54
-
55
- with gr.Row():
56
- with gr.Column(scale=2):
57
- with gr.Group():
58
- input_image = gr.Image(label="Input image", type="pil", sources=["upload", "clipboard"], height=256)
59
- with gr.Accordion(label="Advanced options", open=False):
60
- general_threshold = gr.Slider(label="Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.01, interactive=True)
61
- character_threshold = gr.Slider(label="Character threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.01, interactive=True)
62
- input_tag_type = gr.Radio(label="Convert tags to", info="danbooru for Animagine, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru")
63
- recom_prompt = gr.Radio(label="Insert reccomended prompt", choices=["None", "Animagine", "Pony"], value="None", interactive=True)
64
- keep_tags = gr.Radio(label="Remove tags leaving only the following", choices=["body", "dress", "all"], value="all")
65
- image_algorithms = gr.CheckboxGroup(["Use WD Tagger", "Use Florence-2-SD3-Long-Captioner", "Use Florence-2-base-PromptGen"], label="Algorithms", value=["Use WD Tagger", "Use Florence-2-SD3-Long-Captioner"])
66
- generate_from_image_btn = gr.Button(value="GENERATE TAGS FROM IMAGE", size="lg", variant="primary")
67
-
68
- with gr.Group():
69
  input_character = gr.Textbox(label="Character tags", placeholder="hatsune miku")
70
  input_copyright = gr.Textbox(label="Copyright tags", placeholder="vocaloid")
71
- input_general = gr.TextArea(label="General tags", lines=4, placeholder="1girl, ...", value="")
72
- input_tags_to_copy = gr.Textbox(value="", visible=False)
73
- copy_input_btn = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
74
- translate_input_prompt_button = gr.Button(value="Translate prompt to English", size="sm", variant="secondary")
75
- prompt_enhancer_model = gr.Radio(["Medium", "Long"], label="Model Choice", value="Long", info="Enhance your prompts with Medium or Long answers")
76
-
77
- with gr.Accordion(label="Advanced options", open=False, visible=False):
78
- tag_type = gr.Radio(label="Output tag conversion", info="danbooru for Animagine, e621 for Pony.", choices=["danbooru", "e621"], value="e621", visible=False)
79
- input_rating = gr.Radio(label="Rating", choices=list(V2_RATING_OPTIONS), value="explicit", visible=False)
80
- input_aspect_ratio = gr.Radio(label="Aspect ratio", info="The aspect ratio of the image.", choices=list(V2_ASPECT_RATIO_OPTIONS), value="square", visible=False)
81
- input_length = gr.Radio(label="Length", info="The total length of the tags.", choices=list(V2_LENGTH_OPTIONS), value="very_long", visible=False)
82
- input_identity = gr.Radio(label="Keep identity", info="How strictly to keep the identity of the character or subject. If you specify the detail of subject in the prompt, you should choose `strict`. Otherwise, choose `none` or `lax`. `none` is very creative but sometimes ignores the input prompt.", choices=list(V2_IDENTITY_OPTIONS), value="lax", visible=False)
83
- input_ban_tags = gr.Textbox(label="Ban tags", info="Tags to ban from the output.", placeholder="alternate costumen, ...", value="censored", visible=False)
84
- model_name = gr.Dropdown(label="Model", choices=list(V2_ALL_MODELS.keys()), value=list(V2_ALL_MODELS.keys())[0], visible=False)
85
- dummy_np = gr.Textbox(label="Negative prompt", value="", visible=False)
86
- recom_animagine = gr.Textbox(label="Animagine reccomended prompt", value="Animagine", visible=False)
87
- recom_pony = gr.Textbox(label="Pony reccomended prompt", value="Pony", visible=False)
88
-
89
  generate_btn = gr.Button(value="GENERATE TAGS", size="lg", variant="primary")
90
-
91
  with gr.Group():
92
  output_text = gr.TextArea(label="Output tags", interactive=False, show_copy_button=True)
93
  copy_btn = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
94
- elapsed_time_md = gr.Markdown(label="Elapsed time", value="", visible=False)
95
-
96
  with gr.Group():
97
  output_text_pony = gr.TextArea(label="Output tags (Pony e621 style)", interactive=False, show_copy_button=True)
98
  copy_btn_pony = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
 
 
 
99
 
100
  translate_input_prompt_button.click(translate_prompt, [input_general], [input_general], queue=False)
101
  translate_input_prompt_button.click(translate_prompt, [input_character], [input_character], queue=False)
102
  translate_input_prompt_button.click(translate_prompt, [input_copyright], [input_copyright], queue=False)
103
 
104
  generate_from_image_btn.click(
 
 
105
  predict_tags_wd,
106
  [input_image, input_general, image_algorithms, general_threshold, character_threshold],
107
  [input_copyright, input_character, input_general, copy_input_btn],
 
 
1
  import gradio as gr
2
+ import spaces
3
 
 
 
 
 
4
  from utils import (
5
  gradio_copy_text,
6
  COPY_ACTION_JS,
 
 
 
 
7
  )
8
  from tagger import (
9
  predict_tags_wd,
 
12
  insert_recom_prompt,
13
  compose_prompt_to_copy,
14
  translate_prompt,
15
+ select_random_character,
16
  )
17
+ from fl2sd3longcap import predict_tags_fl2_sd3
18
+ from fl2basepromptgen import predict_tags_fl2_base_prompt_gen
 
 
 
 
19
  from promptenhancer import prompt_enhancer
20
 
 
21
  def description_ui():
22
  gr.Markdown(
23
  """
24
  ## Prompt Enhancer with WD Tagger & SD3 Long Captioner
25
  (Image =>) Prompt => Upsampled longer prompt
26
+ """
27
+ )
28
+ def description_ui2():
29
+ gr.Markdown(
30
+ """
31
  - It's a mod. Original Spaces: p1atdev's [WD Tagger with πŸ€— transformers](https://huggingface.co/spaces/p1atdev/wd-tagger-transformers),\
32
  gokaygokay's [Prompt-Enhancer](https://huggingface.co/spaces/gokaygokay/Prompt-Enhancer) /\
33
  [Florence-2-SD3-Captioner](https://huggingface.co/spaces/gokaygokay/Florence-2-SD3-Captioner).
 
39
  """
40
  )
41
 
 
42
  def main():
 
43
  with gr.Blocks() as ui:
44
  description_ui()
45
+ with gr.Column():
46
+ with gr.Group():
47
+ input_image = gr.Image(label="Input image", type="pil", sources=["upload", "clipboard"], height=256)
48
+ with gr.Accordion(label="Advanced options", open=False):
49
+ general_threshold = gr.Slider(label="Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.01, interactive=True)
50
+ character_threshold = gr.Slider(label="Character threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.01, interactive=True)
51
+ input_tag_type = gr.Radio(label="Convert tags to", info="danbooru for Animagine, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru")
52
+ recom_prompt = gr.Radio(label="Insert reccomended prompt", choices=["None", "Animagine", "Pony"], value="None", interactive=True)
53
+ keep_tags = gr.Radio(label="Remove tags leaving only the following", choices=["body", "dress", "all"], value="all")
54
+ image_algorithms = gr.CheckboxGroup(["Use WD Tagger", "Use Florence-2-SD3-Long-Captioner", "Use Florence-2-base-PromptGen"], label="Algorithms", value=["Use WD Tagger", "Use Florence-2-SD3-Long-Captioner"])
55
+ generate_from_image_btn = gr.Button(value="GENERATE TAGS FROM IMAGE", size="lg", variant="primary")
56
+ with gr.Group():
57
+ with gr.Row():
 
 
58
  input_character = gr.Textbox(label="Character tags", placeholder="hatsune miku")
59
  input_copyright = gr.Textbox(label="Copyright tags", placeholder="vocaloid")
60
+ random_character = gr.Button(value="Random character 🎲", size="sm")
61
+ input_general = gr.TextArea(label="General tags", lines=4, placeholder="1girl, ...", value="")
62
+ input_tags_to_copy = gr.Textbox(value="", visible=False)
63
+ copy_input_btn = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
64
+ translate_input_prompt_button = gr.Button(value="Translate prompt to English", size="sm", variant="secondary")
65
+ prompt_enhancer_model = gr.Radio(["Medium", "Long"], label="Model Choice", value="Long", info="Enhance your prompts with Medium or Long answers")
66
+ with gr.Accordion(label="Advanced options", open=False, visible=False):
67
+ tag_type = gr.Radio(label="Output tag conversion", info="danbooru for Animagine, e621 for Pony.", choices=["danbooru", "e621"], value="e621", visible=False)
68
+ dummy_np = gr.Textbox(label="Negative prompt", value="", visible=False)
69
+ recom_animagine = gr.Textbox(label="Animagine reccomended prompt", value="Animagine", visible=False)
70
+ recom_pony = gr.Textbox(label="Pony reccomended prompt", value="Pony", visible=False)
 
 
 
 
 
 
 
71
  generate_btn = gr.Button(value="GENERATE TAGS", size="lg", variant="primary")
72
+ with gr.Row():
73
  with gr.Group():
74
  output_text = gr.TextArea(label="Output tags", interactive=False, show_copy_button=True)
75
  copy_btn = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
 
 
76
  with gr.Group():
77
  output_text_pony = gr.TextArea(label="Output tags (Pony e621 style)", interactive=False, show_copy_button=True)
78
  copy_btn_pony = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
79
+ description_ui2()
80
+
81
+ random_character.click(select_random_character, [input_copyright, input_character], [input_copyright, input_character], queue=False)
82
 
83
  translate_input_prompt_button.click(translate_prompt, [input_general], [input_general], queue=False)
84
  translate_input_prompt_button.click(translate_prompt, [input_character], [input_character], queue=False)
85
  translate_input_prompt_button.click(translate_prompt, [input_copyright], [input_copyright], queue=False)
86
 
87
  generate_from_image_btn.click(
88
+ lambda: ("", "", ""), None, [input_copyright, input_character, input_general], queue=False,
89
+ ).success(
90
  predict_tags_wd,
91
  [input_image, input_general, image_algorithms, general_threshold, character_threshold],
92
  [input_copyright, input_character, input_general, copy_input_btn],
fl2basepromptgen.py CHANGED
@@ -1,11 +1,13 @@
1
  from transformers import AutoProcessor, AutoModelForCausalLM
2
  import spaces
3
  from PIL import Image
 
4
 
5
- #import subprocess
6
- #subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
7
 
8
- fl_model = AutoModelForCausalLM.from_pretrained('MiaoshouAI/Florence-2-base-PromptGen', trust_remote_code=True).eval()
 
9
  fl_processor = AutoProcessor.from_pretrained('MiaoshouAI/Florence-2-base-PromptGen', trust_remote_code=True)
10
 
11
 
@@ -18,7 +20,7 @@ def fl_run(image):
18
  if image.mode != "RGB":
19
  image = image.convert("RGB")
20
 
21
- inputs = fl_processor(text=prompt, images=image, return_tensors="pt")
22
  generated_ids = fl_model.generate(
23
  input_ids=inputs["input_ids"],
24
  pixel_values=inputs["pixel_values"],
 
1
  from transformers import AutoProcessor, AutoModelForCausalLM
2
  import spaces
3
  from PIL import Image
4
+ import torch
5
 
6
+ import subprocess
7
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
8
 
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ fl_model = AutoModelForCausalLM.from_pretrained('MiaoshouAI/Florence-2-base-PromptGen', trust_remote_code=True).to(device).eval()
11
  fl_processor = AutoProcessor.from_pretrained('MiaoshouAI/Florence-2-base-PromptGen', trust_remote_code=True)
12
 
13
 
 
20
  if image.mode != "RGB":
21
  image = image.convert("RGB")
22
 
23
+ inputs = fl_processor(text=prompt, images=image, return_tensors="pt").to(device)
24
  generated_ids = fl_model.generate(
25
  input_ids=inputs["input_ids"],
26
  pixel_values=inputs["pixel_values"],
fl2sd3longcap.py CHANGED
@@ -2,11 +2,13 @@ from transformers import AutoProcessor, AutoModelForCausalLM
2
  import spaces
3
  import re
4
  from PIL import Image
 
5
 
6
  import subprocess
7
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
8
 
9
- fl_model = AutoModelForCausalLM.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True).eval()
 
10
  fl_processor = AutoProcessor.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True)
11
 
12
 
@@ -48,7 +50,7 @@ def fl_run_example(image):
48
  if image.mode != "RGB":
49
  image = image.convert("RGB")
50
 
51
- inputs = fl_processor(text=prompt, images=image, return_tensors="pt")
52
  generated_ids = fl_model.generate(
53
  input_ids=inputs["input_ids"],
54
  pixel_values=inputs["pixel_values"],
 
2
  import spaces
3
  import re
4
  from PIL import Image
5
+ import torch
6
 
7
  import subprocess
8
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
9
 
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ fl_model = AutoModelForCausalLM.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True).to(device).eval()
12
  fl_processor = AutoProcessor.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True)
13
 
14
 
 
50
  if image.mode != "RGB":
51
  image = image.convert("RGB")
52
 
53
+ inputs = fl_processor(text=prompt, images=image, return_tensors="pt").to(device)
54
  generated_ids = fl_model.generate(
55
  input_ids=inputs["input_ids"],
56
  pixel_values=inputs["pixel_values"],
promptenhancer.py CHANGED
@@ -2,10 +2,13 @@ import spaces
2
  import gradio as gr
3
  from transformers import pipeline
4
  import re
 
 
 
5
 
6
  def load_models():
7
- enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device=0)
8
- enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=0)
9
  return enhancer_medium, enhancer_long
10
 
11
  enhancer_medium, enhancer_long = load_models()
@@ -39,4 +42,4 @@ def prompt_enhancer(character: str, series: str, general: str, model_choice: str
39
  output = enhance_prompt(cprompt, model_choice)
40
  prompt = cprompt + ", " + output
41
 
42
- return prompt, gr.update(interactive=True), gr.update(interactive=True),
 
2
  import gradio as gr
3
  from transformers import pipeline
4
  import re
5
+ import torch
6
+
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
  def load_models():
10
+ enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device=device)
11
+ enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
12
  return enhancer_medium, enhancer_long
13
 
14
  enhancer_medium, enhancer_long = load_models()
 
42
  output = enhance_prompt(cprompt, model_choice)
43
  prompt = cprompt + ", " + output
44
 
45
+ return prompt, gr.update(interactive=True), gr.update(interactive=True)
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- torch
2
  torchvision
3
  accelerate
4
  transformers
 
1
+ torch==2.2.0
2
  torchvision
3
  accelerate
4
  transformers
tagger.py CHANGED
@@ -1,12 +1,13 @@
1
  from PIL import Image
2
  import torch
3
  import gradio as gr
4
- import spaces # ZERO GPU
5
-
6
  from transformers import (
7
  AutoImageProcessor,
8
  AutoModelForImageClassification,
9
  )
 
 
10
 
11
  WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
12
  WD_MODEL_NAME = WD_MODEL_NAMES[0]
@@ -30,12 +31,15 @@ PEOPLE_TAGS = (
30
 
31
 
32
  RATING_MAP = {
 
33
  "general": "safe",
34
  "sensitive": "sensitive",
35
  "questionable": "nsfw",
36
  "explicit": "explicit, nsfw",
37
  }
38
  DANBOORU_TO_E621_RATING_MAP = {
 
 
39
  "safe": "rating_safe",
40
  "sensitive": "rating_safe",
41
  "nsfw": "rating_explicit",
@@ -49,6 +53,34 @@ DANBOORU_TO_E621_RATING_MAP = {
49
  }
50
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def to_list(s):
53
  return [x.strip() for x in s.split(",") if not s == ""]
54
 
@@ -62,9 +94,16 @@ def list_uniq(l):
62
 
63
 
64
  def load_dict_from_csv(filename):
65
- with open(filename, 'r', encoding="utf-8") as f:
66
- lines = f.readlines()
67
  dict = {}
 
 
 
 
 
 
 
 
 
68
  for line in lines:
69
  parts = line.strip().split(',')
70
  dict[parts[0]] = parts[1]
@@ -94,7 +133,8 @@ def character_list_to_series_list(character_list):
94
 
95
 
96
  def select_random_character(series: str, character: str):
97
- from random import randrange
 
98
  character_list = list(anime_series_dict.keys())
99
  character = character_list[randrange(len(character_list) - 1)]
100
  series = anime_series_dict.get(character.split(",")[0].strip(), "")
@@ -104,7 +144,7 @@ def select_random_character(series: str, character: str):
104
  def danbooru_to_e621(dtag, e621_dict):
105
  def d_to_e(match, e621_dict):
106
  dtag = match.group(0)
107
- etag = e621_dict.get(dtag.strip().replace("_", " "), "")
108
  if etag:
109
  return etag
110
  else:
@@ -112,7 +152,6 @@ def danbooru_to_e621(dtag, e621_dict):
112
 
113
  import re
114
  tag = re.sub(r'[\w ]+', lambda wrapper: d_to_e(wrapper, e621_dict), dtag, 2)
115
-
116
  return tag
117
 
118
 
@@ -128,7 +167,7 @@ def convert_danbooru_to_e621_prompt(input_prompt: str = "", prompt_type: str = "
128
 
129
  e621_dict = danbooru_to_e621_dict
130
  for tag in tags:
131
- tag = tag.strip().replace("_", " ")
132
  tag = danbooru_to_e621(tag, e621_dict)
133
  if tag in PEOPLE_TAGS:
134
  people_tags.append(tag)
@@ -156,6 +195,7 @@ def translate_prompt(prompt: str = ""):
156
  translated_prompt = translator.translate(prompt, src='auto', dest='en').text
157
  return translated_prompt
158
  except Exception as e:
 
159
  return prompt
160
 
161
  def is_japanese(s):
@@ -188,6 +228,7 @@ def translate_prompt_to_ja(prompt: str = ""):
188
  translated_prompt = translator.translate(prompt, src='en', dest='ja').text
189
  return translated_prompt
190
  except Exception as e:
 
191
  return prompt
192
 
193
  def is_japanese(s):
@@ -213,7 +254,7 @@ def translate_prompt_to_ja(prompt: str = ""):
213
  def tags_to_ja(itag, dict):
214
  def t_to_j(match, dict):
215
  tag = match.group(0)
216
- ja = dict.get(tag.strip().replace("_", " "), "")
217
  if ja:
218
  return ja
219
  else:
@@ -232,7 +273,7 @@ def convert_tags_to_ja(input_prompt: str = ""):
232
  tags_to_ja_dict = load_dict_from_csv('all_tags_ja_ext.csv')
233
  dict = tags_to_ja_dict
234
  for tag in tags:
235
- tag = tag.strip().replace("_", " ")
236
  tag = tags_to_ja(tag, dict)
237
  out_tags.append(tag)
238
 
@@ -242,13 +283,13 @@ def convert_tags_to_ja(input_prompt: str = ""):
242
  enable_auto_recom_prompt = True
243
 
244
 
245
- animagine_ps = to_list("anime artwork, anime style, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres")
246
  animagine_nps = to_list("lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
247
- pony_ps = to_list("source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres")
248
- pony_nps = to_list("source_pony, source_furry, source_cartoon, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends")
249
  other_ps = to_list("anime artwork, anime style, studio anime, highly detailed, cinematic photo, 35mm photograph, film, bokeh, professional, 4k, highly detailed")
250
  other_nps = to_list("photo, deformed, black and white, realism, disfigured, low contrast, drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly")
251
- default_ps = to_list("score_9, score_8_up, score_7_up, highly detailed, masterpiece, best quality, very aesthetic, absurdres")
252
  default_nps = to_list("score_6, score_5, score_4, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
253
  def insert_recom_prompt(prompt: str = "", neg_prompt: str = "", type: str = "None"):
254
  global enable_auto_recom_prompt
@@ -281,6 +322,7 @@ def insert_recom_prompt(prompt: str = "", neg_prompt: str = "", type: str = "Non
281
  def load_model_prompt_dict():
282
  import json
283
  dict = {}
 
284
  try:
285
  with open('model_dict.json', encoding='utf-8') as f:
286
  dict = json.load(f)
@@ -359,7 +401,7 @@ def remove_specific_prompt(input_prompt: str = "", keep_tags: str = "all"):
359
 
360
  group_dict = tag_group_dict
361
  for tag in tags:
362
- tag = tag.strip().replace("_", " ")
363
  if tag in PEOPLE_TAGS:
364
  people_tags.append(tag)
365
  elif is_necessary(tag, keep_tags, group_dict):
@@ -387,7 +429,7 @@ def sort_taglist(tags: list[str]):
387
  rating_set = set(DANBOORU_TO_E621_RATING_MAP.keys()) | set(DANBOORU_TO_E621_RATING_MAP.values())
388
 
389
  for tag in tags:
390
- tag = tag.strip().replace("_", " ")
391
  if tag in PEOPLE_TAGS:
392
  people_tags.append(tag)
393
  elif tag in rating_set:
@@ -488,12 +530,13 @@ def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_t
488
  output_series_tag = output_series_list[0]
489
  else:
490
  output_series_tag = ""
491
- return output_series_tag, ", ".join(character.keys()), prompt, gr.update(interactive=True),
492
 
493
 
494
- def predict_tags_wd(image: Image.Image, input_tags: str, algo: list[str], general_threshold: float = 0.3, character_threshold: float = 0.8):
 
495
  if not "Use WD Tagger" in algo and len(algo) != 0:
496
- return "", "", input_tags, gr.update(interactive=True),
497
  return predict_tags(image, general_threshold, character_threshold)
498
 
499
 
 
1
  from PIL import Image
2
  import torch
3
  import gradio as gr
4
+ import spaces
 
5
  from transformers import (
6
  AutoImageProcessor,
7
  AutoModelForImageClassification,
8
  )
9
+ from pathlib import Path
10
+
11
 
12
  WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
13
  WD_MODEL_NAME = WD_MODEL_NAMES[0]
 
31
 
32
 
33
  RATING_MAP = {
34
+ "sfw": "safe",
35
  "general": "safe",
36
  "sensitive": "sensitive",
37
  "questionable": "nsfw",
38
  "explicit": "explicit, nsfw",
39
  }
40
  DANBOORU_TO_E621_RATING_MAP = {
41
+ "sfw": "rating_safe",
42
+ "general": "rating_safe",
43
  "safe": "rating_safe",
44
  "sensitive": "rating_safe",
45
  "nsfw": "rating_explicit",
 
53
  }
54
 
55
 
56
+ # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
57
+ kaomojis = [
58
+ "0_0",
59
+ "(o)_(o)",
60
+ "+_+",
61
+ "+_-",
62
+ "._.",
63
+ "<o>_<o>",
64
+ "<|>_<|>",
65
+ "=_=",
66
+ ">_<",
67
+ "3_3",
68
+ "6_9",
69
+ ">_o",
70
+ "@_@",
71
+ "^_^",
72
+ "o_o",
73
+ "u_u",
74
+ "x_x",
75
+ "|_|",
76
+ "||_||",
77
+ ]
78
+
79
+
80
+ def replace_underline(x: str):
81
+ return x.strip().replace("_", " ") if x not in kaomojis else x.strip()
82
+
83
+
84
  def to_list(s):
85
  return [x.strip() for x in s.split(",") if not s == ""]
86
 
 
94
 
95
 
96
  def load_dict_from_csv(filename):
 
 
97
  dict = {}
98
+ if not Path(filename).exists():
99
+ if Path('./tagger/', filename).exists(): filename = str(Path('./tagger/', filename))
100
+ else: return dict
101
+ try:
102
+ with open(filename, 'r', encoding="utf-8") as f:
103
+ lines = f.readlines()
104
+ except Exception:
105
+ print(f"Failed to open dictionary file: {filename}")
106
+ return dict
107
  for line in lines:
108
  parts = line.strip().split(',')
109
  dict[parts[0]] = parts[1]
 
133
 
134
 
135
  def select_random_character(series: str, character: str):
136
+ from random import seed, randrange
137
+ seed()
138
  character_list = list(anime_series_dict.keys())
139
  character = character_list[randrange(len(character_list) - 1)]
140
  series = anime_series_dict.get(character.split(",")[0].strip(), "")
 
144
  def danbooru_to_e621(dtag, e621_dict):
145
  def d_to_e(match, e621_dict):
146
  dtag = match.group(0)
147
+ etag = e621_dict.get(replace_underline(dtag), "")
148
  if etag:
149
  return etag
150
  else:
 
152
 
153
  import re
154
  tag = re.sub(r'[\w ]+', lambda wrapper: d_to_e(wrapper, e621_dict), dtag, 2)
 
155
  return tag
156
 
157
 
 
167
 
168
  e621_dict = danbooru_to_e621_dict
169
  for tag in tags:
170
+ tag = replace_underline(tag)
171
  tag = danbooru_to_e621(tag, e621_dict)
172
  if tag in PEOPLE_TAGS:
173
  people_tags.append(tag)
 
195
  translated_prompt = translator.translate(prompt, src='auto', dest='en').text
196
  return translated_prompt
197
  except Exception as e:
198
+ print(e)
199
  return prompt
200
 
201
  def is_japanese(s):
 
228
  translated_prompt = translator.translate(prompt, src='en', dest='ja').text
229
  return translated_prompt
230
  except Exception as e:
231
+ print(e)
232
  return prompt
233
 
234
  def is_japanese(s):
 
254
  def tags_to_ja(itag, dict):
255
  def t_to_j(match, dict):
256
  tag = match.group(0)
257
+ ja = dict.get(replace_underline(tag), "")
258
  if ja:
259
  return ja
260
  else:
 
273
  tags_to_ja_dict = load_dict_from_csv('all_tags_ja_ext.csv')
274
  dict = tags_to_ja_dict
275
  for tag in tags:
276
+ tag = replace_underline(tag)
277
  tag = tags_to_ja(tag, dict)
278
  out_tags.append(tag)
279
 
 
283
  enable_auto_recom_prompt = True
284
 
285
 
286
+ animagine_ps = to_list("masterpiece, best quality, very aesthetic, absurdres")
287
  animagine_nps = to_list("lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
288
+ pony_ps = to_list("score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres")
289
+ pony_nps = to_list("source_pony, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends")
290
  other_ps = to_list("anime artwork, anime style, studio anime, highly detailed, cinematic photo, 35mm photograph, film, bokeh, professional, 4k, highly detailed")
291
  other_nps = to_list("photo, deformed, black and white, realism, disfigured, low contrast, drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly")
292
+ default_ps = to_list("highly detailed, masterpiece, best quality, very aesthetic, absurdres")
293
  default_nps = to_list("score_6, score_5, score_4, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
294
  def insert_recom_prompt(prompt: str = "", neg_prompt: str = "", type: str = "None"):
295
  global enable_auto_recom_prompt
 
322
  def load_model_prompt_dict():
323
  import json
324
  dict = {}
325
+ path = 'model_dict.json' if Path('model_dict.json').exists() else './tagger/model_dict.json'
326
  try:
327
  with open('model_dict.json', encoding='utf-8') as f:
328
  dict = json.load(f)
 
401
 
402
  group_dict = tag_group_dict
403
  for tag in tags:
404
+ tag = replace_underline(tag)
405
  if tag in PEOPLE_TAGS:
406
  people_tags.append(tag)
407
  elif is_necessary(tag, keep_tags, group_dict):
 
429
  rating_set = set(DANBOORU_TO_E621_RATING_MAP.keys()) | set(DANBOORU_TO_E621_RATING_MAP.values())
430
 
431
  for tag in tags:
432
+ tag = replace_underline(tag)
433
  if tag in PEOPLE_TAGS:
434
  people_tags.append(tag)
435
  elif tag in rating_set:
 
530
  output_series_tag = output_series_list[0]
531
  else:
532
  output_series_tag = ""
533
+ return output_series_tag, ", ".join(character.keys()), prompt, gr.update(interactive=True)
534
 
535
 
536
+ def predict_tags_wd(image: Image.Image, input_tags: str, algo: list[str], general_threshold: float = 0.3,
537
+ character_threshold: float = 0.8, input_series: str = "", input_character: str = ""):
538
  if not "Use WD Tagger" in algo and len(algo) != 0:
539
+ return input_series, input_character, input_tags, gr.update(interactive=True)
540
  return predict_tags(image, general_threshold, character_threshold)
541
 
542
 
utils.py CHANGED
@@ -43,3 +43,8 @@ COPY_ACTION_JS = """\
43
  navigator.clipboard.writeText(inputs);
44
  }
45
  }"""
 
 
 
 
 
 
43
  navigator.clipboard.writeText(inputs);
44
  }
45
  }"""
46
+
47
+
48
+ def gradio_copy_prompt(prompt: str):
49
+ gr.Info("Copied!")
50
+ return prompt