bguisard commited on
Commit
dac85aa
1 Parent(s): 1e446d2

Update app.py

Browse files

Add seed and inference steps inputs.

Files changed (1) hide show
  1. app.py +15 -7
app.py CHANGED
@@ -6,30 +6,38 @@ pipeline, pipeline_params = FlaxStableDiffusionPipeline.from_pretrained(
6
  "bguisard/stable-diffusion-nano",
7
  )
8
 
9
- prng_seed = jax.random.PRNGKey(0)
10
- inference_steps = 50
11
 
 
 
12
 
13
- def generate_image(prompt: str):
14
  prompt_ids = pipeline.prepare_inputs(prompt)
15
  images = pipeline(
16
  prompt_ids=prompt_ids,
17
  params=pipeline_params,
18
- prng_seed=prng_seed,
19
  height=128,
20
  width=128,
21
- num_inference_steps=inference_steps,
22
  jit=False,
23
  ).images
 
24
  pil_imgs = pipeline.numpy_to_pil(images)
25
  return pil_imgs[0]
26
 
27
 
 
 
 
 
 
 
 
 
28
  app = gr.Interface(
29
  fn=generate_image,
30
- inputs="text",
31
  outputs=gr.Image(shape=(128, 128)),
32
  examples=[["A watercolor painting of a bird"]],
33
  )
34
 
35
- app.launch()
 
6
  "bguisard/stable-diffusion-nano",
7
  )
8
 
 
 
9
 
10
+ def generate_image(prompt: str, inference_steps: int, prng_seed: int):
11
+ rng = jax.random.PRNGKey(int(prng_seed))
12
 
 
13
  prompt_ids = pipeline.prepare_inputs(prompt)
14
  images = pipeline(
15
  prompt_ids=prompt_ids,
16
  params=pipeline_params,
17
+ prng_seed=rng,
18
  height=128,
19
  width=128,
20
+ num_inference_steps=int(inference_steps),
21
  jit=False,
22
  ).images
23
+
24
  pil_imgs = pipeline.numpy_to_pil(images)
25
  return pil_imgs[0]
26
 
27
 
28
+ prompt_input = gr.inputs.Textbox(
29
+ label="Prompt", placeholder="A watercolor painting of a bird"
30
+ )
31
+ inf_steps_input = gr.inputs.Slider(
32
+ minimum=1, maximum=100, default=30, step=1, label="Inference Steps"
33
+ )
34
+ seed_input = gr.inputs.Number(default=0, label="Seed")
35
+
36
  app = gr.Interface(
37
  fn=generate_image,
38
+ inputs=[prompt_input, inf_steps_input, seed_input],
39
  outputs=gr.Image(shape=(128, 128)),
40
  examples=[["A watercolor painting of a bird"]],
41
  )
42
 
43
+ app.launch()