jean1.yu commited on
Commit
3ab117a
1 Parent(s): 5b6457f

commit from jxuhf

Browse files
default_config.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ deepspeed_config: {}
3
+ distributed_type: 'NO'
4
+ downcast_bf16: 'no'
5
+ fsdp_config: {}
6
+ machine_rank: 0
7
+ main_process_ip: null
8
+ main_process_port: null
9
+ main_training_function: main
10
+ mixed_precision: fp16
11
+ num_machines: 1
12
+ num_processes: 1
13
+ use_cpu: false
inference.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import StableDiffusionPipeline, DDIMScheduler
2
+ import torch
3
+
4
+ device = "cuda"
5
+ # use DDIM scheduler, you can modify it to use other scheduler
6
+ scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=True)
7
+
8
+ # modify the model path
9
+ pipe = StableDiffusionPipeline.from_pretrained(
10
+ f"./output-models/1500/",
11
+ scheduler=scheduler,
12
+ safety_checker=None,
13
+ torch_dtype=torch.float16,
14
+ ).to(device)
15
+
16
+ # enable xformers memory attention
17
+ pipe.enable_xformers_memory_efficient_attention()
18
+
19
+ prompt = "photo of zwx dog with Texas bluebonnet"
20
+ negative_prompt = ""
21
+ num_samples = 4
22
+ guidance_scale = 7.5
23
+ num_inference_steps = 50
24
+ height = 512
25
+ width = 512
26
+
27
+ with torch.autocast("cuda"), torch.inference_mode():
28
+ images = pipe(
29
+ prompt,
30
+ height=height,
31
+ width=width,
32
+ negative_prompt=negative_prompt,
33
+ num_images_per_prompt=num_samples,
34
+ num_inference_steps=num_inference_steps,
35
+ guidance_scale=guidance_scale
36
+ ).images
37
+
38
+ count = 1
39
+ for image in images:
40
+ # save image to local directory
41
+ image.save(f"img-{count}.png")
42
+ count += 1
43
+
instance-images/image-1.jpg ADDED
instance-images/image-10.jpg ADDED
instance-images/image-11.jpg ADDED
instance-images/image-12.jpg ADDED
instance-images/image-13.jpg ADDED
instance-images/image-14.jpg ADDED
instance-images/image-15.jpg ADDED
instance-images/image-2.jpg ADDED
instance-images/image-3.jpg ADDED
instance-images/image-4.jpg ADDED
instance-images/image-5.jpg ADDED
instance-images/image-6.jpg ADDED
instance-images/image-7.jpg ADDED
instance-images/image-8.jpg ADDED
instance-images/image-9.jpg ADDED
output-models/2000/args.json ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "pretrained_model_name_or_path": "/store/diffusers/huggingface/models--runwayml--stable-diffusion-v1-5/snapshots/63534535d4730d5976c5c647a7f2adaea1102f5b",
3
+ "pretrained_vae_name_or_path": "/store/diffusers/huggingface/models--stabilityai--sd-vae-ft-mse/snapshots/07d70db1bb648cea307cb6b9e32a50c8655a08e7",
4
+ "revision": null,
5
+ "tokenizer_name": null,
6
+ "instance_data_dir": "/store/diffusers/instance-images/",
7
+ "class_data_dir": "/store/diffusers/class-images/",
8
+ "instance_prompt": "photo of zwx dog",
9
+ "class_prompt": "photo of dog",
10
+ "save_sample_prompt": null,
11
+ "save_sample_negative_prompt": null,
12
+ "n_save_sample": 4,
13
+ "save_guidance_scale": 7.5,
14
+ "save_infer_steps": 50,
15
+ "pad_tokens": false,
16
+ "with_prior_preservation": true,
17
+ "prior_loss_weight": 1.0,
18
+ "num_class_images": 300,
19
+ "output_dir": "/store/diffusers/output-models/",
20
+ "seed": null,
21
+ "resolution": 512,
22
+ "center_crop": false,
23
+ "train_text_encoder": false,
24
+ "train_batch_size": 1,
25
+ "sample_batch_size": 4,
26
+ "num_train_epochs": 7,
27
+ "max_train_steps": 2000,
28
+ "gradient_accumulation_steps": 1,
29
+ "gradient_checkpointing": true,
30
+ "learning_rate": 1e-06,
31
+ "scale_lr": false,
32
+ "lr_scheduler": "constant",
33
+ "lr_warmup_steps": 200,
34
+ "use_8bit_adam": true,
35
+ "adam_beta1": 0.9,
36
+ "adam_beta2": 0.999,
37
+ "adam_weight_decay": 0.01,
38
+ "adam_epsilon": 1e-08,
39
+ "max_grad_norm": 1.0,
40
+ "push_to_hub": false,
41
+ "hub_token": null,
42
+ "hub_model_id": null,
43
+ "logging_dir": "logs",
44
+ "log_interval": 10,
45
+ "save_interval": 500,
46
+ "save_min_steps": 0,
47
+ "mixed_precision": "fp16",
48
+ "not_cache_latents": false,
49
+ "hflip": false,
50
+ "local_rank": -1,
51
+ "concepts_list": [
52
+ {
53
+ "instance_prompt": "photo of zwx dog",
54
+ "class_prompt": "photo of dog",
55
+ "instance_data_dir": "/store/diffusers/instance-images/",
56
+ "class_data_dir": "/store/diffusers/class-images/"
57
+ }
58
+ ]
59
+ }
output-models/2000/model_index.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "StableDiffusionPipeline",
3
+ "_diffusers_version": "0.9.0",
4
+ "requires_safety_checker": false,
5
+ "scheduler": [
6
+ "diffusers",
7
+ "DDIMScheduler"
8
+ ],
9
+ "text_encoder": [
10
+ "transformers",
11
+ "CLIPTextModel"
12
+ ],
13
+ "tokenizer": [
14
+ "transformers",
15
+ "CLIPTokenizer"
16
+ ],
17
+ "unet": [
18
+ "diffusers",
19
+ "UNet2DConditionModel"
20
+ ],
21
+ "vae": [
22
+ "diffusers",
23
+ "AutoencoderKL"
24
+ ]
25
+ }
output-models/2000/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DDIMScheduler",
3
+ "_diffusers_version": "0.9.0",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.00085,
7
+ "clip_sample": false,
8
+ "num_train_timesteps": 1000,
9
+ "prediction_type": "epsilon",
10
+ "set_alpha_to_one": false,
11
+ "steps_offset": 1,
12
+ "trained_betas": null
13
+ }
output-models/2000/text_encoder/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/store/diffusers/huggingface/models--runwayml--stable-diffusion-v1-5/snapshots/63534535d4730d5976c5c647a7f2adaea1102f5b",
3
+ "architectures": [
4
+ "CLIPTextModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dropout": 0.0,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "quick_gelu",
11
+ "hidden_size": 768,
12
+ "initializer_factor": 1.0,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 77,
17
+ "model_type": "clip_text_model",
18
+ "num_attention_heads": 12,
19
+ "num_hidden_layers": 12,
20
+ "pad_token_id": 1,
21
+ "projection_dim": 768,
22
+ "torch_dtype": "float32",
23
+ "transformers_version": "4.25.1",
24
+ "vocab_size": 49408
25
+ }
output-models/2000/text_encoder/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:770a47a9ffdcfda0b05506a7888ed714d06131d60267e6cf52765d61cf59fd67
3
+ size 492305335
output-models/2000/tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
output-models/2000/tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<|endoftext|>",
17
+ "unk_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
output-models/2000/tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": {
4
+ "__type": "AddedToken",
5
+ "content": "<|startoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false
10
+ },
11
+ "do_lower_case": true,
12
+ "eos_token": {
13
+ "__type": "AddedToken",
14
+ "content": "<|endoftext|>",
15
+ "lstrip": false,
16
+ "normalized": true,
17
+ "rstrip": false,
18
+ "single_word": false
19
+ },
20
+ "errors": "replace",
21
+ "model_max_length": 77,
22
+ "name_or_path": "/store/diffusers/huggingface/models--runwayml--stable-diffusion-v1-5/snapshots/63534535d4730d5976c5c647a7f2adaea1102f5b/tokenizer",
23
+ "pad_token": "<|endoftext|>",
24
+ "special_tokens_map_file": "./special_tokens_map.json",
25
+ "tokenizer_class": "CLIPTokenizer",
26
+ "unk_token": {
27
+ "__type": "AddedToken",
28
+ "content": "<|endoftext|>",
29
+ "lstrip": false,
30
+ "normalized": true,
31
+ "rstrip": false,
32
+ "single_word": false
33
+ }
34
+ }
output-models/2000/tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
output-models/2000/unet/config.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNet2DConditionModel",
3
+ "_diffusers_version": "0.9.0",
4
+ "_name_or_path": "/store/diffusers/huggingface/models--runwayml--stable-diffusion-v1-5/snapshots/63534535d4730d5976c5c647a7f2adaea1102f5b",
5
+ "act_fn": "silu",
6
+ "attention_head_dim": 8,
7
+ "block_out_channels": [
8
+ 320,
9
+ 640,
10
+ 1280,
11
+ 1280
12
+ ],
13
+ "center_input_sample": false,
14
+ "cross_attention_dim": 768,
15
+ "down_block_types": [
16
+ "CrossAttnDownBlock2D",
17
+ "CrossAttnDownBlock2D",
18
+ "CrossAttnDownBlock2D",
19
+ "DownBlock2D"
20
+ ],
21
+ "downsample_padding": 1,
22
+ "dual_cross_attention": false,
23
+ "flip_sin_to_cos": true,
24
+ "freq_shift": 0,
25
+ "in_channels": 4,
26
+ "layers_per_block": 2,
27
+ "mid_block_scale_factor": 1,
28
+ "norm_eps": 1e-05,
29
+ "norm_num_groups": 32,
30
+ "num_class_embeds": null,
31
+ "only_cross_attention": false,
32
+ "out_channels": 4,
33
+ "sample_size": 64,
34
+ "up_block_types": [
35
+ "UpBlock2D",
36
+ "CrossAttnUpBlock2D",
37
+ "CrossAttnUpBlock2D",
38
+ "CrossAttnUpBlock2D"
39
+ ],
40
+ "use_linear_projection": false
41
+ }
output-models/2000/unet/diffusion_pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d96b8d3d9e6c71ac852d3df7f815a896d7f97dd56780c5950e323e4273efb856
3
+ size 3438364325
output-models/2000/vae/config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.9.0",
4
+ "_name_or_path": "/store/diffusers/huggingface/models--stabilityai--sd-vae-ft-mse/snapshots/07d70db1bb648cea307cb6b9e32a50c8655a08e7",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "in_channels": 3,
19
+ "latent_channels": 4,
20
+ "layers_per_block": 2,
21
+ "norm_num_groups": 32,
22
+ "out_channels": 3,
23
+ "sample_size": 256,
24
+ "up_block_types": [
25
+ "UpDecoderBlock2D",
26
+ "UpDecoderBlock2D",
27
+ "UpDecoderBlock2D",
28
+ "UpDecoderBlock2D"
29
+ ]
30
+ }
output-models/2000/vae/diffusion_pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b4889b6b1d4ce7ae320a02dedaeff1780ad77d415ea0d744b476155c6377ddc
3
+ size 334707217
train_dreambooth.py ADDED
@@ -0,0 +1,822 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import hashlib
3
+ import itertools
4
+ import random
5
+ import json
6
+ import math
7
+ import os
8
+ from contextlib import nullcontext
9
+ from pathlib import Path
10
+ from typing import Optional
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ import torch.utils.checkpoint
15
+ from torch.utils.data import Dataset
16
+
17
+ from accelerate import Accelerator
18
+ from accelerate.logging import get_logger
19
+ from accelerate.utils import set_seed
20
+ from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
21
+ from diffusers.optimization import get_scheduler
22
+ from huggingface_hub import HfFolder, Repository, whoami
23
+ from PIL import Image
24
+ from torchvision import transforms
25
+ from tqdm.auto import tqdm
26
+ from transformers import CLIPTextModel, CLIPTokenizer
27
+
28
+
29
+ torch.backends.cudnn.benchmark = True
30
+
31
+
32
+ logger = get_logger(__name__)
33
+
34
+
35
+ def parse_args(input_args=None):
36
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
37
+ parser.add_argument(
38
+ "--pretrained_model_name_or_path",
39
+ type=str,
40
+ default=None,
41
+ required=True,
42
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
43
+ )
44
+ parser.add_argument(
45
+ "--pretrained_vae_name_or_path",
46
+ type=str,
47
+ default=None,
48
+ help="Path to pretrained vae or vae identifier from huggingface.co/models.",
49
+ )
50
+ parser.add_argument(
51
+ "--revision",
52
+ type=str,
53
+ default=None,
54
+ required=False,
55
+ help="Revision of pretrained model identifier from huggingface.co/models.",
56
+ )
57
+ parser.add_argument(
58
+ "--tokenizer_name",
59
+ type=str,
60
+ default=None,
61
+ help="Pretrained tokenizer name or path if not the same as model_name",
62
+ )
63
+ parser.add_argument(
64
+ "--instance_data_dir",
65
+ type=str,
66
+ default=None,
67
+ help="A folder containing the training data of instance images.",
68
+ )
69
+ parser.add_argument(
70
+ "--class_data_dir",
71
+ type=str,
72
+ default=None,
73
+ help="A folder containing the training data of class images.",
74
+ )
75
+ parser.add_argument(
76
+ "--instance_prompt",
77
+ type=str,
78
+ default=None,
79
+ help="The prompt with identifier specifying the instance",
80
+ )
81
+ parser.add_argument(
82
+ "--class_prompt",
83
+ type=str,
84
+ default=None,
85
+ help="The prompt to specify images in the same class as provided instance images.",
86
+ )
87
+ parser.add_argument(
88
+ "--save_sample_prompt",
89
+ type=str,
90
+ default=None,
91
+ help="The prompt used to generate sample outputs to save.",
92
+ )
93
+ parser.add_argument(
94
+ "--save_sample_negative_prompt",
95
+ type=str,
96
+ default=None,
97
+ help="The negative prompt used to generate sample outputs to save.",
98
+ )
99
+ parser.add_argument(
100
+ "--n_save_sample",
101
+ type=int,
102
+ default=4,
103
+ help="The number of samples to save.",
104
+ )
105
+ parser.add_argument(
106
+ "--save_guidance_scale",
107
+ type=float,
108
+ default=7.5,
109
+ help="CFG for save sample.",
110
+ )
111
+ parser.add_argument(
112
+ "--save_infer_steps",
113
+ type=int,
114
+ default=50,
115
+ help="The number of inference steps for save sample.",
116
+ )
117
+ parser.add_argument(
118
+ "--pad_tokens",
119
+ default=False,
120
+ action="store_true",
121
+ help="Flag to pad tokens to length 77.",
122
+ )
123
+ parser.add_argument(
124
+ "--with_prior_preservation",
125
+ default=False,
126
+ action="store_true",
127
+ help="Flag to add prior preservation loss.",
128
+ )
129
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
130
+ parser.add_argument(
131
+ "--num_class_images",
132
+ type=int,
133
+ default=100,
134
+ help=(
135
+ "Minimal class images for prior preservation loss. If not have enough images, additional images will be"
136
+ " sampled with class_prompt."
137
+ ),
138
+ )
139
+ parser.add_argument(
140
+ "--output_dir",
141
+ type=str,
142
+ default="text-inversion-model",
143
+ help="The output directory where the model predictions and checkpoints will be written.",
144
+ )
145
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
146
+ parser.add_argument(
147
+ "--resolution",
148
+ type=int,
149
+ default=512,
150
+ help=(
151
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
152
+ " resolution"
153
+ ),
154
+ )
155
+ parser.add_argument(
156
+ "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
157
+ )
158
+ parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
159
+ parser.add_argument(
160
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
161
+ )
162
+ parser.add_argument(
163
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
164
+ )
165
+ parser.add_argument("--num_train_epochs", type=int, default=1)
166
+ parser.add_argument(
167
+ "--max_train_steps",
168
+ type=int,
169
+ default=None,
170
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
171
+ )
172
+ parser.add_argument(
173
+ "--gradient_accumulation_steps",
174
+ type=int,
175
+ default=1,
176
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
177
+ )
178
+ parser.add_argument(
179
+ "--gradient_checkpointing",
180
+ action="store_true",
181
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
182
+ )
183
+ parser.add_argument(
184
+ "--learning_rate",
185
+ type=float,
186
+ default=5e-6,
187
+ help="Initial learning rate (after the potential warmup period) to use.",
188
+ )
189
+ parser.add_argument(
190
+ "--scale_lr",
191
+ action="store_true",
192
+ default=False,
193
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
194
+ )
195
+ parser.add_argument(
196
+ "--lr_scheduler",
197
+ type=str,
198
+ default="constant",
199
+ help=(
200
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
201
+ ' "constant", "constant_with_warmup"]'
202
+ ),
203
+ )
204
+ parser.add_argument(
205
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
206
+ )
207
+ parser.add_argument(
208
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
209
+ )
210
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
211
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
212
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
213
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
214
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
215
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
216
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
217
+ parser.add_argument(
218
+ "--hub_model_id",
219
+ type=str,
220
+ default=None,
221
+ help="The name of the repository to keep in sync with the local `output_dir`.",
222
+ )
223
+ parser.add_argument(
224
+ "--logging_dir",
225
+ type=str,
226
+ default="logs",
227
+ help=(
228
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
229
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
230
+ ),
231
+ )
232
+ parser.add_argument("--log_interval", type=int, default=10, help="Log every N steps.")
233
+ parser.add_argument("--save_interval", type=int, default=10_000, help="Save weights every N steps.")
234
+ parser.add_argument("--save_min_steps", type=int, default=0, help="Start saving weights after N steps.")
235
+ parser.add_argument(
236
+ "--mixed_precision",
237
+ type=str,
238
+ default=None,
239
+ choices=["no", "fp16", "bf16"],
240
+ help=(
241
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
242
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
243
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
244
+ ),
245
+ )
246
+ parser.add_argument("--not_cache_latents", action="store_true", help="Do not precompute and cache latents from VAE.")
247
+ parser.add_argument("--hflip", action="store_true", help="Apply horizontal flip data augmentation.")
248
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
249
+ parser.add_argument(
250
+ "--concepts_list",
251
+ type=str,
252
+ default=None,
253
+ help="Path to json containing multiple concepts, will overwrite parameters like instance_prompt, class_prompt, etc.",
254
+ )
255
+
256
+ if input_args is not None:
257
+ args = parser.parse_args(input_args)
258
+ else:
259
+ args = parser.parse_args()
260
+
261
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
262
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
263
+ args.local_rank = env_local_rank
264
+
265
+ return args
266
+
267
+
268
+ class DreamBoothDataset(Dataset):
269
+ """
270
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
271
+ It pre-processes the images and the tokenizes prompts.
272
+ """
273
+
274
+ def __init__(
275
+ self,
276
+ concepts_list,
277
+ tokenizer,
278
+ with_prior_preservation=True,
279
+ size=512,
280
+ center_crop=False,
281
+ num_class_images=None,
282
+ pad_tokens=False,
283
+ hflip=False
284
+ ):
285
+ self.size = size
286
+ self.center_crop = center_crop
287
+ self.tokenizer = tokenizer
288
+ self.with_prior_preservation = with_prior_preservation
289
+ self.pad_tokens = pad_tokens
290
+
291
+ self.instance_images_path = []
292
+ self.class_images_path = []
293
+
294
+ for concept in concepts_list:
295
+ inst_img_path = [(x, concept["instance_prompt"]) for x in Path(concept["instance_data_dir"]).iterdir() if x.is_file()]
296
+ self.instance_images_path.extend(inst_img_path)
297
+
298
+ if with_prior_preservation:
299
+ class_img_path = [(x, concept["class_prompt"]) for x in Path(concept["class_data_dir"]).iterdir() if x.is_file()]
300
+ self.class_images_path.extend(class_img_path[:num_class_images])
301
+
302
+ random.shuffle(self.instance_images_path)
303
+ self.num_instance_images = len(self.instance_images_path)
304
+ self.num_class_images = len(self.class_images_path)
305
+ self._length = max(self.num_class_images, self.num_instance_images)
306
+
307
+ self.image_transforms = transforms.Compose(
308
+ [
309
+ transforms.RandomHorizontalFlip(0.5 * hflip),
310
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
311
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
312
+ transforms.ToTensor(),
313
+ transforms.Normalize([0.5], [0.5]),
314
+ ]
315
+ )
316
+
317
+ def __len__(self):
318
+ return self._length
319
+
320
+ def __getitem__(self, index):
321
+ example = {}
322
+ instance_path, instance_prompt = self.instance_images_path[index % self.num_instance_images]
323
+ instance_image = Image.open(instance_path)
324
+ if not instance_image.mode == "RGB":
325
+ instance_image = instance_image.convert("RGB")
326
+ example["instance_images"] = self.image_transforms(instance_image)
327
+ example["instance_prompt_ids"] = self.tokenizer(
328
+ instance_prompt,
329
+ padding="max_length" if self.pad_tokens else "do_not_pad",
330
+ truncation=True,
331
+ max_length=self.tokenizer.model_max_length,
332
+ ).input_ids
333
+
334
+ if self.with_prior_preservation:
335
+ class_path, class_prompt = self.class_images_path[index % self.num_class_images]
336
+ class_image = Image.open(class_path)
337
+ if not class_image.mode == "RGB":
338
+ class_image = class_image.convert("RGB")
339
+ example["class_images"] = self.image_transforms(class_image)
340
+ example["class_prompt_ids"] = self.tokenizer(
341
+ class_prompt,
342
+ padding="max_length" if self.pad_tokens else "do_not_pad",
343
+ truncation=True,
344
+ max_length=self.tokenizer.model_max_length,
345
+ ).input_ids
346
+
347
+ return example
348
+
349
+
350
+ class PromptDataset(Dataset):
351
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
352
+
353
+ def __init__(self, prompt, num_samples):
354
+ self.prompt = prompt
355
+ self.num_samples = num_samples
356
+
357
+ def __len__(self):
358
+ return self.num_samples
359
+
360
+ def __getitem__(self, index):
361
+ example = {}
362
+ example["prompt"] = self.prompt
363
+ example["index"] = index
364
+ return example
365
+
366
+
367
+ class LatentsDataset(Dataset):
368
+ def __init__(self, latents_cache, text_encoder_cache):
369
+ self.latents_cache = latents_cache
370
+ self.text_encoder_cache = text_encoder_cache
371
+
372
+ def __len__(self):
373
+ return len(self.latents_cache)
374
+
375
+ def __getitem__(self, index):
376
+ return self.latents_cache[index], self.text_encoder_cache[index]
377
+
378
+
379
+ class AverageMeter:
380
+ def __init__(self, name=None):
381
+ self.name = name
382
+ self.reset()
383
+
384
+ def reset(self):
385
+ self.sum = self.count = self.avg = 0
386
+
387
+ def update(self, val, n=1):
388
+ self.sum += val * n
389
+ self.count += n
390
+ self.avg = self.sum / self.count
391
+
392
+
393
+ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
394
+ if token is None:
395
+ token = HfFolder.get_token()
396
+ if organization is None:
397
+ username = whoami(token)["name"]
398
+ return f"{username}/{model_id}"
399
+ else:
400
+ return f"{organization}/{model_id}"
401
+
402
+
403
+ def main(args):
404
+ logging_dir = Path(args.output_dir, "0", args.logging_dir)
405
+
406
+ accelerator = Accelerator(
407
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
408
+ mixed_precision=args.mixed_precision,
409
+ log_with="tensorboard",
410
+ logging_dir=logging_dir,
411
+ )
412
+
413
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
414
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
415
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
416
+ if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
417
+ raise ValueError(
418
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
419
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
420
+ )
421
+
422
+ if args.seed is not None:
423
+ set_seed(args.seed)
424
+
425
+ if args.concepts_list is None:
426
+ args.concepts_list = [
427
+ {
428
+ "instance_prompt": args.instance_prompt,
429
+ "class_prompt": args.class_prompt,
430
+ "instance_data_dir": args.instance_data_dir,
431
+ "class_data_dir": args.class_data_dir
432
+ }
433
+ ]
434
+ else:
435
+ with open(args.concepts_list, "r") as f:
436
+ args.concepts_list = json.load(f)
437
+
438
+ if args.with_prior_preservation:
439
+ pipeline = None
440
+ for concept in args.concepts_list:
441
+ class_images_dir = Path(concept["class_data_dir"])
442
+ class_images_dir.mkdir(parents=True, exist_ok=True)
443
+ cur_class_images = len(list(class_images_dir.iterdir()))
444
+
445
+ if cur_class_images < args.num_class_images:
446
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
447
+ if pipeline is None:
448
+ pipeline = StableDiffusionPipeline.from_pretrained(
449
+ args.pretrained_model_name_or_path,
450
+ vae=AutoencoderKL.from_pretrained(
451
+ args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path,
452
+ subfolder=None if args.pretrained_vae_name_or_path else "vae",
453
+ revision=None if args.pretrained_vae_name_or_path else args.revision,
454
+ torch_dtype=torch_dtype
455
+ ),
456
+ torch_dtype=torch_dtype,
457
+ safety_checker=None,
458
+ revision=args.revision
459
+ )
460
+ pipeline.set_progress_bar_config(disable=True)
461
+ pipeline.to(accelerator.device)
462
+
463
+ num_new_images = args.num_class_images - cur_class_images
464
+ logger.info(f"Number of class images to sample: {num_new_images}.")
465
+
466
+ sample_dataset = PromptDataset(concept["class_prompt"], num_new_images)
467
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
468
+
469
+ sample_dataloader = accelerator.prepare(sample_dataloader)
470
+
471
+ with torch.autocast("cuda"), torch.inference_mode():
472
+ for example in tqdm(
473
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
474
+ ):
475
+ images = pipeline(example["prompt"]).images
476
+
477
+ for i, image in enumerate(images):
478
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
479
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
480
+ image.save(image_filename)
481
+
482
+ del pipeline
483
+ if torch.cuda.is_available():
484
+ torch.cuda.empty_cache()
485
+
486
+ # Load the tokenizer
487
+ if args.tokenizer_name:
488
+ tokenizer = CLIPTokenizer.from_pretrained(
489
+ args.tokenizer_name,
490
+ revision=args.revision,
491
+ )
492
+ elif args.pretrained_model_name_or_path:
493
+ tokenizer = CLIPTokenizer.from_pretrained(
494
+ args.pretrained_model_name_or_path,
495
+ subfolder="tokenizer",
496
+ revision=args.revision,
497
+ )
498
+
499
+ # Load models and create wrapper for stable diffusion
500
+ text_encoder = CLIPTextModel.from_pretrained(
501
+ args.pretrained_model_name_or_path,
502
+ subfolder="text_encoder",
503
+ revision=args.revision,
504
+ )
505
+ vae = AutoencoderKL.from_pretrained(
506
+ args.pretrained_model_name_or_path,
507
+ subfolder="vae",
508
+ revision=args.revision,
509
+ )
510
+ unet = UNet2DConditionModel.from_pretrained(
511
+ args.pretrained_model_name_or_path,
512
+ subfolder="unet",
513
+ revision=args.revision,
514
+ torch_dtype=torch.float32
515
+ )
516
+
517
+ vae.requires_grad_(False)
518
+ if not args.train_text_encoder:
519
+ text_encoder.requires_grad_(False)
520
+
521
+ if args.gradient_checkpointing:
522
+ unet.enable_gradient_checkpointing()
523
+ if args.train_text_encoder:
524
+ text_encoder.gradient_checkpointing_enable()
525
+
526
+ if args.scale_lr:
527
+ args.learning_rate = (
528
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
529
+ )
530
+
531
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
532
+ if args.use_8bit_adam:
533
+ try:
534
+ import bitsandbytes as bnb
535
+ except ImportError:
536
+ raise ImportError(
537
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
538
+ )
539
+
540
+ optimizer_class = bnb.optim.AdamW8bit
541
+ else:
542
+ optimizer_class = torch.optim.AdamW
543
+
544
+ params_to_optimize = (
545
+ itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
546
+ )
547
+ optimizer = optimizer_class(
548
+ params_to_optimize,
549
+ lr=args.learning_rate,
550
+ betas=(args.adam_beta1, args.adam_beta2),
551
+ weight_decay=args.adam_weight_decay,
552
+ eps=args.adam_epsilon,
553
+ )
554
+
555
+ noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
556
+
557
+ train_dataset = DreamBoothDataset(
558
+ concepts_list=args.concepts_list,
559
+ tokenizer=tokenizer,
560
+ with_prior_preservation=args.with_prior_preservation,
561
+ size=args.resolution,
562
+ center_crop=args.center_crop,
563
+ num_class_images=args.num_class_images,
564
+ pad_tokens=args.pad_tokens,
565
+ hflip=args.hflip
566
+ )
567
+
568
+ def collate_fn(examples):
569
+ input_ids = [example["instance_prompt_ids"] for example in examples]
570
+ pixel_values = [example["instance_images"] for example in examples]
571
+
572
+ # Concat class and instance examples for prior preservation.
573
+ # We do this to avoid doing two forward passes.
574
+ if args.with_prior_preservation:
575
+ input_ids += [example["class_prompt_ids"] for example in examples]
576
+ pixel_values += [example["class_images"] for example in examples]
577
+
578
+ pixel_values = torch.stack(pixel_values)
579
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
580
+
581
+ input_ids = tokenizer.pad(
582
+ {"input_ids": input_ids},
583
+ padding=True,
584
+ return_tensors="pt",
585
+ ).input_ids
586
+
587
+ batch = {
588
+ "input_ids": input_ids,
589
+ "pixel_values": pixel_values,
590
+ }
591
+ return batch
592
+
593
+ train_dataloader = torch.utils.data.DataLoader(
594
+ train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, pin_memory=True
595
+ )
596
+
597
+ weight_dtype = torch.float32
598
+ if args.mixed_precision == "fp16":
599
+ weight_dtype = torch.float16
600
+ elif args.mixed_precision == "bf16":
601
+ weight_dtype = torch.bfloat16
602
+
603
+ # Move text_encode and vae to gpu.
604
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
605
+ # as these models are only used for inference, keeping weights in full precision is not required.
606
+ vae.to(accelerator.device, dtype=weight_dtype)
607
+ if not args.train_text_encoder:
608
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
609
+
610
+ if not args.not_cache_latents:
611
+ latents_cache = []
612
+ text_encoder_cache = []
613
+ for batch in tqdm(train_dataloader, desc="Caching latents"):
614
+ with torch.no_grad():
615
+ batch["pixel_values"] = batch["pixel_values"].to(accelerator.device, non_blocking=True, dtype=weight_dtype)
616
+ batch["input_ids"] = batch["input_ids"].to(accelerator.device, non_blocking=True)
617
+ latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
618
+ if args.train_text_encoder:
619
+ text_encoder_cache.append(batch["input_ids"])
620
+ else:
621
+ text_encoder_cache.append(text_encoder(batch["input_ids"])[0])
622
+ train_dataset = LatentsDataset(latents_cache, text_encoder_cache)
623
+ train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, collate_fn=lambda x: x, shuffle=True)
624
+
625
+ del vae
626
+ if not args.train_text_encoder:
627
+ del text_encoder
628
+ if torch.cuda.is_available():
629
+ torch.cuda.empty_cache()
630
+
631
+ # Scheduler and math around the number of training steps.
632
+ overrode_max_train_steps = False
633
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
634
+ if args.max_train_steps is None:
635
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
636
+ overrode_max_train_steps = True
637
+
638
+ lr_scheduler = get_scheduler(
639
+ args.lr_scheduler,
640
+ optimizer=optimizer,
641
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
642
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
643
+ )
644
+
645
+ if args.train_text_encoder:
646
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
647
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
648
+ )
649
+ else:
650
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
651
+ unet, optimizer, train_dataloader, lr_scheduler
652
+ )
653
+
654
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
655
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
656
+ if overrode_max_train_steps:
657
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
658
+ # Afterwards we recalculate our number of training epochs
659
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
660
+
661
+ # We need to initialize the trackers we use, and also store our configuration.
662
+ # The trackers initializes automatically on the main process.
663
+ if accelerator.is_main_process:
664
+ accelerator.init_trackers("dreambooth")
665
+
666
+ # Train!
667
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
668
+
669
+ logger.info("***** Running training *****")
670
+ logger.info(f" Num examples = {len(train_dataset)}")
671
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
672
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
673
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
674
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
675
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
676
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
677
+
678
+ def save_weights(step):
679
+ # Create the pipeline using using the trained modules and save it.
680
+ if accelerator.is_main_process:
681
+ if args.train_text_encoder:
682
+ text_enc_model = accelerator.unwrap_model(text_encoder)
683
+ else:
684
+ text_enc_model = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision)
685
+ scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
686
+ pipeline = StableDiffusionPipeline.from_pretrained(
687
+ args.pretrained_model_name_or_path,
688
+ unet=accelerator.unwrap_model(unet),
689
+ text_encoder=text_enc_model,
690
+ vae=AutoencoderKL.from_pretrained(
691
+ args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path,
692
+ subfolder=None if args.pretrained_vae_name_or_path else "vae",
693
+ revision=None if args.pretrained_vae_name_or_path else args.revision,
694
+ ),
695
+ safety_checker=None,
696
+ scheduler=scheduler,
697
+ torch_dtype=torch.float16,
698
+ revision=args.revision,
699
+ )
700
+ save_dir = os.path.join(args.output_dir, f"{step}")
701
+ pipeline.save_pretrained(save_dir)
702
+ with open(os.path.join(save_dir, "args.json"), "w") as f:
703
+ json.dump(args.__dict__, f, indent=2)
704
+
705
+ if args.save_sample_prompt is not None:
706
+ pipeline = pipeline.to(accelerator.device)
707
+ g_cuda = torch.Generator(device=accelerator.device).manual_seed(args.seed)
708
+ pipeline.set_progress_bar_config(disable=True)
709
+ sample_dir = os.path.join(save_dir, "samples")
710
+ os.makedirs(sample_dir, exist_ok=True)
711
+ with torch.autocast("cuda"), torch.inference_mode():
712
+ for i in tqdm(range(args.n_save_sample), desc="Generating samples"):
713
+ images = pipeline(
714
+ args.save_sample_prompt,
715
+ negative_prompt=args.save_sample_negative_prompt,
716
+ guidance_scale=args.save_guidance_scale,
717
+ num_inference_steps=args.save_infer_steps,
718
+ generator=g_cuda
719
+ ).images
720
+ images[0].save(os.path.join(sample_dir, f"{i}.png"))
721
+ del pipeline
722
+ if torch.cuda.is_available():
723
+ torch.cuda.empty_cache()
724
+ print(f"[*] Weights saved at {save_dir}")
725
+
726
+ # Only show the progress bar once on each machine.
727
+ progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
728
+ progress_bar.set_description("Steps")
729
+ global_step = 0
730
+ loss_avg = AverageMeter()
731
+ text_enc_context = nullcontext() if args.train_text_encoder else torch.no_grad()
732
+ for epoch in range(args.num_train_epochs):
733
+ unet.train()
734
+ if args.train_text_encoder:
735
+ text_encoder.train()
736
+ for step, batch in enumerate(train_dataloader):
737
+ with accelerator.accumulate(unet):
738
+ # Convert images to latent space
739
+ with torch.no_grad():
740
+ if not args.not_cache_latents:
741
+ latent_dist = batch[0][0]
742
+ else:
743
+ latent_dist = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist
744
+ latents = latent_dist.sample() * 0.18215
745
+
746
+ # Sample noise that we'll add to the latents
747
+ noise = torch.randn_like(latents)
748
+ bsz = latents.shape[0]
749
+ # Sample a random timestep for each image
750
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
751
+ timesteps = timesteps.long()
752
+
753
+ # Add noise to the latents according to the noise magnitude at each timestep
754
+ # (this is the forward diffusion process)
755
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
756
+
757
+ # Get the text embedding for conditioning
758
+ with text_enc_context:
759
+ if not args.not_cache_latents:
760
+ if args.train_text_encoder:
761
+ encoder_hidden_states = text_encoder(batch[0][1])[0]
762
+ else:
763
+ encoder_hidden_states = batch[0][1]
764
+ else:
765
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
766
+
767
+ # Predict the noise residual
768
+ noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
769
+
770
+ if args.with_prior_preservation:
771
+ # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
772
+ noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
773
+ noise, noise_prior = torch.chunk(noise, 2, dim=0)
774
+
775
+ # Compute instance loss
776
+ loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
777
+
778
+ # Compute prior loss
779
+ prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
780
+
781
+ # Add the prior loss to the instance loss.
782
+ loss = loss + args.prior_loss_weight * prior_loss
783
+ else:
784
+ loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
785
+
786
+ accelerator.backward(loss)
787
+ # if accelerator.sync_gradients:
788
+ # params_to_clip = (
789
+ # itertools.chain(unet.parameters(), text_encoder.parameters())
790
+ # if args.train_text_encoder
791
+ # else unet.parameters()
792
+ # )
793
+ # accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
794
+ optimizer.step()
795
+ lr_scheduler.step()
796
+ optimizer.zero_grad(set_to_none=True)
797
+ loss_avg.update(loss.detach_(), bsz)
798
+
799
+ if not global_step % args.log_interval:
800
+ logs = {"loss": loss_avg.avg.item(), "lr": lr_scheduler.get_last_lr()[0]}
801
+ progress_bar.set_postfix(**logs)
802
+ accelerator.log(logs, step=global_step)
803
+
804
+ if global_step > 0 and not global_step % args.save_interval and global_step >= args.save_min_steps:
805
+ save_weights(global_step)
806
+
807
+ progress_bar.update(1)
808
+ global_step += 1
809
+
810
+ if global_step >= args.max_train_steps:
811
+ break
812
+
813
+ accelerator.wait_for_everyone()
814
+
815
+ save_weights(global_step)
816
+
817
+ accelerator.end_training()
818
+
819
+
820
+ if __name__ == "__main__":
821
+ args = parse_args()
822
+ main(args)