John6666 commited on
Commit
1731cc9
1 Parent(s): 240160f

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +63 -6
  2. fl2sd3longcap.py +75 -0
  3. requirements.txt +2 -1
app.py CHANGED
@@ -410,20 +410,24 @@ from utils import (
410
  V2_IDENTITY_OPTIONS
411
  )
412
  from tagger import (
413
- predict_tags,
414
  convert_danbooru_to_e621_prompt,
415
  remove_specific_prompt,
416
  insert_recom_prompt,
417
  compose_prompt_to_copy,
418
  translate_prompt,
419
  )
 
 
 
420
  def description_ui():
421
  gr.Markdown(
422
  """
423
  ## Danbooru Tags Transformer V2 Demo with WD Tagger
424
  (Image =>) Prompt => Upsampled longer prompt
425
  - Mod of p1atdev's [Danbooru Tags Transformer V2 Demo](https://huggingface.co/spaces/p1atdev/danbooru-tags-transformer-v2) and [WD Tagger with 🤗 transformers](https://huggingface.co/spaces/p1atdev/wd-tagger-transformers).
426
- - Models: p1atdev's [wd-swinv2-tagger-v3-hf](https://huggingface.co/p1atdev/wd-swinv2-tagger-v3-hf), [dart-v2-moe-sft](https://huggingface.co/p1atdev/dart-v2-moe-sft)
 
427
  """
428
  )
429
  ## END MOD
@@ -861,6 +865,7 @@ with gr.Blocks(theme="NoCrypt/miku", elem_id="main", css=CSS) as app:
861
  tag_type_gui = gr.Radio(label="Convert tags to", info="danbooru for Animagine, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru")
862
  recom_prompt_gui = gr.Radio(label="Insert reccomended prompt", choices=["None", "Animagine", "Pony"], value="None", interactive=True)
863
  keep_tags_gui = gr.Radio(label="Remove tags leaving only the following", choices=["body", "dress", "all"], value="all")
 
864
  generate_from_image_btn_gui = gr.Button(value="GENERATE TAGS FROM IMAGE", size="lg", variant="primary")
865
  with gr.Group():
866
  prompt_gui = gr.Textbox(lines=6, placeholder="1girl, solo, ...", label="Prompt", show_copy_button=True)
@@ -1275,6 +1280,49 @@ with gr.Blocks(theme="NoCrypt/miku", elem_id="main", css=CSS) as app:
1275
  "Classic",
1276
  "Nearest",
