multimodalart HF staff commited on
Commit
4a09d4f
1 Parent(s): 6db905d

Create cog_sdxl_dataset_and_utils.py

Browse files
Files changed (1) hide show
  1. cog_sdxl_dataset_and_utils.py +422 -0
cog_sdxl_dataset_and_utils.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dataset_and_utils.py file taken from https://github.com/replicate/cog-sdxl/blob/main/dataset_and_utils.py
2
+ import os
3
+ from typing import Dict, List, Optional, Tuple
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import PIL
8
+ import torch
9
+ import torch.utils.checkpoint
10
+ from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
11
+ from PIL import Image
12
+ from safetensors import safe_open
13
+ from safetensors.torch import save_file
14
+ from torch.utils.data import Dataset
15
+ from transformers import AutoTokenizer, PretrainedConfig
16
+
17
+
18
+ def prepare_image(
19
+ pil_image: PIL.Image.Image, w: int = 512, h: int = 512
20
+ ) -> torch.Tensor:
21
+ pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
22
+ arr = np.array(pil_image.convert("RGB"))
23
+ arr = arr.astype(np.float32) / 127.5 - 1
24
+ arr = np.transpose(arr, [2, 0, 1])
25
+ image = torch.from_numpy(arr).unsqueeze(0)
26
+ return image
27
+
28
+
29
+ def prepare_mask(
30
+ pil_image: PIL.Image.Image, w: int = 512, h: int = 512
31
+ ) -> torch.Tensor:
32
+ pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
33
+ arr = np.array(pil_image.convert("L"))
34
+ arr = arr.astype(np.float32) / 255.0
35
+ arr = np.expand_dims(arr, 0)
36
+ image = torch.from_numpy(arr).unsqueeze(0)
37
+ return image
38
+
39
+
40
+ class PreprocessedDataset(Dataset):
41
+ def __init__(
42
+ self,
43
+ csv_path: str,
44
+ tokenizer_1,
45
+ tokenizer_2,
46
+ vae_encoder,
47
+ text_encoder_1=None,
48
+ text_encoder_2=None,
49
+ do_cache: bool = False,
50
+ size: int = 512,
51
+ text_dropout: float = 0.0,
52
+ scale_vae_latents: bool = True,
53
+ substitute_caption_map: Dict[str, str] = {},
54
+ ):
55
+ super().__init__()
56
+
57
+ self.data = pd.read_csv(csv_path)
58
+ self.csv_path = csv_path
59
+
60
+ self.caption = self.data["caption"]
61
+ # make it lowercase
62
+ self.caption = self.caption.str.lower()
63
+ for key, value in substitute_caption_map.items():
64
+ self.caption = self.caption.str.replace(key.lower(), value)
65
+
66
+ self.image_path = self.data["image_path"]
67
+
68
+ if "mask_path" not in self.data.columns:
69
+ self.mask_path = None
70
+ else:
71
+ self.mask_path = self.data["mask_path"]
72
+
73
+ if text_encoder_1 is None:
74
+ self.return_text_embeddings = False
75
+ else:
76
+ self.text_encoder_1 = text_encoder_1
77
+ self.text_encoder_2 = text_encoder_2
78
+ self.return_text_embeddings = True
79
+ assert (
80
+ NotImplementedError
81
+ ), "Preprocessing Text Encoder is not implemented yet"
82
+
83
+ self.tokenizer_1 = tokenizer_1
84
+ self.tokenizer_2 = tokenizer_2
85
+
86
+ self.vae_encoder = vae_encoder
87
+ self.scale_vae_latents = scale_vae_latents
88
+ self.text_dropout = text_dropout
89
+
90
+ self.size = size
91
+
92
+ if do_cache:
93
+ self.vae_latents = []
94
+ self.tokens_tuple = []
95
+ self.masks = []
96
+
97
+ self.do_cache = True
98
+
99
+ print("Captions to train on: ")
100
+ for idx in range(len(self.data)):
101
+ token, vae_latent, mask = self._process(idx)
102
+ self.vae_latents.append(vae_latent)
103
+ self.tokens_tuple.append(token)
104
+ self.masks.append(mask)
105
+
106
+ del self.vae_encoder
107
+
108
+ else:
109
+ self.do_cache = False
110
+
111
+ @torch.no_grad()
112
+ def _process(
113
+ self, idx: int
114
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
115
+ image_path = self.image_path[idx]
116
+ image_path = os.path.join(os.path.dirname(self.csv_path), image_path)
117
+
118
+ image = PIL.Image.open(image_path).convert("RGB")
119
+ image = prepare_image(image, self.size, self.size).to(
120
+ dtype=self.vae_encoder.dtype, device=self.vae_encoder.device
121
+ )
122
+
123
+ caption = self.caption[idx]
124
+
125
+ print(caption)
126
+
127
+ # tokenizer_1
128
+ ti1 = self.tokenizer_1(
129
+ caption,
130
+ padding="max_length",
131
+ max_length=77,
132
+ truncation=True,
133
+ add_special_tokens=True,
134
+ return_tensors="pt",
135
+ ).input_ids
136
+
137
+ ti2 = self.tokenizer_2(
138
+ caption,
139
+ padding="max_length",
140
+ max_length=77,
141
+ truncation=True,
142
+ add_special_tokens=True,
143
+ return_tensors="pt",
144
+ ).input_ids
145
+
146
+ vae_latent = self.vae_encoder.encode(image).latent_dist.sample()
147
+
148
+ if self.scale_vae_latents:
149
+ vae_latent = vae_latent * self.vae_encoder.config.scaling_factor
150
+
151
+ if self.mask_path is None:
152
+ mask = torch.ones_like(
153
+ vae_latent, dtype=self.vae_encoder.dtype, device=self.vae_encoder.device
154
+ )
155
+
156
+ else:
157
+ mask_path = self.mask_path[idx]
158
+ mask_path = os.path.join(os.path.dirname(self.csv_path), mask_path)
159
+
160
+ mask = PIL.Image.open(mask_path)
161
+ mask = prepare_mask(mask, self.size, self.size).to(
162
+ dtype=self.vae_encoder.dtype, device=self.vae_encoder.device
163
+ )
164
+
165
+ mask = torch.nn.functional.interpolate(
166
+ mask, size=(vae_latent.shape[-2], vae_latent.shape[-1]), mode="nearest"
167
+ )
168
+ mask = mask.repeat(1, vae_latent.shape[1], 1, 1)
169
+
170
+ assert len(mask.shape) == 4 and len(vae_latent.shape) == 4
171
+
172
+ return (ti1.squeeze(), ti2.squeeze()), vae_latent.squeeze(), mask.squeeze()
173
+
174
+ def __len__(self) -> int:
175
+ return len(self.data)
176
+
177
+ def atidx(
178
+ self, idx: int
179
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
180
+ if self.do_cache:
181
+ return self.tokens_tuple[idx], self.vae_latents[idx], self.masks[idx]
182
+ else:
183
+ return self._process(idx)
184
+
185
+ def __getitem__(
186
+ self, idx: int
187
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
188
+ token, vae_latent, mask = self.atidx(idx)
189
+ return token, vae_latent, mask
190
+
191
+
192
+ def import_model_class_from_model_name_or_path(
193
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
194
+ ):
195
+ text_encoder_config = PretrainedConfig.from_pretrained(
196
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
197
+ )
198
+ model_class = text_encoder_config.architectures[0]
199
+
200
+ if model_class == "CLIPTextModel":
201
+ from transformers import CLIPTextModel
202
+
203
+ return CLIPTextModel
204
+ elif model_class == "CLIPTextModelWithProjection":
205
+ from transformers import CLIPTextModelWithProjection
206
+
207
+ return CLIPTextModelWithProjection
208
+ else:
209
+ raise ValueError(f"{model_class} is not supported.")
210
+
211
+
212
+ def load_models(pretrained_model_name_or_path, revision, device, weight_dtype):
213
+ tokenizer_one = AutoTokenizer.from_pretrained(
214
+ pretrained_model_name_or_path,
215
+ subfolder="tokenizer",
216
+ revision=revision,
217
+ use_fast=False,
218
+ )
219
+ tokenizer_two = AutoTokenizer.from_pretrained(
220
+ pretrained_model_name_or_path,
221
+ subfolder="tokenizer_2",
222
+ revision=revision,
223
+ use_fast=False,
224
+ )
225
+
226
+ # Load scheduler and models
227
+ noise_scheduler = DDPMScheduler.from_pretrained(
228
+ pretrained_model_name_or_path, subfolder="scheduler"
229
+ )
230
+ # import correct text encoder classes
231
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
232
+ pretrained_model_name_or_path, revision
233
+ )
234
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
235
+ pretrained_model_name_or_path, revision, subfolder="text_encoder_2"
236
+ )
237
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
238
+ pretrained_model_name_or_path, subfolder="text_encoder", revision=revision
239
+ )
240
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
241
+ pretrained_model_name_or_path, subfolder="text_encoder_2", revision=revision
242
+ )
243
+
244
+ vae = AutoencoderKL.from_pretrained(
245
+ pretrained_model_name_or_path, subfolder="vae", revision=revision
246
+ )
247
+ unet = UNet2DConditionModel.from_pretrained(
248
+ pretrained_model_name_or_path, subfolder="unet", revision=revision
249
+ )
250
+
251
+ vae.requires_grad_(False)
252
+ text_encoder_one.requires_grad_(False)
253
+ text_encoder_two.requires_grad_(False)
254
+
255
+ unet.to(device, dtype=weight_dtype)
256
+ vae.to(device, dtype=torch.float32)
257
+ text_encoder_one.to(device, dtype=weight_dtype)
258
+ text_encoder_two.to(device, dtype=weight_dtype)
259
+
260
+ return (
261
+ tokenizer_one,
262
+ tokenizer_two,
263
+ noise_scheduler,
264
+ text_encoder_one,
265
+ text_encoder_two,
266
+ vae,
267
+ unet,
268
+ )
269
+
270
+
271
+ def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
272
+ """
273
+ Returns:
274
+ a state dict containing just the attention processor parameters.
275
+ """
276
+ attn_processors = unet.attn_processors
277
+
278
+ attn_processors_state_dict = {}
279
+
280
+ for attn_processor_key, attn_processor in attn_processors.items():
281
+ for parameter_key, parameter in attn_processor.state_dict().items():
282
+ attn_processors_state_dict[
283
+ f"{attn_processor_key}.{parameter_key}"
284
+ ] = parameter
285
+
286
+ return attn_processors_state_dict
287
+
288
+
289
+ class TokenEmbeddingsHandler:
290
+ def __init__(self, text_encoders, tokenizers):
291
+ self.text_encoders = text_encoders
292
+ self.tokenizers = tokenizers
293
+
294
+ self.train_ids: Optional[torch.Tensor] = None
295
+ self.inserting_toks: Optional[List[str]] = None
296
+ self.embeddings_settings = {}
297
+
298
+ def initialize_new_tokens(self, inserting_toks: List[str]):
299
+ idx = 0
300
+ for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
301
+ assert isinstance(
302
+ inserting_toks, list
303
+ ), "inserting_toks should be a list of strings."
304
+ assert all(
305
+ isinstance(tok, str) for tok in inserting_toks
306
+ ), "All elements in inserting_toks should be strings."
307
+
308
+ self.inserting_toks = inserting_toks
309
+ special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
310
+ tokenizer.add_special_tokens(special_tokens_dict)
311
+ text_encoder.resize_token_embeddings(len(tokenizer))
312
+
313
+ self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
314
+
315
+ # random initialization of new tokens
316
+
317
+ std_token_embedding = (
318
+ text_encoder.text_model.embeddings.token_embedding.weight.data.std()
319
+ )
320
+
321
+ print(f"{idx} text encodedr's std_token_embedding: {std_token_embedding}")
322
+
323
+ text_encoder.text_model.embeddings.token_embedding.weight.data[
324
+ self.train_ids
325
+ ] = (
326
+ torch.randn(
327
+ len(self.train_ids), text_encoder.text_model.config.hidden_size
328
+ )
329
+ .to(device=self.device)
330
+ .to(dtype=self.dtype)
331
+ * std_token_embedding
332
+ )
333
+ self.embeddings_settings[
334
+ f"original_embeddings_{idx}"
335
+ ] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
336
+ self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
337
+
338
+ inu = torch.ones((len(tokenizer),), dtype=torch.bool)
339
+ inu[self.train_ids] = False
340
+
341
+ self.embeddings_settings[f"index_no_updates_{idx}"] = inu
342
+
343
+ print(self.embeddings_settings[f"index_no_updates_{idx}"].shape)
344
+
345
+ idx += 1
346
+
347
+ def save_embeddings(self, file_path: str):
348
+ assert (
349
+ self.train_ids is not None
350
+ ), "Initialize new tokens before saving embeddings."
351
+ tensors = {}
352
+ for idx, text_encoder in enumerate(self.text_encoders):
353
+ assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[
354
+ 0
355
+ ] == len(self.tokenizers[0]), "Tokenizers should be the same."
356
+ new_token_embeddings = (
357
+ text_encoder.text_model.embeddings.token_embedding.weight.data[
358
+ self.train_ids
359
+ ]
360
+ )
361
+ tensors[f"text_encoders_{idx}"] = new_token_embeddings
362
+
363
+ save_file(tensors, file_path)
364
+
365
+ @property
366
+ def dtype(self):
367
+ return self.text_encoders[0].dtype
368
+
369
+ @property
370
+ def device(self):
371
+ return self.text_encoders[0].device
372
+
373
+ def _load_embeddings(self, loaded_embeddings, tokenizer, text_encoder):
374
+ # Assuming new tokens are of the format <s_i>
375
+ self.inserting_toks = [f"<s{i}>" for i in range(loaded_embeddings.shape[0])]
376
+ special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
377
+ tokenizer.add_special_tokens(special_tokens_dict)
378
+ text_encoder.resize_token_embeddings(len(tokenizer))
379
+
380
+ self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
381
+ assert self.train_ids is not None, "New tokens could not be converted to IDs."
382
+ text_encoder.text_model.embeddings.token_embedding.weight.data[
383
+ self.train_ids
384
+ ] = loaded_embeddings.to(device=self.device).to(dtype=self.dtype)
385
+
386
+ @torch.no_grad()
387
+ def retract_embeddings(self):
388
+ for idx, text_encoder in enumerate(self.text_encoders):
389
+ index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
390
+ text_encoder.text_model.embeddings.token_embedding.weight.data[
391
+ index_no_updates
392
+ ] = (
393
+ self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
394
+ .to(device=text_encoder.device)
395
+ .to(dtype=text_encoder.dtype)
396
+ )
397
+
398
+ # for the parts that were updated, we need to normalize them
399
+ # to have the same std as before
400
+ std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"]
401
+
402
+ index_updates = ~index_no_updates
403
+ new_embeddings = (
404
+ text_encoder.text_model.embeddings.token_embedding.weight.data[
405
+ index_updates
406
+ ]
407
+ )
408
+ off_ratio = std_token_embedding / new_embeddings.std()
409
+
410
+ new_embeddings = new_embeddings * (off_ratio**0.1)
411
+ text_encoder.text_model.embeddings.token_embedding.weight.data[
412
+ index_updates
413
+ ] = new_embeddings
414
+
415
+ def load_embeddings(self, file_path: str):
416
+ with safe_open(file_path, framework="pt", device=self.device.type) as f:
417
+ for idx in range(len(self.text_encoders)):
418
+ text_encoder = self.text_encoders[idx]
419
+ tokenizer = self.tokenizers[idx]
420
+
421
+ loaded_embeddings = f.get_tensor(f"text_encoders_{idx}")
422
+ self._load_embeddings(loaded_embeddings, tokenizer, text_encoder)