```python #!/usr/bin/env python3 from diffusers import FlaxStableDiffusionPipeline from jax import pmap import numpy as np import jax from flax.jax_utils import replicate from flax.training.common_utils import shard prng_seed = jax.random.PRNGKey(0) num_inference_steps = 50 pipeline, params = FlaxStableDiffusionPipeline.from_pretrained("fusing/stable-diffusion-flax-new", use_auth_token=True) del params["safety_checker"] # pmap p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,)) # prep prompts prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic" num_samples = jax.device_count() prompt = num_samples * [prompt] prompt_ids = pipeline.prepare_inputs(prompt) # replicate params = replicate(params) prng_seed = jax.random.split(prng_seed, 8) prompt_ids = shard(prompt_ids) # run images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images # get pil images images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) import ipdb; ipdb.set_trace() print("Images should be good") # images_pil[0].save(...) ```