ash123 commited on
Commit
3d8eafa
1 Parent(s): 70191e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -1
app.py CHANGED
@@ -6,10 +6,23 @@ from flax.jax_utils import replicate
6
  from flax.training.common_utils import shard
7
  from share_btn import community_icon_html, loading_icon_html, share_js
8
 
 
 
9
  pipeline, pipeline_params = FlaxStableDiffusionPipeline.from_pretrained(
10
  "bguisard/stable-diffusion-nano-2-1",
11
- dtype=jnp.float16
12
  )
 
 
 
 
 
 
 
 
 
 
 
13
 
14
 
15
  def generate_image(prompt: str, negative_prompt: str = "", inference_steps: int = 25, prng_seed: int = 0, guidance_scale: float = 9):
@@ -20,6 +33,7 @@ def generate_image(prompt: str, negative_prompt: str = "", inference_steps: int
20
  num_samples = 1
21
  prompt_ids = pipeline.prepare_inputs([prompt] * num_samples)
22
  prompt_ids = shard(prompt_ids)
 
23
  if negative_prompt == "":
24
  images = pipeline(
25
  prompt_ids=prompt_ids,
 
6
  from flax.training.common_utils import shard
7
  from share_btn import community_icon_html, loading_icon_html, share_js
8
 
9
+ DTYPE = jnp.float16
10
+
11
  pipeline, pipeline_params = FlaxStableDiffusionPipeline.from_pretrained(
12
  "bguisard/stable-diffusion-nano-2-1",
13
+ dtype=DTYPE,
14
  )
15
+ if DTYPE != jnp.float32:
16
+ # There is a known issue with schedulers when loading from a pre trained
17
+ # pipeline. We need the schedulers to always use float32.
18
+ # See: https://github.com/huggingface/diffusers/issues/2155
19
+ scheduler, scheduler_params = FlaxPNDMScheduler.from_pretrained(
20
+ pretrained_model_name_or_path="bguisard/stable-diffusion-nano-2-1",
21
+ subfolder="scheduler",
22
+ dtype=jnp.float32,
23
+ )
24
+ pipeline_params["scheduler"] = scheduler_params
25
+ pipeline.scheduler = scheduler
26
 
27
 
28
  def generate_image(prompt: str, negative_prompt: str = "", inference_steps: int = 25, prng_seed: int = 0, guidance_scale: float = 9):
 
33
  num_samples = 1
34
  prompt_ids = pipeline.prepare_inputs([prompt] * num_samples)
35
  prompt_ids = shard(prompt_ids)
36
+
37
  if negative_prompt == "":
38
  images = pipeline(
39
  prompt_ids=prompt_ids,