Leyo commited on
Commit
410687f
1 Parent(s): a85c08c

update reqs + small refactor

Browse files
Files changed (2) hide show
  1. app_dialogue.py +119 -118
  2. requirements.txt +1 -3
app_dialogue.py CHANGED
@@ -546,6 +546,125 @@ def expand_layout():
546
  return gr.Column(scale=2), gr.Gallery(height=682)
547
 
548
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
549
  textbox = gr.Textbox(
550
  placeholder="Upload an image and ask the AI to create a meme!",
551
  show_label=False,
@@ -764,115 +883,6 @@ with gr.Blocks(title="AI Meme Generator", theme=gr.themes.Base(), css=css) as de
764
  with gr.Row():
765
  chatbot.render()
766
 
767
- def generate_meme(
768
- i,
769
- client,
770
- query,
771
- image,
772
- font_meme_text,
773
- all_caps_meme_text,
774
- text_at_the_top,
775
- generation_args,
776
- ):
777
- try:
778
- text = client.generate(prompt=query, **generation_args).generated_text
779
- except Exception as e:
780
- logger.error(f"Error {e} while generating meme text")
781
- text = ""
782
- if image is not None and text != "":
783
- meme_image = make_meme_image(
784
- image=image,
785
- text=text,
786
- font_meme_text=font_meme_text,
787
- all_caps_meme_text=all_caps_meme_text,
788
- text_at_the_top=text_at_the_top,
789
- )
790
- return meme_image
791
- else:
792
- return None
793
-
794
- def model_inference(
795
- model_selector,
796
- system_prompt,
797
- user_prompt_str,
798
- chat_history,
799
- image,
800
- decoding_strategy,
801
- temperature,
802
- max_new_tokens,
803
- repetition_penalty,
804
- top_p,
805
- all_caps_meme_text,
806
- text_at_the_top,
807
- font_meme_text,
808
- ):
809
- chat_history = []
810
- if user_prompt_str.strip() == "" and image is None:
811
- return "", None, chat_history
812
-
813
- system_prompt = ast.literal_eval(system_prompt)
814
- (
815
- formated_prompt_list,
816
- user_prompt_list,
817
- ) = format_user_prompt_with_im_history_and_system_conditioning(
818
- system_prompt=system_prompt,
819
- current_user_prompt_str=user_prompt_str.strip(),
820
- current_image=image,
821
- history=chat_history,
822
- )
823
-
824
- client_endpoint = API_PATHS[model_selector]
825
- client = Client(
826
- base_url=client_endpoint,
827
- headers={"x-use-cache": "0", "Authorization": f"Bearer {API_TOKEN}"},
828
- timeout=45,
829
- )
830
-
831
- # Common parameters to all decoding strategies
832
- # This documentation is useful to read: https://huggingface.co/docs/transformers/main/en/generation_strategies
833
- generation_args = {
834
- "max_new_tokens": max_new_tokens,
835
- "repetition_penalty": repetition_penalty,
836
- "stop_sequences": EOS_STRINGS,
837
- }
838
-
839
- assert decoding_strategy in [
840
- "Greedy",
841
- "Top P Sampling",
842
- ]
843
- if decoding_strategy == "Greedy":
844
- generation_args["do_sample"] = False
845
- elif decoding_strategy == "Top P Sampling":
846
- generation_args["temperature"] = temperature
847
- generation_args["do_sample"] = True
848
- generation_args["top_p"] = top_p
849
-
850
- chat_history.append([prompt_list_to_markdown(user_prompt_list), ""])
851
-
852
- query = prompt_list_to_tgi_input(formated_prompt_list)
853
- all_meme_images = []
854
- with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
855
- futures = [
856
- executor.submit(
857
- generate_meme,
858
- i,
859
- client,
860
- query,
861
- image,
862
- font_meme_text,
863
- all_caps_meme_text,
864
- text_at_the_top,
865
- generation_args,
866
- )
867
- for i in range(4)
868
- ]
869
-
870
- for future in concurrent.futures.as_completed(futures):
871
- meme_image = future.result(timeout=45)
872
- if meme_image:
873
- all_meme_images.append(meme_image)
874
- return user_prompt_str, all_meme_images, chat_history
875
-
876
  gr.on(
877
  triggers=[
878
  textbox.submit,
@@ -906,15 +916,6 @@ with gr.Blocks(title="AI Meme Generator", theme=gr.themes.Base(), css=css) as de
906
  outputs=[textbox, generated_memes_gallery, chatbot],
907
  )
908
 
909
- def remove_last_turn(chat_history):
910
- if len(chat_history) == 0:
911
- return chat_history, "", ""
912
- last_interaction = chat_history[-1]
913
- chat_history = chat_history[:-1]
914
- chat_update = chat_history
915
- text_update = last_interaction[0]
916
- return chat_update, text_update, ""
917
-
918
  regenerate_btn.click(
919
  fn=remove_last_turn,
920
  inputs=chatbot,
 
546
  return gr.Column(scale=2), gr.Gallery(height=682)
547
 
548
 
549
+ def generate_meme(
550
+ client,
551
+ query,
552
+ image,
553
+ font_meme_text,
554
+ all_caps_meme_text,
555
+ text_at_the_top,
556
+ generation_args,
557
+ ):
558
+ try:
559
+ text = client.generate(prompt=query, **generation_args).generated_text
560
+ except Exception as e:
561
+ logger.error(f"Error {e} while generating meme text")
562
+ text = ""
563
+ if image is not None and text != "":
564
+ meme_image = make_meme_image(
565
+ image=image,
566
+ text=text,
567
+ font_meme_text=font_meme_text,
568
+ all_caps_meme_text=all_caps_meme_text,
569
+ text_at_the_top=text_at_the_top,
570
+ )
571
+ return meme_image
572
+ else:
573
+ return None
574
+
575
+
576
+ def model_inference(
577
+ model_selector,
578
+ system_prompt,
579
+ user_prompt_str,
580
+ chat_history,
581
+ image,
582
+ decoding_strategy,
583
+ temperature,
584
+ max_new_tokens,
585
+ repetition_penalty,
586
+ top_p,
587
+ all_caps_meme_text,
588
+ text_at_the_top,
589
+ font_meme_text,
590
+ ):
591
+ chat_history = []
592
+ if user_prompt_str.strip() == "" and image is None:
593
+ return "", None, chat_history
594
+
595
+ system_prompt = ast.literal_eval(system_prompt)
596
+ (
597
+ formated_prompt_list,
598
+ user_prompt_list,
599
+ ) = format_user_prompt_with_im_history_and_system_conditioning(
600
+ system_prompt=system_prompt,
601
+ current_user_prompt_str=user_prompt_str.strip(),
602
+ current_image=image,
603
+ history=chat_history,
604
+ )
605
+
606
+ client_endpoint = API_PATHS[model_selector]
607
+ client = Client(
608
+ base_url=client_endpoint,
609
+ headers={"x-use-cache": "0", "Authorization": f"Bearer {API_TOKEN}"},
610
+ timeout=45,
611
+ )
612
+
613
+ # Common parameters to all decoding strategies
614
+ # This documentation is useful to read: https://huggingface.co/docs/transformers/main/en/generation_strategies
615
+ generation_args = {
616
+ "max_new_tokens": max_new_tokens,
617
+ "repetition_penalty": repetition_penalty,
618
+ "stop_sequences": EOS_STRINGS,
619
+ }
620
+
621
+ assert decoding_strategy in [
622
+ "Greedy",
623
+ "Top P Sampling",
624
+ ]
625
+ if decoding_strategy == "Greedy":
626
+ generation_args["do_sample"] = False
627
+ elif decoding_strategy == "Top P Sampling":
628
+ generation_args["temperature"] = temperature
629
+ generation_args["do_sample"] = True
630
+ generation_args["top_p"] = top_p
631
+
632
+ chat_history.append([prompt_list_to_markdown(user_prompt_list), ""])
633
+
634
+ query = prompt_list_to_tgi_input(formated_prompt_list)
635
+ all_meme_images = []
636
+ with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
637
+ futures = [
638
+ executor.submit(
639
+ generate_meme,
640
+ client,
641
+ query,
642
+ image,
643
+ font_meme_text,
644
+ all_caps_meme_text,
645
+ text_at_the_top,
646
+ generation_args,
647
+ )
648
+ for i in range(4)
649
+ ]
650
+
651
+ for future in concurrent.futures.as_completed(futures):
652
+ meme_image = future.result(timeout=45)
653
+ if meme_image:
654
+ all_meme_images.append(meme_image)
655
+ return user_prompt_str, all_meme_images, chat_history
656
+
657
+
658
+ def remove_last_turn(chat_history):
659
+ if len(chat_history) == 0:
660
+ return chat_history, "", ""
661
+ last_interaction = chat_history[-1]
662
+ chat_history = chat_history[:-1]
663
+ chat_update = chat_history
664
+ text_update = last_interaction[0]
665
+ return chat_update, text_update, ""
666
+
667
+
668
  textbox = gr.Textbox(
669
  placeholder="Upload an image and ask the AI to create a meme!",
670
  show_label=False,
 
883
  with gr.Row():
884
  chatbot.render()
885
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
886
  gr.on(
887
  triggers=[
888
  textbox.submit,
 
916
  outputs=[textbox, generated_memes_gallery, chatbot],
917
  )
918
 
 
 
 
 
 
 
 
 
 
919
  regenerate_btn.click(
920
  fn=remove_last_turn,
921
  inputs=chatbot,
requirements.txt CHANGED
@@ -9,10 +9,8 @@ opencv-python
9
  numpy
10
  accelerate
11
  joblib
12
- deepspeed
13
  parameterized
14
  einops
15
  pynvml
16
  sentencepiece
17
- text_generation
18
- https://gradio-builds.s3.amazonaws.com/2060bfe3e7eb57fb9b5c8695ebfc900469263d1f/gradio-3.46.0-py3-none-any.whl
 
9
  numpy
10
  accelerate
11
  joblib
 
12
  parameterized
13
  einops
14
  pynvml
15
  sentencepiece
16
+ text_generation