G-AshwinKumar commited on
Commit
90816e9
1 Parent(s): 733669b
Files changed (3) hide show
  1. app.py +4 -0
  2. pre-requirements.txt +5 -0
  3. requirements.txt +5 -3
app.py CHANGED
@@ -1,11 +1,15 @@
1
  import gradio as gr
2
  import jax
 
3
  from diffusers import FlaxStableDiffusionPipeline
4
  from flax.jax_utils import replicate
5
  from flax.training.common_utils import shard
6
 
7
  pipeline, pipeline_params = FlaxStableDiffusionPipeline.from_pretrained(
8
  "bguisard/stable-diffusion-nano",
 
 
 
9
  )
10
 
11
 
 
1
  import gradio as gr
2
  import jax
3
+ import jax.numpy as jnp
4
  from diffusers import FlaxStableDiffusionPipeline
5
  from flax.jax_utils import replicate
6
  from flax.training.common_utils import shard
7
 
8
  pipeline, pipeline_params = FlaxStableDiffusionPipeline.from_pretrained(
9
  "bguisard/stable-diffusion-nano",
10
+ dtype=jnp.float16,
11
+ resume_download=True,
12
+ use_memory_efficient_attention=True
13
  )
14
 
15
 
pre-requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ pip
2
+ setuptools
3
+ wheel
4
+ ninja
5
+ cmake
requirements.txt CHANGED
@@ -1,4 +1,6 @@
1
  transformers
2
- diffusers
3
- jax[cuda11_pip] #jax[cuda11_cudnn82] #jax[cuda11_cudnn86] #jax[cuda11_cudnn805]
4
- flax
 
 
 
1
  transformers
2
+ flax
3
+ jax[cuda11_pip]
4
+ -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
5
+ jaxlib
6
+ git+https://github.com/huggingface/diffusers@main