1277
  ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1278
  [
1279
  "yoshida yuuko, machikado mazoku, 1girl, solo, demon horns,horns, school uniform, long hair, open mouth, skirt, demon girl, ahoge, shiny, shiny hair, anime artwork",
1280
  "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
@@ -1453,14 +1501,18 @@ with gr.Blocks(theme="NoCrypt/miku", elem_id="main", css=CSS) as app:
1453
  optimization_gui.change(set_optimization, [optimization_gui, steps_gui, cfg_gui, sampler_gui, clip_skip_gui, lora1_gui, lora_scale_1_gui], [steps_gui, cfg_gui, sampler_gui, clip_skip_gui, lora1_gui, lora_scale_1_gui])
1454
 
1455
  generate_from_image_btn_gui.click(
1456
- predict_tags,
1457
- inputs=[input_image_gui, general_threshold_gui, character_threshold_gui],
1458
  outputs=[
1459
  series_dbt,
1460
  character_dbt,
1461
  prompt_gui,
1462
  copy_button_dbt,
1463
  ],
 
 
 
 
1464
  ).then(
1465
  compose_prompt_to_copy, inputs=[character_dbt, series_dbt, prompt_gui], outputs=[prompt_gui]
1466
  ).then(
@@ -1639,6 +1691,7 @@ with gr.Blocks(theme="NoCrypt/miku", elem_id="main", css=CSS) as app:
1639
  character_threshold = gr.Slider(label="Character threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.01, interactive=True)
1640
  input_tag_type = gr.Radio(label="Convert tags to", info="danbooru for Animagine, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru")
1641
  recom_prompt = gr.Radio(label="Insert reccomended prompt", choices=["None", "Animagine", "Pony"], value="None", interactive=True)
 
1642
  keep_tags = gr.Radio(label="Remove tags leaving only the following", choices=["body", "dress", "all"], value="all")
1643
  generate_from_image_btn = gr.Button(value="GENERATE TAGS FROM IMAGE", size="lg", variant="primary")
1644
 
@@ -1691,14 +1744,18 @@ with gr.Blocks(theme="NoCrypt/miku", elem_id="main", css=CSS) as app:
1691
  translate_input_prompt_button.click(translate_prompt, inputs=[input_copyright], outputs=[input_copyright])
1692
 
1693
  generate_from_image_btn.click(
1694
- predict_tags,
1695
- inputs=[input_image, general_threshold, character_threshold],
1696
  outputs=[
1697
  input_copyright,
1698
  input_character,
1699
  input_general,
1700
  copy_input_btn,
1701
  ],
 
 
 
 
1702
  ).then(
1703
  remove_specific_prompt, inputs=[input_general, keep_tags], outputs=[input_general],
1704
  ).then(
 
410
  V2_IDENTITY_OPTIONS
411
  )
412
  from tagger import (
413
+ predict_tags_wd,
414
  convert_danbooru_to_e621_prompt,
415
  remove_specific_prompt,
416
  insert_recom_prompt,
417
  compose_prompt_to_copy,
418
  translate_prompt,
419
  )
420
+ from fl2sd3longcap import (
421
+ predict_tags_fl2_sd3,
422
+ )
423
  def description_ui():
424
  gr.Markdown(
425
  """
426
  ## Danbooru Tags Transformer V2 Demo with WD Tagger
427
  (Image =>) Prompt => Upsampled longer prompt
428
  - Mod of p1atdev's [Danbooru Tags Transformer V2 Demo](https://huggingface.co/spaces/p1atdev/danbooru-tags-transformer-v2) and [WD Tagger with 🤗 transformers](https://huggingface.co/spaces/p1atdev/wd-tagger-transformers).
429
+ - Models: p1atdev's [wd-swinv2-tagger-v3-hf](https://huggingface.co/p1atdev/wd-swinv2-tagger-v3-hf), [dart-v2-moe-sft](https://huggingface.co/p1atdev/dart-v2-moe-sft)\
430
+ , gokaygokay's [Florence-2-SD3-Captioner](https://huggingface.co/gokaygokay/Florence-2-SD3-Captioner)
431
  """
432
  )
433
  ## END MOD
 
865
  tag_type_gui = gr.Radio(label="Convert tags to", info="danbooru for Animagine, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru")
866
  recom_prompt_gui = gr.Radio(label="Insert reccomended prompt", choices=["None", "Animagine", "Pony"], value="None", interactive=True)
867
  keep_tags_gui = gr.Radio(label="Remove tags leaving only the following", choices=["body", "dress", "all"], value="all")
868
+ image_algorithms = gr.CheckboxGroup(["Use WD Tagger", "Use Florence-2-SD3-Long-Captioner"], label="Algorithms", value=["Use WD Tagger"])
869
  generate_from_image_btn_gui = gr.Button(value="GENERATE TAGS FROM IMAGE", size="lg", variant="primary")
870
  with gr.Group():
871
  prompt_gui = gr.Textbox(lines=6, placeholder="1girl, solo, ...", label="Prompt", show_copy_button=True)
 
1280
  "Classic",
1281
  "Nearest",
1282
  ],
1283
+ [
1284
+ "1girl, oomuro sakurako, yuru yuri, official art, anime style, school uniform, masterpiece, best quality, very aesthetic, absurdres",
1285
+ "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
1286
+ 1,
1287
+ 40,
1288
+ 7.5,
1289
+ True,
1290
+ -1,
1291
+ None,
1292
+ 1.0,
1293
+ None,
1294
+ 1.0,
1295
+ None,
1296
+ 1.0,
1297
+ None,
1298
+ 1.0,
1299
+ None,
1300
+ 1.0,
1301
+ "Euler",
1302
+ 1024,
1303
+ 1024,
1304
+ "Raelina/Rae-Diffusion-XL-V2",
1305
+ "vaes/sdxl.vae.safetensors", # vae
1306
+ "txt2img",
1307
+ None, # img conttol
1308
+ "Canny", # preprocessor
1309
+ 512, # preproc resolution
1310
+ 1024, # img resolution
1311
+ None, # Style prompt
1312
+ None, # Style json
1313
+ None, # img Mask
1314
+ 0.35, # strength
1315
+ 100, # low th canny
1316
+ 200, # high th canny
1317
+ 0.1, # value mstd
1318
+ 0.1, # distance mstd
1319
+ 1.0, # cn scale
1320
+ 0., # cn start
1321
+ 1., # cn end
1322
+ False, # ti
1323
+ "Classic",
1324
+ "Nearest",
1325
+ ],
1326
  [
1327
  "yoshida yuuko, machikado mazoku, 1girl, solo, demon horns,horns, school uniform, long hair, open mouth, skirt, demon girl, ahoge, shiny, shiny hair, anime artwork",
1328
  "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
 
1501
  optimization_gui.change(set_optimization, [optimization_gui, steps_gui, cfg_gui, sampler_gui, clip_skip_gui, lora1_gui, lora_scale_1_gui], [steps_gui, cfg_gui, sampler_gui, clip_skip_gui, lora1_gui, lora_scale_1_gui])
1502
 
1503
  generate_from_image_btn_gui.click(
1504
+ predict_tags_wd,
1505
+ inputs=[input_image_gui, prompt_gui, image_algorithms, general_threshold_gui, character_threshold_gui],
1506
  outputs=[
1507
  series_dbt,
1508
  character_dbt,
1509
  prompt_gui,
1510
  copy_button_dbt,
1511
  ],
1512
+ ).then(
1513
+ predict_tags_fl2_sd3,
1514
+ inputs=[input_image_gui, prompt_gui, image_algorithms],
1515
+ outputs=[prompt_gui],
1516
  ).then(
1517
  compose_prompt_to_copy, inputs=[character_dbt, series_dbt, prompt_gui], outputs=[prompt_gui]
1518
  ).then(
 
1691
  character_threshold = gr.Slider(label="Character threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.01, interactive=True)
1692
  input_tag_type = gr.Radio(label="Convert tags to", info="danbooru for Animagine, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru")
1693
  recom_prompt = gr.Radio(label="Insert reccomended prompt", choices=["None", "Animagine", "Pony"], value="None", interactive=True)
1694
+ image_algorithms = gr.CheckboxGroup(["Use WD Tagger", "Use Florence-2-SD3-Long-Captioner"], label="Algorithms", value=["Use WD Tagger"])
1695
  keep_tags = gr.Radio(label="Remove tags leaving only the following", choices=["body", "dress", "all"], value="all")
1696
  generate_from_image_btn = gr.Button(value="GENERATE TAGS FROM IMAGE", size="lg", variant="primary")
1697
 
 
1744
  translate_input_prompt_button.click(translate_prompt, inputs=[input_copyright], outputs=[input_copyright])
1745
 
1746
  generate_from_image_btn.click(
1747
+ predict_tags_wd,
1748
+ inputs=[input_image, input_general, image_algorithms, general_threshold, character_threshold],
1749
  outputs=[
1750
  input_copyright,
1751
  input_character,
1752
  input_general,
1753
  copy_input_btn,
1754
  ],
1755
+ ).then(
1756
+ predict_tags_fl2_sd3,
1757
+ inputs=[input_image, input_general, image_algorithms],
1758
+ outputs=[input_general],
1759
  ).then(
1760
  remove_specific_prompt, inputs=[input_general, keep_tags], outputs=[input_general],
1761
  ).then(
fl2sd3longcap.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoProcessor, AutoModelForCausalLM
2
+ import spaces
3
+ import re
4
+ from PIL import Image
5
+ import torch
6
+
7
+ import subprocess
8
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
9
+
10
+
11
+ fl_model = AutoModelForCausalLM.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', torch_dtype=torch.float16, attn_implementation="flash_attention_2", trust_remote_code=True).to("cuda").eval()
12
+ fl_processor = AutoProcessor.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', torch_dtype=torch.float16, attn_implementation="flash_attention_2", trust_remote_code=True)
13
+
14
+
15
+ def fl_modify_caption(caption: str) -> str:
16
+ """
17
+ Removes specific prefixes from captions if present, otherwise returns the original caption.
18
+ Args:
19
+ caption (str): A string containing a caption.
20
+ Returns:
21
+ str: The caption with the prefix removed if it was present, or the original caption.
22
+ """
23
+ # Define the prefixes to remove
24
+ prefix_substrings = [
25
+ ('captured from ', ''),
26
+ ('captured at ', '')
27
+ ]
28
+
29
+ # Create a regex pattern to match any of the prefixes
30
+ pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings])
31
+ replacers = {opening.lower(): replacer for opening, replacer in prefix_substrings}
32
+
33
+ # Function to replace matched prefix with its corresponding replacement
34
+ def replace_fn(match):
35
+ return replacers[match.group(0).lower()]
36
+
37
+ # Apply the regex to the caption
38
+ modified_caption = re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE)
39
+
40
+ # If the caption was modified, return the modified version; otherwise, return the original
41
+ return modified_caption if modified_caption != caption else caption
42
+
43
+ @spaces.GPU
44
+ def fl_run_example(image):
45
+ task_prompt = "<DESCRIPTION>"
46
+ prompt = task_prompt + "Describe this image in great detail."
47
+
48
+ # Ensure the image is in RGB mode
49
+ if image.mode != "RGB":
50
+ image = image.convert("RGB")
51
+
52
+ inputs = fl_processor(text=prompt, images=image, return_tensors="pt").to("cuda")
53
+ generated_ids = fl_model.generate(
54
+ input_ids=inputs["input_ids"],
55
+ pixel_values=inputs["pixel_values"],
56
+ max_new_tokens=1024,
57
+ num_beams=3
58
+ )
59
+ generated_text = fl_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
60
+ parsed_answer = fl_processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
61
+ return fl_modify_caption(parsed_answer["<DESCRIPTION>"])
62
+
63
+
64
+ def predict_tags_fl2_sd3(image: Image.Image, input_tags: str, algo: list[str]):
65
+ def to_list(s):
66
+ return [x.strip() for x in s.split(",") if not s == ""]
67
+
68
+ def list_uniq(l):
69
+ return sorted(set(l), key=l.index)
70
+
71
+ if not "Use Florence-2-SD3-Long-Captioner" in algo:
72
+ return input_tags
73
+ tag_list = list_uniq(to_list(input_tags) + to_list(fl_run_example(image) + ", "))
74
+ tag_list.remove("")
75
+ return ", ".join(tag_list)
requirements.txt CHANGED
@@ -12,4 +12,5 @@ huggingface_hub
12
  diffusers
13
  httpx==0.13.3
14
  httpcore
15
- googletrans==4.0.0rc1
 
 
12
  diffusers
13
  httpx==0.13.3
14
  httpcore
15
+ googletrans==4.0.0rc1
16
+ timm