FLAX Latent Consistency Model (LCM) LoRA: SDXL - UNet

Unet with merged LCM weights (lora_scale=0.7) and converted to work with FLAX.

Setup

To use on TPUs:

git clone https://github.com/entrpn/diffusers
cd diffusers
git checkout lcm_flax
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install transformers flax torch torchvision
pip install .

Run

import os
from diffusers import FlaxStableDiffusionXLPipeline
import torch
import time
import jax
import jax.numpy as jnp
from flax.jax_utils import replicate
import numpy as np
from jax.experimental.compilation_cache import compilation_cache as cc
cc.initialize_cache(os.path.expanduser("~/jax_cache"))

from diffusers import (
    FlaxUNet2DConditionModel,
    FlaxLCMScheduler
)

base_model = "stabilityai/stable-diffusion-xl-base-1.0"
weight_dtype = jnp.bfloat16
revision= 'refs/pr/95'

pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
    base_model, revision=revision, dtype=weight_dtype
  )

del params["unet"]

unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
    "jffacevedo/flax_lcm_unet",
    dtype=weight_dtype,
)

scheduler, scheduler_state = FlaxLCMScheduler.from_pretrained(
    base_model,
    subfolder="scheduler",
    revision=revision,
    dtype=jnp.float32
)

params["unet"] = unet_params
pipeline.unet = unet

pipeline.scheduler = scheduler

params = jax.tree_util.tree_map(lambda x: x.astype(weight_dtype), params)
params["scheduler"] = scheduler_state

default_prompt = "high-quality photo of a baby dolphin ​​playing in a pool and wearing a party hat"
default_neg_prompt = ""
default_seed = 42
default_guidance_scale = 1.0
default_num_steps = 4

def tokenize_prompt(prompt, neg_prompt):
    prompt_ids = pipeline.prepare_inputs(prompt)
    neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
    return prompt_ids, neg_prompt_ids

NUM_DEVICES = jax.device_count()

p_params = replicate(params)

def replicate_all(prompt_ids, neg_prompt_ids, seed):
    p_prompt_ids = replicate(prompt_ids)
    p_neg_prompt_ids = replicate(neg_prompt_ids)
    rng = jax.random.PRNGKey(seed)
    rng = jax.random.split(rng, NUM_DEVICES)
    return p_prompt_ids, p_neg_prompt_ids, rng

def generate(
    prompt,
    negative_prompt,
    seed=default_seed,
    guidance_scale=default_guidance_scale,
    num_inference_steps=default_num_steps,
):
    prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
    prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
    images = pipeline(
        prompt_ids,
        p_params,
        rng,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        do_classifier_free_guidance=False,
        jit=True,
    ).images
    print("images.shape: ", images.shape)
    # convert the images to PIL
    images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
    return pipeline.numpy_to_pil(np.array(images))

start = time.time()
print(f"Compiling ...")
generate(default_prompt, default_neg_prompt)
print(f"Compiled in {time.time() - start}")

dts = []
i = 0
for x in range(2):
    
    start = time.time()
    prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
    neg_prompt = ""

    print(f"Prompt: {prompt}")
    images = generate(prompt, neg_prompt)
    t = time.time() - start
    print(f"Inference in {t}")

    dts.append(t)
    for img in images:
        img.save(f'{i:06d}.jpg')
        i += 1    

mean = np.mean(dts)
stdev = np.std(dts)
print(f"batches: {i},  Mean {mean:.2f} sec/batch± {stdev * 1.96 / np.sqrt(len(dts)):.2f} (95%)")
Downloads last month
7
Inference API
Unable to determine this model’s pipeline type. Check the docs .