John6666 commited on
Commit
27879cc
1 Parent(s): 9e9e011

Upload 10 files

Browse files
Files changed (4) hide show
  1. app.py +89 -97
  2. multit2i.py +88 -36
  3. tagger/fl2sd3longcap.py +10 -4
  4. tagger/tagger.py +12 -5
app.py CHANGED
@@ -1,92 +1,94 @@
1
  import gradio as gr
2
  from model import models
3
- from multit2i import (
4
- load_models, infer_fn, infer_rand_fn, save_gallery,
5
  change_model, warm_model, get_model_info_md, loaded_models,
6
  get_positive_prefix, get_positive_suffix, get_negative_prefix, get_negative_suffix,
7
- get_recom_prompt_type, set_recom_prompt_preset, get_tag_type,
8
- )
9
- from tagger.tagger import (
10
- predict_tags_wd, remove_specific_prompt, convert_danbooru_to_e621_prompt,
11
- insert_recom_prompt, compose_prompt_to_copy,
12
- )
13
  from tagger.fl2sd3longcap import predict_tags_fl2_sd3
14
  from tagger.v2 import V2_ALL_MODELS, v2_random_prompt
15
- from tagger.utils import (
16
- V2_ASPECT_RATIO_OPTIONS, V2_RATING_OPTIONS,
17
- V2_LENGTH_OPTIONS, V2_IDENTITY_OPTIONS,
18
- )
19
 
20
-
21
- max_images = 8
22
  load_models(models)
23
 
24
  css = """
25
  .model_info { text-align: center; }
26
- .output { width=112px; height=112px; !important; }
27
- .gallery { width=100%; min_height=768px; !important; }
28
  """
29
 
30
  with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", fill_width=True, css=css) as demo:
