File size: 3,722 Bytes
b73544b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
---
license: openrail++
---

# FLAX Latent Consistency Model (LCM) LoRA: SDXL - UNet

Unet with merged LCM weights (lora_scale=0.7) and converted to work with FLAX.

## Setup

To use on TPUs:

```bash
git clone https://github.com/entrpn/diffusers
cd diffusers
git checkout lcm_flax
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install transformers flax torch torchvision
pip install .
```

## Run


```python
import os
from diffusers import FlaxStableDiffusionXLPipeline
import torch
import time
import jax
import jax.numpy as jnp
from flax.jax_utils import replicate
import numpy as np
from jax.experimental.compilation_cache import compilation_cache as cc
cc.initialize_cache(os.path.expanduser("~/jax_cache"))

from diffusers import (
    FlaxUNet2DConditionModel,
    FlaxLCMScheduler
)

base_model = "stabilityai/stable-diffusion-xl-base-1.0"
weight_dtype = jnp.bfloat16
revision= 'refs/pr/95'

pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
    base_model, revision=revision, dtype=weight_dtype
  )

del params["unet"]

unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
    "jffacevedo/flax_lcm_unet",
    dtype=weight_dtype,
)

scheduler, scheduler_state = FlaxLCMScheduler.from_pretrained(
    base_model,
    subfolder="scheduler",
    revision=revision,
    dtype=jnp.float32
)

params["unet"] = unet_params
pipeline.unet = unet

pipeline.scheduler = scheduler

params = jax.tree_util.tree_map(lambda x: x.astype(weight_dtype), params)
params["scheduler"] = scheduler_state

default_prompt = "high-quality photo of a baby dolphin ​​playing in a pool and wearing a party hat"
default_neg_prompt = ""
default_seed = 42
default_guidance_scale = 1.0
default_num_steps = 4

def tokenize_prompt(prompt, neg_prompt):
    prompt_ids = pipeline.prepare_inputs(prompt)
    neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
    return prompt_ids, neg_prompt_ids

NUM_DEVICES = jax.device_count()

p_params = replicate(params)

def replicate_all(prompt_ids, neg_prompt_ids, seed):
    p_prompt_ids = replicate(prompt_ids)
    p_neg_prompt_ids = replicate(neg_prompt_ids)
    rng = jax.random.PRNGKey(seed)
    rng = jax.random.split(rng, NUM_DEVICES)
    return p_prompt_ids, p_neg_prompt_ids, rng

def generate(
    prompt,
    negative_prompt,
    seed=default_seed,
    guidance_scale=default_guidance_scale,
    num_inference_steps=default_num_steps,
):
    prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
    prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
    images = pipeline(
        prompt_ids,
        p_params,
        rng,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        do_classifier_free_guidance=False,
        jit=True,
    ).images
    print("images.shape: ", images.shape)
    # convert the images to PIL
    images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
    return pipeline.numpy_to_pil(np.array(images))

start = time.time()
print(f"Compiling ...")
generate(default_prompt, default_neg_prompt)
print(f"Compiled in {time.time() - start}")

dts = []
i = 0
for x in range(2):
    
    start = time.time()
    prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
    neg_prompt = ""

    print(f"Prompt: {prompt}")
    images = generate(prompt, neg_prompt)
    t = time.time() - start
    print(f"Inference in {t}")

    dts.append(t)
    for img in images:
        img.save(f'{i:06d}.jpg')
        i += 1    

mean = np.mean(dts)
stdev = np.std(dts)
print(f"batches: {i},  Mean {mean:.2f} sec/batch± {stdev * 1.96 / np.sqrt(len(dts)):.2f} (95%)")
```