bguisard commited on
Commit
7aec4a9
1 Parent(s): 17a50fe

Update dtype and example

Browse files
Files changed (1) hide show
  1. app.py +27 -2
app.py CHANGED
@@ -1,12 +1,27 @@
1
  import gradio as gr
2
  import jax
3
- from diffusers import FlaxStableDiffusionPipeline
 
4
  from flax.jax_utils import replicate
5
  from flax.training.common_utils import shard
6
 
 
 
7
  pipeline, pipeline_params = FlaxStableDiffusionPipeline.from_pretrained(
8
  "bguisard/stable-diffusion-nano-2-1",
 
9
  )
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  def generate_image(prompt: str, inference_steps: int = 30, prng_seed: int = 0):
@@ -51,7 +66,17 @@ app = gr.Interface(
51
  "Stable Diffusion Nano allows for fast prototyping of diffusion models, "
52
  "enabling quick experimentation with easily available hardware."
53
  ),
54
- examples=[["A watercolor painting of a bird", 30, 0]],
 
 
 
 
 
 
 
 
 
55
  )
56
 
57
  app.launch()
 
 
1
  import gradio as gr
2
  import jax
3
+ import jax.numpy as jnp
4
+ from diffusers import FlaxPNDMScheduler, FlaxStableDiffusionPipeline
5
  from flax.jax_utils import replicate
6
  from flax.training.common_utils import shard
7
 
8
+ DTYPE = jnp.bfloat16
9
+
10
  pipeline, pipeline_params = FlaxStableDiffusionPipeline.from_pretrained(
11
  "bguisard/stable-diffusion-nano-2-1",
12
+ dtype=DTYPE,
13
  )
14
+ if DTYPE != jnp.float32:
15
+ # There is a known issue with schedulers when loading from a pre trained
16
+ # pipeline. We need the schedulers to always use float32.
17
+ # See: https://github.com/huggingface/diffusers/issues/2155
18
+ scheduler, scheduler_params = FlaxPNDMScheduler.from_pretrained(
19
+ pretrained_model_name_or_path="bguisard/stable-diffusion-nano-2-1",
20
+ subfolder="scheduler",
21
+ dtype=jnp.float32,
22
+ )
23
+ pipeline_params["scheduler"] = scheduler_params
24
+ pipeline.scheduler = scheduler
25
 
26
 
27
  def generate_image(prompt: str, inference_steps: int = 30, prng_seed: int = 0):
 
66
  "Stable Diffusion Nano allows for fast prototyping of diffusion models, "
67
  "enabling quick experimentation with easily available hardware."
68
  ),
69
+ # Some examples were copied from hf.co/spaces/stabilityai/stable-diffusion
70
+ examples=[
71
+ # ["A watercolor painting of a bird", 30, 0],
72
+ [
73
+ "A small cabin on top of a snowy mountain in the style of Disney, artstation",
74
+ 25,
75
+ 3129302,
76
+ ],
77
+ # ["A mecha robot in a favela in expressionist style", 30, 827198341273],
78
+ ],
79
  )
80
 
81
  app.launch()
82
+