echarlaix HF staff commited on
Commit
9659c11
1 Parent(s): 47178c0

Statically reshapes the SD modelto speed up inference

Browse files
Files changed (2) hide show
  1. app.py +16 -17
  2. requirements.txt +5 -1
app.py CHANGED
@@ -5,20 +5,21 @@ import re
5
 
6
  import torch
7
  from transformers import AutoModelWithLMHead, AutoTokenizer, pipeline, set_seed
 
8
 
9
- import gradio as grad
10
- from diffusers import StableDiffusionPipeline
11
 
12
- tokenizer = AutoTokenizer.from_pretrained("shahp7575/gpt2-horoscopes")
13
- model = AutoModelWithLMHead.from_pretrained("shahp7575/gpt2-horoscopes")
 
 
 
 
 
 
 
14
 
15
  def fn(sign, cat):
16
- sign = "scorpio"
17
-
18
  prompt = f"<|category|> {cat} <|horoscope|> {sign}"
19
-
20
-
21
-
22
  prompt_encoded = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
23
 
24
  sample_outputs = model.generate(
@@ -29,18 +30,16 @@ def fn(sign, cat):
29
  top_p=0.95,
30
  temperature=0.95,
31
  num_beams=4,
32
- num_return_sequences=4,
33
  )
34
 
35
  final_out = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
36
  starting_text = " ".join(final_out.split(" ")[4:])
37
- pipe = pipeline("text-generation", model="Gustavosta/MagicPrompt-Stable-Diffusion", tokenizer="gpt2")
38
-
39
  seed = random.randint(100, 1000000)
40
  set_seed(seed)
41
- response = pipe(starting_text + " " + sign + " art.", max_length=(len(starting_text) + random.randint(60, 90)), num_return_sequences=1)
42
- pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
43
- image = pipe(response[0]["generated_text"], num_inference_steps=5).images[0]
44
  return [image, starting_text]
45
 
46
 
@@ -52,7 +51,7 @@ with block:
52
  with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
53
  text = gr.Dropdown(
54
  label="Star Sign",
55
- choices=["aries", "taurus","gemini", "cancer", "leo", "virgo", "libra", "scorpio", "sagittarius", "capricorn", "aquarius", "Pisces"],
56
  show_label=True,
57
  max_lines=1,
58
  placeholder="Enter your prompt",
@@ -64,7 +63,7 @@ with block:
64
  )
65
 
66
  text2 = gr.Dropdown(
67
- choices=["love", "career", "wellness"],
68
  label="Category",
69
  show_label=True,
70
  max_lines=1,
 
5
 
6
  import torch
7
  from transformers import AutoModelWithLMHead, AutoTokenizer, pipeline, set_seed
8
+ from optimum.intel.openvino import OVStableDiffusionPipeline
9
 
 
 
10
 
11
+ horoscope_model_id = "shahp7575/gpt2-horoscopes"
12
+ tokenizer = AutoTokenizer.from_pretrained(horoscope_model_id)
13
+ model = AutoModelWithLMHead.from_pretrained(horoscope_model_id)
14
+ text_generation_pipe = pipeline("text-generation", model="Gustavosta/MagicPrompt-Stable-Diffusion", tokenizer="gpt2")
15
+ stable_diffusion_pipe = OVStableDiffusionPipeline.from_pretrained("echarlaix/stable-diffusion-v1-5-openvino", revision="fp16", compile=False)
16
+ height = 128
17
+ width = 128
18
+ stable_diffusion_pipe.reshape(batch_size=1, height=height, width=width, num_images_per_prompt=1)
19
+ stable_diffusion_pipe.compile()
20
 
21
  def fn(sign, cat):
 
 
22
  prompt = f"<|category|> {cat} <|horoscope|> {sign}"
 
 
 
23
  prompt_encoded = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
24
 
25
  sample_outputs = model.generate(
 
30
  top_p=0.95,
31
  temperature=0.95,
32
  num_beams=4,
33
+ num_return_sequences=1,
34
  )
35
 
36
  final_out = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
37
  starting_text = " ".join(final_out.split(" ")[4:])
 
 
38
  seed = random.randint(100, 1000000)
39
  set_seed(seed)
40
+ response = text_generation_pipe(starting_text + " " + sign + " art", max_length=(len(starting_text) + random.randint(60, 90)), num_return_sequences=1)
41
+ image = stable_diffusion_pipe(response[0]["generated_text"], height=height, width=width, num_inference_steps=30).images[0]
42
+
43
  return [image, starting_text]
44
 
45
 
 
51
  with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
52
  text = gr.Dropdown(
53
  label="Star Sign",
54
+ choices=["Aries", "Taurus","Gemini", "Cancer", "Leo", "Virgo", "Libra", "Scorpio", "Sagittarius", "Capricorn", "Aquarius", "Pisces"],
55
  show_label=True,
56
  max_lines=1,
57
  placeholder="Enter your prompt",
 
63
  )
64
 
65
  text2 = gr.Dropdown(
66
+ choices=["Love", "Career", "Wellness"],
67
  label="Category",
68
  show_label=True,
69
  max_lines=1,
requirements.txt CHANGED
@@ -1,3 +1,7 @@
1
  transformers
2
  torch
3
- diffusers
 
 
 
 
 
1
  transformers
2
  torch
3
+ diffusers
4
+ onnx
5
+ onnxruntime
6
+ openvino
7
+ optimum-intel