John6666 commited on
Commit
375b410
1 Parent(s): 6a659da

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +18 -4
  2. multit2i.py +32 -11
app.py CHANGED
@@ -3,6 +3,7 @@ from multit2i import (
3
  load_models,
4
  find_model_list,
5
  infer_multi,
 
6
  save_gallery_images,
7
  change_model,
8
  get_model_info_md,
@@ -51,19 +52,22 @@ css = """"""
51
 
52
  with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", css=css) as demo:
53
  with gr.Column():
54
- model_name = gr.Dropdown(label="Select Model", choices=list(loaded_models.keys()), value=list(loaded_models.keys())[0], allow_custom_value=True)
55
- model_info = gr.Markdown(value=get_model_info_md(list(loaded_models.keys())[0]))
56
  with gr.Accordion("Advanced settings", open=False):
57
- image_num = gr.Slider(label="Number of Images", minimum=1, maximum=8, value=1, step=1)
58
  with gr.Accordion("Recommended Prompt"):
59
  recom_prompt_preset = gr.Radio(label="Set Presets", choices=get_recom_prompt_type(), value="Common")
60
  positive_prefix = gr.CheckboxGroup(label="Use Positive Prefix", choices=get_positive_prefix(), value=[])
61
  positive_suffix = gr.CheckboxGroup(label="Use Positive Suffix", choices=get_positive_suffix(), value=["Common"])
62
  negative_prefix = gr.CheckboxGroup(label="Use Negative Prefix", choices=get_negative_prefix(), value=[], visible=False)
63
  negative_suffix = gr.CheckboxGroup(label="Use Negative Suffix", choices=get_negative_suffix(), value=["Common"], visible=False)
 
 
 
64
  prompt = gr.Text(label="Prompt", lines=1, max_lines=8, placeholder="1girl, solo, ...")
65
  neg_prompt = gr.Text(label="Negative Prompt", lines=1, max_lines=8, placeholder="", visible=False)
66
- run_button = gr.Button("Generate Image")
 
 
 
67
  results = gr.Gallery(label="Gallery", interactive=False, show_download_button=True, show_share_button=False,
68
  container=True, format="png", object_fit="contain")
69
  image_files = gr.Files(label="Download", interactive=False)
@@ -99,6 +103,16 @@ This is due to the time it takes for Gradio to generate an example image to cach
99
  show_progress="full",
100
  show_api=True,
101
  ).success(save_gallery_images, [results], [results, image_files], queue=False, show_api=False)
 
 
 
 
 
 
 
 
 
 
102
  clear_results.click(lambda: (None, None), None, [results, image_files], queue=False, show_api=False)
103
  recom_prompt_preset.change(set_recom_prompt_preset, [recom_prompt_preset],
104
  [positive_prefix, positive_suffix, negative_prefix, negative_suffix], queue=False, show_api=False)
 
3
  load_models,
4
  find_model_list,
5
  infer_multi,
6
+ infer_multi_random,
7
  save_gallery_images,
8
  change_model,
9
  get_model_info_md,
 
52
 
53
  with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", css=css) as demo:
54
  with gr.Column():
 
 
55
  with gr.Accordion("Advanced settings", open=False):
 
56
  with gr.Accordion("Recommended Prompt"):
57
  recom_prompt_preset = gr.Radio(label="Set Presets", choices=get_recom_prompt_type(), value="Common")
58
  positive_prefix = gr.CheckboxGroup(label="Use Positive Prefix", choices=get_positive_prefix(), value=[])
59
  positive_suffix = gr.CheckboxGroup(label="Use Positive Suffix", choices=get_positive_suffix(), value=["Common"])
60
  negative_prefix = gr.CheckboxGroup(label="Use Negative Prefix", choices=get_negative_prefix(), value=[], visible=False)
61
  negative_suffix = gr.CheckboxGroup(label="Use Negative Suffix", choices=get_negative_suffix(), value=["Common"], visible=False)
62
+ with gr.Group():
63
+ model_name = gr.Dropdown(label="Select Model", choices=list(loaded_models.keys()), value=list(loaded_models.keys())[0])
64
+ model_info = gr.Markdown(value=get_model_info_md(list(loaded_models.keys())[0]))
65
  prompt = gr.Text(label="Prompt", lines=1, max_lines=8, placeholder="1girl, solo, ...")
66
  neg_prompt = gr.Text(label="Negative Prompt", lines=1, max_lines=8, placeholder="", visible=False)
