|
--- |
|
license: openrail++ |
|
--- |
|
|
|
# FLAX Latent Consistency Model (LCM) LoRA: SDXL - UNet |
|
|
|
Unet with merged LCM weights (lora_scale=0.7) and converted to work with FLAX. |
|
|
|
## Setup |
|
|
|
To use on TPUs: |
|
|
|
```bash |
|
git clone https://github.com/entrpn/diffusers |
|
cd diffusers |
|
git checkout lcm_flax |
|
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html |
|
pip install transformers flax torch torchvision |
|
pip install . |
|
``` |
|
|
|
## Run |
|
|
|
|
|
```python |
|
import os |
|
from diffusers import FlaxStableDiffusionXLPipeline |
|
import torch |
|
import time |
|
import jax |
|
import jax.numpy as jnp |
|
from flax.jax_utils import replicate |
|
import numpy as np |
|
from jax.experimental.compilation_cache import compilation_cache as cc |
|
cc.initialize_cache(os.path.expanduser("~/jax_cache")) |
|
|
|
from diffusers import ( |
|
FlaxUNet2DConditionModel, |
|
FlaxLCMScheduler |
|
) |
|
|
|
base_model = "stabilityai/stable-diffusion-xl-base-1.0" |
|
weight_dtype = jnp.bfloat16 |
|
revision= 'refs/pr/95' |
|
|
|
pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( |
|
base_model, revision=revision, dtype=weight_dtype |
|
) |
|
|
|
del params["unet"] |
|
|
|
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained( |
|
"jffacevedo/flax_lcm_unet", |
|
dtype=weight_dtype, |
|
) |
|
|
|
scheduler, scheduler_state = FlaxLCMScheduler.from_pretrained( |
|
base_model, |
|
subfolder="scheduler", |
|
revision=revision, |
|
dtype=jnp.float32 |
|
) |
|
|
|
params["unet"] = unet_params |
|
pipeline.unet = unet |
|
|
|
pipeline.scheduler = scheduler |
|
|
|
params = jax.tree_util.tree_map(lambda x: x.astype(weight_dtype), params) |
|
params["scheduler"] = scheduler_state |
|
|
|
default_prompt = "high-quality photo of a baby dolphin playing in a pool and wearing a party hat" |
|
default_neg_prompt = "" |
|
default_seed = 42 |
|
default_guidance_scale = 1.0 |
|
default_num_steps = 4 |
|
|
|
def tokenize_prompt(prompt, neg_prompt): |
|
prompt_ids = pipeline.prepare_inputs(prompt) |
|
neg_prompt_ids = pipeline.prepare_inputs(neg_prompt) |
|
return prompt_ids, neg_prompt_ids |
|
|
|
NUM_DEVICES = jax.device_count() |
|
|
|
p_params = replicate(params) |
|
|
|
def replicate_all(prompt_ids, neg_prompt_ids, seed): |
|
p_prompt_ids = replicate(prompt_ids) |
|
p_neg_prompt_ids = replicate(neg_prompt_ids) |
|
rng = jax.random.PRNGKey(seed) |
|
rng = jax.random.split(rng, NUM_DEVICES) |
|
return p_prompt_ids, p_neg_prompt_ids, rng |
|
|
|
def generate( |
|
prompt, |
|
negative_prompt, |
|
seed=default_seed, |
|
guidance_scale=default_guidance_scale, |
|
num_inference_steps=default_num_steps, |
|
): |
|
prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt) |
|
prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed) |
|
images = pipeline( |
|
prompt_ids, |
|
p_params, |
|
rng, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
do_classifier_free_guidance=False, |
|
jit=True, |
|
).images |
|
print("images.shape: ", images.shape) |
|
# convert the images to PIL |
|
images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:]) |
|
return pipeline.numpy_to_pil(np.array(images)) |
|
|
|
start = time.time() |
|
print(f"Compiling ...") |
|
generate(default_prompt, default_neg_prompt) |
|
print(f"Compiled in {time.time() - start}") |
|
|
|
dts = [] |
|
i = 0 |
|
for x in range(2): |
|
|
|
start = time.time() |
|
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" |
|
neg_prompt = "" |
|
|
|
print(f"Prompt: {prompt}") |
|
images = generate(prompt, neg_prompt) |
|
t = time.time() - start |
|
print(f"Inference in {t}") |
|
|
|
dts.append(t) |
|
for img in images: |
|
img.save(f'{i:06d}.jpg') |
|
i += 1 |
|
|
|
mean = np.mean(dts) |
|
stdev = np.std(dts) |
|
print(f"batches: {i}, Mean {mean:.2f} sec/batch± {stdev * 1.96 / np.sqrt(len(dts)):.2f} (95%)") |
|
``` |
|
|
|
|