jffacevedo commited on
Commit
b73544b
1 Parent(s): cf84649

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +145 -0
README.md ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: openrail++
3
+ ---
4
+
5
+ # FLAX Latent Consistency Model (LCM) LoRA: SDXL - UNet
6
+
7
+ Unet with merged LCM weights (lora_scale=0.7) and converted to work with FLAX.
8
+
9
+ ## Setup
10
+
11
+ To use on TPUs:
12
+
13
+ ```bash
14
+ git clone https://github.com/entrpn/diffusers
15
+ cd diffusers
16
+ git checkout lcm_flax
17
+ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
18
+ pip install transformers flax torch torchvision
19
+ pip install .
20
+ ```
21
+
22
+ ## Run
23
+
24
+
25
+ ```python
26
+ import os
27
+ from diffusers import FlaxStableDiffusionXLPipeline
28
+ import torch
29
+ import time
30
+ import jax
31
+ import jax.numpy as jnp
32
+ from flax.jax_utils import replicate
33
+ import numpy as np
34
+ from jax.experimental.compilation_cache import compilation_cache as cc
35
+ cc.initialize_cache(os.path.expanduser("~/jax_cache"))
36
+
37
+ from diffusers import (
38
+ FlaxUNet2DConditionModel,
39
+ FlaxLCMScheduler
40
+ )
41
+
42
+ base_model = "stabilityai/stable-diffusion-xl-base-1.0"
43
+ lcm_model = "sd_lora_model"
44
+ weight_dtype = jnp.bfloat16
45
+ revision= 'refs/pr/95'
46
+
47
+ pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
48
+ base_model, revision=revision, dtype=weight_dtype
49
+ )
50
+
51
+ del params["unet"]
52
+
53
+ unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
54
+ "jffacevedo/flax_lcm_unet",
55
+ dtype=weight_dtype,
56
+ )
57
+
58
+ scheduler, scheduler_state = FlaxLCMScheduler.from_pretrained(
59
+ base_model,
60
+ subfolder="scheduler",
61
+ revision=revision,
62
+ dtype=jnp.float32
63
+ )
64
+
65
+ params["unet"] = unet_params
66
+ pipeline.unet = unet
67
+
68
+ pipeline.scheduler = scheduler
69
+
70
+ params = jax.tree_util.tree_map(lambda x: x.astype(weight_dtype), params)
71
+ params["scheduler"] = scheduler_state
72
+
73
+ default_prompt = "high-quality photo of a baby dolphin ​​playing in a pool and wearing a party hat"
74
+ default_neg_prompt = ""
75
+ default_seed = 42
76
+ default_guidance_scale = 1.0
77
+ default_num_steps = 4
78
+
79
+ def tokenize_prompt(prompt, neg_prompt):
80
+ prompt_ids = pipeline.prepare_inputs(prompt)
81
+ neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
82
+ return prompt_ids, neg_prompt_ids
83
+
84
+ NUM_DEVICES = jax.device_count()
85
+
86
+ p_params = replicate(params)
87
+
88
+ def replicate_all(prompt_ids, neg_prompt_ids, seed):
89
+ p_prompt_ids = replicate(prompt_ids)
90
+ p_neg_prompt_ids = replicate(neg_prompt_ids)
91
+ rng = jax.random.PRNGKey(seed)
92
+ rng = jax.random.split(rng, NUM_DEVICES)
93
+ return p_prompt_ids, p_neg_prompt_ids, rng
94
+
95
+ def generate(
96
+ prompt,
97
+ negative_prompt,
98
+ seed=default_seed,
99
+ guidance_scale=default_guidance_scale,
100
+ num_inference_steps=default_num_steps,
101
+ ):
102
+ prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
103
+ prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
104
+ images = pipeline(
105
+ prompt_ids,
106
+ p_params,
107
+ rng,
108
+ num_inference_steps=num_inference_steps,
109
+ guidance_scale=guidance_scale,
110
+ do_classifier_free_guidance=False,
111
+ jit=True,
112
+ ).images
113
+ print("images.shape: ", images.shape)
114
+ # convert the images to PIL
115
+ images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
116
+ return pipeline.numpy_to_pil(np.array(images))
117
+
118
+ start = time.time()
119
+ print(f"Compiling ...")
120
+ generate(default_prompt, default_neg_prompt)
121
+ print(f"Compiled in {time.time() - start}")
122
+
123
+ dts = []
124
+ i = 0
125
+ for x in range(2):
126
+
127
+ start = time.time()
128
+ prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
129
+ neg_prompt = ""
130
+
131
+ print(f"Prompt: {prompt}")
132
+ images = generate(prompt, neg_prompt)
133
+ t = time.time() - start
134
+ print(f"Inference in {t}")
135
+
136
+ dts.append(t)
137
+ for img in images:
138
+ img.save(f'{i:06d}.jpg')
139
+ i += 1
140
+
141
+ mean = np.mean(dts)
142
+ stdev = np.std(dts)
143
+ print(f"batches: {i}, Mean {mean:.2f} sec/batch± {stdev * 1.96 / np.sqrt(len(dts)):.2f} (95%)")
144
+ ```
145
+