Leyo commited on
Commit
78d83ec
1 Parent(s): 5cc0179

add multiprocessing

Browse files
Files changed (1) hide show
  1. app_dialogue.py +43 -14
app_dialogue.py CHANGED
@@ -10,6 +10,7 @@ from typing import List, Optional, Tuple
10
  from urllib.parse import urlparse
11
  from PIL import Image, ImageDraw, ImageFont
12
 
 
13
  import random
14
  import gradio as gr
15
  import PIL
@@ -777,6 +778,28 @@ with gr.Blocks(title="AI Meme Generator", theme=gr.themes.Base(), css=css) as de
777
  with gr.Row():
778
  chatbot.render()
779
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
780
  def model_inference(
781
  model_selector,
782
  system_prompt,
@@ -849,21 +872,27 @@ with gr.Blocks(title="AI Meme Generator", theme=gr.themes.Base(), css=css) as de
849
 
850
  query = prompt_list_to_tgi_input(formated_prompt_list)
851
  all_meme_images = []
852
- for i in range(4):
853
- text = client.generate(prompt=query, **generation_args).generated_text
854
- if image is not None and text != "":
855
- meme_image = make_meme_image(
856
- image=image,
857
- text=text,
858
- font_meme_text=font_meme_text,
859
- all_caps_meme_text=all_caps_meme_text,
860
- text_at_the_top=text_at_the_top,
 
 
 
861
  )
862
- meme_image = pil_to_temp_file(meme_image)
863
- all_meme_images.append(meme_image)
864
- yield user_prompt_str, all_meme_images, chat_history
865
- if i == 3:
866
- return
 
 
 
867
 
868
  gr.on(
869
  triggers=[textbox.submit, imagebox.upload, submit_btn.click],
 
10
  from urllib.parse import urlparse
11
  from PIL import Image, ImageDraw, ImageFont
12
 
13
+ import concurrent.futures
14
  import random
15
  import gradio as gr
16
  import PIL
 
778
  with gr.Row():
779
  chatbot.render()
780
 
781
+ def generate_meme(
782
+ i,
783
+ client,
784
+ query,
785
+ image,
786
+ font_meme_text,
787
+ all_caps_meme_text,
788
+ text_at_the_top,
789
+ generation_args,
790
+ ):
791
+ text = client.generate(prompt=query, **generation_args).generated_text
792
+ if image is not None and text != "":
793
+ meme_image = make_meme_image(
794
+ image=image,
795
+ text=text,
796
+ font_meme_text=font_meme_text,
797
+ all_caps_meme_text=all_caps_meme_text,
798
+ text_at_the_top=text_at_the_top,
799
+ )
800
+ meme_image = pil_to_temp_file(meme_image)
801
+ return meme_image
802
+
803
  def model_inference(
804
  model_selector,
805
  system_prompt,
 
872
 
873
  query = prompt_list_to_tgi_input(formated_prompt_list)
874
  all_meme_images = []
875
+ with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
876
+ futures = [
877
+ executor.submit(
878
+ generate_meme,
879
+ i,
880
+ client,
881
+ query,
882
+ image,
883
+ font_meme_text,
884
+ all_caps_meme_text,
885
+ text_at_the_top,
886
+ generation_args,
887
  )
888
+ for i in range(4)
889
+ ]
890
+
891
+ for future in concurrent.futures.as_completed(futures):
892
+ meme_image = future.result()
893
+ if meme_image:
894
+ all_meme_images.append(meme_image)
895
+ return user_prompt_str, all_meme_images, chat_history
896
 
897
  gr.on(
898
  triggers=[textbox.submit, imagebox.upload, submit_btn.click],