FIX: add max filename checks and trim if too long
Browse filesI ran into a problem where the length of the out_folder, which is named based on the prompt, exceeded the maximum length of a filenames. I added support to trim the long out_folder name to the maximum allowed by the operating system.
That said, characters beyond the limit are removed which might cause collisions if multiple prompts have the same first 255 characters.
- run_rknn-lcm.py.py +698 -0
run_rknn-lcm.py.py
ADDED
@@ -0,0 +1,698 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from rknnlite.api import RKNNLite
|
2 |
+
from PIL import Image
|
3 |
+
from typing import Callable, List, Optional, Union, Tuple
|
4 |
+
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
5 |
+
import torch # Only used for `torch.from_tensor` in `pipe.scheduler.step()`
|
6 |
+
import numpy as np
|
7 |
+
import logging
|
8 |
+
from diffusers.schedulers import (
|
9 |
+
LCMScheduler
|
10 |
+
)
|
11 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
12 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
13 |
+
from diffusers import StableDiffusionPipeline
|
14 |
+
import PIL
|
15 |
+
import platform
|
16 |
+
import os
|
17 |
+
import time
|
18 |
+
import json
|
19 |
+
import argparse
|
20 |
+
|
21 |
+
|
22 |
+
logging.basicConfig()
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
logger.setLevel(logging.INFO)
|
25 |
+
|
26 |
+
|
27 |
+
class RKNN2Model:
|
28 |
+
""" Wrapper for running RKNPU2 models """
|
29 |
+
|
30 |
+
def __init__(self, model_dir):
|
31 |
+
logger.info(f"Loading {model_dir}")
|
32 |
+
start = time.time()
|
33 |
+
self.config = json.load(open(os.path.join(model_dir, "config.json")))
|
34 |
+
assert os.path.exists(model_dir) and os.path.exists(
|
35 |
+
os.path.join(model_dir, "model.rknn"))
|
36 |
+
self.rknnlite = RKNNLite()
|
37 |
+
self.rknnlite.load_rknn(os.path.join(model_dir, "model.rknn"))
|
38 |
+
# Multi-core will cause kernel crash
|
39 |
+
self.rknnlite.init_runtime(core_mask=RKNNLite.NPU_CORE_AUTO)
|
40 |
+
load_time = time.time() - start
|
41 |
+
logger.info(f"Done. Took {load_time:.1f} seconds.")
|
42 |
+
self.modelname = model_dir.split("/")[-1]
|
43 |
+
self.inference_time = 0
|
44 |
+
|
45 |
+
def __call__(self, **kwargs) -> List[np.ndarray]:
|
46 |
+
# np.savez(f"rknn_out/{self.modelname}_input_{self.inference_time}.npz", **kwargs)
|
47 |
+
# self.inference_time += 1
|
48 |
+
# print(kwargs)
|
49 |
+
input_list = [value for key, value in kwargs.items()]
|
50 |
+
for i, input in enumerate(input_list):
|
51 |
+
if isinstance(input, np.ndarray):
|
52 |
+
print(f"input {i} shape: {input.shape}")
|
53 |
+
|
54 |
+
results = self.rknnlite.inference(
|
55 |
+
inputs=input_list, data_format='nchw')
|
56 |
+
for res in results:
|
57 |
+
print(f"output shape: {res.shape}")
|
58 |
+
return results
|
59 |
+
|
60 |
+
|
61 |
+
class RKNN2LatentConsistencyPipeline(DiffusionPipeline):
|
62 |
+
|
63 |
+
def __init__(
|
64 |
+
self,
|
65 |
+
text_encoder: RKNN2Model,
|
66 |
+
unet: RKNN2Model,
|
67 |
+
vae_decoder: RKNN2Model,
|
68 |
+
scheduler: LCMScheduler,
|
69 |
+
tokenizer: CLIPTokenizer,
|
70 |
+
force_zeros_for_empty_prompt: Optional[bool] = True,
|
71 |
+
feature_extractor: Optional[CLIPFeatureExtractor] = None,
|
72 |
+
text_encoder_2: Optional[RKNN2Model] = None,
|
73 |
+
tokenizer_2: Optional[CLIPTokenizer] = None
|
74 |
+
):
|
75 |
+
super().__init__()
|
76 |
+
|
77 |
+
self.register_modules(
|
78 |
+
tokenizer=tokenizer,
|
79 |
+
scheduler=scheduler,
|
80 |
+
feature_extractor=feature_extractor,
|
81 |
+
)
|
82 |
+
self.force_zeros_for_empty_prompt = force_zeros_for_empty_prompt
|
83 |
+
self.safety_checker = None
|
84 |
+
|
85 |
+
self.text_encoder = text_encoder
|
86 |
+
self.text_encoder_2 = text_encoder_2
|
87 |
+
self.tokenizer_2 = tokenizer_2
|
88 |
+
self.unet = unet
|
89 |
+
self.vae_decoder = vae_decoder
|
90 |
+
|
91 |
+
VAE_DECODER_UPSAMPLE_FACTOR = 8
|
92 |
+
self.vae_scale_factor = VAE_DECODER_UPSAMPLE_FACTOR
|
93 |
+
|
94 |
+
@staticmethod
|
95 |
+
def postprocess(
|
96 |
+
image: np.ndarray,
|
97 |
+
output_type: str = "pil",
|
98 |
+
do_denormalize: Optional[List[bool]] = None,
|
99 |
+
):
|
100 |
+
def numpy_to_pil(images: np.ndarray):
|
101 |
+
"""
|
102 |
+
Convert a numpy image or a batch of images to a PIL image.
|
103 |
+
"""
|
104 |
+
if images.ndim == 3:
|
105 |
+
images = images[None, ...]
|
106 |
+
images = (images * 255).round().astype("uint8")
|
107 |
+
if images.shape[-1] == 1:
|
108 |
+
# special case for grayscale (single channel) images
|
109 |
+
pil_images = [Image.fromarray(
|
110 |
+
image.squeeze(), mode="L") for image in images]
|
111 |
+
else:
|
112 |
+
pil_images = [Image.fromarray(image) for image in images]
|
113 |
+
|
114 |
+
return pil_images
|
115 |
+
|
116 |
+
def denormalize(images: np.ndarray):
|
117 |
+
"""
|
118 |
+
Denormalize an image array to [0,1].
|
119 |
+
"""
|
120 |
+
return np.clip(images / 2 + 0.5, 0, 1)
|
121 |
+
|
122 |
+
if not isinstance(image, np.ndarray):
|
123 |
+
raise ValueError(
|
124 |
+
f"Input for postprocessing is in incorrect format: {
|
125 |
+
type(image)}. We only support np array"
|
126 |
+
)
|
127 |
+
if output_type not in ["latent", "np", "pil"]:
|
128 |
+
deprecation_message = (
|
129 |
+
f"the output_type {
|
130 |
+
output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
|
131 |
+
"`pil`, `np`, `pt`, `latent`"
|
132 |
+
)
|
133 |
+
logger.warning(deprecation_message)
|
134 |
+
output_type = "np"
|
135 |
+
|
136 |
+
if output_type == "latent":
|
137 |
+
return image
|
138 |
+
|
139 |
+
if do_denormalize is None:
|
140 |
+
raise ValueError("do_denormalize is required for postprocessing")
|
141 |
+
|
142 |
+
image = np.stack(
|
143 |
+
[denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])], axis=0
|
144 |
+
)
|
145 |
+
image = image.transpose((0, 2, 3, 1))
|
146 |
+
|
147 |
+
if output_type == "pil":
|
148 |
+
image = numpy_to_pil(image)
|
149 |
+
|
150 |
+
return image
|
151 |
+
|
152 |
+
def _encode_prompt(
|
153 |
+
self,
|
154 |
+
prompt: Union[str, List[str]],
|
155 |
+
num_images_per_prompt: int,
|
156 |
+
do_classifier_free_guidance: bool,
|
157 |
+
negative_prompt: Optional[Union[str, list]],
|
158 |
+
prompt_embeds: Optional[np.ndarray] = None,
|
159 |
+
negative_prompt_embeds: Optional[np.ndarray] = None,
|
160 |
+
):
|
161 |
+
r"""
|
162 |
+
Encodes the prompt into text encoder hidden states.
|
163 |
+
|
164 |
+
Args:
|
165 |
+
prompt (`Union[str, List[str]]`):
|
166 |
+
prompt to be encoded
|
167 |
+
num_images_per_prompt (`int`):
|
168 |
+
number of images that should be generated per prompt
|
169 |
+
do_classifier_free_guidance (`bool`):
|
170 |
+
whether to use classifier free guidance or not
|
171 |
+
negative_prompt (`Optional[Union[str, list]]`):
|
172 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
173 |
+
if `guidance_scale` is less than `1`).
|
174 |
+
prompt_embeds (`Optional[np.ndarray]`, defaults to `None`):
|
175 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
176 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
177 |
+
negative_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`):
|
178 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
179 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
180 |
+
argument.
|
181 |
+
"""
|
182 |
+
if isinstance(prompt, str):
|
183 |
+
batch_size = 1
|
184 |
+
elif isinstance(prompt, list):
|
185 |
+
batch_size = len(prompt)
|
186 |
+
else:
|
187 |
+
batch_size = prompt_embeds.shape[0]
|
188 |
+
|
189 |
+
if prompt_embeds is None:
|
190 |
+
# get prompt text embeddings
|
191 |
+
text_inputs = self.tokenizer(
|
192 |
+
prompt,
|
193 |
+
padding="max_length",
|
194 |
+
max_length=self.tokenizer.model_max_length,
|
195 |
+
truncation=True,
|
196 |
+
return_tensors="np",
|
197 |
+
)
|
198 |
+
text_input_ids = text_inputs.input_ids
|
199 |
+
untruncated_ids = self.tokenizer(
|
200 |
+
prompt, padding="max_length", return_tensors="np").input_ids
|
201 |
+
|
202 |
+
if not np.array_equal(text_input_ids, untruncated_ids):
|
203 |
+
removed_text = self.tokenizer.batch_decode(
|
204 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]
|
205 |
+
)
|
206 |
+
logger.warning(
|
207 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
208 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
209 |
+
)
|
210 |
+
|
211 |
+
prompt_embeds = self.text_encoder(
|
212 |
+
input_ids=text_input_ids.astype(np.int32))[0]
|
213 |
+
|
214 |
+
prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
|
215 |
+
|
216 |
+
# get unconditional embeddings for classifier free guidance
|
217 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
218 |
+
uncond_tokens: List[str]
|
219 |
+
if negative_prompt is None:
|
220 |
+
uncond_tokens = [""] * batch_size
|
221 |
+
elif type(prompt) is not type(negative_prompt):
|
222 |
+
raise TypeError(
|
223 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {
|
224 |
+
type(negative_prompt)} !="
|
225 |
+
f" {type(prompt)}."
|
226 |
+
)
|
227 |
+
elif isinstance(negative_prompt, str):
|
228 |
+
uncond_tokens = [negative_prompt] * batch_size
|
229 |
+
elif batch_size != len(negative_prompt):
|
230 |
+
raise ValueError(
|
231 |
+
f"`negative_prompt`: {negative_prompt} has batch size {
|
232 |
+
len(negative_prompt)}, but `prompt`:"
|
233 |
+
f" {prompt} has batch size {
|
234 |
+
batch_size}. Please make sure that passed `negative_prompt` matches"
|
235 |
+
" the batch size of `prompt`."
|
236 |
+
)
|
237 |
+
else:
|
238 |
+
uncond_tokens = negative_prompt
|
239 |
+
|
240 |
+
max_length = prompt_embeds.shape[1]
|
241 |
+
uncond_input = self.tokenizer(
|
242 |
+
uncond_tokens,
|
243 |
+
padding="max_length",
|
244 |
+
max_length=max_length,
|
245 |
+
truncation=True,
|
246 |
+
return_tensors="np",
|
247 |
+
)
|
248 |
+
negative_prompt_embeds = self.text_encoder(
|
249 |
+
input_ids=uncond_input.input_ids.astype(np.int32))[0]
|
250 |
+
|
251 |
+
if do_classifier_free_guidance:
|
252 |
+
negative_prompt_embeds = np.repeat(
|
253 |
+
negative_prompt_embeds, num_images_per_prompt, axis=0)
|
254 |
+
|
255 |
+
# For classifier free guidance, we need to do two forward passes.
|
256 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
257 |
+
# to avoid doing two forward passes
|
258 |
+
prompt_embeds = np.concatenate(
|
259 |
+
[negative_prompt_embeds, prompt_embeds])
|
260 |
+
|
261 |
+
return prompt_embeds
|
262 |
+
|
263 |
+
# Copied from https://github.com/huggingface/diffusers/blob/v0.17.1/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L217
|
264 |
+
def check_inputs(
|
265 |
+
self,
|
266 |
+
prompt: Union[str, List[str]],
|
267 |
+
height: Optional[int],
|
268 |
+
width: Optional[int],
|
269 |
+
callback_steps: int,
|
270 |
+
negative_prompt: Optional[str] = None,
|
271 |
+
prompt_embeds: Optional[np.ndarray] = None,
|
272 |
+
negative_prompt_embeds: Optional[np.ndarray] = None,
|
273 |
+
):
|
274 |
+
if height % 8 != 0 or width % 8 != 0:
|
275 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {
|
276 |
+
height} and {width}.")
|
277 |
+
|
278 |
+
if (callback_steps is None) or (
|
279 |
+
callback_steps is not None and (not isinstance(
|
280 |
+
callback_steps, int) or callback_steps <= 0)
|
281 |
+
):
|
282 |
+
raise ValueError(
|
283 |
+
f"`callback_steps` has to be a positive integer but is {
|
284 |
+
callback_steps} of type"
|
285 |
+
f" {type(callback_steps)}."
|
286 |
+
)
|
287 |
+
|
288 |
+
if prompt is not None and prompt_embeds is not None:
|
289 |
+
raise ValueError(
|
290 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {
|
291 |
+
prompt_embeds}. Please make sure to"
|
292 |
+
" only forward one of the two."
|
293 |
+
)
|
294 |
+
elif prompt is None and prompt_embeds is None:
|
295 |
+
raise ValueError(
|
296 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
297 |
+
)
|
298 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
299 |
+
raise ValueError(
|
300 |
+
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
301 |
+
|
302 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
303 |
+
raise ValueError(
|
304 |
+
f"Cannot forward both `negative_prompt`: {
|
305 |
+
negative_prompt} and `negative_prompt_embeds`:"
|
306 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
307 |
+
)
|
308 |
+
|
309 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
310 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
311 |
+
raise ValueError(
|
312 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
313 |
+
f" got: `prompt_embeds` {
|
314 |
+
prompt_embeds.shape} != `negative_prompt_embeds`"
|
315 |
+
f" {negative_prompt_embeds.shape}."
|
316 |
+
)
|
317 |
+
|
318 |
+
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
319 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None):
|
320 |
+
shape = (batch_size, num_channels_latents, height //
|
321 |
+
self.vae_scale_factor, width // self.vae_scale_factor)
|
322 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
323 |
+
raise ValueError(
|
324 |
+
f"You have passed a list of generators of length {
|
325 |
+
len(generator)}, but requested an effective batch"
|
326 |
+
f" size of {
|
327 |
+
batch_size}. Make sure the batch size matches the length of the generators."
|
328 |
+
)
|
329 |
+
|
330 |
+
if latents is None:
|
331 |
+
if isinstance(generator, np.random.RandomState):
|
332 |
+
latents = generator.randn(*shape).astype(dtype)
|
333 |
+
elif isinstance(generator, torch.Generator):
|
334 |
+
latents = torch.randn(
|
335 |
+
*shape, generator=generator).numpy().astype(dtype)
|
336 |
+
else:
|
337 |
+
raise ValueError(
|
338 |
+
f"Expected `generator` to be of type `np.random.RandomState` or `torch.Generator`, but got"
|
339 |
+
f" {type(generator)}."
|
340 |
+
)
|
341 |
+
elif latents.shape != shape:
|
342 |
+
raise ValueError(f"Unexpected latents shape, got {
|
343 |
+
latents.shape}, expected {shape}")
|
344 |
+
|
345 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
346 |
+
latents = latents * np.float64(self.scheduler.init_noise_sigma)
|
347 |
+
|
348 |
+
return latents
|
349 |
+
|
350 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/v0.22.0/src/diffusers/pipelines/latent_consistency/pipeline_latent_consistency.py#L264
|
351 |
+
def __call__(
|
352 |
+
self,
|
353 |
+
prompt: Union[str, List[str]] = "",
|
354 |
+
height: Optional[int] = None,
|
355 |
+
width: Optional[int] = None,
|
356 |
+
num_inference_steps: int = 4,
|
357 |
+
original_inference_steps: int = None,
|
358 |
+
guidance_scale: float = 8.5,
|
359 |
+
num_images_per_prompt: int = 1,
|
360 |
+
generator: Optional[Union[np.random.RandomState,
|
361 |
+
torch.Generator]] = None,
|
362 |
+
latents: Optional[np.ndarray] = None,
|
363 |
+
prompt_embeds: Optional[np.ndarray] = None,
|
364 |
+
output_type: str = "pil",
|
365 |
+
return_dict: bool = True,
|
366 |
+
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
367 |
+
callback_steps: int = 1,
|
368 |
+
):
|
369 |
+
r"""
|
370 |
+
Function invoked when calling the pipeline for generation.
|
371 |
+
|
372 |
+
Args:
|
373 |
+
prompt (`Optional[Union[str, List[str]]]`, defaults to None):
|
374 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
375 |
+
instead.
|
376 |
+
height (`Optional[int]`, defaults to None):
|
377 |
+
The height in pixels of the generated image.
|
378 |
+
width (`Optional[int]`, defaults to None):
|
379 |
+
The width in pixels of the generated image.
|
380 |
+
num_inference_steps (`int`, defaults to 50):
|
381 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
382 |
+
expense of slower inference.
|
383 |
+
guidance_scale (`float`, defaults to 7.5):
|
384 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
385 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
386 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
387 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
388 |
+
usually at the expense of lower image quality.
|
389 |
+
num_images_per_prompt (`int`, defaults to 1):
|
390 |
+
The number of images to generate per prompt.
|
391 |
+
generator (`Optional[Union[np.random.RandomState, torch.Generator]]`, defaults to `None`):
|
392 |
+
A np.random.RandomState to make generation deterministic.
|
393 |
+
latents (`Optional[np.ndarray]`, defaults to `None`):
|
394 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
395 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
396 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
397 |
+
prompt_embeds (`Optional[np.ndarray]`, defaults to `None`):
|
398 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
399 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
400 |
+
output_type (`str`, defaults to `"pil"`):
|
401 |
+
The output format of the generate image. Choose between
|
402 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
403 |
+
return_dict (`bool`, defaults to `True`):
|
404 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
405 |
+
plain tuple.
|
406 |
+
callback (Optional[Callable], defaults to `None`):
|
407 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
408 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
409 |
+
callback_steps (`int`, defaults to 1):
|
410 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
411 |
+
called at every step.
|
412 |
+
guidance_rescale (`float`, defaults to 0.0):
|
413 |
+
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
414 |
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
|
415 |
+
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
416 |
+
Guidance rescale factor should fix overexposure when using zero terminal SNR.
|
417 |
+
|
418 |
+
Returns:
|
419 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
420 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
421 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
422 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
423 |
+
(nsfw) content, according to the `safety_checker`.
|
424 |
+
"""
|
425 |
+
height = height or self.unet.config["sample_size"] * \
|
426 |
+
self.vae_scale_factor
|
427 |
+
width = width or self.unet.config["sample_size"] * \
|
428 |
+
self.vae_scale_factor
|
429 |
+
|
430 |
+
# Don't need to get negative prompts due to LCM guided distillation
|
431 |
+
negative_prompt = None
|
432 |
+
negative_prompt_embeds = None
|
433 |
+
|
434 |
+
# check inputs. Raise error if not correct
|
435 |
+
self.check_inputs(
|
436 |
+
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
|
437 |
+
)
|
438 |
+
|
439 |
+
# define call parameters
|
440 |
+
if isinstance(prompt, str):
|
441 |
+
batch_size = 1
|
442 |
+
elif isinstance(prompt, list):
|
443 |
+
batch_size = len(prompt)
|
444 |
+
else:
|
445 |
+
batch_size = prompt_embeds.shape[0]
|
446 |
+
|
447 |
+
if generator is None:
|
448 |
+
generator = np.random.RandomState()
|
449 |
+
|
450 |
+
start_time = time.time()
|
451 |
+
prompt_embeds = self._encode_prompt(
|
452 |
+
prompt,
|
453 |
+
num_images_per_prompt,
|
454 |
+
False,
|
455 |
+
negative_prompt,
|
456 |
+
prompt_embeds=prompt_embeds,
|
457 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
458 |
+
)
|
459 |
+
encode_prompt_time = time.time() - start_time
|
460 |
+
print(f"Prompt encoding time: {encode_prompt_time:.2f}s")
|
461 |
+
|
462 |
+
# set timesteps
|
463 |
+
self.scheduler.set_timesteps(
|
464 |
+
num_inference_steps, original_inference_steps=original_inference_steps)
|
465 |
+
timesteps = self.scheduler.timesteps
|
466 |
+
|
467 |
+
latents = self.prepare_latents(
|
468 |
+
batch_size * num_images_per_prompt,
|
469 |
+
self.unet.config["in_channels"],
|
470 |
+
height,
|
471 |
+
width,
|
472 |
+
prompt_embeds.dtype,
|
473 |
+
generator,
|
474 |
+
latents,
|
475 |
+
)
|
476 |
+
|
477 |
+
bs = batch_size * num_images_per_prompt
|
478 |
+
# get Guidance Scale Embedding
|
479 |
+
w = np.full(bs, guidance_scale - 1, dtype=prompt_embeds.dtype)
|
480 |
+
w_embedding = self.get_guidance_scale_embedding(
|
481 |
+
w, embedding_dim=self.unet.config["time_cond_proj_dim"], dtype=prompt_embeds.dtype
|
482 |
+
)
|
483 |
+
|
484 |
+
# Adapted from diffusers to extend it for other runtimes than ORT
|
485 |
+
timestep_dtype = np.int64
|
486 |
+
|
487 |
+
num_warmup_steps = len(timesteps) - \
|
488 |
+
num_inference_steps * self.scheduler.order
|
489 |
+
inference_start = time.time()
|
490 |
+
for i, t in enumerate(self.progress_bar(timesteps)):
|
491 |
+
timestep = np.array([t], dtype=timestep_dtype)
|
492 |
+
noise_pred = self.unet(
|
493 |
+
sample=latents,
|
494 |
+
timestep=timestep,
|
495 |
+
encoder_hidden_states=prompt_embeds,
|
496 |
+
timestep_cond=w_embedding,
|
497 |
+
)[0]
|
498 |
+
|
499 |
+
# compute the previous noisy sample x_t -> x_t-1
|
500 |
+
latents, denoised = self.scheduler.step(
|
501 |
+
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), return_dict=False
|
502 |
+
)
|
503 |
+
latents, denoised = latents.numpy(), denoised.numpy()
|
504 |
+
|
505 |
+
# call the callback, if provided
|
506 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
507 |
+
if callback is not None and i % callback_steps == 0:
|
508 |
+
callback(i, t, latents)
|
509 |
+
inference_time = time.time() - inference_start
|
510 |
+
print(f"Inference time: {inference_time:.2f}s")
|
511 |
+
|
512 |
+
decode_start = time.time()
|
513 |
+
if output_type == "latent":
|
514 |
+
image = denoised
|
515 |
+
has_nsfw_concept = None
|
516 |
+
else:
|
517 |
+
denoised /= self.vae_decoder.config["scaling_factor"]
|
518 |
+
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
519 |
+
image = np.concatenate(
|
520 |
+
[self.vae_decoder(latent_sample=denoised[i: i + 1])[0]
|
521 |
+
for i in range(denoised.shape[0])]
|
522 |
+
)
|
523 |
+
# image, has_nsfw_concept = self.run_safety_checker(image)
|
524 |
+
has_nsfw_concept = None # skip safety checker
|
525 |
+
|
526 |
+
if has_nsfw_concept is None:
|
527 |
+
do_denormalize = [True] * image.shape[0]
|
528 |
+
else:
|
529 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
530 |
+
|
531 |
+
image = self.postprocess(
|
532 |
+
image, output_type=output_type, do_denormalize=do_denormalize)
|
533 |
+
decode_time = time.time() - decode_start
|
534 |
+
print(f"Decode time: {decode_time:.2f}s")
|
535 |
+
|
536 |
+
total_time = encode_prompt_time + inference_time + decode_time
|
537 |
+
print(f"Total time: {total_time:.2f}s")
|
538 |
+
|
539 |
+
if not return_dict:
|
540 |
+
return (image, has_nsfw_concept)
|
541 |
+
|
542 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
543 |
+
|
544 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/v0.22.0/src/diffusers/pipelines/latent_consistency/pipeline_latent_consistency.py#L264
|
545 |
+
|
546 |
+
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=None):
|
547 |
+
"""
|
548 |
+
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
549 |
+
|
550 |
+
Args:
|
551 |
+
timesteps (`torch.Tensor`):
|
552 |
+
generate embedding vectors at these timesteps
|
553 |
+
embedding_dim (`int`, *optional*, defaults to 512):
|
554 |
+
dimension of the embeddings to generate
|
555 |
+
dtype:
|
556 |
+
data type of the generated embeddings
|
557 |
+
|
558 |
+
Returns:
|
559 |
+
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
|
560 |
+
"""
|
561 |
+
w = w * 1000
|
562 |
+
half_dim = embedding_dim // 2
|
563 |
+
emb = np.log(10000.0) / (half_dim - 1)
|
564 |
+
emb = np.exp(np.arange(half_dim, dtype=dtype) * -emb)
|
565 |
+
emb = w[:, None] * emb[None, :]
|
566 |
+
emb = np.concatenate([np.sin(emb), np.cos(emb)], axis=1)
|
567 |
+
|
568 |
+
if embedding_dim % 2 == 1: # zero pad
|
569 |
+
emb = np.pad(emb, [(0, 0), (0, 1)])
|
570 |
+
|
571 |
+
assert emb.shape == (w.shape[0], embedding_dim)
|
572 |
+
return emb
|
573 |
+
|
574 |
+
|
575 |
+
def get_max_filename_length():
|
576 |
+
if platform.system() == 'Windows':
|
577 |
+
return get_max_filename_length_windows()
|
578 |
+
elif platform.system() in ['Linux', 'Darwin']: # Darwin is for MacOS
|
579 |
+
return get_max_filename_length_unix()
|
580 |
+
else:
|
581 |
+
raise Exception(f"Unsupported operating system: {platform.system()}")
|
582 |
+
|
583 |
+
|
584 |
+
def get_max_filename_length_windows():
|
585 |
+
try:
|
586 |
+
max_length = os.path.getconf('PC_NAME_MAX')
|
587 |
+
print(
|
588 |
+
f"The maximum file name length on Windows is: {max_length} characters.")
|
589 |
+
return max_length
|
590 |
+
except Exception as e:
|
591 |
+
print(f"An error occurred: {e}")
|
592 |
+
|
593 |
+
|
594 |
+
def get_max_filename_length_unix():
|
595 |
+
try:
|
596 |
+
max_length = os.pathconf('/', 'PC_NAME_MAX')
|
597 |
+
return max_length
|
598 |
+
except Exception as e:
|
599 |
+
print(f"An error occurred: {e}")
|
600 |
+
|
601 |
+
|
602 |
+
def get_image_path(args, **override_kwargs):
|
603 |
+
""" mkdir output folder and encode metadata in the filename
|
604 |
+
"""
|
605 |
+
out_folder = os.path.join(args.o, "_".join(
|
606 |
+
args.prompt.replace("/", "_").rsplit(" ")))
|
607 |
+
max_length = get_max_filename_length()
|
608 |
+
if len(out_folder) > max_length:
|
609 |
+
out_folder = out_folder[:max_length]
|
610 |
+
os.makedirs(out_folder, exist_ok=True)
|
611 |
+
|
612 |
+
out_fname = f"randomSeed_{override_kwargs.get('seed', None) or args.seed}"
|
613 |
+
|
614 |
+
out_fname += f"_LCM_"
|
615 |
+
out_fname += f"_numInferenceSteps{override_kwargs.get(
|
616 |
+
'num_inference_steps', None) or args.num_inference_steps}"
|
617 |
+
|
618 |
+
return os.path.join(out_folder, out_fname + ".png")
|
619 |
+
|
620 |
+
|
621 |
+
def prepare_controlnet_cond(image_path, height, width):
|
622 |
+
image = Image.open(image_path).convert("RGB")
|
623 |
+
image = image.resize((height, width), resample=Image.LANCZOS)
|
624 |
+
image = np.array(image).transpose(2, 0, 1) / 255.0
|
625 |
+
return image
|
626 |
+
|
627 |
+
|
628 |
+
def main(args):
|
629 |
+
logger.info(f"Setting random seed to {args.seed}")
|
630 |
+
|
631 |
+
# load scheduler from /scheduler/scheduler_config.json
|
632 |
+
scheduler_config_path = os.path.join(
|
633 |
+
args.i, "scheduler/scheduler_config.json")
|
634 |
+
with open(scheduler_config_path, "r") as f:
|
635 |
+
scheduler_config = json.load(f)
|
636 |
+
user_specified_scheduler = LCMScheduler.from_config(scheduler_config)
|
637 |
+
|
638 |
+
print("user_specified_scheduler", user_specified_scheduler)
|
639 |
+
|
640 |
+
pipe = RKNN2LatentConsistencyPipeline(
|
641 |
+
text_encoder=RKNN2Model(os.path.join(args.i, "text_encoder")),
|
642 |
+
unet=RKNN2Model(os.path.join(args.i, "unet")),
|
643 |
+
vae_decoder=RKNN2Model(os.path.join(args.i, "vae_decoder")),
|
644 |
+
scheduler=user_specified_scheduler,
|
645 |
+
tokenizer=CLIPTokenizer.from_pretrained(
|
646 |
+
"openai/clip-vit-base-patch16"),
|
647 |
+
)
|
648 |
+
|
649 |
+
logger.info("Beginning image generation.")
|
650 |
+
image = pipe(
|
651 |
+
prompt=args.prompt,
|
652 |
+
height=int(args.size.split("x")[0]),
|
653 |
+
width=int(args.size.split("x")[1]),
|
654 |
+
num_inference_steps=args.num_inference_steps,
|
655 |
+
guidance_scale=args.guidance_scale,
|
656 |
+
generator=np.random.RandomState(args.seed),
|
657 |
+
)
|
658 |
+
|
659 |
+
out_path = get_image_path(args)
|
660 |
+
logger.info(f"Saving generated image to {out_path}")
|
661 |
+
image["images"][0].save(out_path)
|
662 |
+
|
663 |
+
|
664 |
+
if __name__ == "__main__":
|
665 |
+
parser = argparse.ArgumentParser()
|
666 |
+
|
667 |
+
parser.add_argument(
|
668 |
+
"--prompt",
|
669 |
+
required=True,
|
670 |
+
help="The text prompt to be used for text-to-image generation.")
|
671 |
+
parser.add_argument(
|
672 |
+
"-i",
|
673 |
+
required=True,
|
674 |
+
help=("Path to model directory"))
|
675 |
+
parser.add_argument("-o", required=True)
|
676 |
+
parser.add_argument("--seed",
|
677 |
+
default=93,
|
678 |
+
type=int,
|
679 |
+
help="Random seed to be able to reproduce results")
|
680 |
+
parser.add_argument(
|
681 |
+
"-s",
|
682 |
+
"--size",
|
683 |
+
default="256x256",
|
684 |
+
type=str,
|
685 |
+
help="Image size")
|
686 |
+
parser.add_argument(
|
687 |
+
"--num-inference-steps",
|
688 |
+
default=4,
|
689 |
+
type=int,
|
690 |
+
help="The number of iterations the unet model will be executed throughout the reverse diffusion process")
|
691 |
+
parser.add_argument(
|
692 |
+
"--guidance-scale",
|
693 |
+
default=7.5,
|
694 |
+
type=float,
|
695 |
+
help="Controls the influence of the text prompt on sampling process (0=random images)")
|
696 |
+
|
697 |
+
args = parser.parse_args()
|
698 |
+
main(args)
|