bguisard commited on
Commit
0d29e9f
1 Parent(s): b466caf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -1
app.py CHANGED
@@ -1,3 +1,35 @@
1
  import gradio as gr
 
 
2
 
3
- gr.Interface.load("models/bguisard/stable-diffusion-nano").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import jax
3
+ from diffusers import FlaxStableDiffusionPipeline
4
 
5
+ pipeline, pipeline_params = FlaxStableDiffusionPipeline.from_pretrained(
6
+ "bguisard/stable-diffusion-nano",
7
+ )
8
+
9
+ prng_seed = jax.random.PRNGKey(0)
10
+ inference_steps = 50
11
+
12
+
13
+ def generate_image(prompt: str):
14
+ prompt_ids = pipeline.prepare_inputs(prompt)
15
+ images = pipeline(
16
+ prompt_ids=prompt_ids,
17
+ params=pipeline_params,
18
+ prng_seed=prng_seed,
19
+ height=128,
20
+ width=128,
21
+ num_inference_steps=inference_steps,
22
+ jit=False,
23
+ ).images
24
+ pil_imgs = pipeline.numpy_to_pil(images)
25
+ return pil_imgs[0]
26
+
27
+
28
+ app = gr.Interface(
29
+ fn=generate_image,
30
+ inputs="text",
31
+ outputs=gr.Image(shape=(128, 128)),
32
+ examples=[["A watercolor painting of a bird"]],
33
+ )
34
+
35
+ app.launch()