--- license: openrail++ --- # FLAX Latent Consistency Model (LCM) LoRA: SDXL - UNet Unet with merged LCM weights (lora_scale=0.7) and converted to work with FLAX. ## Setup To use on TPUs: ```bash git clone https://github.com/entrpn/diffusers cd diffusers git checkout lcm_flax pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html pip install transformers flax torch torchvision pip install . ``` ## Run ```python import os from diffusers import FlaxStableDiffusionXLPipeline import torch import time import jax import jax.numpy as jnp from flax.jax_utils import replicate import numpy as np from jax.experimental.compilation_cache import compilation_cache as cc cc.initialize_cache(os.path.expanduser("~/jax_cache")) from diffusers import ( FlaxUNet2DConditionModel, FlaxLCMScheduler ) base_model = "stabilityai/stable-diffusion-xl-base-1.0" weight_dtype = jnp.bfloat16 revision= 'refs/pr/95' pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( base_model, revision=revision, dtype=weight_dtype ) del params["unet"] unet, unet_params = FlaxUNet2DConditionModel.from_pretrained( "jffacevedo/flax_lcm_unet", dtype=weight_dtype, ) scheduler, scheduler_state = FlaxLCMScheduler.from_pretrained( base_model, subfolder="scheduler", revision=revision, dtype=jnp.float32 ) params["unet"] = unet_params pipeline.unet = unet pipeline.scheduler = scheduler params = jax.tree_util.tree_map(lambda x: x.astype(weight_dtype), params) params["scheduler"] = scheduler_state default_prompt = "high-quality photo of a baby dolphin ​​playing in a pool and wearing a party hat" default_neg_prompt = "" default_seed = 42 default_guidance_scale = 1.0 default_num_steps = 4 def tokenize_prompt(prompt, neg_prompt): prompt_ids = pipeline.prepare_inputs(prompt) neg_prompt_ids = pipeline.prepare_inputs(neg_prompt) return prompt_ids, neg_prompt_ids NUM_DEVICES = jax.device_count() p_params = replicate(params) def replicate_all(prompt_ids, neg_prompt_ids, seed): p_prompt_ids = replicate(prompt_ids) p_neg_prompt_ids = replicate(neg_prompt_ids) rng = jax.random.PRNGKey(seed) rng = jax.random.split(rng, NUM_DEVICES) return p_prompt_ids, p_neg_prompt_ids, rng def generate( prompt, negative_prompt, seed=default_seed, guidance_scale=default_guidance_scale, num_inference_steps=default_num_steps, ): prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt) prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed) images = pipeline( prompt_ids, p_params, rng, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, do_classifier_free_guidance=False, jit=True, ).images print("images.shape: ", images.shape) # convert the images to PIL images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:]) return pipeline.numpy_to_pil(np.array(images)) start = time.time() print(f"Compiling ...") generate(default_prompt, default_neg_prompt) print(f"Compiled in {time.time() - start}") dts = [] i = 0 for x in range(2): start = time.time() prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" neg_prompt = "" print(f"Prompt: {prompt}") images = generate(prompt, neg_prompt) t = time.time() - start print(f"Inference in {t}") dts.append(t) for img in images: img.save(f'{i:06d}.jpg') i += 1 mean = np.mean(dts) stdev = np.std(dts) print(f"batches: {i}, Mean {mean:.2f} sec/batch± {stdev * 1.96 / np.sqrt(len(dts)):.2f} (95%)") ```