bguisard commited on
Commit
54771e9
1 Parent(s): b314fb6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -5
app.py CHANGED
@@ -9,11 +9,16 @@ pipeline, pipeline_params = FlaxStableDiffusionPipeline.from_pretrained(
9
 
10
  def generate_image(prompt: str, inference_steps: int = 30, prng_seed: int = 0):
11
  rng = jax.random.PRNGKey(int(prng_seed))
12
-
13
- prompt_ids = pipeline.prepare_inputs(prompt)
 
 
 
 
 
14
  images = pipeline(
15
  prompt_ids=prompt_ids,
16
- params=pipeline_params,
17
  prng_seed=rng,
18
  height=128,
19
  width=128,
@@ -21,8 +26,9 @@ def generate_image(prompt: str, inference_steps: int = 30, prng_seed: int = 0):
21
  jit=True,
22
  ).images
23
 
24
- pil_imgs = pipeline.numpy_to_pil(images)
25
- return pil_imgs[0]
 
26
 
27
 
28
  prompt_input = gr.inputs.Textbox(
 
9
 
10
  def generate_image(prompt: str, inference_steps: int = 30, prng_seed: int = 0):
11
  rng = jax.random.PRNGKey(int(prng_seed))
12
+ rng = jax.random.split(rng, jax.device_count())
13
+ p_params = replicate(params)
14
+
15
+ num_samples = 1
16
+ prompt_ids = pipeline.prepare_inputs([prompt] * num_samples)
17
+ prompt_ids = shard(prompt_ids)
18
+
19
  images = pipeline(
20
  prompt_ids=prompt_ids,
21
+ params=p_params,
22
  prng_seed=rng,
23
  height=128,
24
  width=128,
 
26
  jit=True,
27
  ).images
28
 
29
+ images = images.reshape((num_samples,) + output.shape[-3:])
30
+ images = pipeline.numpy_to_pil(images)
31
+ return images
32
 
33
 
34
  prompt_input = gr.inputs.Textbox(