patrickvonplaten commited on
Commit
66674a1
1 Parent(s): 82fc287

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +40 -0
README.md ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ```python
2
+ #!/usr/bin/env python3
3
+ from diffusers import FlaxStableDiffusionPipeline
4
+ from jax import pmap
5
+ import numpy as np
6
+ import jax
7
+ from flax.jax_utils import replicate
8
+ from flax.training.common_utils import shard
9
+
10
+
11
+ prng_seed = jax.random.PRNGKey(0)
12
+ num_inference_steps = 50
13
+
14
+ pipeline, params = FlaxStableDiffusionPipeline.from_pretrained("fusing/stable-diffusion-flax-new", use_auth_token=True)
15
+ del params["safety_checker"]
16
+
17
+ # pmap
18
+ p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
19
+
20
+ # prep prompts
21
+ prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
22
+ num_samples = jax.device_count()
23
+ prompt = num_samples * [prompt]
24
+ prompt_ids = pipeline.prepare_inputs(prompt)
25
+
26
+ # replicate
27
+ params = replicate(params)
28
+ prng_seed = jax.random.split(prng_seed, 8)
29
+ prompt_ids = shard(prompt_ids)
30
+
31
+ # run
32
+ images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
33
+
34
+ # get pil images
35
+ images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
36
+
37
+ import ipdb; ipdb.set_trace()
38
+ print("Images should be good")
39
+ # images_pil[0].save(...)
40
+ ```