67
+ with gr.Row():
68
+ run_button = gr.Button("Generate Image", scale=2)
69
+ random_button = gr.Button("Random Model 🎲", scale=1)
70
+ image_num = gr.Number(label="Number of images", minimum=1, maximum=16, value=1, step=1, interactive=True, scale=1)
71
  results = gr.Gallery(label="Gallery", interactive=False, show_download_button=True, show_share_button=False,
72
  container=True, format="png", object_fit="contain")
73
  image_files = gr.Files(label="Download", interactive=False)
 
103
  show_progress="full",
104
  show_api=True,
105
  ).success(save_gallery_images, [results], [results, image_files], queue=False, show_api=False)
106
+ gr.on(
107
+ triggers=[random_button.click],
108
+ fn=infer_multi_random,
109
+ inputs=[prompt, neg_prompt, results, image_num,
110
+ positive_prefix, positive_suffix, negative_prefix, negative_suffix],
111
+ outputs=[results],
112
+ queue=True,
113
+ show_progress="full",
114
+ show_api=True,
115
+ ).success(save_gallery_images, [results], [results, image_files], queue=False, show_api=False)
116
  clear_results.click(lambda: (None, None), None, [results, image_files], queue=False, show_api=False)
117
  recom_prompt_preset.change(set_recom_prompt_preset, [recom_prompt_preset],
118
  [positive_prefix, positive_suffix, negative_prefix, negative_suffix], queue=False, show_api=False)
multit2i.py CHANGED
@@ -74,7 +74,7 @@ def get_t2i_model_info_dict(repo_id: str):
74
  info["last_modified"] = model.last_modified.strftime("lastmod: %Y-%m-%d")
75
  un_tags = ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']
76
  descs = [info["ver"]] + list_sub(info["tags"], un_tags) + [f'DLs: {info["downloads"]}'] + [f'❤: {info["likes"]}'] + [info["last_modified"]]
77
- info["md"] = f'Model Info: {", ".join(descs)} [Model Repo]({info["url"]})'
78
  return info
79
 
80
 
@@ -126,6 +126,7 @@ def load_model(model_name: str):
126
 
127
 
128
  async def async_load_models(models: list, limit: int=5):
 
129
  sem = asyncio.Semaphore(limit)
130
  async def async_load_model(model: str):
131
  async with sem:
@@ -134,22 +135,23 @@ async def async_load_models(models: list, limit: int=5):
134
  except Exception as e:
135
  print(e)
136
  tasks = [asyncio.create_task(async_load_model(model)) for model in models]
137
- return await asyncio.wait(tasks)
138
 
139
 
140
  def load_models(models: list, limit: int=5):
141
- loop = asyncio.get_event_loop()
142
  try:
143
  loop.run_until_complete(async_load_models(models, limit))
144
  except Exception as e:
145
  print(e)
146
  pass
147
- loop.close()
 
148
 
149
 
150
  positive_prefix = {
151
  "Pony": to_list("score_9, score_8_up, score_7_up"),
152
- "Pony Anime": to_list("source_anime, score_9, score_8_up, score_7_up"),
153
  }
