YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)
#!/usr/bin/env python3
from diffusers import FlaxStableDiffusionPipeline
from jax import pmap
import numpy as np
import jax
from flax.jax_utils import replicate
from flax.training.common_utils import shard


prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained("fusing/stable-diffusion-flax-new", use_auth_token=True)
del params["safety_checker"]

# pmap
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))

# prep prompts
prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)

# replicate
params = replicate(params)
prng_seed = jax.random.split(prng_seed, 8)
prompt_ids = shard(prompt_ids)

# run
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images

# get pil images
images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))

import ipdb; ipdb.set_trace()
print("Images should be good")
# images_pil[0].save(...)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.