tranmc commited on
Commit
fe313c3
·
1 Parent(s): 8551104

Delete train_dreambooth.py

Browse files
Files changed (1) hide show
  1. train_dreambooth.py +0 -695
train_dreambooth.py DELETED
@@ -1,695 +0,0 @@
1
- import argparse
2
- import itertools
3
- import math
4
- import os
5
- from pathlib import Path
6
- from typing import Optional
7
- from contextlib import nullcontext
8
- from diffusers.pipelines.stable_diffusion import safety_checker
9
-
10
- import torch
11
- import torch.nn.functional as F
12
- import torch.utils.checkpoint
13
- from torch.utils.data import Dataset
14
-
15
- from accelerate import Accelerator
16
- from accelerate.logging import get_logger
17
- from accelerate.utils import set_seed
18
- from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
19
- from diffusers.optimization import get_scheduler
20
- from huggingface_hub import HfFolder, Repository, whoami
21
- from PIL import Image
22
- from torchvision import transforms
23
- from tqdm.auto import tqdm
24
- from transformers import CLIPTextModel, CLIPTokenizer
25
-
26
-
27
- logger = get_logger(__name__)
28
-
29
-
30
- def parse_args():
31
- parser = argparse.ArgumentParser(description="Simple example of a training script.")
32
- parser.add_argument(
33
- "--pretrained_model_name_or_path",
34
- type=str,
35
- default=None,
36
- required=True,
37
- help="Path to pretrained model or model identifier from huggingface.co/models.",
38
- )
39
- parser.add_argument(
40
- "--tokenizer_name",
41
- type=str,
42
- default=None,
43
- help="Pretrained tokenizer name or path if not the same as model_name",
44
- )
45
- parser.add_argument(
46
- "--instance_data_dir",
47
- type=str,
48
- default=None,
49
- required=True,
50
- help="A folder containing the training data of instance images.",
51
- )
52
- parser.add_argument(
53
- "--class_data_dir",
54
- type=str,
55
- default=None,
56
- required=False,
57
- help="A folder containing the training data of class images.",
58
- )
59
- parser.add_argument(
60
- "--instance_prompt",
61
- type=str,
62
- default=None,
63
- help="The prompt with identifier specifying the instance",
64
- )
65
- parser.add_argument(
66
- "--class_prompt",
67
- type=str,
68
- default=None,
69
- help="The prompt to specify images in the same class as provided instance images.",
70
- )
71
- parser.add_argument(
72
- "--with_prior_preservation",
73
- default=False,
74
- action="store_true",
75
- help="Flag to add prior preservation loss.",
76
- )
77
- parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
78
- parser.add_argument(
79
- "--num_class_images",
80
- type=int,
81
- default=100,
82
- help=(
83
- "Minimal class images for prior preservation loss. If not have enough images, additional images will be"
84
- " sampled with class_prompt."
85
- ),
86
- )
87
- parser.add_argument(
88
- "--output_dir",
89
- type=str,
90
- default="text-inversion-model",
91
- help="The output directory where the model predictions and checkpoints will be written.",
92
- )
93
- parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
94
- parser.add_argument(
95
- "--resolution",
96
- type=int,
97
- default=512,
98
- help=(
99
- "The resolution for input images, all the images in the train/validation dataset will be resized to this"
100
- " resolution"
101
- ),
102
- )
103
- parser.add_argument(
104
- "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
105
- )
106
- parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
107
- parser.add_argument(
108
- "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
109
- )
110
- parser.add_argument(
111
- "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
112
- )
113
- parser.add_argument("--num_train_epochs", type=int, default=1)
114
- parser.add_argument(
115
- "--max_train_steps",
116
- type=int,
117
- default=None,
118
- help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
119
- )
120
- parser.add_argument(
121
- "--gradient_accumulation_steps",
122
- type=int,
123
- default=1,
124
- help="Number of updates steps to accumulate before performing a backward/update pass.",
125
- )
126
- parser.add_argument(
127
- "--gradient_checkpointing",
128
- action="store_true",
129
- help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
130
- )
131
- parser.add_argument(
132
- "--learning_rate",
133
- type=float,
134
- default=5e-6,
135
- help="Initial learning rate (after the potential warmup period) to use.",
136
- )
137
- parser.add_argument(
138
- "--scale_lr",
139
- action="store_true",
140
- default=False,
141
- help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
142
- )
143
- parser.add_argument(
144
- "--lr_scheduler",
145
- type=str,
146
- default="constant",
147
- help=(
148
- 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
149
- ' "constant", "constant_with_warmup"]'
150
- ),
151
- )
152
- parser.add_argument(
153
- "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
154
- )
155
- parser.add_argument(
156
- "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
157
- )
158
- parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
159
- parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
160
- parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
161
- parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
162
- parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
163
- parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
164
- parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
165
- parser.add_argument(
166
- "--hub_model_id",
167
- type=str,
168
- default=None,
169
- help="The name of the repository to keep in sync with the local `output_dir`.",
170
- )
171
- parser.add_argument(
172
- "--logging_dir",
173
- type=str,
174
- default="logs",
175
- help=(
176
- "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
177
- " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
178
- ),
179
- )
180
- parser.add_argument("--log_interval", type=int, default=10, help="Log every N steps.")
181
- parser.add_argument(
182
- "--mixed_precision",
183
- type=str,
184
- default="no",
185
- choices=["no", "fp16", "bf16"],
186
- help=(
187
- "Whether to use mixed precision. Choose"
188
- "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
189
- "and an Nvidia Ampere GPU."
190
- ),
191
- )
192
- parser.add_argument("--not_cache_latents", action="store_true", help="Do not precompute and cache latents from VAE.")
193
- parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
194
-
195
- args = parser.parse_args()
196
- env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
197
- if env_local_rank != -1 and env_local_rank != args.local_rank:
198
- args.local_rank = env_local_rank
199
-
200
- if args.instance_data_dir is None:
201
- raise ValueError("You must specify a train data directory.")
202
-
203
- if args.with_prior_preservation:
204
- if args.class_data_dir is None:
205
- raise ValueError("You must specify a data directory for class images.")
206
- if args.class_prompt is None:
207
- raise ValueError("You must specify prompt for class images.")
208
-
209
- return args
210
-
211
-
212
- class DreamBoothDataset(Dataset):
213
- """
214
- A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
215
- It pre-processes the images and the tokenizes prompts.
216
- """
217
-
218
- def __init__(
219
- self,
220
- instance_data_root,
221
- instance_prompt,
222
- tokenizer,
223
- class_data_root=None,
224
- class_prompt=None,
225
- size=512,
226
- center_crop=False,
227
- ):
228
- self.size = size
229
- self.center_crop = center_crop
230
- self.tokenizer = tokenizer
231
-
232
- self.instance_data_root = Path(instance_data_root)
233
- if not self.instance_data_root.exists():
234
- raise ValueError("Instance images root doesn't exists.")
235
-
236
- self.instance_images_path = [x for x in Path(instance_data_root).iterdir() if x.is_file()]
237
- self.num_instance_images = len(self.instance_images_path)
238
- self.instance_prompt = instance_prompt
239
- self._length = self.num_instance_images
240
-
241
- if class_data_root is not None:
242
- self.class_data_root = Path(class_data_root)
243
- self.class_data_root.mkdir(parents=True, exist_ok=True)
244
- self.class_images_path = [x for x in self.class_data_root.iterdir() if x.is_file()]
245
- self.num_class_images = len(self.class_images_path)
246
- self._length = max(self.num_class_images, self.num_instance_images)
247
- self.class_prompt = class_prompt
248
- else:
249
- self.class_data_root = None
250
-
251
- self.image_transforms = transforms.Compose(
252
- [
253
- transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
254
- transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
255
- transforms.ToTensor(),
256
- transforms.Normalize([0.5], [0.5]),
257
- ]
258
- )
259
-
260
- def __len__(self):
261
- return self._length
262
-
263
- def __getitem__(self, index):
264
- example = {}
265
- instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
266
- if not instance_image.mode == "RGB":
267
- instance_image = instance_image.convert("RGB")
268
- example["instance_images"] = self.image_transforms(instance_image)
269
- example["instance_prompt_ids"] = self.tokenizer(
270
- self.instance_prompt,
271
- padding="do_not_pad",
272
- truncation=True,
273
- max_length=self.tokenizer.model_max_length,
274
- ).input_ids
275
-
276
- if self.class_data_root:
277
- class_image = Image.open(self.class_images_path[index % self.num_class_images])
278
- if not class_image.mode == "RGB":
279
- class_image = class_image.convert("RGB")
280
- example["class_images"] = self.image_transforms(class_image)
281
- example["class_prompt_ids"] = self.tokenizer(
282
- self.class_prompt,
283
- padding="do_not_pad",
284
- truncation=True,
285
- max_length=self.tokenizer.model_max_length,
286
- ).input_ids
287
-
288
- return example
289
-
290
-
291
- class PromptDataset(Dataset):
292
- "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
293
-
294
- def __init__(self, prompt, num_samples):
295
- self.prompt = prompt
296
- self.num_samples = num_samples
297
-
298
- def __len__(self):
299
- return self.num_samples
300
-
301
- def __getitem__(self, index):
302
- example = {}
303
- example["prompt"] = self.prompt
304
- example["index"] = index
305
- return example
306
-
307
-
308
- class LatentsDataset(Dataset):
309
- def __init__(self, latents_cache, text_encoder_cache):
310
- self.latents_cache = latents_cache
311
- self.text_encoder_cache = text_encoder_cache
312
-
313
- def __len__(self):
314
- return len(self.latents_cache)
315
-
316
- def __getitem__(self, index):
317
- return self.latents_cache[index], self.text_encoder_cache[index]
318
-
319
-
320
- class AverageMeter:
321
- def __init__(self, name=None):
322
- self.name = name
323
- self.reset()
324
-
325
- def reset(self):
326
- self.sum = self.count = self.avg = 0
327
-
328
- def update(self, val, n=1):
329
- self.sum += val * n
330
- self.count += n
331
- self.avg = self.sum / self.count
332
-
333
-
334
- def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
335
- if token is None:
336
- token = HfFolder.get_token()
337
- if organization is None:
338
- username = whoami(token)["name"]
339
- return f"{username}/{model_id}"
340
- else:
341
- return f"{organization}/{model_id}"
342
-
343
-
344
- def main():
345
- args = parse_args()
346
- logging_dir = Path(args.output_dir, args.logging_dir)
347
-
348
- accelerator = Accelerator(
349
- gradient_accumulation_steps=args.gradient_accumulation_steps,
350
- mixed_precision=args.mixed_precision,
351
- log_with="tensorboard",
352
- logging_dir=logging_dir,
353
- )
354
-
355
- # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
356
- # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
357
- # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
358
- if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
359
- raise ValueError(
360
- "Gradient accumulation is not supported when training the text encoder in distributed training. "
361
- "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
362
- )
363
-
364
- if args.seed is not None:
365
- set_seed(args.seed)
366
-
367
- if args.with_prior_preservation:
368
- class_images_dir = Path(args.class_data_dir)
369
- if not class_images_dir.exists():
370
- class_images_dir.mkdir(parents=True)
371
- cur_class_images = len(list(class_images_dir.iterdir()))
372
-
373
- if cur_class_images < args.num_class_images:
374
- torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
375
- pipeline = StableDiffusionPipeline.from_pretrained(
376
- args.pretrained_model_name_or_path, torch_dtype=torch_dtype, use_auth_token=False
377
- )
378
- pipeline.set_progress_bar_config(disable=True)
379
-
380
- num_new_images = args.num_class_images - cur_class_images
381
- logger.info(f"Number of class images to sample: {num_new_images}.")
382
-
383
- sample_dataset = PromptDataset(args.class_prompt, num_new_images)
384
- sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
385
-
386
- sample_dataloader = accelerator.prepare(sample_dataloader)
387
- pipeline.to(accelerator.device)
388
-
389
- with torch.autocast("cuda"), torch.inference_mode():
390
- for example in tqdm(
391
- sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
392
- ):
393
- images = pipeline(example["prompt"]).images
394
-
395
- for i, image in enumerate(images):
396
- image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg")
397
-
398
- del pipeline
399
- if torch.cuda.is_available():
400
- torch.cuda.empty_cache()
401
-
402
- # Handle the repository creation
403
- if accelerator.is_main_process:
404
- if args.push_to_hub:
405
- if args.hub_model_id is None:
406
- repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
407
- else:
408
- repo_name = args.hub_model_id
409
- repo = Repository(args.output_dir, clone_from=repo_name)
410
-
411
- with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
412
- if "step_*" not in gitignore:
413
- gitignore.write("step_*\n")
414
- if "epoch_*" not in gitignore:
415
- gitignore.write("epoch_*\n")
416
- elif args.output_dir is not None:
417
- os.makedirs(args.output_dir, exist_ok=True)
418
-
419
- # Load the tokenizer
420
- if args.tokenizer_name:
421
- tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
422
- elif args.pretrained_model_name_or_path:
423
- tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer", use_auth_token=False)
424
-
425
- # Load models and create wrapper for stable diffusion
426
- text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder", use_auth_token=False)
427
- vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", use_auth_token=False)
428
- unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet", use_auth_token=False)
429
-
430
- vae.requires_grad_(False)
431
- if not args.train_text_encoder:
432
- text_encoder.requires_grad_(False)
433
-
434
- if args.gradient_checkpointing:
435
- unet.enable_gradient_checkpointing()
436
- if args.train_text_encoder:
437
- text_encoder.gradient_checkpointing_enable()
438
-
439
- if args.scale_lr:
440
- args.learning_rate = (
441
- args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
442
- )
443
-
444
- # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
445
- if args.use_8bit_adam:
446
- try:
447
- import bitsandbytes as bnb
448
- except ImportError:
449
- raise ImportError(
450
- "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
451
- )
452
-
453
- optimizer_class = bnb.optim.AdamW8bit
454
- else:
455
- optimizer_class = torch.optim.AdamW
456
-
457
- params_to_optimize = (
458
- itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
459
- )
460
- optimizer = optimizer_class(
461
- params_to_optimize,
462
- lr=args.learning_rate,
463
- betas=(args.adam_beta1, args.adam_beta2),
464
- weight_decay=args.adam_weight_decay,
465
- eps=args.adam_epsilon,
466
- )
467
-
468
- noise_scheduler = DDPMScheduler(
469
- beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
470
- )
471
-
472
- train_dataset = DreamBoothDataset(
473
- instance_data_root=args.instance_data_dir,
474
- instance_prompt=args.instance_prompt,
475
- class_data_root=args.class_data_dir if args.with_prior_preservation else None,
476
- class_prompt=args.class_prompt,
477
- tokenizer=tokenizer,
478
- size=args.resolution,
479
- center_crop=args.center_crop,
480
- )
481
-
482
- def collate_fn(examples):
483
- input_ids = [example["instance_prompt_ids"] for example in examples]
484
- pixel_values = [example["instance_images"] for example in examples]
485
-
486
- # Concat class and instance examples for prior preservation.
487
- # We do this to avoid doing two forward passes.
488
- if args.with_prior_preservation:
489
- input_ids += [example["class_prompt_ids"] for example in examples]
490
- pixel_values += [example["class_images"] for example in examples]
491
-
492
- pixel_values = torch.stack(pixel_values)
493
- pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
494
-
495
- input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
496
-
497
- batch = {
498
- "input_ids": input_ids,
499
- "pixel_values": pixel_values,
500
- }
501
- return batch
502
-
503
- train_dataloader = torch.utils.data.DataLoader(
504
- train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, pin_memory=True
505
- )
506
-
507
- weight_dtype = torch.float32
508
- if args.mixed_precision == "fp16":
509
- weight_dtype = torch.float16
510
- elif args.mixed_precision == "bf16":
511
- weight_dtype = torch.bfloat16
512
-
513
- # Move text_encode and vae to gpu.
514
- # For mixed precision training we cast the text_encoder and vae weights to half-precision
515
- # as these models are only used for inference, keeping weights in full precision is not required.
516
- vae.to(accelerator.device, dtype=weight_dtype)
517
- if not args.train_text_encoder:
518
- text_encoder.to(accelerator.device, dtype=weight_dtype)
519
-
520
- if not args.not_cache_latents:
521
- latents_cache = []
522
- text_encoder_cache = []
523
- for batch in tqdm(train_dataloader, desc="Caching latents"):
524
- with torch.no_grad():
525
- batch["pixel_values"] = batch["pixel_values"].to(accelerator.device, non_blocking=True, dtype=weight_dtype)
526
- batch["input_ids"] = batch["input_ids"].to(accelerator.device, non_blocking=True)
527
- latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
528
- if args.train_text_encoder:
529
- text_encoder_cache.append(batch["input_ids"])
530
- else:
531
- text_encoder_cache.append(text_encoder(batch["input_ids"])[0])
532
- train_dataset = LatentsDataset(latents_cache, text_encoder_cache)
533
- train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, collate_fn=lambda x: x, shuffle=True)
534
-
535
- del vae
536
- if not args.train_text_encoder:
537
- del text_encoder
538
- if torch.cuda.is_available():
539
- torch.cuda.empty_cache()
540
-
541
- # Scheduler and math around the number of training steps.
542
- overrode_max_train_steps = False
543
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
544
- if args.max_train_steps is None:
545
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
546
- overrode_max_train_steps = True
547
-
548
- lr_scheduler = get_scheduler(
549
- args.lr_scheduler,
550
- optimizer=optimizer,
551
- num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
552
- num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
553
- )
554
-
555
- if args.train_text_encoder:
556
- unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
557
- unet, text_encoder, optimizer, train_dataloader, lr_scheduler
558
- )
559
- else:
560
- unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
561
- unet, optimizer, train_dataloader, lr_scheduler
562
- )
563
-
564
- # We need to recalculate our total training steps as the size of the training dataloader may have changed.
565
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
566
- if overrode_max_train_steps:
567
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
568
- # Afterwards we recalculate our number of training epochs
569
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
570
-
571
- # We need to initialize the trackers we use, and also store our configuration.
572
- # The trackers initializes automatically on the main process.
573
- if accelerator.is_main_process:
574
- accelerator.init_trackers("dreambooth", config=vars(args))
575
-
576
- # Train!
577
- total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
578
-
579
- logger.info("***** Running training *****")
580
- logger.info(f" Num examples = {len(train_dataset)}")
581
- logger.info(f" Num batches each epoch = {len(train_dataloader)}")
582
- logger.info(f" Num Epochs = {args.num_train_epochs}")
583
- logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
584
- logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
585
- logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
586
- logger.info(f" Total optimization steps = {args.max_train_steps}")
587
- # Only show the progress bar once on each machine.
588
- progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
589
- progress_bar.set_description("Steps")
590
- global_step = 0
591
- loss_avg = AverageMeter()
592
- text_enc_context = nullcontext() if args.train_text_encoder else torch.no_grad()
593
- for epoch in range(args.num_train_epochs):
594
- unet.train()
595
- for step, batch in enumerate(train_dataloader):
596
- with accelerator.accumulate(unet):
597
- # Convert images to latent space
598
- with torch.no_grad():
599
- if not args.not_cache_latents:
600
- latent_dist = batch[0][0]
601
- else:
602
- latent_dist = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist
603
- latents = latent_dist.sample() * 0.18215
604
-
605
- # Sample noise that we'll add to the latents
606
- noise = torch.randn_like(latents)
607
- bsz = latents.shape[0]
608
- # Sample a random timestep for each image
609
- timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
610
- timesteps = timesteps.long()
611
-
612
- # Add noise to the latents according to the noise magnitude at each timestep
613
- # (this is the forward diffusion process)
614
- noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
615
-
616
- # Get the text embedding for conditioning
617
- with text_enc_context:
618
- if not args.not_cache_latents:
619
- if args.train_text_encoder:
620
- encoder_hidden_states = text_encoder(batch[0][1])[0]
621
- else:
622
- encoder_hidden_states = batch[0][1]
623
- else:
624
- encoder_hidden_states = text_encoder(batch["input_ids"])[0]
625
-
626
- # Predict the noise residual
627
- noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
628
-
629
- if args.with_prior_preservation:
630
- # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
631
- noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
632
- noise, noise_prior = torch.chunk(noise, 2, dim=0)
633
-
634
- # Compute instance loss
635
- loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
636
-
637
- # Compute prior loss
638
- prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
639
-
640
- # Add the prior loss to the instance loss.
641
- loss = loss + args.prior_loss_weight * prior_loss
642
- else:
643
- loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
644
-
645
- accelerator.backward(loss)
646
- # if accelerator.sync_gradients:
647
- # params_to_clip = (
648
- # itertools.chain(unet.parameters(), text_encoder.parameters())
649
- # if args.train_text_encoder
650
- # else unet.parameters()
651
- # )
652
- # accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
653
- optimizer.step()
654
- lr_scheduler.step()
655
- optimizer.zero_grad(set_to_none=True)
656
- loss_avg.update(loss.detach_(), bsz)
657
-
658
- if not global_step % args.log_interval:
659
- logs = {"loss": loss_avg.avg.item(), "lr": lr_scheduler.get_last_lr()[0]}
660
- progress_bar.set_postfix(**logs)
661
- accelerator.log(logs, step=global_step)
662
-
663
- progress_bar.update(1)
664
- global_step += 1
665
-
666
- if global_step >= args.max_train_steps:
667
- break
668
-
669
- accelerator.wait_for_everyone()
670
-
671
- # Create the pipeline using using the trained modules and save it.
672
- if accelerator.is_main_process:
673
- if args.train_text_encoder:
674
- pipeline = StableDiffusionPipeline.from_pretrained(
675
- args.pretrained_model_name_or_path,
676
- unet=accelerator.unwrap_model(unet),
677
- text_encoder=accelerator.unwrap_model(text_encoder),
678
- use_auth_token=False
679
- )
680
- else:
681
- pipeline = StableDiffusionPipeline.from_pretrained(
682
- args.pretrained_model_name_or_path,
683
- unet=accelerator.unwrap_model(unet),
684
- use_auth_token=False
685
- )
686
- pipeline.save_pretrained(args.output_dir)
687
-
688
- if args.push_to_hub:
689
- repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
690
-
691
- accelerator.end_training()
692
-
693
-
694
- if __name__ == "__main__":
695
- main()