154
  positive_suffix = {
155
  "Common": to_list("highly detailed, masterpiece, best quality, very aesthetic, absurdres"),
@@ -161,7 +163,7 @@ negative_prefix = {
161
  "Pony Real": to_list("score_6, score_5, score_4, source_anime, source_pony, source_furry, source_cartoon"),
162
  }
163
  negative_suffix = {
164
- "Common": to_list("lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]"),
165
  "Pony Anime": to_list("busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends"),
166
  "Pony Real": to_list("ugly, airbrushed, simple background, cgi, cartoon, anime"),
167
  }
@@ -248,7 +250,7 @@ def change_model(model_name: str):
248
  return get_model_info_md(model_name)
249
 
250
 
251
- def infer(prompt: str, neg_prompt: str, model_name: str, progress=gr.Progress(track_tqdm=True)):
252
  from PIL import Image
253
  import random
254
  seed = ""
@@ -260,18 +262,37 @@ def infer(prompt: str, neg_prompt: str, model_name: str, progress=gr.Progress(tr
260
  model = load_model(model_name)
261
  if not model: return (Image.Image(), None)
262
  image_path = model(prompt + seed)
263
- image = Image.open(image_path).convert('RGB')
264
  except Exception as e:
265
  print(e)
266
  return (Image.Image(), None)
267
  return (image, caption)
268
 
269
 
270
- def infer_multi(prompt: str, neg_prompt: str, results: list, image_num: float, model_name: str,
271
  pos_pre: list = [], pos_suf: list = [], neg_pre: list = [], neg_suf: list = [], progress=gr.Progress(track_tqdm=True)):
 
272
  image_num = int(image_num)
273
  images = results if results else []
274
  prompt, neg_prompt = recom_prompt(prompt, neg_prompt, pos_pre, pos_suf, neg_pre, neg_suf)
275
- for i in range(image_num):
276
- images.append(infer(prompt, neg_prompt, model_name, recom_prompt))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  yield images
 
74
  info["last_modified"] = model.last_modified.strftime("lastmod: %Y-%m-%d")
75
  un_tags = ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']
76
  descs = [info["ver"]] + list_sub(info["tags"], un_tags) + [f'DLs: {info["downloads"]}'] + [f'❤: {info["likes"]}'] + [info["last_modified"]]
77
+ info["md"] = f'    Model Info: {", ".join(descs)} [Model Repo]({info["url"]})'
78
  return info
79
 
80
 
 
126
 
127
 
128
  async def async_load_models(models: list, limit: int=5):
129
+ from tqdm.asyncio import tqdm_asyncio
130
  sem = asyncio.Semaphore(limit)
131
  async def async_load_model(model: str):
132
  async with sem:
 
135
  except Exception as e:
136
  print(e)
137
  tasks = [asyncio.create_task(async_load_model(model)) for model in models]
138
+ return await tqdm_asyncio.gather(*tasks)
139
 
140
 
141
  def load_models(models: list, limit: int=5):
142
+ loop = asyncio.new_event_loop()
143
  try:
144
  loop.run_until_complete(async_load_models(models, limit))
145
  except Exception as e:
146
  print(e)
147
  pass
148
+ finally:
149
+ loop.close()
150
 
151
 
152
  positive_prefix = {
153
  "Pony": to_list("score_9, score_8_up, score_7_up"),
154
+ "Pony Anime": to_list("source_anime, anime, score_9, score_8_up, score_7_up"),
155
  }
156
  positive_suffix = {
157
  "Common": to_list("highly detailed, masterpiece, best quality, very aesthetic, absurdres"),
 
163
  "Pony Real": to_list("score_6, score_5, score_4, source_anime, source_pony, source_furry, source_cartoon"),
164
  }
165
  negative_suffix = {
166
+ "Common": to_list("lowres, (bad), bad hands, bad feet, text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]"),
167
  "Pony Anime": to_list("busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends"),
168
  "Pony Real": to_list("ugly, airbrushed, simple background, cgi, cartoon, anime"),
169
  }
 
250
  return get_model_info_md(model_name)
251
 
252
 
253
+ def infer(prompt: str, neg_prompt: str, model_name: str):
254
  from PIL import Image
255
  import random
256
  seed = ""
 
262
  model = load_model(model_name)
263
  if not model: return (Image.Image(), None)
264
  image_path = model(prompt + seed)
265
+ image = Image.open(image_path).convert('RGBA')
266
  except Exception as e:
267
  print(e)
268
  return (Image.Image(), None)
269
  return (image, caption)
270
 
271
 
272
+ async def infer_multi(prompt: str, neg_prompt: str, results: list, image_num: float, model_name: str,
273
  pos_pre: list = [], pos_suf: list = [], neg_pre: list = [], neg_suf: list = [], progress=gr.Progress(track_tqdm=True)):
274
+ import asyncio
275
  image_num = int(image_num)
276
  images = results if results else []
277
  prompt, neg_prompt = recom_prompt(prompt, neg_prompt, pos_pre, pos_suf, neg_pre, neg_suf)
278
+ tasks = [asyncio.to_thread(infer, prompt, neg_prompt, model_name) for i in range(image_num)]
279
+ results = await asyncio.gather(*tasks, return_exceptions=True)
280
+ for result in results:
281
+ images.append(result)
282
+ yield images
283
+
284
+
285
+ async def infer_multi_random(prompt: str, neg_prompt: str, results: list, image_num: float,
286
+ pos_pre: list = [], pos_suf: list = [], neg_pre: list = [], neg_suf: list = [], progress=gr.Progress(track_tqdm=True)):
287
+ import asyncio
288
+ import random
289
+ image_num = int(image_num)
290
+ images = results if results else []
291
+ random.seed()
292
+ model_names = random.choices(list(loaded_models.keys()), k = image_num)
293
+ prompt, neg_prompt = recom_prompt(prompt, neg_prompt, pos_pre, pos_suf, neg_pre, neg_suf)
294
+ tasks = [asyncio.to_thread(infer, prompt, neg_prompt, model_name) for model_name in model_names]
295
+ results = await asyncio.gather(*tasks, return_exceptions=True)
296
+ for result in results:
297
+ images.append(result)
298
  yield images