bguisard commited on
Commit
a9057b2
1 Parent(s): 680c3bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -12,7 +12,7 @@ pipeline, pipeline_params = FlaxStableDiffusionPipeline.from_pretrained(
12
  def generate_image(prompt: str, inference_steps: int = 30, prng_seed: int = 0):
13
  rng = jax.random.PRNGKey(int(prng_seed))
14
  rng = jax.random.split(rng, jax.device_count())
15
- p_params = replicate(params)
16
 
17
  num_samples = 1
18
  prompt_ids = pipeline.prepare_inputs([prompt] * num_samples)
 
12
  def generate_image(prompt: str, inference_steps: int = 30, prng_seed: int = 0):
13
  rng = jax.random.PRNGKey(int(prng_seed))
14
  rng = jax.random.split(rng, jax.device_count())
15
+ p_params = replicate(pipeline_params)
16
 
17
  num_samples = 1
18
  prompt_ids = pipeline.prepare_inputs([prompt] * num_samples)