BiliSakura commited on
Commit
a296060
·
verified ·
1 Parent(s): e754228

Update all files for BitDance-Tokenizer-diffusers

Browse files
bitdance_diffusers/pipeline_bitdance.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import nullcontext
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from PIL import Image
7
+ from tqdm.auto import tqdm
8
+
9
+ from diffusers import DiffusionPipeline
10
+ from diffusers.pipelines.pipeline_utils import ImagePipelineOutput
11
+
12
+ from .constants import SUPPORTED_IMAGE_SIZES
13
+
14
+
15
+ PromptType = Union[str, List[str]]
16
+
17
+
18
+ def _get_pkv_seq_len(past_key_values) -> int:
19
+ """Get cached sequence length from past_key_values (supports tuple and DynamicCache)."""
20
+ if hasattr(past_key_values, "get_seq_length"):
21
+ return past_key_values.get_seq_length()
22
+ return past_key_values[0][0].shape[2]
23
+
24
+
25
+ class BitDanceDiffusionPipeline(DiffusionPipeline):
26
+ model_cpu_offload_seq = "text_encoder->projector->diffusion_head->autoencoder"
27
+
28
+ def __init__(
29
+ self,
30
+ tokenizer,
31
+ text_encoder,
32
+ autoencoder,
33
+ diffusion_head,
34
+ projector,
35
+ supported_image_sizes: Optional[List[List[int]]] = None,
36
+ dtype: Optional[torch.dtype] = None,
37
+ ) -> None:
38
+ super().__init__()
39
+ self.register_modules(
40
+ tokenizer=tokenizer,
41
+ text_encoder=text_encoder,
42
+ autoencoder=autoencoder,
43
+ diffusion_head=diffusion_head,
44
+ projector=projector,
45
+ )
46
+
47
+ image_sizes = supported_image_sizes or SUPPORTED_IMAGE_SIZES
48
+ self.register_to_config(supported_image_sizes=[list(size) for size in image_sizes])
49
+
50
+ self.hidden_size = self.text_encoder.config.hidden_size
51
+ self.vae_patch_size = self.autoencoder.patch_size
52
+ self.parallel_num = int(self.diffusion_head.config.parallel_num)
53
+ self.ps = int(self.parallel_num**0.5)
54
+ if self.ps * self.ps != self.parallel_num:
55
+ raise ValueError(
56
+ f"parallel_num must be a perfect square (got {self.parallel_num})."
57
+ )
58
+
59
+ self._build_pos_embed()
60
+
61
+ @property
62
+ def supported_image_sizes(self) -> List[List[int]]:
63
+ return [list(size) for size in self.config.supported_image_sizes]
64
+
65
+ def _execution_device_fallback(self) -> torch.device:
66
+ if getattr(self, "_execution_device", None) is not None:
67
+ return self._execution_device
68
+ return next(self.text_encoder.parameters()).device
69
+
70
+ def _build_pos_embed(self) -> None:
71
+ max_resolution = max(max(size) for size in self.supported_image_sizes)
72
+ max_len = max_resolution // self.vae_patch_size
73
+ pos_embed_1d = self._get_1d_sincos_pos_embed(self.hidden_size // 2, max_len)
74
+ self.pos_embed_1d = pos_embed_1d
75
+
76
+ @staticmethod
77
+ def _get_1d_sincos_pos_embed(dim: int, max_len: int, pe_interpolation: float = 1.0) -> torch.Tensor:
78
+ if dim % 2 != 0:
79
+ raise ValueError(f"dim must be even, got {dim}")
80
+ omega = torch.arange(dim // 2, dtype=torch.float32)
81
+ omega /= dim / 2.0
82
+ omega = 1.0 / 10000**omega
83
+ pos = torch.arange(max_len, dtype=torch.float32) / pe_interpolation
84
+ out = torch.einsum("m,d->md", pos, omega)
85
+ emb_sin = torch.sin(out)
86
+ emb_cos = torch.cos(out)
87
+ return torch.cat([emb_sin, emb_cos], dim=1)
88
+
89
+ def _get_2d_embed(self, h: int, w: int, ps: int = 1) -> torch.Tensor:
90
+ emb_v = self.pos_embed_1d[:h]
91
+ emb_h = self.pos_embed_1d[:w]
92
+ grid_v = emb_v.view(h, 1, self.hidden_size // 2).repeat(1, w, 1)
93
+ grid_h = emb_h.view(1, w, self.hidden_size // 2).repeat(h, 1, 1)
94
+ pos_embed = torch.cat([grid_h, grid_v], dim=-1)
95
+ return rearrange(pos_embed, "(h p1) (w p2) c -> (h w p1 p2) c", p1=ps, p2=ps)
96
+
97
+ def _encode_prompt_to_embeds(
98
+ self,
99
+ prompt: str,
100
+ image_size: Tuple[int, int],
101
+ num_images_per_prompt: int,
102
+ guidance_scale: float,
103
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
104
+ device = self._execution_device_fallback()
105
+ model = self.text_encoder.model
106
+ tokenizer = self.tokenizer
107
+
108
+ cond_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
109
+ uncond_prompt = "<|im_start|>assistant\n"
110
+
111
+ cond_ids = torch.tensor(tokenizer.encode(cond_prompt), device=device, dtype=torch.long)
112
+ cond_emb = model.embed_tokens(cond_ids)
113
+ uncond_emb = None
114
+ if guidance_scale > 1.0:
115
+ uncond_ids = torch.tensor(tokenizer.encode(uncond_prompt), device=device, dtype=torch.long)
116
+ uncond_emb = model.embed_tokens(uncond_ids)
117
+
118
+ image_h, image_w = image_size
119
+ img_start_id = tokenizer.convert_tokens_to_ids("<|vision_start|>")
120
+ res_h_token_id = tokenizer.convert_tokens_to_ids(f"<|res_{image_h // self.vae_patch_size}|>")
121
+ res_w_token_id = tokenizer.convert_tokens_to_ids(f"<|res_{image_w // self.vae_patch_size}|>")
122
+ img_start_emb = model.embed_tokens(torch.tensor([img_start_id, res_h_token_id, res_w_token_id], device=device))
123
+
124
+ for i in range(1, self.parallel_num):
125
+ query_token_id = tokenizer.convert_tokens_to_ids(f"<|query_{i}|>")
126
+ query_token = torch.tensor([query_token_id], device=device, dtype=torch.long)
127
+ query_embed = model.embed_tokens(query_token)
128
+ img_start_emb = torch.cat([img_start_emb, query_embed], dim=0)
129
+
130
+ input_embeds_cond = torch.cat([cond_emb, img_start_emb], dim=0).unsqueeze(0).repeat(num_images_per_prompt, 1, 1)
131
+ input_embeds_uncond = None
132
+ if guidance_scale > 1.0 and uncond_emb is not None:
133
+ input_embeds_uncond = torch.cat([uncond_emb, img_start_emb], dim=0).unsqueeze(0).repeat(num_images_per_prompt, 1, 1)
134
+ return input_embeds_cond, input_embeds_uncond, img_start_emb
135
+
136
+ def _decode_tokens_to_image(self, image_latents: torch.Tensor, image_size: Tuple[int, int], ps: int = 1) -> torch.Tensor:
137
+ h, w = image_size
138
+ image_latents = rearrange(image_latents, "b (h w p1 p2) c -> b c (h p1) (w p2)", h=h // ps, w=w // ps, p1=ps, p2=ps)
139
+ return self.autoencoder.decode(image_latents)
140
+
141
+ @torch.no_grad()
142
+ def _generate_single_prompt(
143
+ self,
144
+ prompt: str,
145
+ height: int,
146
+ width: int,
147
+ num_inference_steps: int,
148
+ guidance_scale: float,
149
+ num_images_per_prompt: int,
150
+ generator: Optional[torch.Generator],
151
+ show_progress_bar: bool,
152
+ ) -> torch.Tensor:
153
+ image_size = (height, width)
154
+ if list(image_size) not in self.supported_image_sizes:
155
+ raise ValueError(
156
+ f"image_size {list(image_size)} is not supported. "
157
+ f"Please choose from {self.supported_image_sizes}"
158
+ )
159
+
160
+ h, w = height // self.vae_patch_size, width // self.vae_patch_size
161
+ max_length = h * w
162
+ step_width = self.parallel_num
163
+ if max_length % step_width != 0:
164
+ raise ValueError(
165
+ f"max_length ({max_length}) must be divisible by parallel_num ({step_width})."
166
+ )
167
+ num_steps = max_length // step_width
168
+
169
+ device = self._execution_device_fallback()
170
+ model = self.text_encoder.model
171
+ dtype = next(self.text_encoder.parameters()).dtype
172
+
173
+ input_embeds_cond, input_embeds_uncond, _ = self._encode_prompt_to_embeds(
174
+ prompt=prompt,
175
+ image_size=image_size,
176
+ num_images_per_prompt=num_images_per_prompt,
177
+ guidance_scale=guidance_scale,
178
+ )
179
+ pos_embed_for_diff = self._get_2d_embed(h, w, ps=self.ps).unsqueeze(0).to(device=device, dtype=dtype)
180
+
181
+ autocast_ctx = (
182
+ torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16)
183
+ if device.type == "cuda"
184
+ else nullcontext()
185
+ )
186
+
187
+ with autocast_ctx:
188
+ outputs_c = model(inputs_embeds=input_embeds_cond[:, :-step_width, :], use_cache=True)
189
+ pkv_c = outputs_c.past_key_values
190
+
191
+ bi_attn_mask = torch.ones(
192
+ (input_embeds_cond.shape[0], 1, step_width, step_width + _get_pkv_seq_len(pkv_c)),
193
+ dtype=torch.bool,
194
+ device=device,
195
+ )
196
+ outputs_c = model(
197
+ inputs_embeds=input_embeds_cond[:, -step_width:, :],
198
+ past_key_values=pkv_c,
199
+ use_cache=True,
200
+ attention_mask=bi_attn_mask,
201
+ )
202
+ pkv_c = outputs_c.past_key_values
203
+ hidden_c = outputs_c.last_hidden_state[:, -step_width:]
204
+
205
+ hidden_u = None
206
+ pkv_u = None
207
+ if guidance_scale > 1.0 and input_embeds_uncond is not None:
208
+ outputs_u = model(inputs_embeds=input_embeds_uncond[:, :-step_width, :], use_cache=True)
209
+ pkv_u = outputs_u.past_key_values
210
+ bi_attn_mask_u = torch.ones(
211
+ (input_embeds_uncond.shape[0], 1, step_width, step_width + _get_pkv_seq_len(pkv_u)),
212
+ dtype=torch.bool,
213
+ device=device,
214
+ )
215
+ outputs_u = model(
216
+ inputs_embeds=input_embeds_uncond[:, -step_width:, :],
217
+ past_key_values=pkv_u,
218
+ use_cache=True,
219
+ attention_mask=bi_attn_mask_u,
220
+ )
221
+ pkv_u = outputs_u.past_key_values
222
+ hidden_u = outputs_u.last_hidden_state[:, -step_width:]
223
+
224
+ out_tokens = []
225
+ step_iter = range(num_steps)
226
+ if show_progress_bar:
227
+ step_iter = tqdm(step_iter, total=num_steps, desc="Decoding steps")
228
+
229
+ for step in step_iter:
230
+ if guidance_scale > 1.0 and hidden_u is not None:
231
+ h_fused = torch.cat([hidden_c, hidden_u], dim=0)
232
+ else:
233
+ h_fused = hidden_c
234
+
235
+ pos_slice = pos_embed_for_diff[:, step * step_width : (step + 1) * step_width, :]
236
+ h_fused = h_fused + pos_slice
237
+ pred_latents = self.diffusion_head.sample(
238
+ h_fused,
239
+ num_sampling_steps=num_inference_steps,
240
+ cfg=guidance_scale,
241
+ generator=generator,
242
+ )
243
+ curr_tokens = torch.sign(pred_latents)
244
+ curr_embeds = self.projector(curr_tokens)
245
+ out_tokens.append(curr_tokens[:num_images_per_prompt])
246
+
247
+ model_input = curr_embeds + pos_slice
248
+ bi_attn_mask = torch.ones(
249
+ (model_input.shape[0], 1, model_input.shape[1], model_input.shape[1] + _get_pkv_seq_len(pkv_c)),
250
+ dtype=torch.bool,
251
+ device=device,
252
+ )
253
+ outputs_c = model(
254
+ inputs_embeds=model_input[:num_images_per_prompt],
255
+ past_key_values=pkv_c,
256
+ use_cache=True,
257
+ attention_mask=bi_attn_mask[:num_images_per_prompt],
258
+ )
259
+ pkv_c = outputs_c.past_key_values
260
+ hidden_c = outputs_c.last_hidden_state[:, -step_width:]
261
+
262
+ if guidance_scale > 1.0 and hidden_u is not None and pkv_u is not None:
263
+ bi_attn_mask_u = torch.ones(
264
+ (model_input.shape[0], 1, model_input.shape[1], model_input.shape[1] + _get_pkv_seq_len(pkv_u)),
265
+ dtype=torch.bool,
266
+ device=device,
267
+ )
268
+ outputs_u = model(
269
+ inputs_embeds=model_input[num_images_per_prompt:],
270
+ past_key_values=pkv_u,
271
+ use_cache=True,
272
+ attention_mask=bi_attn_mask_u[num_images_per_prompt:],
273
+ )
274
+ pkv_u = outputs_u.past_key_values
275
+ hidden_u = outputs_u.last_hidden_state[:, -step_width:]
276
+
277
+ full_output = torch.cat(out_tokens, dim=1)
278
+ return self._decode_tokens_to_image(full_output, image_size=(h, w), ps=self.ps)
279
+
280
+ @torch.no_grad()
281
+ def __call__(
282
+ self,
283
+ prompt: PromptType,
284
+ height: int = 1024,
285
+ width: int = 1024,
286
+ num_inference_steps: int = 50,
287
+ guidance_scale: float = 7.5,
288
+ num_images_per_prompt: int = 1,
289
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
290
+ output_type: str = "pil",
291
+ return_dict: bool = True,
292
+ show_progress_bar: bool = False,
293
+ ) -> Union[ImagePipelineOutput, Tuple]:
294
+ prompts = [prompt] if isinstance(prompt, str) else list(prompt)
295
+ if len(prompts) == 0:
296
+ raise ValueError("prompt must be a non-empty string or list of strings.")
297
+
298
+ if isinstance(generator, list) and len(generator) != len(prompts):
299
+ raise ValueError("When passing a list of generators, its length must equal len(prompt).")
300
+
301
+ image_tensors = []
302
+ for i, prompt_text in enumerate(prompts):
303
+ prompt_generator = generator[i] if isinstance(generator, list) else generator
304
+ images = self._generate_single_prompt(
305
+ prompt=prompt_text,
306
+ height=height,
307
+ width=width,
308
+ num_inference_steps=num_inference_steps,
309
+ guidance_scale=guidance_scale,
310
+ num_images_per_prompt=num_images_per_prompt,
311
+ generator=prompt_generator,
312
+ show_progress_bar=show_progress_bar,
313
+ )
314
+ image_tensors.append(images)
315
+
316
+ images_pt = torch.cat(image_tensors, dim=0)
317
+ images_pt_01 = torch.clamp((images_pt + 1.0) / 2.0, 0.0, 1.0)
318
+
319
+ if output_type == "pt":
320
+ output_images = images_pt_01
321
+ elif output_type == "np":
322
+ output_images = images_pt_01.permute(0, 2, 3, 1).float().cpu().numpy()
323
+ elif output_type == "pil":
324
+ images_uint8 = (
325
+ torch.clamp(127.5 * images_pt + 128.0, 0, 255)
326
+ .permute(0, 2, 3, 1)
327
+ .to("cpu", dtype=torch.uint8)
328
+ .numpy()
329
+ )
330
+ output_images = [Image.fromarray(image) for image in images_uint8]
331
+ else:
332
+ raise ValueError(f"Unsupported output_type={output_type}. Expected 'pil', 'np', or 'pt'.")
333
+
334
+ if not return_dict:
335
+ return (output_images,)
336
+ return ImagePipelineOutput(images=output_images)
337
+
338
+ @torch.no_grad()
339
+ def generate(
340
+ self,
341
+ prompt: str,
342
+ height: int = 1024,
343
+ width: int = 1024,
344
+ num_sampling_steps: int = 50,
345
+ guidance_scale: float = 7.5,
346
+ num_images: int = 1,
347
+ seed: Optional[int] = None,
348
+ ) -> List[Image.Image]:
349
+ generator = None
350
+ if seed is not None:
351
+ device = self._execution_device_fallback()
352
+ generator_device = "cuda" if device.type == "cuda" else "cpu"
353
+ generator = torch.Generator(device=generator_device).manual_seed(seed)
354
+ output = self(
355
+ prompt=prompt,
356
+ height=height,
357
+ width=width,
358
+ num_inference_steps=num_sampling_steps,
359
+ guidance_scale=guidance_scale,
360
+ num_images_per_prompt=num_images,
361
+ generator=generator,
362
+ output_type="pil",
363
+ return_dict=True,
364
+ show_progress_bar=True,
365
+ )
366
+ return output.images