philipp-zettl commited on
Commit
223ef25
1 Parent(s): 5d8827b

use prompt embeddings rather than prompt strings

Browse files
Files changed (1) hide show
  1. app.py +30 -2
app.py CHANGED
@@ -11,11 +11,39 @@ pipe = DiffusionPipeline.from_pretrained(
11
  )
12
  pipe.to('cuda')
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  @spaces.GPU
15
  def generate(prompt, negative_prompt, num_inference_steps, guidance_scale, width, height, num_samples):
 
16
  return pipe(
17
- prompt,
18
- negative_prompt=negative_prompt,
19
  num_inference_steps=num_inference_steps,
20
  guidance_scale=guidance_scale,
21
  width=width,
 
11
  )
12
  pipe.to('cuda')
13
 
14
+
15
+ def build_embeddings(enhanced_prompt, negative_prompt=None):
16
+ max_length = pipe.tokenizer.model_max_length
17
+
18
+ input_ids = pipe.tokenizer(enhanced_prompt, return_tensors="pt").input_ids
19
+ input_ids = input_ids.to("cuda")
20
+
21
+ negative_ids = pipe.tokenizer(
22
+ negative_prompt or "",
23
+ truncation=False,
24
+ padding="max_length",
25
+ max_length=input_ids.shape[-1],
26
+ return_tensors="pt"
27
+ ).input_ids
28
+ negative_ids = negative_ids.to("cuda")
29
+
30
+ concat_embeds = []
31
+ neg_embeds = []
32
+ for i in range(0, input_ids.shape[-1], max_length):
33
+ concat_embeds.append(pipe.text_encoder(input_ids[:, i: i + max_length])[0])
34
+ neg_embeds.append(pipe.text_encoder(negative_ids[:, i: i + max_length])[0])
35
+
36
+ prompt_embeds = torch.cat(concat_embeds, dim=1)
37
+ negative_prompt_embeds = torch.cat(neg_embeds, dim=1)
38
+ return prompt_embeds, negative_prompt_embeds
39
+
40
+
41
  @spaces.GPU
42
  def generate(prompt, negative_prompt, num_inference_steps, guidance_scale, width, height, num_samples):
43
+ prompt_embeds, neg_prompt_embeds = build_embeddings(prompt, negative_prompt)
44
  return pipe(
45
+ prompt_embeds=prompt_embeds,
46
+ negative_prompt_embeds=negative_prompt_embeds,
47
  num_inference_steps=num_inference_steps,
48
  guidance_scale=guidance_scale,
49
  width=width,