Bingsu commited on
Commit
a0c557d
1 Parent(s): b69ee8e

add original code

Browse files
Files changed (1) hide show
  1. textual_inversion.py +772 -0
textual_inversion.py ADDED
@@ -0,0 +1,772 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import itertools
3
+ import math
4
+ import os
5
+ import random
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint
13
+ from torch.utils.data import Dataset
14
+
15
+ import PIL
16
+ from accelerate import Accelerator
17
+ from accelerate.logging import get_logger
18
+ from accelerate.utils import set_seed
19
+ from diffusers import (
20
+ AutoencoderKL,
21
+ DDPMScheduler,
22
+ PNDMScheduler,
23
+ StableDiffusionPipeline,
24
+ UNet2DConditionModel,
25
+ )
26
+ from diffusers.optimization import get_scheduler
27
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
28
+ from diffusers.utils import check_min_version
29
+ from huggingface_hub import HfFolder, Repository, whoami
30
+
31
+ # TODO: remove and import from diffusers.utils when the new version of diffusers is released
32
+ from packaging import version
33
+ from PIL import Image
34
+ from torchvision import transforms
35
+ from tqdm.auto import tqdm
36
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
37
+
38
+
39
+ if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
40
+ PIL_INTERPOLATION = {
41
+ "linear": PIL.Image.Resampling.BILINEAR,
42
+ "bilinear": PIL.Image.Resampling.BILINEAR,
43
+ "bicubic": PIL.Image.Resampling.BICUBIC,
44
+ "lanczos": PIL.Image.Resampling.LANCZOS,
45
+ "nearest": PIL.Image.Resampling.NEAREST,
46
+ }
47
+ else:
48
+ PIL_INTERPOLATION = {
49
+ "linear": PIL.Image.LINEAR,
50
+ "bilinear": PIL.Image.BILINEAR,
51
+ "bicubic": PIL.Image.BICUBIC,
52
+ "lanczos": PIL.Image.LANCZOS,
53
+ "nearest": PIL.Image.NEAREST,
54
+ }
55
+ # ------------------------------------------------------------------------------
56
+
57
+
58
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
59
+ check_min_version("0.10.0.dev0")
60
+
61
+
62
+ logger = get_logger(__name__)
63
+
64
+
65
+ def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
66
+ logger.info("Saving embeddings")
67
+ learned_embeds = (
68
+ accelerator.unwrap_model(text_encoder)
69
+ .get_input_embeddings()
70
+ .weight[placeholder_token_id]
71
+ )
72
+ learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
73
+ torch.save(learned_embeds_dict, save_path)
74
+
75
+
76
+ def parse_args():
77
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
78
+ parser.add_argument(
79
+ "--save_steps",
80
+ type=int,
81
+ default=500,
82
+ help="Save learned_embeds.bin every X updates steps.",
83
+ )
84
+ parser.add_argument(
85
+ "--only_save_embeds",
86
+ action="store_true",
87
+ default=False,
88
+ help="Save only the embeddings for the new concept.",
89
+ )
90
+ parser.add_argument(
91
+ "--pretrained_model_name_or_path",
92
+ type=str,
93
+ default=None,
94
+ required=True,
95
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
96
+ )
97
+ parser.add_argument(
98
+ "--revision",
99
+ type=str,
100
+ default=None,
101
+ required=False,
102
+ help="Revision of pretrained model identifier from huggingface.co/models.",
103
+ )
104
+ parser.add_argument(
105
+ "--tokenizer_name",
106
+ type=str,
107
+ default=None,
108
+ help="Pretrained tokenizer name or path if not the same as model_name",
109
+ )
110
+ parser.add_argument(
111
+ "--train_data_dir",
112
+ type=str,
113
+ default=None,
114
+ required=True,
115
+ help="A folder containing the training data.",
116
+ )
117
+ parser.add_argument(
118
+ "--placeholder_token",
119
+ type=str,
120
+ default=None,
121
+ required=True,
122
+ help="A token to use as a placeholder for the concept.",
123
+ )
124
+ parser.add_argument(
125
+ "--initializer_token",
126
+ type=str,
127
+ default=None,
128
+ required=True,
129
+ help="A token to use as initializer word.",
130
+ )
131
+ parser.add_argument(
132
+ "--learnable_property",
133
+ type=str,
134
+ default="object",
135
+ help="Choose between 'object' and 'style'",
136
+ )
137
+ parser.add_argument(
138
+ "--repeats",
139
+ type=int,
140
+ default=100,
141
+ help="How many times to repeat the training data.",
142
+ )
143
+ parser.add_argument(
144
+ "--output_dir",
145
+ type=str,
146
+ default="text-inversion-model",
147
+ help="The output directory where the model predictions and checkpoints will be written.",
148
+ )
149
+ parser.add_argument(
150
+ "--seed", type=int, default=None, help="A seed for reproducible training."
151
+ )
152
+ parser.add_argument(
153
+ "--resolution",
154
+ type=int,
155
+ default=512,
156
+ help=(
157
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
158
+ " resolution"
159
+ ),
160
+ )
161
+ parser.add_argument(
162
+ "--center_crop",
163
+ action="store_true",
164
+ help="Whether to center crop images before resizing to resolution",
165
+ )
166
+ parser.add_argument(
167
+ "--train_batch_size",
168
+ type=int,
169
+ default=16,
170
+ help="Batch size (per device) for the training dataloader.",
171
+ )
172
+ parser.add_argument("--num_train_epochs", type=int, default=100)
173
+ parser.add_argument(
174
+ "--max_train_steps",
175
+ type=int,
176
+ default=5000,
177
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
178
+ )
179
+ parser.add_argument(
180
+ "--gradient_accumulation_steps",
181
+ type=int,
182
+ default=1,
183
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
184
+ )
185
+ parser.add_argument(
186
+ "--learning_rate",
187
+ type=float,
188
+ default=1e-4,
189
+ help="Initial learning rate (after the potential warmup period) to use.",
190
+ )
191
+ parser.add_argument(
192
+ "--scale_lr",
193
+ action="store_true",
194
+ default=True,
195
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
196
+ )
197
+ parser.add_argument(
198
+ "--lr_scheduler",
199
+ type=str,
200
+ default="constant",
201
+ help=(
202
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
203
+ ' "constant", "constant_with_warmup"]'
204
+ ),
205
+ )
206
+ parser.add_argument(
207
+ "--lr_warmup_steps",
208
+ type=int,
209
+ default=500,
210
+ help="Number of steps for the warmup in the lr scheduler.",
211
+ )
212
+ parser.add_argument(
213
+ "--adam_beta1",
214
+ type=float,
215
+ default=0.9,
216
+ help="The beta1 parameter for the Adam optimizer.",
217
+ )
218
+ parser.add_argument(
219
+ "--adam_beta2",
220
+ type=float,
221
+ default=0.999,
222
+ help="The beta2 parameter for the Adam optimizer.",
223
+ )
224
+ parser.add_argument(
225
+ "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use."
226
+ )
227
+ parser.add_argument(
228
+ "--adam_epsilon",
229
+ type=float,
230
+ default=1e-08,
231
+ help="Epsilon value for the Adam optimizer",
232
+ )
233
+ parser.add_argument(
234
+ "--push_to_hub",
235
+ action="store_true",
236
+ help="Whether or not to push the model to the Hub.",
237
+ )
238
+ parser.add_argument(
239
+ "--hub_token",
240
+ type=str,
241
+ default=None,
242
+ help="The token to use to push to the Model Hub.",
243
+ )
244
+ parser.add_argument(
245
+ "--hub_model_id",
246
+ type=str,
247
+ default=None,
248
+ help="The name of the repository to keep in sync with the local `output_dir`.",
249
+ )
250
+ parser.add_argument(
251
+ "--logging_dir",
252
+ type=str,
253
+ default="logs",
254
+ help=(
255
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
256
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
257
+ ),
258
+ )
259
+ parser.add_argument(
260
+ "--mixed_precision",
261
+ type=str,
262
+ default="no",
263
+ choices=["no", "fp16", "bf16"],
264
+ help=(
265
+ "Whether to use mixed precision. Choose"
266
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
267
+ "and an Nvidia Ampere GPU."
268
+ ),
269
+ )
270
+ parser.add_argument(
271
+ "--local_rank",
272
+ type=int,
273
+ default=-1,
274
+ help="For distributed training: local_rank",
275
+ )
276
+
277
+ args = parser.parse_args()
278
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
279
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
280
+ args.local_rank = env_local_rank
281
+
282
+ if args.train_data_dir is None:
283
+ raise ValueError("You must specify a train data directory.")
284
+
285
+ return args
286
+
287
+
288
+ imagenet_templates_small = [
289
+ "a photo of a {}",
290
+ "a rendering of a {}",
291
+ "a cropped photo of the {}",
292
+ "the photo of a {}",
293
+ "a photo of a clean {}",
294
+ "a photo of a dirty {}",
295
+ "a dark photo of the {}",
296
+ "a photo of my {}",
297
+ "a photo of the cool {}",
298
+ "a close-up photo of a {}",
299
+ "a bright photo of the {}",
300
+ "a cropped photo of a {}",
301
+ "a photo of the {}",
302
+ "a good photo of the {}",
303
+ "a photo of one {}",
304
+ "a close-up photo of the {}",
305
+ "a rendition of the {}",
306
+ "a photo of the clean {}",
307
+ "a rendition of a {}",
308
+ "a photo of a nice {}",
309
+ "a good photo of a {}",
310
+ "a photo of the nice {}",
311
+ "a photo of the small {}",
312
+ "a photo of the weird {}",
313
+ "a photo of the large {}",
314
+ "a photo of a cool {}",
315
+ "a photo of a small {}",
316
+ ]
317
+
318
+ imagenet_style_templates_small = [
319
+ "a painting in the style of {}",
320
+ "a rendering in the style of {}",
321
+ "a cropped painting in the style of {}",
322
+ "the painting in the style of {}",
323
+ "a clean painting in the style of {}",
324
+ "a dirty painting in the style of {}",
325
+ "a dark painting in the style of {}",
326
+ "a picture in the style of {}",
327
+ "a cool painting in the style of {}",
328
+ "a close-up painting in the style of {}",
329
+ "a bright painting in the style of {}",
330
+ "a cropped painting in the style of {}",
331
+ "a good painting in the style of {}",
332
+ "a close-up painting in the style of {}",
333
+ "a rendition in the style of {}",
334
+ "a nice painting in the style of {}",
335
+ "a small painting in the style of {}",
336
+ "a weird painting in the style of {}",
337
+ "a large painting in the style of {}",
338
+ ]
339
+
340
+
341
+ class TextualInversionDataset(Dataset):
342
+ def __init__(
343
+ self,
344
+ data_root,
345
+ tokenizer,
346
+ learnable_property="object", # [object, style]
347
+ size=512,
348
+ repeats=100,
349
+ interpolation="bicubic",
350
+ flip_p=0.5,
351
+ set="train",
352
+ placeholder_token="*",
353
+ center_crop=False,
354
+ ):
355
+ self.data_root = data_root
356
+ self.tokenizer = tokenizer
357
+ self.learnable_property = learnable_property
358
+ self.size = size
359
+ self.placeholder_token = placeholder_token
360
+ self.center_crop = center_crop
361
+ self.flip_p = flip_p
362
+
363
+ self.image_paths = [
364
+ os.path.join(self.data_root, file_path)
365
+ for file_path in os.listdir(self.data_root)
366
+ ]
367
+
368
+ self.num_images = len(self.image_paths)
369
+ self._length = self.num_images
370
+
371
+ if set == "train":
372
+ self._length = self.num_images * repeats
373
+
374
+ self.interpolation = {
375
+ "linear": PIL_INTERPOLATION["linear"],
376
+ "bilinear": PIL_INTERPOLATION["bilinear"],
377
+ "bicubic": PIL_INTERPOLATION["bicubic"],
378
+ "lanczos": PIL_INTERPOLATION["lanczos"],
379
+ }[interpolation]
380
+
381
+ self.templates = (
382
+ imagenet_style_templates_small
383
+ if learnable_property == "style"
384
+ else imagenet_templates_small
385
+ )
386
+ self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
387
+
388
+ def __len__(self):
389
+ return self._length
390
+
391
+ def __getitem__(self, i):
392
+ example = {}
393
+ image = Image.open(self.image_paths[i % self.num_images])
394
+
395
+ if not image.mode == "RGB":
396
+ image = image.convert("RGB")
397
+
398
+ placeholder_string = self.placeholder_token
399
+ text = random.choice(self.templates).format(placeholder_string)
400
+
401
+ example["input_ids"] = self.tokenizer(
402
+ text,
403
+ padding="max_length",
404
+ truncation=True,
405
+ max_length=self.tokenizer.model_max_length,
406
+ return_tensors="pt",
407
+ ).input_ids[0]
408
+
409
+ # default to score-sde preprocessing
410
+ img = np.array(image).astype(np.uint8)
411
+
412
+ if self.center_crop:
413
+ crop = min(img.shape[0], img.shape[1])
414
+ h, w, = (
415
+ img.shape[0],
416
+ img.shape[1],
417
+ )
418
+ img = img[
419
+ (h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2
420
+ ]
421
+
422
+ image = Image.fromarray(img)
423
+ image = image.resize((self.size, self.size), resample=self.interpolation)
424
+
425
+ image = self.flip_transform(image)
426
+ image = np.array(image).astype(np.uint8)
427
+ image = (image / 127.5 - 1.0).astype(np.float32)
428
+
429
+ example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
430
+ return example
431
+
432
+
433
+ def get_full_repo_name(
434
+ model_id: str, organization: Optional[str] = None, token: Optional[str] = None
435
+ ):
436
+ if token is None:
437
+ token = HfFolder.get_token()
438
+ if organization is None:
439
+ username = whoami(token)["name"]
440
+ return f"{username}/{model_id}"
441
+ else:
442
+ return f"{organization}/{model_id}"
443
+
444
+
445
+ def freeze_params(params):
446
+ for param in params:
447
+ param.requires_grad = False
448
+
449
+
450
+ def main():
451
+ args = parse_args()
452
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
453
+
454
+ accelerator = Accelerator(
455
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
456
+ mixed_precision=args.mixed_precision,
457
+ log_with="tensorboard",
458
+ logging_dir=logging_dir,
459
+ )
460
+
461
+ # If passed along, set the training seed now.
462
+ if args.seed is not None:
463
+ set_seed(args.seed)
464
+
465
+ # Handle the repository creation
466
+ if accelerator.is_main_process:
467
+ if args.push_to_hub:
468
+ if args.hub_model_id is None:
469
+ repo_name = get_full_repo_name(
470
+ Path(args.output_dir).name, token=args.hub_token
471
+ )
472
+ else:
473
+ repo_name = args.hub_model_id
474
+ repo = Repository(args.output_dir, clone_from=repo_name)
475
+
476
+ with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
477
+ if "step_*" not in gitignore:
478
+ gitignore.write("step_*\n")
479
+ if "epoch_*" not in gitignore:
480
+ gitignore.write("epoch_*\n")
481
+ elif args.output_dir is not None:
482
+ os.makedirs(args.output_dir, exist_ok=True)
483
+
484
+ # Load the tokenizer and add the placeholder token as a additional special token
485
+ if args.tokenizer_name:
486
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
487
+ elif args.pretrained_model_name_or_path:
488
+ tokenizer = CLIPTokenizer.from_pretrained(
489
+ args.pretrained_model_name_or_path, subfolder="tokenizer"
490
+ )
491
+
492
+ # Add the placeholder token in tokenizer
493
+ num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
494
+ if num_added_tokens == 0:
495
+ raise ValueError(
496
+ f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
497
+ " `placeholder_token` that is not already in the tokenizer."
498
+ )
499
+
500
+ # Convert the initializer_token, placeholder_token to ids
501
+ token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
502
+ # Check if initializer_token is a single token or a sequence of tokens
503
+ if len(token_ids) > 1:
504
+ raise ValueError("The initializer token must be a single token.")
505
+
506
+ initializer_token_id = token_ids[0]
507
+ placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
508
+
509
+ # Load models and create wrapper for stable diffusion
510
+ text_encoder = CLIPTextModel.from_pretrained(
511
+ args.pretrained_model_name_or_path,
512
+ subfolder="text_encoder",
513
+ revision=args.revision,
514
+ )
515
+ vae = AutoencoderKL.from_pretrained(
516
+ args.pretrained_model_name_or_path,
517
+ subfolder="vae",
518
+ revision=args.revision,
519
+ )
520
+ unet = UNet2DConditionModel.from_pretrained(
521
+ args.pretrained_model_name_or_path,
522
+ subfolder="unet",
523
+ revision=args.revision,
524
+ )
525
+
526
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
527
+ text_encoder.resize_token_embeddings(len(tokenizer))
528
+
529
+ # Initialise the newly added placeholder token with the embeddings of the initializer token
530
+ token_embeds = text_encoder.get_input_embeddings().weight.data
531
+ token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
532
+
533
+ # Freeze vae and unet
534
+ freeze_params(vae.parameters())
535
+ freeze_params(unet.parameters())
536
+ # Freeze all parameters except for the token embeddings in text encoder
537
+ params_to_freeze = itertools.chain(
538
+ text_encoder.text_model.encoder.parameters(),
539
+ text_encoder.text_model.final_layer_norm.parameters(),
540
+ text_encoder.text_model.embeddings.position_embedding.parameters(),
541
+ )
542
+ freeze_params(params_to_freeze)
543
+
544
+ if args.scale_lr:
545
+ args.learning_rate = (
546
+ args.learning_rate
547
+ * args.gradient_accumulation_steps
548
+ * args.train_batch_size
549
+ * accelerator.num_processes
550
+ )
551
+
552
+ # Initialize the optimizer
553
+ optimizer = torch.optim.AdamW(
554
+ text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
555
+ lr=args.learning_rate,
556
+ betas=(args.adam_beta1, args.adam_beta2),
557
+ weight_decay=args.adam_weight_decay,
558
+ eps=args.adam_epsilon,
559
+ )
560
+
561
+ noise_scheduler = DDPMScheduler.from_pretrained(
562
+ args.pretrained_model_name_or_path, subfolder="scheduler"
563
+ )
564
+
565
+ train_dataset = TextualInversionDataset(
566
+ data_root=args.train_data_dir,
567
+ tokenizer=tokenizer,
568
+ size=args.resolution,
569
+ placeholder_token=args.placeholder_token,
570
+ repeats=args.repeats,
571
+ learnable_property=args.learnable_property,
572
+ center_crop=args.center_crop,
573
+ set="train",
574
+ )
575
+ train_dataloader = torch.utils.data.DataLoader(
576
+ train_dataset, batch_size=args.train_batch_size, shuffle=True
577
+ )
578
+
579
+ # Scheduler and math around the number of training steps.
580
+ overrode_max_train_steps = False
581
+ num_update_steps_per_epoch = math.ceil(
582
+ len(train_dataloader) / args.gradient_accumulation_steps
583
+ )
584
+ if args.max_train_steps is None:
585
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
586
+ overrode_max_train_steps = True
587
+
588
+ lr_scheduler = get_scheduler(
589
+ args.lr_scheduler,
590
+ optimizer=optimizer,
591
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
592
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
593
+ )
594
+
595
+ text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
596
+ text_encoder, optimizer, train_dataloader, lr_scheduler
597
+ )
598
+
599
+ # Move vae and unet to device
600
+ vae.to(accelerator.device)
601
+ unet.to(accelerator.device)
602
+
603
+ # Keep vae and unet in eval model as we don't train these
604
+ vae.eval()
605
+ unet.eval()
606
+
607
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
608
+ num_update_steps_per_epoch = math.ceil(
609
+ len(train_dataloader) / args.gradient_accumulation_steps
610
+ )
611
+ if overrode_max_train_steps:
612
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
613
+ # Afterwards we recalculate our number of training epochs
614
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
615
+
616
+ # We need to initialize the trackers we use, and also store our configuration.
617
+ # The trackers initializes automatically on the main process.
618
+ if accelerator.is_main_process:
619
+ accelerator.init_trackers("textual_inversion", config=vars(args))
620
+
621
+ # Train!
622
+ total_batch_size = (
623
+ args.train_batch_size
624
+ * accelerator.num_processes
625
+ * args.gradient_accumulation_steps
626
+ )
627
+
628
+ logger.info("***** Running training *****")
629
+ logger.info(f" Num examples = {len(train_dataset)}")
630
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
631
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
632
+ logger.info(
633
+ f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
634
+ )
635
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
636
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
637
+ # Only show the progress bar once on each machine.
638
+ progress_bar = tqdm(
639
+ range(args.max_train_steps), disable=not accelerator.is_local_main_process
640
+ )
641
+ progress_bar.set_description("Steps")
642
+ global_step = 0
643
+
644
+ for epoch in range(args.num_train_epochs):
645
+ text_encoder.train()
646
+ for step, batch in enumerate(train_dataloader):
647
+ with accelerator.accumulate(text_encoder):
648
+ # Convert images to latent space
649
+ latents = (
650
+ vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
651
+ )
652
+ latents = latents * 0.18215
653
+
654
+ # Sample noise that we'll add to the latents
655
+ noise = torch.randn(latents.shape).to(latents.device)
656
+ bsz = latents.shape[0]
657
+ # Sample a random timestep for each image
658
+ timesteps = torch.randint(
659
+ 0,
660
+ noise_scheduler.config.num_train_timesteps,
661
+ (bsz,),
662
+ device=latents.device,
663
+ ).long()
664
+
665
+ # Add noise to the latents according to the noise magnitude at each timestep
666
+ # (this is the forward diffusion process)
667
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
668
+
669
+ # Get the text embedding for conditioning
670
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
671
+
672
+ # Predict the noise residual
673
+ model_pred = unet(
674
+ noisy_latents, timesteps, encoder_hidden_states
675
+ ).sample
676
+
677
+ # Get the target for loss depending on the prediction type
678
+ if noise_scheduler.config.prediction_type == "epsilon":
679
+ target = noise
680
+ elif noise_scheduler.config.prediction_type == "v_prediction":
681
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
682
+ else:
683
+ raise ValueError(
684
+ f"Unknown prediction type {noise_scheduler.config.prediction_type}"
685
+ )
686
+
687
+ loss = (
688
+ F.mse_loss(model_pred, target, reduction="none")
689
+ .mean([1, 2, 3])
690
+ .mean()
691
+ )
692
+ accelerator.backward(loss)
693
+
694
+ # Zero out the gradients for all token embeddings except the newly added
695
+ # embeddings for the concept, as we only want to optimize the concept embeddings
696
+ if accelerator.num_processes > 1:
697
+ grads = text_encoder.module.get_input_embeddings().weight.grad
698
+ else:
699
+ grads = text_encoder.get_input_embeddings().weight.grad
700
+ # Get the index for tokens that we want to zero the grads for
701
+ index_grads_to_zero = (
702
+ torch.arange(len(tokenizer)) != placeholder_token_id
703
+ )
704
+ grads.data[index_grads_to_zero, :] = grads.data[
705
+ index_grads_to_zero, :
706
+ ].fill_(0)
707
+
708
+ optimizer.step()
709
+ lr_scheduler.step()
710
+ optimizer.zero_grad()
711
+
712
+ # Checks if the accelerator has performed an optimization step behind the scenes
713
+ if accelerator.sync_gradients:
714
+ progress_bar.update(1)
715
+ global_step += 1
716
+ if global_step % args.save_steps == 0:
717
+ save_path = os.path.join(
718
+ args.output_dir, f"learned_embeds-steps-{global_step}.bin"
719
+ )
720
+ save_progress(
721
+ text_encoder, placeholder_token_id, accelerator, args, save_path
722
+ )
723
+
724
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
725
+ progress_bar.set_postfix(**logs)
726
+ accelerator.log(logs, step=global_step)
727
+
728
+ if global_step >= args.max_train_steps:
729
+ break
730
+
731
+ accelerator.wait_for_everyone()
732
+
733
+ # Create the pipeline using using the trained modules and save it.
734
+ if accelerator.is_main_process:
735
+ if args.push_to_hub and args.only_save_embeds:
736
+ logger.warn(
737
+ "Enabling full model saving because --push_to_hub=True was specified."
738
+ )
739
+ save_full_model = True
740
+ else:
741
+ save_full_model = not args.only_save_embeds
742
+ if save_full_model:
743
+ pipeline = StableDiffusionPipeline(
744
+ text_encoder=accelerator.unwrap_model(text_encoder),
745
+ vae=vae,
746
+ unet=unet,
747
+ tokenizer=tokenizer,
748
+ scheduler=PNDMScheduler.from_pretrained(
749
+ args.pretrained_model_name_or_path, subfolder="scheduler"
750
+ ),
751
+ safety_checker=StableDiffusionSafetyChecker.from_pretrained(
752
+ "CompVis/stable-diffusion-safety-checker"
753
+ ),
754
+ feature_extractor=CLIPFeatureExtractor.from_pretrained(
755
+ "openai/clip-vit-base-patch32"
756
+ ),
757
+ )
758
+ pipeline.save_pretrained(args.output_dir)
759
+ # Save the newly trained embeddings
760
+ save_path = os.path.join(args.output_dir, "learned_embeds.bin")
761
+ save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
762
+
763
+ if args.push_to_hub:
764
+ repo.push_to_hub(
765
+ commit_message="End of training", blocking=False, auto_lfs_prune=True
766
+ )
767
+
768
+ accelerator.end_training()
769
+
770
+
771
+ if __name__ == "__main__":
772
+ main()