erfan-yahoo commited on
Commit
1781f20
1 Parent(s): 5a9feaf

Delete train_controlnet_inpaint.py

Browse files
Files changed (1) hide show
  1. train_controlnet_inpaint.py +0 -1244
train_controlnet_inpaint.py DELETED
@@ -1,1244 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding=utf-8
3
- # Copyright 2024, Yahoo Research
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
-
17
- import argparse
18
- import logging
19
- import math
20
- import os
21
- import random
22
- import shutil
23
- from pathlib import Path
24
- import cv2
25
- from PIL import Image, ImageOps
26
- import accelerate
27
- import numpy as np
28
- import torch
29
- import torch.nn.functional as F
30
- import torch.utils.checkpoint
31
- import transformers
32
- from accelerate import Accelerator
33
- from accelerate.logging import get_logger
34
- from accelerate.utils import ProjectConfiguration, set_seed
35
- from datasets import load_dataset
36
- from huggingface_hub import create_repo, upload_folder
37
- from packaging import version
38
- from PIL import Image
39
- from torchvision import transforms
40
- from tqdm.auto import tqdm
41
- from transformers import AutoTokenizer, PretrainedConfig
42
-
43
- import diffusers
44
- from diffusers import (
45
- AutoencoderKL,
46
- ControlNetModel,
47
- DDPMScheduler,
48
- StableDiffusionControlNetPipeline,
49
- UNet2DConditionModel,
50
- UniPCMultistepScheduler,
51
- )
52
- from diffusers.optimization import get_scheduler
53
- from diffusers.utils import check_min_version, is_wandb_available
54
- from diffusers.utils.import_utils import is_xformers_available
55
-
56
-
57
- if is_wandb_available():
58
- import wandb
59
-
60
- # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
61
- check_min_version("0.20.0.dev0")
62
-
63
- logger = get_logger(__name__)
64
-
65
-
66
- def image_grid(imgs, rows, cols):
67
- assert len(imgs) == rows * cols
68
-
69
- w, h = imgs[0].size
70
- grid = Image.new("RGB", size=(cols * w, rows * h))
71
-
72
- for i, img in enumerate(imgs):
73
- grid.paste(img, box=(i % cols * w, i // cols * h))
74
- return grid
75
-
76
-
77
- def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step):
78
- logger.info("Running validation... ")
79
-
80
- controlnet = accelerator.unwrap_model(controlnet)
81
-
82
- pipeline = StableDiffusionControlNetPipeline.from_pretrained(
83
- args.pretrained_model_name_or_path,
84
- vae=vae,
85
- text_encoder=text_encoder,
86
- tokenizer=tokenizer,
87
- unet=unet,
88
- controlnet=controlnet,
89
- safety_checker=None,
90
- revision=args.revision,
91
- torch_dtype=weight_dtype,
92
- )
93
- pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
94
- pipeline = pipeline.to(accelerator.device)
95
- pipeline.set_progress_bar_config(disable=True)
96
-
97
- if args.enable_xformers_memory_efficient_attention:
98
- pipeline.enable_xformers_memory_efficient_attention()
99
-
100
- if args.seed is None:
101
- generator = None
102
- else:
103
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
104
-
105
- if len(args.validation_image) == len(args.validation_prompt):
106
- validation_images = args.validation_image
107
- validation_inpainting_images = args.validation_inpainting_image
108
- validation_prompts = args.validation_prompt
109
- elif len(args.validation_image) == 1:
110
- validation_images = args.validation_image * len(args.validation_prompt)
111
- validation_inpainting_images = args.validation_inpainting_image * len(args.validation_prompt)
112
- validation_prompts = args.validation_prompt
113
- elif len(args.validation_prompt) == 1:
114
- validation_images = args.validation_image
115
- validation_inpainting_images = args.validation_inpainting_image
116
- validation_prompts = args.validation_prompt * len(args.validation_image)
117
- else:
118
- raise ValueError(
119
- "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
120
- )
121
-
122
- image_logs = []
123
-
124
- for validation_prompt, validation_image, validation_inpainting_image in zip(validation_prompts, validation_images, validation_inpainting_images):
125
- mask = Image.open(validation_image)
126
- mask = resize_with_padding(mask, (512,512))
127
-
128
- inpainting_image = Image.open(validation_inpainting_image).convert("RGB")
129
- inpainting_image = resize_with_padding(inpainting_image, (512,512))
130
-
131
- validation_image = Image.composite(inpainting_image, mask, mask.convert('L')).convert('RGB')
132
- images = []
133
- for _ in range(args.num_validation_images):
134
- with torch.autocast("cuda"):
135
- image = pipeline(
136
- validation_prompt, validation_image, num_inference_steps=20, generator=generator
137
- ).images[0]
138
- images.append(image)
139
-
140
- image_logs.append(
141
- {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
142
- )
143
-
144
- for tracker in accelerator.trackers:
145
- if tracker.name == "tensorboard":
146
- for log in image_logs:
147
- images = log["images"]
148
- validation_prompt = log["validation_prompt"]
149
- validation_image = log["validation_image"]
150
-
151
- formatted_images = []
152
-
153
- formatted_images.append(np.asarray(validation_image))
154
-
155
- for image in images:
156
- formatted_images.append(np.asarray(image))
157
- formatted_images = np.stack(formatted_images)
158
-
159
- tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
160
- elif tracker.name == "wandb":
161
- formatted_images = []
162
-
163
- for log in image_logs:
164
- images = log["images"]
165
- validation_prompt = log["validation_prompt"]
166
- validation_image = log["validation_image"]
167
-
168
- formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
169
-
170
- for image in images:
171
- image = wandb.Image(image, caption=validation_prompt)
172
- formatted_images.append(image)
173
-
174
- tracker.log({"validation": formatted_images})
175
- else:
176
- logger.warn(f"image logging not implemented for {tracker.name}")
177
-
178
- return image_logs
179
-
180
-
181
- def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
182
- text_encoder_config = PretrainedConfig.from_pretrained(
183
- pretrained_model_name_or_path,
184
- subfolder="text_encoder",
185
- revision=revision,
186
- )
187
- model_class = text_encoder_config.architectures[0]
188
-
189
- if model_class == "CLIPTextModel":
190
- from transformers import CLIPTextModel
191
-
192
- return CLIPTextModel
193
- elif model_class == "RobertaSeriesModelWithTransformation":
194
- from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
195
-
196
- return RobertaSeriesModelWithTransformation
197
- else:
198
- raise ValueError(f"{model_class} is not supported.")
199
-
200
-
201
- def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
202
- img_str = ""
203
- if image_logs is not None:
204
- img_str = "You can find some example images below.\n"
205
- for i, log in enumerate(image_logs):
206
- images = log["images"]
207
- validation_prompt = log["validation_prompt"]
208
- validation_image = log["validation_image"]
209
- validation_image.save(os.path.join(repo_folder, "image_control.png"))
210
- img_str += f"prompt: {validation_prompt}\n"
211
- images = [validation_image] + images
212
- image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
213
- img_str += f"![images_{i})](./images_{i}.png)\n"
214
-
215
- yaml = f"""
216
- ---
217
- license: creativeml-openrail-m
218
- base_model: {base_model}
219
- tags:
220
- - stable-diffusion
221
- - stable-diffusion-diffusers
222
- - text-to-image
223
- - diffusers
224
- - controlnet
225
- inference: true
226
- ---
227
- """
228
- model_card = f"""
229
- # controlnet-{repo_id}
230
-
231
- These are controlnet weights trained on {base_model} with new type of conditioning.
232
- {img_str}
233
- """
234
- with open(os.path.join(repo_folder, "README.md"), "w") as f:
235
- f.write(yaml + model_card)
236
-
237
-
238
- def parse_args(input_args=None):
239
- parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
240
- parser.add_argument(
241
- "--pretrained_model_name_or_path",
242
- type=str,
243
- default=None,
244
- required=True,
245
- help="Path to pretrained model or model identifier from huggingface.co/models.",
246
- )
247
- parser.add_argument(
248
- "--controlnet_model_name_or_path",
249
- type=str,
250
- default=None,
251
- help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
252
- " If not specified controlnet weights are initialized from unet.",
253
- )
254
- parser.add_argument(
255
- "--revision",
256
- type=str,
257
- default=None,
258
- required=False,
259
- help=(
260
- "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
261
- " float32 precision."
262
- ),
263
- )
264
- parser.add_argument(
265
- "--tokenizer_name",
266
- type=str,
267
- default=None,
268
- help="Pretrained tokenizer name or path if not the same as model_name",
269
- )
270
- parser.add_argument(
271
- "--output_dir",
272
- type=str,
273
- default="controlnet-model",
274
- help="The output directory where the model predictions and checkpoints will be written.",
275
- )
276
- parser.add_argument(
277
- "--cache_dir",
278
- type=str,
279
- default=None,
280
- help="The directory where the downloaded models and datasets will be stored.",
281
- )
282
- parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
283
- parser.add_argument(
284
- "--resolution",
285
- type=int,
286
- default=512,
287
- help=(
288
- "The resolution for input images, all the images in the train/validation dataset will be resized to this"
289
- " resolution"
290
- ),
291
- )
292
- parser.add_argument(
293
- "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
294
- )
295
- parser.add_argument("--num_train_epochs", type=int, default=1)
296
- parser.add_argument(
297
- "--max_train_steps",
298
- type=int,
299
- default=None,
300
- help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
301
- )
302
- parser.add_argument(
303
- "--checkpointing_steps",
304
- type=int,
305
- default=500,
306
- help=(
307
- "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
308
- "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
309
- "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
310
- "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
311
- "instructions."
312
- ),
313
- )
314
- parser.add_argument(
315
- "--checkpoints_total_limit",
316
- type=int,
317
- default=None,
318
- help=("Max number of checkpoints to store."),
319
- )
320
- parser.add_argument(
321
- "--resume_from_checkpoint",
322
- type=str,
323
- default=None,
324
- help=(
325
- "Whether training should be resumed from a previous checkpoint. Use a path saved by"
326
- ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
327
- ),
328
- )
329
- parser.add_argument(
330
- "--gradient_accumulation_steps",
331
- type=int,
332
- default=1,
333
- help="Number of updates steps to accumulate before performing a backward/update pass.",
334
- )
335
- parser.add_argument(
336
- "--gradient_checkpointing",
337
- action="store_true",
338
- help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
339
- )
340
- parser.add_argument(
341
- "--learning_rate",
342
- type=float,
343
- default=5e-6,
344
- help="Initial learning rate (after the potential warmup period) to use.",
345
- )
346
- parser.add_argument(
347
- "--scale_lr",
348
- action="store_true",
349
- default=False,
350
- help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
351
- )
352
- parser.add_argument(
353
- "--lr_scheduler",
354
- type=str,
355
- default="constant",
356
- help=(
357
- 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
358
- ' "constant", "constant_with_warmup"]'
359
- ),
360
- )
361
- parser.add_argument(
362
- "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
363
- )
364
- parser.add_argument(
365
- "--lr_num_cycles",
366
- type=int,
367
- default=1,
368
- help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
369
- )
370
- parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
371
- parser.add_argument(
372
- "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
373
- )
374
- parser.add_argument(
375
- "--dataloader_num_workers",
376
- type=int,
377
- default=0,
378
- help=(
379
- "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
380
- ),
381
- )
382
- parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
383
- parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
384
- parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
385
- parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
386
- parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
387
- parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
388
- parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
389
- parser.add_argument(
390
- "--hub_model_id",
391
- type=str,
392
- default=None,
393
- help="The name of the repository to keep in sync with the local `output_dir`.",
394
- )
395
- parser.add_argument(
396
- "--logging_dir",
397
- type=str,
398
- default="logs",
399
- help=(
400
- "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
401
- " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
402
- ),
403
- )
404
- parser.add_argument(
405
- "--allow_tf32",
406
- action="store_true",
407
- help=(
408
- "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
409
- " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
410
- ),
411
- )
412
- parser.add_argument(
413
- "--report_to",
414
- type=str,
415
- default="tensorboard",
416
- help=(
417
- 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
418
- ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
419
- ),
420
- )
421
- parser.add_argument(
422
- "--mixed_precision",
423
- type=str,
424
- default=None,
425
- choices=["no", "fp16", "bf16"],
426
- help=(
427
- "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
428
- " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
429
- " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
430
- ),
431
- )
432
- parser.add_argument(
433
- "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
434
- )
435
- parser.add_argument(
436
- "--set_grads_to_none",
437
- action="store_true",
438
- help=(
439
- "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
440
- " behaviors, so disable this argument if it causes any problems. More info:"
441
- " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
442
- ),
443
- )
444
- parser.add_argument(
445
- "--dataset_name",
446
- type=str,
447
- default=None,
448
- help=(
449
- "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
450
- " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
451
- " or to a folder containing files that 🤗 Datasets can understand."
452
- ),
453
- )
454
- parser.add_argument(
455
- "--dataset_config_name",
456
- type=str,
457
- default=None,
458
- help="The config of the Dataset, leave as None if there's only one config.",
459
- )
460
- parser.add_argument(
461
- "--train_data_dir",
462
- type=str,
463
- default=None,
464
- help=(
465
- "A folder containing the training data. Folder contents must follow the structure described in"
466
- " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
467
- " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
468
- ),
469
- )
470
- parser.add_argument(
471
- "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
472
- )
473
- parser.add_argument(
474
- "--conditioning_image_column",
475
- type=str,
476
- default="conditioning_image",
477
- help="The column of the dataset containing the controlnet conditioning image.",
478
- )
479
- parser.add_argument(
480
- "--caption_column",
481
- type=str,
482
- default="text",
483
- help="The column of the dataset containing a caption or a list of captions.",
484
- )
485
- parser.add_argument(
486
- "--max_train_samples",
487
- type=int,
488
- default=None,
489
- help=(
490
- "For debugging purposes or quicker training, truncate the number of training examples to this "
491
- "value if set."
492
- ),
493
- )
494
- parser.add_argument(
495
- "--proportion_empty_prompts",
496
- type=float,
497
- default=0,
498
- help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
499
- )
500
- parser.add_argument(
501
- "--validation_prompt",
502
- type=str,
503
- default=None,
504
- nargs="+",
505
- help=(
506
- "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
507
- " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
508
- " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
509
- ),
510
- )
511
- parser.add_argument(
512
- "--validation_image",
513
- type=str,
514
- default=None,
515
- nargs="+",
516
- help=(
517
- "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
518
- " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
519
- " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
520
- " `--validation_image` that will be used with all `--validation_prompt`s."
521
- ),
522
- )
523
- parser.add_argument(
524
- "--validation_inpainting_image",
525
- type=str,
526
- default=None,
527
- nargs="+",
528
- help=(
529
- "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
530
- " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
531
- " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
532
- " `--validation_image` that will be used with all `--validation_prompt`s."
533
- ),
534
- )
535
- parser.add_argument(
536
- "--num_validation_images",
537
- type=int,
538
- default=4,
539
- help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
540
- )
541
- parser.add_argument(
542
- "--validation_steps",
543
- type=int,
544
- default=100,
545
- help=(
546
- "Run validation every X steps. Validation consists of running the prompt"
547
- " `args.validation_prompt` multiple times: `args.num_validation_images`"
548
- " and logging the images."
549
- ),
550
- )
551
- parser.add_argument(
552
- "--tracker_project_name",
553
- type=str,
554
- default="train_controlnet",
555
- help=(
556
- "The `project_name` argument passed to Accelerator.init_trackers for"
557
- " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
558
- ),
559
- )
560
-
561
- if input_args is not None:
562
- args = parser.parse_args(input_args)
563
- else:
564
- args = parser.parse_args()
565
-
566
- if args.dataset_name is None and args.train_data_dir is None:
567
- raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
568
-
569
- if args.dataset_name is not None and args.train_data_dir is not None:
570
- raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
571
-
572
- if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
573
- raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
574
-
575
- if args.validation_prompt is not None and args.validation_image is None:
576
- raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
577
-
578
- if args.validation_prompt is None and args.validation_image is not None:
579
- raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
580
-
581
- if (
582
- args.validation_image is not None
583
- and args.validation_prompt is not None
584
- and len(args.validation_image) != 1
585
- and len(args.validation_prompt) != 1
586
- and len(args.validation_image) != len(args.validation_prompt)
587
- ):
588
- raise ValueError(
589
- "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
590
- " or the same number of `--validation_prompt`s and `--validation_image`s"
591
- )
592
-
593
- if args.resolution % 8 != 0:
594
- raise ValueError(
595
- "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
596
- )
597
-
598
- return args
599
-
600
-
601
- def make_train_dataset(args, tokenizer, accelerator):
602
- # Get the datasets: you can either provide your own training and evaluation files (see below)
603
- # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
604
-
605
- # In distributed training, the load_dataset function guarantees that only one local process can concurrently
606
- # download the dataset.
607
- if args.dataset_name is not None:
608
- # Downloading and loading a dataset from the hub.
609
- dataset = load_dataset(
610
- args.dataset_name,
611
- args.dataset_config_name,
612
- cache_dir=args.cache_dir,
613
- )
614
- else:
615
- if args.train_data_dir is not None:
616
- dataset = load_dataset(
617
- args.train_data_dir,
618
- cache_dir=args.cache_dir,
619
- )
620
- # See more about loading custom images at
621
- # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
622
-
623
- # Preprocessing the datasets.
624
- # We need to tokenize inputs and targets.
625
- column_names = dataset["train"].column_names
626
-
627
- # 6. Get the column names for input/target.
628
- if args.image_column is None:
629
- image_column = column_names[0]
630
- logger.info(f"image column defaulting to {image_column}")
631
- else:
632
- image_column = args.image_column
633
- if image_column not in column_names:
634
- raise ValueError(
635
- f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
636
- )
637
-
638
- if args.caption_column is None:
639
- caption_column = column_names[1]
640
- logger.info(f"caption column defaulting to {caption_column}")
641
- else:
642
- caption_column = args.caption_column
643
- if caption_column not in column_names:
644
- raise ValueError(
645
- f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
646
- )
647
-
648
- if args.conditioning_image_column is None:
649
- conditioning_image_column = column_names[2]
650
- logger.info(f"conditioning image column defaulting to {conditioning_image_column}")
651
- else:
652
- conditioning_image_column = args.conditioning_image_column
653
- if conditioning_image_column not in column_names:
654
- raise ValueError(
655
- f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
656
- )
657
-
658
- def tokenize_captions(examples, is_train=True):
659
- captions = []
660
- for caption in examples[caption_column]:
661
- if random.random() < args.proportion_empty_prompts:
662
- captions.append("")
663
- elif isinstance(caption, str):
664
- captions.append(caption)
665
- elif isinstance(caption, (list, np.ndarray)):
666
- # take a random caption if there are multiple
667
- captions.append(random.choice(caption) if is_train else caption[0])
668
- else:
669
- raise ValueError(
670
- f"Caption column `{caption_column}` should contain either strings or lists of strings."
671
- )
672
- inputs = tokenizer(
673
- captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
674
- )
675
- return inputs.input_ids
676
-
677
- image_transforms = transforms.Compose(
678
- [
679
- transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
680
- transforms.CenterCrop(args.resolution),
681
- transforms.ToTensor(),
682
- transforms.Normalize([0.5], [0.5]),
683
- ]
684
- )
685
-
686
- conditioning_image_transforms = transforms.Compose(
687
- [
688
- transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
689
- transforms.CenterCrop(args.resolution),
690
- transforms.ToTensor(),
691
- ]
692
- )
693
-
694
- def preprocess_train(examples):
695
- examples["pixel_values"] = examples[image_column] #images
696
- examples["conditioning_pixel_values"] = examples[conditioning_image_column] #conditioning_images
697
- examples["input_ids"] = tokenize_captions(examples)
698
-
699
- return examples
700
-
701
- with accelerator.main_process_first():
702
- if args.max_train_samples is not None:
703
- dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
704
- # Set the training transforms
705
- train_dataset = dataset["train"].with_transform(preprocess_train)
706
-
707
- return train_dataset
708
-
709
-
710
- def resize_with_padding(img, expected_size):
711
- img.thumbnail((expected_size[0], expected_size[1]))
712
- # print(img.size)
713
- delta_width = expected_size[0] - img.size[0]
714
- delta_height = expected_size[1] - img.size[1]
715
- pad_width = delta_width // 2
716
- pad_height = delta_height // 2
717
- padding = (pad_width, pad_height, delta_width - pad_width, delta_height - pad_height)
718
- return ImageOps.expand(img, padding)
719
-
720
- def prepare_mask_and_masked_image(image, mask):
721
- image = np.array(image.convert("RGB"))
722
- image = image[None].transpose(0, 3, 1, 2)
723
- image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
724
-
725
- mask = np.array(mask.convert("L"))
726
- mask = mask.astype(np.float32) / 255.0
727
- mask = mask[None, None]
728
- mask[mask < 0.5] = 0
729
- mask[mask >= 0.5] = 1
730
- #mask = torch.from_numpy(mask)
731
-
732
- masked_image = image * (mask < 0.5)
733
-
734
- return mask, masked_image
735
-
736
- def collate_fn(examples):
737
- pixel_values = [example["pixel_values"].convert("RGB") for example in examples]
738
- conditioning_images = [example["conditioning_pixel_values"].convert("RGB") for example in examples]
739
- masks = []
740
- masked_images = []
741
-
742
- # Resize and random crop images
743
- for i in range(len(pixel_values)):
744
- image = np.array(pixel_values[i])
745
- mask = np.array(conditioning_images[i])
746
- dim_min_ind = np.argmin(image.shape[0:2])
747
- dim = [0, 0]
748
-
749
- resize_len = 768.0
750
- ratio = resize_len / image.shape[0:2][dim_min_ind]
751
- dim[1-dim_min_ind] = int(resize_len)
752
- dim[dim_min_ind] = int(ratio * image.shape[0:2][1-dim_min_ind])
753
- dim = tuple(dim)
754
-
755
- # resize image
756
- image = cv2.resize(image, dim, interpolation = cv2.INTER_AREA)
757
- mask = cv2.resize(mask, dim, interpolation = cv2.INTER_AREA)
758
- max_x = image.shape[1] - 512
759
- max_y = image.shape[0] - 512
760
- x = np.random.randint(0, max_x)
761
- y = np.random.randint(0, max_y)
762
- image = image[y: y + 512, x: x + 512]
763
- mask = mask[y: y + 512, x: x + 512]
764
-
765
- # fix for bluish outputs
766
- r = np.copy(image[:,:,0])
767
- image[:,:,0] = image[:,:,2]
768
- image[:,:,2] = r
769
- image = Image.fromarray(image)
770
- b, g, r = image.split()
771
- image = Image.merge("RGB", (r, g, b))
772
- pixel_values[i] = image
773
- conditioning_images[i] = Image.composite(image, Image.fromarray(mask), Image.fromarray(mask).convert('L')).convert('RGB')
774
-
775
-
776
- image_transforms = transforms.Compose(
777
- [
778
- transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
779
- transforms.CenterCrop(args.resolution),
780
- transforms.ToTensor(),
781
- transforms.Normalize([0.5], [0.5]),
782
- ]
783
- )
784
-
785
- conditioning_image_transforms = transforms.Compose(
786
- [
787
- transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
788
- transforms.CenterCrop(args.resolution),
789
- transforms.ToTensor(),
790
- transforms.Normalize([0.5], [0.5])
791
- ]
792
- )
793
-
794
- pixel_values = [image_transforms(image) for image in pixel_values]
795
- pixel_values = torch.stack(pixel_values)
796
- pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
797
-
798
- conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]
799
- conditioning_pixel_values = torch.stack(conditioning_images)
800
- conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
801
-
802
- input_ids = torch.stack([example["input_ids"] for example in examples])
803
-
804
- # masks = torch.stack(masks)
805
- # masked_images = torch.stack(masked_images)
806
-
807
- return {
808
- "pixel_values": pixel_values,
809
- "conditioning_pixel_values": conditioning_pixel_values,
810
- "input_ids": input_ids,
811
- # "masks": masks, "masked_images": masked_images
812
- }
813
-
814
- # pixel_values = torch.stack([example["pixel_values"] for example in examples])
815
- # pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
816
-
817
- # conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
818
- # conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
819
-
820
- # input_ids = torch.stack([example["input_ids"] for example in examples])
821
-
822
- # return {
823
- # "pixel_values": pixel_values,
824
- # "conditioning_pixel_values": conditioning_pixel_values,
825
- # "input_ids": input_ids,
826
- # }
827
-
828
-
829
- def main(args):
830
- logging_dir = Path(args.output_dir, args.logging_dir)
831
-
832
- accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
833
-
834
- accelerator = Accelerator(
835
- gradient_accumulation_steps=args.gradient_accumulation_steps,
836
- mixed_precision=args.mixed_precision,
837
- log_with=args.report_to,
838
- project_config=accelerator_project_config,
839
- )
840
-
841
- # Make one log on every process with the configuration for debugging.
842
- logging.basicConfig(
843
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
844
- datefmt="%m/%d/%Y %H:%M:%S",
845
- level=logging.INFO,
846
- )
847
- logger.info(accelerator.state, main_process_only=False)
848
- if accelerator.is_local_main_process:
849
- transformers.utils.logging.set_verbosity_warning()
850
- diffusers.utils.logging.set_verbosity_info()
851
- else:
852
- transformers.utils.logging.set_verbosity_error()
853
- diffusers.utils.logging.set_verbosity_error()
854
-
855
- # If passed along, set the training seed now.
856
- if args.seed is not None:
857
- set_seed(args.seed)
858
-
859
- # Handle the repository creation
860
- if accelerator.is_main_process:
861
- if args.output_dir is not None:
862
- os.makedirs(args.output_dir, exist_ok=True)
863
-
864
- if args.push_to_hub:
865
- repo_id = create_repo(
866
- repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
867
- ).repo_id
868
-
869
- # Load the tokenizer
870
- if args.tokenizer_name:
871
- tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
872
- elif args.pretrained_model_name_or_path:
873
- tokenizer = AutoTokenizer.from_pretrained(
874
- args.pretrained_model_name_or_path,
875
- subfolder="tokenizer",
876
- revision=args.revision,
877
- use_fast=False,
878
- )
879
-
880
- # import correct text encoder class
881
- text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
882
-
883
- # Load scheduler and models
884
- noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
885
- text_encoder = text_encoder_cls.from_pretrained(
886
- args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
887
- )
888
- vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
889
- unet = UNet2DConditionModel.from_pretrained(
890
- args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
891
- )
892
-
893
- if args.controlnet_model_name_or_path:
894
- logger.info("Loading existing controlnet weights")
895
- controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
896
- else:
897
- logger.info("Initializing controlnet weights from unet")
898
- controlnet = ControlNetModel.from_unet(unet)
899
-
900
- # `accelerate` 0.16.0 will have better support for customized saving
901
- if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
902
- # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
903
- def save_model_hook(models, weights, output_dir):
904
- i = len(weights) - 1
905
-
906
- while len(weights) > 0:
907
- weights.pop()
908
- model = models[i]
909
-
910
- sub_dir = "controlnet"
911
- model.save_pretrained(os.path.join(output_dir, sub_dir))
912
-
913
- i -= 1
914
-
915
- def load_model_hook(models, input_dir):
916
- while len(models) > 0:
917
- # pop models so that they are not loaded again
918
- model = models.pop()
919
-
920
- # load diffusers style into model
921
- load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet")
922
- model.register_to_config(**load_model.config)
923
-
924
- model.load_state_dict(load_model.state_dict())
925
- del load_model
926
-
927
- accelerator.register_save_state_pre_hook(save_model_hook)
928
- accelerator.register_load_state_pre_hook(load_model_hook)
929
-
930
- vae.requires_grad_(False)
931
- unet.requires_grad_(False)
932
- text_encoder.requires_grad_(False)
933
- controlnet.train()
934
-
935
- if args.enable_xformers_memory_efficient_attention:
936
- if is_xformers_available():
937
- import xformers
938
-
939
- xformers_version = version.parse(xformers.__version__)
940
- if xformers_version == version.parse("0.0.16"):
941
- logger.warn(
942
- "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
943
- )
944
- unet.enable_xformers_memory_efficient_attention()
945
- controlnet.enable_xformers_memory_efficient_attention()
946
- else:
947
- raise ValueError("xformers is not available. Make sure it is installed correctly")
948
-
949
- if args.gradient_checkpointing:
950
- controlnet.enable_gradient_checkpointing()
951
-
952
- # Check that all trainable models are in full precision
953
- low_precision_error_string = (
954
- " Please make sure to always have all model weights in full float32 precision when starting training - even if"
955
- " doing mixed precision training, copy of the weights should still be float32."
956
- )
957
-
958
- if accelerator.unwrap_model(controlnet).dtype != torch.float32:
959
- raise ValueError(
960
- f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}"
961
- )
962
-
963
- # Enable TF32 for faster training on Ampere GPUs,
964
- # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
965
- if args.allow_tf32:
966
- torch.backends.cuda.matmul.allow_tf32 = True
967
-
968
- if args.scale_lr:
969
- args.learning_rate = (
970
- args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
971
- )
972
-
973
- # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
974
- if args.use_8bit_adam:
975
- try:
976
- import bitsandbytes as bnb
977
- except ImportError:
978
- raise ImportError(
979
- "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
980
- )
981
-
982
- optimizer_class = bnb.optim.AdamW8bit
983
- else:
984
- optimizer_class = torch.optim.AdamW
985
-
986
- # Optimizer creation
987
- params_to_optimize = controlnet.parameters()
988
- optimizer = optimizer_class(
989
- params_to_optimize,
990
- lr=args.learning_rate,
991
- betas=(args.adam_beta1, args.adam_beta2),
992
- weight_decay=args.adam_weight_decay,
993
- eps=args.adam_epsilon,
994
- )
995
-
996
- train_dataset = make_train_dataset(args, tokenizer, accelerator)
997
-
998
- train_dataloader = torch.utils.data.DataLoader(
999
- train_dataset,
1000
- shuffle=True,
1001
- collate_fn=collate_fn,
1002
- batch_size=args.train_batch_size,
1003
- num_workers=args.dataloader_num_workers,
1004
- )
1005
-
1006
- # Scheduler and math around the number of training steps.
1007
- overrode_max_train_steps = False
1008
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1009
- if args.max_train_steps is None:
1010
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1011
- overrode_max_train_steps = True
1012
-
1013
- lr_scheduler = get_scheduler(
1014
- args.lr_scheduler,
1015
- optimizer=optimizer,
1016
- num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1017
- num_training_steps=args.max_train_steps * accelerator.num_processes,
1018
- num_cycles=args.lr_num_cycles,
1019
- power=args.lr_power,
1020
- )
1021
-
1022
- # Prepare everything with our `accelerator`.
1023
- controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1024
- controlnet, optimizer, train_dataloader, lr_scheduler
1025
- )
1026
-
1027
- # For mixed precision training we cast the text_encoder and vae weights to half-precision
1028
- # as these models are only used for inference, keeping weights in full precision is not required.
1029
- weight_dtype = torch.float32
1030
- if accelerator.mixed_precision == "fp16":
1031
- weight_dtype = torch.float16
1032
- elif accelerator.mixed_precision == "bf16":
1033
- weight_dtype = torch.bfloat16
1034
-
1035
- # Move vae, unet and text_encoder to device and cast to weight_dtype
1036
- vae.to(accelerator.device, dtype=weight_dtype)
1037
- unet.to(accelerator.device, dtype=weight_dtype)
1038
- text_encoder.to(accelerator.device, dtype=weight_dtype)
1039
-
1040
- # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1041
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1042
- if overrode_max_train_steps:
1043
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1044
- # Afterwards we recalculate our number of training epochs
1045
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1046
-
1047
- # We need to initialize the trackers we use, and also store our configuration.
1048
- # The trackers initializes automatically on the main process.
1049
- if accelerator.is_main_process:
1050
- tracker_config = dict(vars(args))
1051
-
1052
- # tensorboard cannot handle list types for config
1053
- tracker_config.pop("validation_prompt")
1054
- tracker_config.pop("validation_image")
1055
- tracker_config.pop("validation_inpainting_image")
1056
-
1057
- accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
1058
-
1059
- # Train!
1060
- total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1061
-
1062
- logger.info("***** Running training *****")
1063
- logger.info(f" Num examples = {len(train_dataset)}")
1064
- logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1065
- logger.info(f" Num Epochs = {args.num_train_epochs}")
1066
- logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1067
- logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1068
- logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1069
- logger.info(f" Total optimization steps = {args.max_train_steps}")
1070
- global_step = 0
1071
- first_epoch = 0
1072
-
1073
- # Potentially load in the weights and states from a previous save
1074
- if args.resume_from_checkpoint:
1075
- if args.resume_from_checkpoint != "latest":
1076
- path = os.path.basename(args.resume_from_checkpoint)
1077
- else:
1078
- # Get the most recent checkpoint
1079
- dirs = os.listdir(args.output_dir)
1080
- dirs = [d for d in dirs if d.startswith("checkpoint")]
1081
- dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1082
- path = dirs[-1] if len(dirs) > 0 else None
1083
-
1084
- if path is None:
1085
- accelerator.print(
1086
- f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1087
- )
1088
- args.resume_from_checkpoint = None
1089
- initial_global_step = 0
1090
- else:
1091
- accelerator.print(f"Resuming from checkpoint {path}")
1092
- accelerator.load_state(os.path.join(args.output_dir, path))
1093
- global_step = int(path.split("-")[1])
1094
-
1095
- initial_global_step = global_step
1096
- first_epoch = global_step // num_update_steps_per_epoch
1097
- else:
1098
- initial_global_step = 0
1099
-
1100
- progress_bar = tqdm(
1101
- range(0, args.max_train_steps),
1102
- initial=initial_global_step,
1103
- desc="Steps",
1104
- # Only show the progress bar once on each machine.
1105
- disable=not accelerator.is_local_main_process,
1106
- )
1107
-
1108
- image_logs = None
1109
- for epoch in range(first_epoch, args.num_train_epochs):
1110
- for step, batch in enumerate(train_dataloader):
1111
- with accelerator.accumulate(controlnet):
1112
- # Convert images to latent space
1113
- latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
1114
- latents = latents * vae.config.scaling_factor
1115
-
1116
- # Sample noise that we'll add to the latents
1117
- noise = torch.randn_like(latents)
1118
- bsz = latents.shape[0]
1119
- # Sample a random timestep for each image
1120
- timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
1121
- timesteps = timesteps.long()
1122
-
1123
- # Add noise to the latents according to the noise magnitude at each timestep
1124
- # (this is the forward diffusion process)
1125
- noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
1126
-
1127
- # Get the text embedding for conditioning
1128
- encoder_hidden_states = text_encoder(batch["input_ids"])[0]
1129
-
1130
- controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
1131
-
1132
- down_block_res_samples, mid_block_res_sample = controlnet(
1133
- noisy_latents,
1134
- timesteps,
1135
- encoder_hidden_states=encoder_hidden_states,
1136
- controlnet_cond=controlnet_image,
1137
- return_dict=False,
1138
- )
1139
-
1140
- # Predict the noise residual
1141
- model_pred = unet(
1142
- noisy_latents,
1143
- timesteps,
1144
- encoder_hidden_states=encoder_hidden_states,
1145
- down_block_additional_residuals=[
1146
- sample.to(dtype=weight_dtype) for sample in down_block_res_samples
1147
- ],
1148
- mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
1149
- ).sample
1150
-
1151
- # Get the target for loss depending on the prediction type
1152
- if noise_scheduler.config.prediction_type == "epsilon":
1153
- target = noise
1154
- elif noise_scheduler.config.prediction_type == "v_prediction":
1155
- target = noise_scheduler.get_velocity(latents, noise, timesteps)
1156
- else:
1157
- raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1158
- loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1159
-
1160
- accelerator.backward(loss)
1161
- if accelerator.sync_gradients:
1162
- params_to_clip = controlnet.parameters()
1163
- accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1164
- optimizer.step()
1165
- lr_scheduler.step()
1166
- optimizer.zero_grad(set_to_none=args.set_grads_to_none)
1167
-
1168
- # Checks if the accelerator has performed an optimization step behind the scenes
1169
- if accelerator.sync_gradients:
1170
- progress_bar.update(1)
1171
- global_step += 1
1172
-
1173
- if accelerator.is_main_process:
1174
- if global_step % args.checkpointing_steps == 0:
1175
- # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1176
- if args.checkpoints_total_limit is not None:
1177
- checkpoints = os.listdir(args.output_dir)
1178
- checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1179
- checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1180
-
1181
- # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1182
- if len(checkpoints) >= args.checkpoints_total_limit:
1183
- num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1184
- removing_checkpoints = checkpoints[0:num_to_remove]
1185
-
1186
- logger.info(
1187
- f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1188
- )
1189
- logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1190
-
1191
- for removing_checkpoint in removing_checkpoints:
1192
- removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1193
- shutil.rmtree(removing_checkpoint)
1194
-
1195
- save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1196
- accelerator.save_state(save_path)
1197
- logger.info(f"Saved state to {save_path}")
1198
-
1199
- if args.validation_prompt is not None and global_step % args.validation_steps == 0:
1200
- image_logs = log_validation(
1201
- vae,
1202
- text_encoder,
1203
- tokenizer,
1204
- unet,
1205
- controlnet,
1206
- args,
1207
- accelerator,
1208
- weight_dtype,
1209
- global_step,
1210
- )
1211
-
1212
- logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1213
- progress_bar.set_postfix(**logs)
1214
- accelerator.log(logs, step=global_step)
1215
-
1216
- if global_step >= args.max_train_steps:
1217
- break
1218
-
1219
- # Create the pipeline using using the trained modules and save it.
1220
- accelerator.wait_for_everyone()
1221
- if accelerator.is_main_process:
1222
- controlnet = accelerator.unwrap_model(controlnet)
1223
- controlnet.save_pretrained(args.output_dir)
1224
-
1225
- if args.push_to_hub:
1226
- save_model_card(
1227
- repo_id,
1228
- image_logs=image_logs,
1229
- base_model=args.pretrained_model_name_or_path,
1230
- repo_folder=args.output_dir,
1231
- )
1232
- upload_folder(
1233
- repo_id=repo_id,
1234
- folder_path=args.output_dir,
1235
- commit_message="End of training",
1236
- ignore_patterns=["step_*", "epoch_*"],
1237
- )
1238
-
1239
- accelerator.end_training()
1240
-
1241
-
1242
- if __name__ == "__main__":
1243
- args = parse_args()
1244
- main(args)