31
- with gr.Column():
32
- with gr.Group():
33
- model_name = gr.Dropdown(label="Select Model", choices=list(loaded_models.keys()), value=list(loaded_models.keys())[0], allow_custom_value=True)
34
- model_info = gr.Markdown(value=get_model_info_md(list(loaded_models.keys())[0]), elem_classes="model_info")
35
- with gr.Group():
36
- with gr.Accordion("Prompt from Image File", open=False):
37
- tagger_image = gr.Image(label="Input image", type="pil", sources=["upload", "clipboard"], height=256)
38
- with gr.Accordion(label="Advanced options", open=False):
39
- tagger_general_threshold = gr.Slider(label="Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.01, interactive=True)
40
- tagger_character_threshold = gr.Slider(label="Character threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.01, interactive=True)
41
- tagger_tag_type = gr.Radio(label="Convert tags to", info="danbooru for common, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru")
42
- tagger_recom_prompt = gr.Radio(label="Insert reccomended prompt", choices=["None", "Animagine", "Pony"], value="None", interactive=True)
43
- tagger_keep_tags = gr.Radio(label="Remove tags leaving only the following", choices=["body", "dress", "all"], value="all")
44
- tagger_algorithms = gr.CheckboxGroup(["Use WD Tagger", "Use Florence-2-SD3-Long-Captioner"], label="Algorithms", value=["Use WD Tagger"])
45
- tagger_generate_from_image = gr.Button(value="Generate Tags from Image")
46
- with gr.Row():
47
- v2_character = gr.Textbox(label="Character", placeholder="hatsune miku", scale=2)
48
- v2_series = gr.Textbox(label="Series", placeholder="vocaloid", scale=2)
49
- random_prompt = gr.Button(value="Extend Prompt 🎲", size="sm", scale=1)
50
- clear_prompt = gr.Button(value="Clear Prompt 🗑️", size="sm", scale=1)
51
- prompt = gr.Text(label="Prompt", lines=2, max_lines=8, placeholder="1girl, solo, ...", show_copy_button=True)
52
- neg_prompt = gr.Text(label="Negative Prompt", lines=1, max_lines=8, placeholder="", visible=False)
53
- with gr.Accordion("Advanced options", open=False):
54
- width = gr.Number(label="Width", info="If 0, the default value is used.", maximum=1216, step=32, value=0)
55
- height = gr.Number(label="Height", info="If 0, the default value is used.", maximum=1216, step=32, value=0)
56
- steps = gr.Number(label="Number of inference steps", info="If 0, the default value is used.", maximum=100, step=1, value=0)
57
- cfg = gr.Number(label="Guidance scale", info="If 0, the default value is used.", maximum=30.0, step=0.1, value=0)
58
- with gr.Accordion("Recommended Prompt", open=False):
59
- recom_prompt_preset = gr.Radio(label="Set Presets", choices=get_recom_prompt_type(), value="Common")
60
  with gr.Row():
61
- positive_prefix = gr.CheckboxGroup(label="Use Positive Prefix", choices=get_positive_prefix(), value=[])
62
- positive_suffix = gr.CheckboxGroup(label="Use Positive Suffix", choices=get_positive_suffix(), value=["Common"])
63
- negative_prefix = gr.CheckboxGroup(label="Use Negative Prefix", choices=get_negative_prefix(), value=[])
64
- negative_suffix = gr.CheckboxGroup(label="Use Negative Suffix", choices=get_negative_suffix(), value=["Common"])
65
- with gr.Accordion("Prompt Transformer", open=False):
66
- v2_rating = gr.Radio(label="Rating", choices=list(V2_RATING_OPTIONS), value="sfw")
67
- v2_aspect_ratio = gr.Radio(label="Aspect ratio", info="The aspect ratio of the image.", choices=list(V2_ASPECT_RATIO_OPTIONS), value="square", visible=False)
68
- v2_length = gr.Radio(label="Length", info="The total length of the tags.", choices=list(V2_LENGTH_OPTIONS), value="long")
69
- v2_identity = gr.Radio(label="Keep identity", info="How strictly to keep the identity of the character or subject. If you specify the detail of subject in the prompt, you should choose `strict`. Otherwise, choose `none` or `lax`. `none` is very creative but sometimes ignores the input prompt.", choices=list(V2_IDENTITY_OPTIONS), value="lax")
70
- v2_ban_tags = gr.Textbox(label="Ban tags", info="Tags to ban from the output.", placeholder="alternate costumen, ...", value="censored")
71
- v2_tag_type = gr.Radio(label="Tag Type", info="danbooru for common, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru", visible=False)
72
- v2_model = gr.Dropdown(label="Model", choices=list(V2_ALL_MODELS.keys()), value=list(V2_ALL_MODELS.keys())[0])
73
- v2_copy = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
74
- image_num = gr.Slider(label="Number of images", minimum=1, maximum=max_images, value=1, step=1, interactive=True, scale=1)
75
- with gr.Row():
76
- run_button = gr.Button("Generate Image", scale=6)
77
- random_button = gr.Button("Random Model 🎲", scale=3)
78
- stop_button = gr.Button('Stop', interactive=False, scale=1)
79
- with gr.Column():
80
- with gr.Group():
 
 
81
  with gr.Row():
82
- output = [gr.Image(label='', elem_classes="output", type="filepath", format=".png",
83
- show_download_button=True, show_share_button=False, show_label=False,
84
- interactive=False, min_width=80, visible=True) for _ in range(max_images)]
85
- with gr.Group():
86
- results = gr.Gallery(label="Gallery", elem_classes="gallery", interactive=False, show_download_button=True, show_share_button=False,
87
- container=True, format="png", object_fit="cover", columns=2, rows=2)
88
- image_files = gr.Files(label="Download", interactive=False)
89
- clear_results = gr.Button("Clear Gallery / Download 🗑️")
 
 
 
 
 
 
 
 
 
90
  with gr.Column():
91
  examples = gr.Examples(
92
  examples = [
@@ -115,13 +117,13 @@ with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", fill_width=True, css=css) as demo:
115
  img_i = gr.Number(i, visible=False)
116
  image_num.change(lambda i, n: gr.update(visible = (i < n)), [img_i, image_num], o, show_api=False)
117
  gen_event = gr.on(triggers=[run_button.click, prompt.submit],
118
- fn=lambda i, n, m, t1, t2, n1, n2, n3, n4, l1, l2, l3, l4: infer_fn(m, t1, t2, n1, n2, n3, n4, l1, l2, l3, l4) if (i < n) else None,
119
- inputs=[img_i, image_num, model_name, prompt, neg_prompt, height, width, steps, cfg,
120
  positive_prefix, positive_suffix, negative_prefix, negative_suffix],
121
  outputs=[o], queue=True, show_api=False)
122
  gen_event2 = gr.on(triggers=[random_button.click],
123
- fn=lambda i, n, m, t1, t2, n1, n2, n3, n4, l1, l2, l3, l4: infer_rand_fn(m, t1, t2, n1, n2, n3, n4, l1, l2, l3, l4) if (i < n) else None,
124
- inputs=[img_i, image_num, model_name, prompt, neg_prompt, height, width, steps, cfg,
125
  positive_prefix, positive_suffix, negative_prefix, negative_suffix],
126
  outputs=[o], queue=True, show_api=False)
127
  o.change(save_gallery, [o, results], [results, image_files], show_api=False)
@@ -135,29 +137,19 @@ with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", fill_width=True, css=css) as demo:
135
  random_prompt.click(
136
  v2_random_prompt, [prompt, v2_series, v2_character, v2_rating, v2_aspect_ratio, v2_length,
137
  v2_identity, v2_ban_tags, v2_model], [prompt, v2_series, v2_character], show_api=False,
138
- ).success(
139
- get_tag_type, [positive_prefix, positive_suffix, negative_prefix, negative_suffix], [v2_tag_type], queue=False, show_api=False
140
- ).success(
141
- convert_danbooru_to_e621_prompt, [prompt, v2_tag_type], [prompt], queue=False, show_api=False,
142
- )
143
- tagger_generate_from_image.click(
144
- lambda: ("", "", ""), None, [v2_series, v2_character, prompt], queue=False, show_api=False,
145
  ).success(
146
  predict_tags_wd,
147
  [tagger_image, prompt, tagger_algorithms, tagger_general_threshold, tagger_character_threshold],
148
  [v2_series, v2_character, prompt, v2_copy],
149
  show_api=False,
150
- ).success(
151
- predict_tags_fl2_sd3, [tagger_image, prompt, tagger_algorithms], [prompt], show_api=False,
152
- ).success(
153
- remove_specific_prompt, [prompt, tagger_keep_tags], [prompt], queue=False, show_api=False,
154
- ).success(
155
- convert_danbooru_to_e621_prompt, [prompt, tagger_tag_type], [prompt], queue=False, show_api=False,
156
- ).success(
157
- insert_recom_prompt, [prompt, neg_prompt, tagger_recom_prompt], [prompt, neg_prompt], queue=False, show_api=False,
158
- ).success(
159
- compose_prompt_to_copy, [v2_character, v2_series, prompt], [prompt], queue=False, show_api=False,
160
- )
161
 
162
- demo.queue()
163
- demo.launch()
 
1
  import gradio as gr
2
  from model import models
3
+ from multit2i import (load_models, infer_fn, infer_rand_fn, save_gallery,
 
4
  change_model, warm_model, get_model_info_md, loaded_models,
5
  get_positive_prefix, get_positive_suffix, get_negative_prefix, get_negative_suffix,
6
+ get_recom_prompt_type, set_recom_prompt_preset, get_tag_type)
7
+ from tagger.tagger import (predict_tags_wd, remove_specific_prompt, convert_danbooru_to_e621_prompt,
8
+ insert_recom_prompt, compose_prompt_to_copy)
 
 
 
9
  from tagger.fl2sd3longcap import predict_tags_fl2_sd3
10
  from tagger.v2 import V2_ALL_MODELS, v2_random_prompt
11
+ from tagger.utils import (V2_ASPECT_RATIO_OPTIONS, V2_RATING_OPTIONS,
12
+ V2_LENGTH_OPTIONS, V2_IDENTITY_OPTIONS)
 
 
13
 
14
+ max_images = 6
15
+ MAX_SEED = 2**32-1
16
  load_models(models)
17
 
18
  css = """
19
  .model_info { text-align: center; }
20
+ .output { width=112px; height=112px; max_width=112px; max_height=112px; !important; }
21
+ .gallery { min_width=512px; min_height=512px; max_height=1024px; !important; }
22
  """
23
 
24
  with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", fill_width=True, css=css) as demo:
25
+ with gr.Row():
26
+ with gr.Column(scale=10):
27
+ with gr.Group():
28
+ with gr.Accordion("Prompt from Image File", open=False):
29
+ tagger_image = gr.Image(label="Input image", type="pil", sources=["upload", "clipboard"], height=256)
30
+ with gr.Accordion(label="Advanced options", open=False):
31
+ with gr.Row():
32
+ tagger_general_threshold = gr.Slider(label="Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.01, interactive=True)
33
+ tagger_character_threshold = gr.Slider(label="Character threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.01, interactive=True)
34
+ tagger_tag_type = gr.Radio(label="Convert tags to", info="danbooru for common, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru")
35
+ with gr.Row():
36
+ tagger_recom_prompt = gr.Radio(label="Insert reccomended prompt", choices=["None", "Animagine", "Pony"], value="None", interactive=True)
37
+ tagger_keep_tags = gr.Radio(label="Remove tags leaving only the following", choices=["body", "dress", "all"], value="all")
38
+ tagger_algorithms = gr.CheckboxGroup(["Use WD Tagger", "Use Florence-2-SD3-Long-Captioner"], label="Algorithms", value=["Use WD Tagger"])
39
+ tagger_generate_from_image = gr.Button(value="Generate Tags from Image", variant="secondary")
40
+ with gr.Accordion("Prompt Transformer", open=False):
41
+ with gr.Row():
42
+ v2_rating = gr.Radio(label="Rating", choices=list(V2_RATING_OPTIONS), value="sfw")
43
+ v2_aspect_ratio = gr.Radio(label="Aspect ratio", info="The aspect ratio of the image.", choices=list(V2_ASPECT_RATIO_OPTIONS), value="square", visible=False)
44
+ v2_length = gr.Radio(label="Length", info="The total length of the tags.", choices=list(V2_LENGTH_OPTIONS), value="long")
45
+ with gr.Row():
46
+ v2_identity = gr.Radio(label="Keep identity", info="How strictly to keep the identity of the character or subject. If you specify the detail of subject in the prompt, you should choose `strict`. Otherwise, choose `none` or `lax`. `none` is very creative but sometimes ignores the input prompt.", choices=list(V2_IDENTITY_OPTIONS), value="lax")
47
+ v2_ban_tags = gr.Textbox(label="Ban tags", info="Tags to ban from the output.", placeholder="alternate costumen, ...", value="censored")
48
+ v2_tag_type = gr.Radio(label="Tag Type", info="danbooru for common, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru", visible=False)
49
+ v2_model = gr.Dropdown(label="Model", choices=list(V2_ALL_MODELS.keys()), value=list(V2_ALL_MODELS.keys())[0])
50
+ v2_copy = gr.Button(value="Copy to clipboard", variant="secondary", size="sm", interactive=False)
 
 
 
51
  with gr.Row():
52
+ v2_character = gr.Textbox(label="Character", placeholder="hatsune miku", scale=2)
53
+ v2_series = gr.Textbox(label="Series", placeholder="vocaloid", scale=2)
54
+ random_prompt = gr.Button(value="Extend Prompt 🎲", variant="secondary", size="sm", scale=1)
55
+ clear_prompt = gr.Button(value="Clear Prompt 🗑️", variant="secondary", size="sm", scale=1)
56
+ prompt = gr.Text(label="Prompt", lines=2, max_lines=8, placeholder="1girl, solo, ...", show_copy_button=True)
57
+ with gr.Accordion("Advanced options", open=False):
58
+ neg_prompt = gr.Text(label="Negative Prompt", lines=1, max_lines=8, placeholder="")
59
+ with gr.Row():
60
+ width = gr.Slider(label="Width", info="If 0, the default value is used.", maximum=1216, step=32, value=0)
61
+ height = gr.Slider(label="Height", info="If 0, the default value is used.", maximum=1216, step=32, value=0)
62
+ with gr.Row():
63
+ steps = gr.Slider(label="Number of inference steps", info="If 0, the default value is used.", maximum=100, step=1, value=0)
64
+ cfg = gr.Slider(label="Guidance scale", info="If 0, the default value is used.", maximum=30.0, step=0.1, value=0)
65
+ seed = gr.Slider(label="Seed", info="Randomize Seed if -1.", minimum=-1, maximum=MAX_SEED, step=1, value=-1)
66
+ recom_prompt_preset = gr.Radio(label="Set Presets", choices=get_recom_prompt_type(), value="Common")
67
+ with gr.Row():
68
+ positive_prefix = gr.CheckboxGroup(label="Use Positive Prefix", choices=get_positive_prefix(), value=[])
69
+ positive_suffix = gr.CheckboxGroup(label="Use Positive Suffix", choices=get_positive_suffix(), value=["Common"])
70
+ negative_prefix = gr.CheckboxGroup(label="Use Negative Prefix", choices=get_negative_prefix(), value=[])
71
+ negative_suffix = gr.CheckboxGroup(label="Use Negative Suffix", choices=get_negative_suffix(), value=["Common"])
72
+
73
+ image_num = gr.Slider(label="Number of images", minimum=1, maximum=max_images, value=1, step=1, interactive=True, scale=1)
74
  with gr.Row():
75
+ run_button = gr.Button("Generate Image", variant="primary", scale=6)
76
+ random_button = gr.Button("Random Model 🎲", variant="secondary", scale=3)
77
+ stop_button = gr.Button('Stop', variant="stop", interactive=False, scale=1)
78
+ with gr.Group():
79
+ model_name = gr.Dropdown(label="Select Model", choices=list(loaded_models.keys()), value=list(loaded_models.keys())[0], allow_custom_value=True)
80
+ model_info = gr.Markdown(value=get_model_info_md(list(loaded_models.keys())[0]), elem_classes="model_info")
81
+ with gr.Column(scale=10):
82
+ with gr.Group():
83
+ with gr.Row():
84
+ output = [gr.Image(label='', elem_classes="output", type="filepath", format="png",
85
+ show_download_button=True, show_share_button=False, show_label=False,
86
+ interactive=False, min_width=80, visible=True) for _ in range(max_images)]
87
+ with gr.Group():
88
+ results = gr.Gallery(label="Gallery", elem_classes="gallery", interactive=False, show_download_button=True, show_share_button=False,
89
+ container=True, format="png", object_fit="cover", columns=2, rows=2)
90
+ image_files = gr.Files(label="Download", interactive=False)
91
+ clear_results = gr.Button("Clear Gallery / Download 🗑️", variant="secondary")
92
  with gr.Column():
93
  examples = gr.Examples(
94
  examples = [
 
117
  img_i = gr.Number(i, visible=False)
118
  image_num.change(lambda i, n: gr.update(visible = (i < n)), [img_i, image_num], o, show_api=False)
119
  gen_event = gr.on(triggers=[run_button.click, prompt.submit],
120
+ fn=lambda i, n, m, t1, t2, n1, n2, n3, n4, n5, l1, l2, l3, l4: infer_fn(m, t1, t2, n1, n2, n3, n4, n5, l1, l2, l3, l4) if (i < n) else None,
121
+ inputs=[img_i, image_num, model_name, prompt, neg_prompt, height, width, steps, cfg, seed,
122
  positive_prefix, positive_suffix, negative_prefix, negative_suffix],
123
  outputs=[o], queue=True, show_api=False)
124
  gen_event2 = gr.on(triggers=[random_button.click],
125
+ fn=lambda i, n, m, t1, t2, n1, n2, n3, n4, n5, l1, l2, l3, l4: infer_rand_fn(m, t1, t2, n1, n2, n3, n4, n5, l1, l2, l3, l4) if (i < n) else None,
126
+ inputs=[img_i, image_num, model_name, prompt, neg_prompt, height, width, steps, cfg, seed,
127
  positive_prefix, positive_suffix, negative_prefix, negative_suffix],
128
  outputs=[o], queue=True, show_api=False)
129
  o.change(save_gallery, [o, results], [results, image_files], show_api=False)
 
137
  random_prompt.click(
138
  v2_random_prompt, [prompt, v2_series, v2_character, v2_rating, v2_aspect_ratio, v2_length,
139
  v2_identity, v2_ban_tags, v2_model], [prompt, v2_series, v2_character], show_api=False,
140
+ ).success(get_tag_type, [positive_prefix, positive_suffix, negative_prefix, negative_suffix], [v2_tag_type], queue=False, show_api=False
141
+ ).success(convert_danbooru_to_e621_prompt, [prompt, v2_tag_type], [prompt], queue=False, show_api=False)
142
+ tagger_generate_from_image.click(lambda: ("", "", ""), None, [v2_series, v2_character, prompt], queue=False, show_api=False,
 
 
 
 
143
  ).success(
144
  predict_tags_wd,
145
  [tagger_image, prompt, tagger_algorithms, tagger_general_threshold, tagger_character_threshold],
146
  [v2_series, v2_character, prompt, v2_copy],
147
  show_api=False,
148
+ ).success(predict_tags_fl2_sd3, [tagger_image, prompt, tagger_algorithms], [prompt], show_api=False,
149
+ ).success(remove_specific_prompt, [prompt, tagger_keep_tags], [prompt], queue=False, show_api=False,
150
+ ).success(convert_danbooru_to_e621_prompt, [prompt, tagger_tag_type], [prompt], queue=False, show_api=False,
151
+ ).success(insert_recom_prompt, [prompt, neg_prompt, tagger_recom_prompt], [prompt, neg_prompt], queue=False, show_api=False,
152
+ ).success(compose_prompt_to_copy, [v2_character, v2_series, prompt], [prompt], queue=False, show_api=False)
 
 
 
 
 
 
153
 
154
+ demo.queue(default_concurrency_limit=200, max_size=200)
155
+ demo.launch(max_threads=400)
multit2i.py CHANGED
@@ -3,8 +3,10 @@ import asyncio
3
  from threading import RLock
4
  from pathlib import Path
5
  from huggingface_hub import InferenceClient
 
6
 
7
 
 
8
  server_timeout = 600
9
  inference_timeout = 300
10
 
@@ -31,22 +33,43 @@ def is_repo_name(s):
31
  return re.fullmatch(r'^[^/]+?/[^/]+?$', s)
32
 
33
 
34
- def find_model_list(author: str="", tags: list[str]=[], not_tag="", sort: str="last_modified", limit: int=30):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  from huggingface_hub import HfApi
36
- api = HfApi()
37
  default_tags = ["diffusers"]
38
  if not sort: sort = "last_modified"
 
39
  models = []
40
  try:
41
- model_infos = api.list_models(author=author, pipeline_tag="text-to-image",
42
- tags=list_uniq(default_tags + tags), cardData=True, sort=sort, limit=limit * 5)
43
  except Exception as e:
44
  print(f"Error: Failed to list models.")
45
  print(e)
46
  return models
47
  for model in model_infos:
48
- if not model.private and not model.gated:
49
- if not_tag and not_tag in model.tags: continue
 
50
  models.append(model.id)
51
  if len(models) == limit: break
52
  return models
@@ -54,23 +77,24 @@ def find_model_list(author: str="", tags: list[str]=[], not_tag="", sort: str="l
54
 
55
  def get_t2i_model_info_dict(repo_id: str):
56
  from huggingface_hub import HfApi
57
- api = HfApi()
58
  info = {"md": "None"}
59
  try:
60
  if not is_repo_name(repo_id) or not api.repo_exists(repo_id=repo_id): return info
61
- model = api.model_info(repo_id=repo_id)
62
  except Exception as e:
63
  print(f"Error: Failed to get {repo_id}'s info.")
64
  print(e)
65
  return info
66
- if model.private or model.gated: return info
67
  try:
68
  tags = model.tags
69
  except Exception as e:
70
  print(e)
71
  return info
72
  if not 'diffusers' in model.tags: return info
73
- if 'diffusers:StableDiffusionXLPipeline' in tags: info["ver"] = "SDXL"
 
74
  elif 'diffusers:StableDiffusionPipeline' in tags: info["ver"] = "SD1.5"
75
  elif 'diffusers:StableDiffusion3Pipeline' in tags: info["ver"] = "SD3"
76
  else: info["ver"] = "Other"
@@ -118,20 +142,23 @@ def save_gallery(image_path: str | None, images: list[tuple] | None):
118
 
119
  # https://github.com/gradio-app/gradio/blob/main/gradio/external.py
120
  # https://huggingface.co/docs/huggingface_hub/package_reference/inference_client
121
- def load_from_model(model_name: str, hf_token: str = None):
 
122
  import httpx
123
  import huggingface_hub
124
- from gradio.exceptions import ModelNotFoundError
125
  model_url = f"https://huggingface.co/{model_name}"
126
  api_url = f"https://api-inference.huggingface.co/models/{model_name}"
127
  print(f"Fetching model from: {model_url}")
128
 
129
- headers = {"Authorization": f"Bearer {hf_token}"} if hf_token is not None else {}
130
  response = httpx.request("GET", api_url, headers=headers)
131
  if response.status_code != 200:
132
  raise ModelNotFoundError(
133
  f"Could not find model: {model_name}. If it is a private or gated model, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `hf_token` parameter."
134
  )
 
 
135
  headers["X-Wait-For-Model"] = "true"
136
  client = huggingface_hub.InferenceClient(model=model_name, headers=headers,
137
  token=hf_token, timeout=server_timeout)
@@ -140,7 +167,14 @@ def load_from_model(model_name: str, hf_token: str = None):
140
  fn = client.text_to_image
141
 
142
  def query_huggingface_inference_endpoints(*data, **kwargs):
143
- return fn(*data, **kwargs)
 
 
 
 
 
 
 
144
 
145
  interface_info = {
146
  "fn": query_huggingface_inference_endpoints,
@@ -156,7 +190,7 @@ def load_model(model_name: str):
156
  global model_info_dict
157
  if model_name in loaded_models.keys(): return loaded_models[model_name]
158
  try:
159
- loaded_models[model_name] = load_from_model(model_name)
160
  print(f"Loaded: {model_name}")
161
  except Exception as e:
162
  if model_name in loaded_models.keys(): del loaded_models[model_name]
@@ -179,12 +213,12 @@ def load_model_api(model_name: str):
179
  if model_name in loaded_models.keys(): return loaded_models[model_name]
180
  try:
181
  client = InferenceClient(timeout=5)
182
- status = client.get_model_status(model_name)
183
  if status is None or status.framework != "diffusers" or status.state not in ["Loadable", "Loaded"]:
184
  print(f"Failed to load by API: {model_name}")
185
  return None
186
  else:
187
- loaded_models[model_name] = InferenceClient(model_name, timeout=server_timeout)
188
  print(f"Loaded by API: {model_name}")
189
  except Exception as e:
190
  if model_name in loaded_models.keys(): del loaded_models[model_name]
@@ -329,49 +363,58 @@ def warm_model(model_name: str):
329
 
330
  # https://huggingface.co/docs/api-inference/detailed_parameters
331
  # https://huggingface.co/docs/huggingface_hub/package_reference/inference_client
332
- def infer_body(client: InferenceClient | gr.Interface, prompt: str, neg_prompt: str | None = None,
333
  height: int | None = None, width: int | None = None,
334
- steps: int | None = None, cfg: int | None = None):
335
  png_path = "image.png"
336
  kwargs = {}
337
  if height is not None and height >= 256: kwargs["height"] = height
338
  if width is not None and width >= 256: kwargs["width"] = width
339
  if steps is not None and steps >= 1: kwargs["num_inference_steps"] = steps
340
  if cfg is not None and cfg > 0: cfg = kwargs["guidance_scale"] = cfg
 
341
  try:
342
  if isinstance(client, InferenceClient):
343
- image = client.text_to_image(prompt=prompt, negative_prompt=neg_prompt, **kwargs)
344
  elif isinstance(client, gr.Interface):
345
- image = client.fn(prompt=prompt, negative_prompt=neg_prompt, **kwargs)
346
  else: return None
 
347
  image.save(png_path)
348
  return str(Path(png_path).resolve())
349
  except Exception as e:
350
  print(e)
351
- return None
352
 
353
 
354
  async def infer(model_name: str, prompt: str, neg_prompt: str | None = None,
355
  height: int | None = None, width: int | None = None,
356
- steps: int | None = None, cfg: int | None = None,
357
  save_path: str | None = None, timeout: float = inference_timeout):
358
  import random
359
  noise = ""
360
- rand = random.randint(1, 500)
361
- for i in range(rand):
362
- noise += " "
 
363
  model = load_model(model_name)
364
  if not model: return None
365
  task = asyncio.create_task(asyncio.to_thread(infer_body, model, f"{prompt} {noise}", neg_prompt,
366
- height, width, steps, cfg))
367
  await asyncio.sleep(0)
368
  try:
369
  result = await asyncio.wait_for(task, timeout=timeout)
370
- except (Exception, asyncio.TimeoutError) as e:
371
  print(e)
372
  print(f"Task timed out: {model_name}")
373
  if not task.done(): task.cancel()
374
  result = None
 
 
 
 
 
 
375
  if task.done() and result is not None:
376
  with lock:
377
  image = rename_image(result, model_name, save_path)
@@ -379,27 +422,32 @@ async def infer(model_name: str, prompt: str, neg_prompt: str | None = None,
379
  return None
380
 
381
 
 
382
  def infer_fn(model_name: str, prompt: str, neg_prompt: str | None = None, height: int | None = None,
383
- width: int | None = None, steps: int | None = None, cfg: int | None = None,
384
  pos_pre: list = [], pos_suf: list = [], neg_pre: list = [], neg_suf: list = [], save_path: str | None = None):
385
  if model_name == 'NA':
386
  return None
387
  try:
388
- prompt, neg_prompt = recom_prompt(prompt, neg_prompt, pos_pre, pos_suf, neg_pre, neg_suf)
 
389
  loop = asyncio.new_event_loop()
 
 
390
  result = loop.run_until_complete(infer(model_name, prompt, neg_prompt, height, width,
391
- steps, cfg, save_path, inference_timeout))
392
  except (Exception, asyncio.CancelledError) as e:
393
  print(e)
394
- print(f"Task aborted: {model_name}")
395
  result = None
 
396
  finally:
397
  loop.close()
398
  return result
399
 
400
 
401
  def infer_rand_fn(model_name_dummy: str, prompt: str, neg_prompt: str | None = None, height: int | None = None,
402
- width: int | None = None, steps: int | None = None, cfg: int | None = None,
403
  pos_pre: list = [], pos_suf: list = [], neg_pre: list = [], neg_suf: list = [], save_path: str | None = None):
404
  import random
405
  if model_name_dummy == 'NA':
@@ -407,14 +455,18 @@ def infer_rand_fn(model_name_dummy: str, prompt: str, neg_prompt: str | None = N
407
  random.seed()
408
  model_name = random.choice(list(loaded_models.keys()))
409
  try:
410
- prompt, neg_prompt = recom_prompt(prompt, neg_prompt, pos_pre, pos_suf, neg_pre, neg_suf)
 
411
  loop = asyncio.new_event_loop()
 
 
412
  result = loop.run_until_complete(infer(model_name, prompt, neg_prompt, height, width,
413
- steps, cfg, save_path, inference_timeout))
414
  except (Exception, asyncio.CancelledError) as e:
415
  print(e)
416
- print(f"Task aborted: {model_name}")
417
  result = None
 
418
  finally:
419
  loop.close()
420
  return result
 
3
  from threading import RLock
4
  from pathlib import Path
5
  from huggingface_hub import InferenceClient
6
+ import os
7
 
8
 
9
+ HF_TOKEN = os.environ.get("HF_TOKEN") if os.environ.get("HF_TOKEN") else None # If private or gated models aren't used, ENV setting is unnecessary.
10
  server_timeout = 600
11
  inference_timeout = 300
12
 
 
33
  return re.fullmatch(r'^[^/]+?/[^/]+?$', s)
34
 
35
 
36
+ def get_status(model_name: str):
37
+ from huggingface_hub import InferenceClient
38
+ client = InferenceClient(token=HF_TOKEN, timeout=10)
39
+ return client.get_model_status(model_name)
40
+
41
+
42
+ def is_loadable(model_name: str, force_gpu: bool = False):
43
+ try:
44
+ status = get_status(model_name)
45
+ except Exception as e:
46
+ print(e)
47
+ print(f"Couldn't load {model_name}.")
48
+ return False
49
+ gpu_state = isinstance(status.compute_type, dict) and "gpu" in status.compute_type.keys()
50
+ if status is None or status.state not in ["Loadable", "Loaded"] or (force_gpu and not gpu_state):
51
+ print(f"Couldn't load {model_name}. Model state:'{status.state}', GPU:{gpu_state}")
52
+ return status is not None and status.state in ["Loadable", "Loaded"] and (not force_gpu or gpu_state)
53
+
54
+
55
+ def find_model_list(author: str="", tags: list[str]=[], not_tag="", sort: str="last_modified", limit: int=30, force_gpu=False, check_status=False):
56
  from huggingface_hub import HfApi
57
+ api = HfApi(token=HF_TOKEN)
58
  default_tags = ["diffusers"]
59
  if not sort: sort = "last_modified"
60
+ limit = limit * 20 if check_status and force_gpu else limit * 5
61
  models = []
62
  try:
63
+ model_infos = api.list_models(author=author, task="text-to-image",
64
+ tags=list_uniq(default_tags + tags), cardData=True, sort=sort, limit=limit)
65
  except Exception as e:
66
  print(f"Error: Failed to list models.")
67
  print(e)
68
  return models
69
  for model in model_infos:
70
+ if not model.private and not model.gated or HF_TOKEN is not None:
71
+ loadable = is_loadable(model.id, force_gpu) if check_status else True
72
+ if not_tag and not_tag in model.tags or not loadable: continue
73
  models.append(model.id)
74
  if len(models) == limit: break
75
  return models
 
77
 
78
  def get_t2i_model_info_dict(repo_id: str):
79
  from huggingface_hub import HfApi
80
+ api = HfApi(token=HF_TOKEN)
81
  info = {"md": "None"}
82
  try:
83
  if not is_repo_name(repo_id) or not api.repo_exists(repo_id=repo_id): return info
84
+ model = api.model_info(repo_id=repo_id, token=HF_TOKEN)
85
  except Exception as e:
86
  print(f"Error: Failed to get {repo_id}'s info.")
87
  print(e)
88
  return info
89
+ if model.private or model.gated and HF_TOKEN is None: return info
90
  try:
91
  tags = model.tags
92
  except Exception as e:
93
  print(e)
94
  return info
95
  if not 'diffusers' in model.tags: return info
96
+ if 'diffusers:FluxPipeline' in tags: info["ver"] = "FLUX.1"
97
+ elif 'diffusers:StableDiffusionXLPipeline' in tags: info["ver"] = "SDXL"
98
  elif 'diffusers:StableDiffusionPipeline' in tags: info["ver"] = "SD1.5"
99
  elif 'diffusers:StableDiffusion3Pipeline' in tags: info["ver"] = "SD3"
100
  else: info["ver"] = "Other"
 
142
 
143
  # https://github.com/gradio-app/gradio/blob/main/gradio/external.py
144
  # https://huggingface.co/docs/huggingface_hub/package_reference/inference_client
145
+ from typing import Literal
146
+ def load_from_model(model_name: str, hf_token: str | Literal[False] | None = None):
147
  import httpx
148
  import huggingface_hub
149
+ from gradio.exceptions import ModelNotFoundError, TooManyRequestsError
150
  model_url = f"https://huggingface.co/{model_name}"
151
  api_url = f"https://api-inference.huggingface.co/models/{model_name}"
152
  print(f"Fetching model from: {model_url}")
153
 
154
+ headers = ({} if hf_token in [False, None] else {"Authorization": f"Bearer {hf_token}"})
155
  response = httpx.request("GET", api_url, headers=headers)
156
  if response.status_code != 200:
157
  raise ModelNotFoundError(
158
  f"Could not find model: {model_name}. If it is a private or gated model, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `hf_token` parameter."
159
  )
160
+ p = response.json().get("pipeline_tag")
161
+ if p != "text-to-image": raise ModelNotFoundError(f"This model isn't for text-to-image or unsupported: {model_name}.")
162
  headers["X-Wait-For-Model"] = "true"
163
  client = huggingface_hub.InferenceClient(model=model_name, headers=headers,
164
  token=hf_token, timeout=server_timeout)
 
167
  fn = client.text_to_image
168
 
169
  def query_huggingface_inference_endpoints(*data, **kwargs):
170
+ try:
171
+ data = fn(*data, **kwargs) # type: ignore
172
+ except huggingface_hub.utils.HfHubHTTPError as e:
173
+ if "429" in str(e):
174
+ raise TooManyRequestsError() from e
175
+ except Exception as e:
176
+ raise Exception() from e
177
+ return data
178
 
179
  interface_info = {
180
  "fn": query_huggingface_inference_endpoints,
 
190
  global model_info_dict
191
  if model_name in loaded_models.keys(): return loaded_models[model_name]
192
  try:
193
+ loaded_models[model_name] = load_from_model(model_name, hf_token=HF_TOKEN)
194
  print(f"Loaded: {model_name}")
195
  except Exception as e:
196
  if model_name in loaded_models.keys(): del loaded_models[model_name]
 
213
  if model_name in loaded_models.keys(): return loaded_models[model_name]
214
  try:
215
  client = InferenceClient(timeout=5)
216
+ status = client.get_model_status(model_name, token=HF_TOKEN)
217
  if status is None or status.framework != "diffusers" or status.state not in ["Loadable", "Loaded"]:
218
  print(f"Failed to load by API: {model_name}")
219
  return None
220
  else:
221
+ loaded_models[model_name] = InferenceClient(model_name, token=HF_TOKEN, timeout=server_timeout)
222
  print(f"Loaded by API: {model_name}")
223
  except Exception as e:
224
  if model_name in loaded_models.keys(): del loaded_models[model_name]
 
363
 
364
  # https://huggingface.co/docs/api-inference/detailed_parameters
365
  # https://huggingface.co/docs/huggingface_hub/package_reference/inference_client
366
+ def infer_body(client: InferenceClient | gr.Interface | object, prompt: str, neg_prompt: str | None = None,
367
  height: int | None = None, width: int | None = None,
368
+ steps: int | None = None, cfg: int | None = None, seed: int = -1):
369
  png_path = "image.png"
370
  kwargs = {}
371
  if height is not None and height >= 256: kwargs["height"] = height
372
  if width is not None and width >= 256: kwargs["width"] = width
373
  if steps is not None and steps >= 1: kwargs["num_inference_steps"] = steps
374
  if cfg is not None and cfg > 0: cfg = kwargs["guidance_scale"] = cfg
375
+ if seed >= 0: kwargs["seed"] = seed
376
  try:
377
  if isinstance(client, InferenceClient):
378
+ image = client.text_to_image(prompt=prompt, negative_prompt=neg_prompt, **kwargs, token=HF_TOKEN)
379
  elif isinstance(client, gr.Interface):
380
+ image = client.fn(prompt=prompt, negative_prompt=neg_prompt, **kwargs, token=HF_TOKEN)
381
  else: return None
382
+ if isinstance(image, tuple): return None
383
  image.save(png_path)
384
  return str(Path(png_path).resolve())
385
  except Exception as e:
386
  print(e)
387
+ raise Exception() from e
388
 
389
 
390
  async def infer(model_name: str, prompt: str, neg_prompt: str | None = None,
391
  height: int | None = None, width: int | None = None,
392
+ steps: int | None = None, cfg: int | None = None, seed: int = -1,
393
  save_path: str | None = None, timeout: float = inference_timeout):
394
  import random
395
  noise = ""
396
+ if seed < 0:
397
+ rand = random.randint(1, 500)
398
+ for i in range(rand):
399
+ noise += " "
400
  model = load_model(model_name)
401
  if not model: return None
402
  task = asyncio.create_task(asyncio.to_thread(infer_body, model, f"{prompt} {noise}", neg_prompt,
403
+ height, width, steps, cfg, seed))
404
  await asyncio.sleep(0)
405
  try:
406
  result = await asyncio.wait_for(task, timeout=timeout)
407
+ except asyncio.TimeoutError as e:
408
  print(e)
409
  print(f"Task timed out: {model_name}")
410
  if not task.done(): task.cancel()
411
  result = None
412
+ raise Exception(f"Task timed out: {model_name}") from e
413
+ except Exception as e:
414
+ print(e)
415
+ if not task.done(): task.cancel()
416
+ result = None
417
+ raise Exception() from e
418
  if task.done() and result is not None:
419
  with lock:
420
  image = rename_image(result, model_name, save_path)
 
422
  return None
423
 
424
 
425
+ # https://github.com/aio-libs/pytest-aiohttp/issues/8 # also AsyncInferenceClient is buggy.
426
  def infer_fn(model_name: str, prompt: str, neg_prompt: str | None = None, height: int | None = None,
427
+ width: int | None = None, steps: int | None = None, cfg: int | None = None, seed: int = -1,
428
  pos_pre: list = [], pos_suf: list = [], neg_pre: list = [], neg_suf: list = [], save_path: str | None = None):
429
  if model_name == 'NA':
430
  return None
431
  try:
432
+ loop = asyncio.get_running_loop()
433
+ except Exception:
434
  loop = asyncio.new_event_loop()
435
+ try:
436
+ prompt, neg_prompt = recom_prompt(prompt, neg_prompt, pos_pre, pos_suf, neg_pre, neg_suf)
437
  result = loop.run_until_complete(infer(model_name, prompt, neg_prompt, height, width,
438
+ steps, cfg, seed, save_path, inference_timeout))
439
  except (Exception, asyncio.CancelledError) as e:
440
  print(e)
441
+ print(f"Task aborted: {model_name}, Error: {e}")
442
  result = None
443
+ raise gr.Error(f"Task aborted: {model_name}, Error: {e}")
444
  finally:
445
  loop.close()
446
  return result
447
 
448
 
449
  def infer_rand_fn(model_name_dummy: str, prompt: str, neg_prompt: str | None = None, height: int | None = None,
450
+ width: int | None = None, steps: int | None = None, cfg: int | None = None, seed: int = -1,
451
  pos_pre: list = [], pos_suf: list = [], neg_pre: list = [], neg_suf: list = [], save_path: str | None = None):
452
  import random
453
  if model_name_dummy == 'NA':
 
455
  random.seed()
456
  model_name = random.choice(list(loaded_models.keys()))
457
  try:
458
+ loop = asyncio.get_running_loop()
459
+ except Exception:
460
  loop = asyncio.new_event_loop()
461
+ try:
462
+ prompt, neg_prompt = recom_prompt(prompt, neg_prompt, pos_pre, pos_suf, neg_pre, neg_suf)
463
  result = loop.run_until_complete(infer(model_name, prompt, neg_prompt, height, width,
464
+ steps, cfg, seed, save_path, inference_timeout))
465
  except (Exception, asyncio.CancelledError) as e:
466
  print(e)
467
+ print(f"Task aborted: {model_name}, Error: {e}")
468
  result = None
469
+ raise gr.Error(f"Task aborted: {model_name}, Error: {e}")
470
  finally:
471
  loop.close()
472
  return result
tagger/fl2sd3longcap.py CHANGED
@@ -1,5 +1,5 @@
1
- from transformers import AutoProcessor, AutoModelForCausalLM
2
  import spaces
 
3
  import re
4
  from PIL import Image
5
  import torch
@@ -8,9 +8,13 @@ import subprocess
8
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
- fl_model = AutoModelForCausalLM.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True).to(device).eval()
12
- fl_processor = AutoProcessor.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True)
13
 
 
 
 
 
 
 
14
 
15
  def fl_modify_caption(caption: str) -> str:
16
  """
@@ -41,7 +45,7 @@ def fl_modify_caption(caption: str) -> str:
41
  return modified_caption if modified_caption != caption else caption
42
 
43
 
44
- @spaces.GPU
45
  def fl_run_example(image):
46
  task_prompt = "<DESCRIPTION>"
47
  prompt = task_prompt + "Describe this image in great detail."
@@ -50,6 +54,7 @@ def fl_run_example(image):
50
  if image.mode != "RGB":
51
  image = image.convert("RGB")
52
 
 
53
  inputs = fl_processor(text=prompt, images=image, return_tensors="pt").to(device)
54
  generated_ids = fl_model.generate(
55
  input_ids=inputs["input_ids"],
@@ -57,6 +62,7 @@ def fl_run_example(image):
57
  max_new_tokens=1024,
58
  num_beams=3
59
  )
 
60
  generated_text = fl_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
61
  parsed_answer = fl_processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
62
  return fl_modify_caption(parsed_answer["<DESCRIPTION>"])
 
 
1
  import spaces
2
+ from transformers import AutoProcessor, AutoModelForCausalLM
3
  import re
4
  from PIL import Image
5
  import torch
 
8
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
11
 
12
+ try:
13
+ fl_model = AutoModelForCausalLM.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True).to("cpu").eval()
14
+ fl_processor = AutoProcessor.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True)
15
+ except Exception as e:
16
+ print(e)
17
+ fl_model = fl_processor = None
18
 
19
  def fl_modify_caption(caption: str) -> str:
20
  """
 
45
  return modified_caption if modified_caption != caption else caption
46
 
47
 
48
+ @spaces.GPU(duration=30)
49
  def fl_run_example(image):
50
  task_prompt = "<DESCRIPTION>"
51
  prompt = task_prompt + "Describe this image in great detail."
 
54
  if image.mode != "RGB":
55
  image = image.convert("RGB")
56
 
57
+ fl_model.to(device)
58
  inputs = fl_processor(text=prompt, images=image, return_tensors="pt").to(device)
59
  generated_ids = fl_model.generate(
60
  input_ids=inputs["input_ids"],
 
62
  max_new_tokens=1024,
63
  num_beams=3
64
  )
65
+ fl_model.to("cpu")
66
  generated_text = fl_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
67
  parsed_answer = fl_processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
68
  return fl_modify_caption(parsed_answer["<DESCRIPTION>"])
tagger/tagger.py CHANGED
@@ -1,7 +1,7 @@
 
1
  from PIL import Image
2
  import torch
3
  import gradio as gr
4
- import spaces
5
  from transformers import (
6
  AutoImageProcessor,
7
  AutoModelForImageClassification,
@@ -12,10 +12,15 @@ from pathlib import Path
12
  WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
13
  WD_MODEL_NAME = WD_MODEL_NAMES[0]
14
 
15
- wd_model = AutoModelForImageClassification.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
16
- wd_model.to("cuda" if torch.cuda.is_available() else "cpu")
17
- wd_processor = AutoImageProcessor.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
18
 
 
 
 
 
 
 
19
 
20
  def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
21
  return (
@@ -506,7 +511,7 @@ def gen_prompt(rating: list[str], character: list[str], general: list[str]):
506
  return ", ".join(all_tags)
507
 
508
 
509
- @spaces.GPU()
510
  def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8):
511
  inputs = wd_processor.preprocess(image, return_tensors="pt")
512
 
@@ -514,9 +519,11 @@ def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_t
514
  logits = torch.sigmoid(outputs.logits[0]) # take the first logits
515
 
516
  # get probabilities
 
517
  results = {
518
  wd_model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits)
519
  }
 
520
  # rating, character, general
521
  rating, character, general = postprocess_results(
522
  results, general_threshold, character_threshold
 
1
+ import spaces
2
  from PIL import Image
3
  import torch
4
  import gradio as gr
 
5
  from transformers import (
6
  AutoImageProcessor,
7
  AutoModelForImageClassification,
 
12
  WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
13
  WD_MODEL_NAME = WD_MODEL_NAMES[0]
14
 
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ default_device = device
 
17
 
18
+ try:
19
+ wd_model = AutoModelForImageClassification.from_pretrained(WD_MODEL_NAME, trust_remote_code=True).to(default_device).eval()
20
+ wd_processor = AutoImageProcessor.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
21
+ except Exception as e:
22
+ print(e)
23
+ wd_model = wd_processor = None
24
 
25
  def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
26
  return (
 
511
  return ", ".join(all_tags)
512
 
513
 
514
+ @spaces.GPU(duration=30)
515
  def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8):
516
  inputs = wd_processor.preprocess(image, return_tensors="pt")
517
 
 
519
  logits = torch.sigmoid(outputs.logits[0]) # take the first logits
520
 
521
  # get probabilities
522
+ if device != default_device: wd_model.to(device=device)
523
  results = {
524
  wd_model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits)
525
  }
526
+ if device != default_device: wd_model.to(device=default_device)
527
  # rating, character, general
528
  rating, character, general = postprocess_results(
529
  results, general_threshold, character_threshold