jpawan33 commited on
Commit
3119683
1 Parent(s): f547584

End of training

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ dog/alvan-nee-Id1DBHv4fbg-unsplash.jpeg filter=lfs diff=lfs merge=lfs -text
37
+ dog/alvan-nee-bQaAJCbNq3g-unsplash.jpeg filter=lfs diff=lfs merge=lfs -text
38
+ dog/alvan-nee-brFsZ7qszSY-unsplash.jpeg filter=lfs diff=lfs merge=lfs -text
39
+ dog/alvan-nee-eoqnr8ikwFE-unsplash.jpeg filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ---
3
+ license: creativeml-openrail-m
4
+ base_model: CompVis/stable-diffusion-v1-4
5
+ instance_prompt: a photo of sks dog
6
+ tags:
7
+ - stable-diffusion
8
+ - stable-diffusion-diffusers
9
+ - text-to-image
10
+ - diffusers
11
+ - dreambooth
12
+ inference: true
13
+ ---
14
+
15
+ # DreamBooth - jpawan33/dreambooth
16
+
17
+ This is a dreambooth model derived from CompVis/stable-diffusion-v1-4. The weights were trained on a photo of sks dog using [DreamBooth](https://dreambooth.github.io/).
18
+ You can find some example images in the following.
19
+
20
+
21
+
22
+ DreamBooth for the text encoder was enabled: False.
README_sdxl.md ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DreamBooth training example for Stable Diffusion XL (SDXL)
2
+
3
+ [DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.
4
+
5
+ The `train_dreambooth_lora_sdxl.py` script shows how to implement the training procedure and adapt it for [Stable Diffusion XL](https://huggingface.co/papers/2307.01952).
6
+
7
+ > 💡 **Note**: For now, we only allow DreamBooth fine-tuning of the SDXL UNet via LoRA. LoRA is a parameter-efficient fine-tuning technique introduced in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
8
+
9
+ ## Running locally with PyTorch
10
+
11
+ ### Installing the dependencies
12
+
13
+ Before running the scripts, make sure to install the library's training dependencies:
14
+
15
+ **Important**
16
+
17
+ To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
18
+
19
+ ```bash
20
+ git clone https://github.com/huggingface/diffusers
21
+ cd diffusers
22
+ pip install -e .
23
+ ```
24
+
25
+ Then cd in the `examples/dreambooth` folder and run
26
+ ```bash
27
+ pip install -r requirements_sdxl.txt
28
+ ```
29
+
30
+ And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
31
+
32
+ ```bash
33
+ accelerate config
34
+ ```
35
+
36
+ Or for a default accelerate configuration without answering questions about your environment
37
+
38
+ ```bash
39
+ accelerate config default
40
+ ```
41
+
42
+ Or if your environment doesn't support an interactive shell (e.g., a notebook)
43
+
44
+ ```python
45
+ from accelerate.utils import write_basic_config
46
+ write_basic_config()
47
+ ```
48
+
49
+ When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
50
+
51
+ ### Dog toy example
52
+
53
+ Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
54
+
55
+ Let's first download it locally:
56
+
57
+ ```python
58
+ from huggingface_hub import snapshot_download
59
+
60
+ local_dir = "./dog"
61
+ snapshot_download(
62
+ "diffusers/dog-example",
63
+ local_dir=local_dir, repo_type="dataset",
64
+ ignore_patterns=".gitattributes",
65
+ )
66
+ ```
67
+
68
+ This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
69
+
70
+ Now, we can launch training using:
71
+
72
+ ```bash
73
+ export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
74
+ export INSTANCE_DIR="dog"
75
+ export OUTPUT_DIR="lora-trained-xl"
76
+
77
+ accelerate launch train_dreambooth_lora_sdxl.py \
78
+ --pretrained_model_name_or_path=$MODEL_NAME \
79
+ --instance_data_dir=$INSTANCE_DIR \
80
+ --output_dir=$OUTPUT_DIR \
81
+ --mixed_precision="fp16" \
82
+ --instance_prompt="a photo of sks dog" \
83
+ --resolution=1024 \
84
+ --train_batch_size=1 \
85
+ --gradient_accumulation_steps=4 \
86
+ --learning_rate=1e-4 \
87
+ --report_to="wandb" \
88
+ --lr_scheduler="constant" \
89
+ --lr_warmup_steps=0 \
90
+ --max_train_steps=500 \
91
+ --validation_prompt="A photo of sks dog in a bucket" \
92
+ --validation_epochs=25 \
93
+ --seed="0" \
94
+ --push_to_hub
95
+ ```
96
+
97
+ To better track our training experiments, we're using the following flags in the command above:
98
+
99
+ * `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
100
+ * `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
101
+
102
+ Our experiments were conducted on a single 40GB A100 GPU.
103
+
104
+ ### Dog toy example with < 16GB VRAM
105
+
106
+ By making use of [`gradient_checkpointing`](https://pytorch.org/docs/stable/checkpoint.html) (which is natively supported in Diffusers), [`xformers`](https://github.com/facebookresearch/xformers), and [`bitsandbytes`](https://github.com/TimDettmers/bitsandbytes) libraries, you can train SDXL LoRAs with less than 16GB of VRAM by adding the following flags to your accelerate launch command:
107
+
108
+ ```diff
109
+ + --enable_xformers_memory_efficient_attention \
110
+ + --gradient_checkpointing \
111
+ + --use_8bit_adam \
112
+ + --mixed_precision="fp16" \
113
+ ```
114
+
115
+ and making sure that you have the following libraries installed:
116
+
117
+ ```
118
+ bitsandbytes>=0.40.0
119
+ xformers>=0.0.20
120
+ ```
121
+
122
+ ### Inference
123
+
124
+ Once training is done, we can perform inference like so:
125
+
126
+ ```python
127
+ from huggingface_hub.repocard import RepoCard
128
+ from diffusers import DiffusionPipeline
129
+ import torch
130
+
131
+ lora_model_id = <"lora-sdxl-dreambooth-id">
132
+ card = RepoCard.load(lora_model_id)
133
+ base_model_id = card.data.to_dict()["base_model"]
134
+
135
+ pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
136
+ pipe = pipe.to("cuda")
137
+ pipe.load_lora_weights(lora_model_id)
138
+ image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0]
139
+ image.save("sks_dog.png")
140
+ ```
141
+
142
+ We can further refine the outputs with the [Refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0):
143
+
144
+ ```python
145
+ from huggingface_hub.repocard import RepoCard
146
+ from diffusers import DiffusionPipeline, StableDiffusionXLImg2ImgPipeline
147
+ import torch
148
+
149
+ lora_model_id = <"lora-sdxl-dreambooth-id">
150
+ card = RepoCard.load(lora_model_id)
151
+ base_model_id = card.data.to_dict()["base_model"]
152
+
153
+ # Load the base pipeline and load the LoRA parameters into it.
154
+ pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
155
+ pipe = pipe.to("cuda")
156
+ pipe.load_lora_weights(lora_model_id)
157
+
158
+ # Load the refiner.
159
+ refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
160
+ "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
161
+ )
162
+ refiner.to("cuda")
163
+
164
+ prompt = "A picture of a sks dog in a bucket"
165
+ generator = torch.Generator("cuda").manual_seed(0)
166
+
167
+ # Run inference.
168
+ image = pipe(prompt=prompt, output_type="latent", generator=generator).images[0]
169
+ image = refiner(prompt=prompt, image=image[None, :], generator=generator).images[0]
170
+ image.save("refined_sks_dog.png")
171
+ ```
172
+
173
+ Here's a side-by-side comparison of the with and without Refiner pipeline outputs:
174
+
175
+ | Without Refiner | With Refiner |
176
+ |---|---|
177
+ | ![](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/sks_dog.png) | ![](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/refined_sks_dog.png) |
178
+
179
+ ### Training with text encoder(s)
180
+
181
+ Alongside the UNet, LoRA fine-tuning of the text encoders is also supported. To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind:
182
+
183
+ * SDXL has two text encoders. So, we fine-tune both using LoRA.
184
+ * When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memory.
185
+
186
+ ### Specifying a better VAE
187
+
188
+ SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
189
+
190
+ ## Notes
191
+
192
+ In our experiments, we found that SDXL yields good initial results without extensive hyperparameter tuning. For example, without fine-tuning the text encoders and without using prior-preservation, we observed decent results. We didn't explore further hyper-parameter tuning experiments, but we do encourage the community to explore this avenue further and share their results with us 🤗
193
+
194
+ ## Results
195
+
196
+ You can explore the results from a couple of our internal experiments by checking out this link: [https://wandb.ai/sayakpaul/dreambooth-lora-sd-xl](https://wandb.ai/sayakpaul/dreambooth-lora-sd-xl). Specifically, we used the same script with the exact same hyperparameters on the following datasets:
197
+
198
+ * [Dogs](https://huggingface.co/datasets/diffusers/dog-example)
199
+ * [Starbucks logo](https://huggingface.co/datasets/diffusers/starbucks-example)
200
+ * [Mr. Potato Head](https://huggingface.co/datasets/diffusers/potato-head-example)
201
+ * [Keramer face](https://huggingface.co/datasets/diffusers/keramer-face-example)
dog/alvan-nee-9M0tSjb-cpA-unsplash.jpeg ADDED
dog/alvan-nee-Id1DBHv4fbg-unsplash.jpeg ADDED

Git LFS Details

  • SHA256: a65d3a853b7c65dd4d394cb6b209f77666351d2bae7c6670c5677d8eb5981644
  • Pointer size: 132 Bytes
  • Size of remote file: 1.16 MB
dog/alvan-nee-bQaAJCbNq3g-unsplash.jpeg ADDED

Git LFS Details

  • SHA256: 4cda55c53c11843ed368eb8eb68fd79521ac7b839bdd70f8f89589cf7006ed97
  • Pointer size: 132 Bytes
  • Size of remote file: 1.4 MB
dog/alvan-nee-brFsZ7qszSY-unsplash.jpeg ADDED

Git LFS Details

  • SHA256: 9d8013d9efa2edb356e0f88c66de044f71247a99cab52b1628e753c2a08bb602
  • Pointer size: 132 Bytes
  • Size of remote file: 1.19 MB
dog/alvan-nee-eoqnr8ikwFE-unsplash.jpeg ADDED

Git LFS Details

  • SHA256: 5c9805758a8f8950a35df820f3bfc32b3c6ca2a0e0e214a7978ea147a233bd54
  • Pointer size: 132 Bytes
  • Size of remote file: 1.17 MB
feature_extractor/preprocessor_config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": {
3
+ "height": 224,
4
+ "width": 224
5
+ },
6
+ "do_center_crop": true,
7
+ "do_convert_rgb": true,
8
+ "do_normalize": true,
9
+ "do_rescale": true,
10
+ "do_resize": true,
11
+ "feature_extractor_type": "CLIPFeatureExtractor",
12
+ "image_mean": [
13
+ 0.48145466,
14
+ 0.4578275,
15
+ 0.40821073
16
+ ],
17
+ "image_processor_type": "CLIPImageProcessor",
18
+ "image_std": [
19
+ 0.26862954,
20
+ 0.26130258,
21
+ 0.27577711
22
+ ],
23
+ "resample": 3,
24
+ "rescale_factor": 0.00392156862745098,
25
+ "size": {
26
+ "shortest_edge": 224
27
+ }
28
+ }
logs/dreambooth/1691434852.0698752/events.out.tfevents.1691434852.ip-172-31-26-230.295956.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39c4f3880c40b945e6436e1a05347f8e6cf7907ea0eb3d9f680f028e55efcdfd
3
+ size 2713
logs/dreambooth/1691434852.0710692/hparams.yml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ adam_beta1: 0.9
2
+ adam_beta2: 0.999
3
+ adam_epsilon: 1.0e-08
4
+ adam_weight_decay: 0.01
5
+ allow_tf32: false
6
+ center_crop: false
7
+ checkpointing_steps: 500
8
+ checkpoints_total_limit: null
9
+ class_data_dir: null
10
+ class_labels_conditioning: null
11
+ class_prompt: null
12
+ dataloader_num_workers: 0
13
+ enable_xformers_memory_efficient_attention: false
14
+ gradient_accumulation_steps: 1
15
+ gradient_checkpointing: false
16
+ hub_model_id: null
17
+ hub_token: null
18
+ instance_data_dir: ./dog
19
+ instance_prompt: a photo of sks dog
20
+ learning_rate: 5.0e-06
21
+ local_rank: -1
22
+ logging_dir: logs
23
+ lr_num_cycles: 1
24
+ lr_power: 1.0
25
+ lr_scheduler: constant
26
+ lr_warmup_steps: 0
27
+ max_grad_norm: 1.0
28
+ max_train_steps: 400
29
+ mixed_precision: null
30
+ num_class_images: 100
31
+ num_train_epochs: 80
32
+ num_validation_images: 4
33
+ offset_noise: false
34
+ output_dir: /home/ubuntu/StableDiffusion/diffusers/examples/dreambooth
35
+ pre_compute_text_embeddings: false
36
+ pretrained_model_name_or_path: CompVis/stable-diffusion-v1-4
37
+ prior_generation_precision: null
38
+ prior_loss_weight: 1.0
39
+ push_to_hub: true
40
+ report_to: tensorboard
41
+ resolution: 512
42
+ resume_from_checkpoint: null
43
+ revision: null
44
+ sample_batch_size: 4
45
+ scale_lr: false
46
+ seed: null
47
+ set_grads_to_none: false
48
+ skip_save_text_encoder: false
49
+ text_encoder_use_attention_mask: false
50
+ tokenizer_max_length: null
51
+ tokenizer_name: null
52
+ train_batch_size: 1
53
+ train_text_encoder: false
54
+ use_8bit_adam: false
55
+ validation_prompt: null
56
+ validation_steps: 100
57
+ with_prior_preservation: false
logs/dreambooth/events.out.tfevents.1691434852.ip-172-31-26-230.295956.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8412a30292b7870337ddccfca4b08c25f86efbee7d031cabb9c0bd4d992edf1
3
+ size 33434
model_index.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "StableDiffusionPipeline",
3
+ "_diffusers_version": "0.20.0.dev0",
4
+ "_name_or_path": "CompVis/stable-diffusion-v1-4",
5
+ "feature_extractor": [
6
+ "transformers",
7
+ "CLIPImageProcessor"
8
+ ],
9
+ "requires_safety_checker": true,
10
+ "safety_checker": [
11
+ "stable_diffusion",
12
+ "StableDiffusionSafetyChecker"
13
+ ],
14
+ "scheduler": [
15
+ "diffusers",
16
+ "PNDMScheduler"
17
+ ],
18
+ "text_encoder": [
19
+ "transformers",
20
+ "CLIPTextModel"
21
+ ],
22
+ "tokenizer": [
23
+ "transformers",
24
+ "CLIPTokenizer"
25
+ ],
26
+ "unet": [
27
+ "diffusers",
28
+ "UNet2DConditionModel"
29
+ ],
30
+ "vae": [
31
+ "diffusers",
32
+ "AutoencoderKL"
33
+ ]
34
+ }
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ accelerate>=0.16.0
2
+ torchvision
3
+ transformers>=4.25.1
4
+ ftfy
5
+ tensorboard
6
+ Jinja2
requirements_flax.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ transformers>=4.25.1
2
+ flax
3
+ optax
4
+ torch
5
+ torchvision
6
+ ftfy
7
+ tensorboard
8
+ Jinja2
requirements_sdxl.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ accelerate>=0.16.0
2
+ torchvision
3
+ transformers>=4.25.1
4
+ ftfy
5
+ tensorboard
6
+ Jinja2
safety_checker/config.json ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": "b95be7d6f134c3a9e62ee616f310733567f069ce",
3
+ "_name_or_path": "/home/ubuntu/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/b95be7d6f134c3a9e62ee616f310733567f069ce/safety_checker",
4
+ "architectures": [
5
+ "StableDiffusionSafetyChecker"
6
+ ],
7
+ "initializer_factor": 1.0,
8
+ "logit_scale_init_value": 2.6592,
9
+ "model_type": "clip",
10
+ "projection_dim": 768,
11
+ "text_config": {
12
+ "_name_or_path": "",
13
+ "add_cross_attention": false,
14
+ "architectures": null,
15
+ "attention_dropout": 0.0,
16
+ "bad_words_ids": null,
17
+ "begin_suppress_tokens": null,
18
+ "bos_token_id": 49406,
19
+ "chunk_size_feed_forward": 0,
20
+ "cross_attention_hidden_size": null,
21
+ "decoder_start_token_id": null,
22
+ "diversity_penalty": 0.0,
23
+ "do_sample": false,
24
+ "dropout": 0.0,
25
+ "early_stopping": false,
26
+ "encoder_no_repeat_ngram_size": 0,
27
+ "eos_token_id": 49407,
28
+ "exponential_decay_length_penalty": null,
29
+ "finetuning_task": null,
30
+ "forced_bos_token_id": null,
31
+ "forced_eos_token_id": null,
32
+ "hidden_act": "quick_gelu",
33
+ "hidden_size": 768,
34
+ "id2label": {
35
+ "0": "LABEL_0",
36
+ "1": "LABEL_1"
37
+ },
38
+ "initializer_factor": 1.0,
39
+ "initializer_range": 0.02,
40
+ "intermediate_size": 3072,
41
+ "is_decoder": false,
42
+ "is_encoder_decoder": false,
43
+ "label2id": {
44
+ "LABEL_0": 0,
45
+ "LABEL_1": 1
46
+ },
47
+ "layer_norm_eps": 1e-05,
48
+ "length_penalty": 1.0,
49
+ "max_length": 20,
50
+ "max_position_embeddings": 77,
51
+ "min_length": 0,
52
+ "model_type": "clip_text_model",
53
+ "no_repeat_ngram_size": 0,
54
+ "num_attention_heads": 12,
55
+ "num_beam_groups": 1,
56
+ "num_beams": 1,
57
+ "num_hidden_layers": 12,
58
+ "num_return_sequences": 1,
59
+ "output_attentions": false,
60
+ "output_hidden_states": false,
61
+ "output_scores": false,
62
+ "pad_token_id": 1,
63
+ "prefix": null,
64
+ "problem_type": null,
65
+ "projection_dim": 512,
66
+ "pruned_heads": {},
67
+ "remove_invalid_values": false,
68
+ "repetition_penalty": 1.0,
69
+ "return_dict": true,
70
+ "return_dict_in_generate": false,
71
+ "sep_token_id": null,
72
+ "suppress_tokens": null,
73
+ "task_specific_params": null,
74
+ "temperature": 1.0,
75
+ "tf_legacy_loss": false,
76
+ "tie_encoder_decoder": false,
77
+ "tie_word_embeddings": true,
78
+ "tokenizer_class": null,
79
+ "top_k": 50,
80
+ "top_p": 1.0,
81
+ "torch_dtype": null,
82
+ "torchscript": false,
83
+ "transformers_version": "4.31.0",
84
+ "typical_p": 1.0,
85
+ "use_bfloat16": false,
86
+ "vocab_size": 49408
87
+ },
88
+ "torch_dtype": "float32",
89
+ "transformers_version": null,
90
+ "vision_config": {
91
+ "_name_or_path": "",
92
+ "add_cross_attention": false,
93
+ "architectures": null,
94
+ "attention_dropout": 0.0,
95
+ "bad_words_ids": null,
96
+ "begin_suppress_tokens": null,
97
+ "bos_token_id": null,
98
+ "chunk_size_feed_forward": 0,
99
+ "cross_attention_hidden_size": null,
100
+ "decoder_start_token_id": null,
101
+ "diversity_penalty": 0.0,
102
+ "do_sample": false,
103
+ "dropout": 0.0,
104
+ "early_stopping": false,
105
+ "encoder_no_repeat_ngram_size": 0,
106
+ "eos_token_id": null,
107
+ "exponential_decay_length_penalty": null,
108
+ "finetuning_task": null,
109
+ "forced_bos_token_id": null,
110
+ "forced_eos_token_id": null,
111
+ "hidden_act": "quick_gelu",
112
+ "hidden_size": 1024,
113
+ "id2label": {
114
+ "0": "LABEL_0",
115
+ "1": "LABEL_1"
116
+ },
117
+ "image_size": 224,
118
+ "initializer_factor": 1.0,
119
+ "initializer_range": 0.02,
120
+ "intermediate_size": 4096,
121
+ "is_decoder": false,
122
+ "is_encoder_decoder": false,
123
+ "label2id": {
124
+ "LABEL_0": 0,
125
+ "LABEL_1": 1
126
+ },
127
+ "layer_norm_eps": 1e-05,
128
+ "length_penalty": 1.0,
129
+ "max_length": 20,
130
+ "min_length": 0,
131
+ "model_type": "clip_vision_model",
132
+ "no_repeat_ngram_size": 0,
133
+ "num_attention_heads": 16,
134
+ "num_beam_groups": 1,
135
+ "num_beams": 1,
136
+ "num_channels": 3,
137
+ "num_hidden_layers": 24,
138
+ "num_return_sequences": 1,
139
+ "output_attentions": false,
140
+ "output_hidden_states": false,
141
+ "output_scores": false,
142
+ "pad_token_id": null,
143
+ "patch_size": 14,
144
+ "prefix": null,
145
+ "problem_type": null,
146
+ "projection_dim": 512,
147
+ "pruned_heads": {},
148
+ "remove_invalid_values": false,
149
+ "repetition_penalty": 1.0,
150
+ "return_dict": true,
151
+ "return_dict_in_generate": false,
152
+ "sep_token_id": null,
153
+ "suppress_tokens": null,
154
+ "task_specific_params": null,
155
+ "temperature": 1.0,
156
+ "tf_legacy_loss": false,
157
+ "tie_encoder_decoder": false,
158
+ "tie_word_embeddings": true,
159
+ "tokenizer_class": null,
160
+ "top_k": 50,
161
+ "top_p": 1.0,
162
+ "torch_dtype": null,
163
+ "torchscript": false,
164
+ "transformers_version": "4.31.0",
165
+ "typical_p": 1.0,
166
+ "use_bfloat16": false
167
+ }
168
+ }
safety_checker/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:753acd54aa6d288d6c0ce9d51468eb28f495fcbaacf0edf755fa5fc7ce678cd9
3
+ size 1216062333
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "PNDMScheduler",
3
+ "_diffusers_version": "0.20.0.dev0",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.00085,
7
+ "clip_sample": false,
8
+ "num_train_timesteps": 1000,
9
+ "prediction_type": "epsilon",
10
+ "set_alpha_to_one": false,
11
+ "skip_prk_steps": true,
12
+ "steps_offset": 1,
13
+ "timestep_spacing": "leading",
14
+ "trained_betas": null
15
+ }
text_encoder/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "CompVis/stable-diffusion-v1-4",
3
+ "architectures": [
4
+ "CLIPTextModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dropout": 0.0,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "quick_gelu",
11
+ "hidden_size": 768,
12
+ "initializer_factor": 1.0,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 77,
17
+ "model_type": "clip_text_model",
18
+ "num_attention_heads": 12,
19
+ "num_hidden_layers": 12,
20
+ "pad_token_id": 1,
21
+ "projection_dim": 512,
22
+ "torch_dtype": "float16",
23
+ "transformers_version": "4.31.0",
24
+ "vocab_size": 49408
25
+ }
text_encoder/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b020851da42091416889fa03bf3e527e9bc8a7f0b1164147ce06536a5c22494c
3
+ size 246187869
tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<|endoftext|>",
17
+ "unk_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": {
4
+ "__type": "AddedToken",
5
+ "content": "<|startoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false
10
+ },
11
+ "clean_up_tokenization_spaces": true,
12
+ "do_lower_case": true,
13
+ "eos_token": {
14
+ "__type": "AddedToken",
15
+ "content": "<|endoftext|>",
16
+ "lstrip": false,
17
+ "normalized": true,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ },
21
+ "errors": "replace",
22
+ "model_max_length": 77,
23
+ "pad_token": "<|endoftext|>",
24
+ "tokenizer_class": "CLIPTokenizer",
25
+ "unk_token": {
26
+ "__type": "AddedToken",
27
+ "content": "<|endoftext|>",
28
+ "lstrip": false,
29
+ "normalized": true,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ }
33
+ }
tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
train_dreambooth.py ADDED
@@ -0,0 +1,1378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import copy
18
+ import gc
19
+ import hashlib
20
+ import itertools
21
+ import logging
22
+ import math
23
+ import os
24
+ import shutil
25
+ import warnings
26
+ from pathlib import Path
27
+
28
+ import numpy as np
29
+ import torch
30
+ import torch.nn.functional as F
31
+ import torch.utils.checkpoint
32
+ import transformers
33
+ from accelerate import Accelerator
34
+ from accelerate.logging import get_logger
35
+ from accelerate.utils import ProjectConfiguration, set_seed
36
+ from huggingface_hub import create_repo, model_info, upload_folder
37
+ from packaging import version
38
+ from PIL import Image
39
+ from PIL.ImageOps import exif_transpose
40
+ from torch.utils.data import Dataset
41
+ from torchvision import transforms
42
+ from tqdm.auto import tqdm
43
+ from transformers import AutoTokenizer, PretrainedConfig
44
+
45
+ import diffusers
46
+ from diffusers import (
47
+ AutoencoderKL,
48
+ DDPMScheduler,
49
+ DiffusionPipeline,
50
+ DPMSolverMultistepScheduler,
51
+ StableDiffusionPipeline,
52
+ UNet2DConditionModel,
53
+ )
54
+ from diffusers.optimization import get_scheduler
55
+ from diffusers.utils import check_min_version, is_wandb_available
56
+ from diffusers.utils.import_utils import is_xformers_available
57
+
58
+
59
+ if is_wandb_available():
60
+ import wandb
61
+
62
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
63
+ check_min_version("0.20.0.dev0")
64
+
65
+ logger = get_logger(__name__)
66
+
67
+
68
+ def save_model_card(
69
+ repo_id: str,
70
+ images=None,
71
+ base_model=str,
72
+ train_text_encoder=False,
73
+ prompt=str,
74
+ repo_folder=None,
75
+ pipeline: DiffusionPipeline = None,
76
+ ):
77
+ img_str = ""
78
+ for i, image in enumerate(images):
79
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
80
+ img_str += f"![img_{i}](./image_{i}.png)\n"
81
+
82
+ yaml = f"""
83
+ ---
84
+ license: creativeml-openrail-m
85
+ base_model: {base_model}
86
+ instance_prompt: {prompt}
87
+ tags:
88
+ - {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'}
89
+ - {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'}
90
+ - text-to-image
91
+ - diffusers
92
+ - dreambooth
93
+ inference: true
94
+ ---
95
+ """
96
+ model_card = f"""
97
+ # DreamBooth - {repo_id}
98
+
99
+ This is a dreambooth model derived from {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/).
100
+ You can find some example images in the following. \n
101
+ {img_str}
102
+
103
+ DreamBooth for the text encoder was enabled: {train_text_encoder}.
104
+ """
105
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
106
+ f.write(yaml + model_card)
107
+
108
+
109
+ def log_validation(
110
+ text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch, prompt_embeds, negative_prompt_embeds
111
+ ):
112
+ logger.info(
113
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
114
+ f" {args.validation_prompt}."
115
+ )
116
+
117
+ pipeline_args = {}
118
+
119
+ if vae is not None:
120
+ pipeline_args["vae"] = vae
121
+
122
+ if text_encoder is not None:
123
+ text_encoder = accelerator.unwrap_model(text_encoder)
124
+
125
+ # create pipeline (note: unet and vae are loaded again in float32)
126
+ pipeline = DiffusionPipeline.from_pretrained(
127
+ args.pretrained_model_name_or_path,
128
+ tokenizer=tokenizer,
129
+ text_encoder=text_encoder,
130
+ unet=accelerator.unwrap_model(unet),
131
+ revision=args.revision,
132
+ torch_dtype=weight_dtype,
133
+ **pipeline_args,
134
+ )
135
+
136
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
137
+ scheduler_args = {}
138
+
139
+ if "variance_type" in pipeline.scheduler.config:
140
+ variance_type = pipeline.scheduler.config.variance_type
141
+
142
+ if variance_type in ["learned", "learned_range"]:
143
+ variance_type = "fixed_small"
144
+
145
+ scheduler_args["variance_type"] = variance_type
146
+
147
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
148
+ pipeline = pipeline.to(accelerator.device)
149
+ pipeline.set_progress_bar_config(disable=True)
150
+
151
+ if args.pre_compute_text_embeddings:
152
+ pipeline_args = {
153
+ "prompt_embeds": prompt_embeds,
154
+ "negative_prompt_embeds": negative_prompt_embeds,
155
+ }
156
+ else:
157
+ pipeline_args = {"prompt": args.validation_prompt}
158
+
159
+ # run inference
160
+ generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
161
+ images = []
162
+ if args.validation_images is None:
163
+ for _ in range(args.num_validation_images):
164
+ with torch.autocast("cuda"):
165
+ image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0]
166
+ images.append(image)
167
+ else:
168
+ for image in args.validation_images:
169
+ image = Image.open(image)
170
+ image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
171
+ images.append(image)
172
+
173
+ for tracker in accelerator.trackers:
174
+ if tracker.name == "tensorboard":
175
+ np_images = np.stack([np.asarray(img) for img in images])
176
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
177
+ if tracker.name == "wandb":
178
+ tracker.log(
179
+ {
180
+ "validation": [
181
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
182
+ ]
183
+ }
184
+ )
185
+
186
+ del pipeline
187
+ torch.cuda.empty_cache()
188
+
189
+ return images
190
+
191
+
192
+ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
193
+ text_encoder_config = PretrainedConfig.from_pretrained(
194
+ pretrained_model_name_or_path,
195
+ subfolder="text_encoder",
196
+ revision=revision,
197
+ )
198
+ model_class = text_encoder_config.architectures[0]
199
+
200
+ if model_class == "CLIPTextModel":
201
+ from transformers import CLIPTextModel
202
+
203
+ return CLIPTextModel
204
+ elif model_class == "RobertaSeriesModelWithTransformation":
205
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
206
+
207
+ return RobertaSeriesModelWithTransformation
208
+ elif model_class == "T5EncoderModel":
209
+ from transformers import T5EncoderModel
210
+
211
+ return T5EncoderModel
212
+ else:
213
+ raise ValueError(f"{model_class} is not supported.")
214
+
215
+
216
+ def parse_args(input_args=None):
217
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
218
+ parser.add_argument(
219
+ "--pretrained_model_name_or_path",
220
+ type=str,
221
+ default=None,
222
+ required=True,
223
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
224
+ )
225
+ parser.add_argument(
226
+ "--revision",
227
+ type=str,
228
+ default=None,
229
+ required=False,
230
+ help=(
231
+ "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
232
+ " float32 precision."
233
+ ),
234
+ )
235
+ parser.add_argument(
236
+ "--tokenizer_name",
237
+ type=str,
238
+ default=None,
239
+ help="Pretrained tokenizer name or path if not the same as model_name",
240
+ )
241
+ parser.add_argument(
242
+ "--instance_data_dir",
243
+ type=str,
244
+ default=None,
245
+ required=True,
246
+ help="A folder containing the training data of instance images.",
247
+ )
248
+ parser.add_argument(
249
+ "--class_data_dir",
250
+ type=str,
251
+ default=None,
252
+ required=False,
253
+ help="A folder containing the training data of class images.",
254
+ )
255
+ parser.add_argument(
256
+ "--instance_prompt",
257
+ type=str,
258
+ default=None,
259
+ required=True,
260
+ help="The prompt with identifier specifying the instance",
261
+ )
262
+ parser.add_argument(
263
+ "--class_prompt",
264
+ type=str,
265
+ default=None,
266
+ help="The prompt to specify images in the same class as provided instance images.",
267
+ )
268
+ parser.add_argument(
269
+ "--with_prior_preservation",
270
+ default=False,
271
+ action="store_true",
272
+ help="Flag to add prior preservation loss.",
273
+ )
274
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
275
+ parser.add_argument(
276
+ "--num_class_images",
277
+ type=int,
278
+ default=100,
279
+ help=(
280
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
281
+ " class_data_dir, additional images will be sampled with class_prompt."
282
+ ),
283
+ )
284
+ parser.add_argument(
285
+ "--output_dir",
286
+ type=str,
287
+ default="text-inversion-model",
288
+ help="The output directory where the model predictions and checkpoints will be written.",
289
+ )
290
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
291
+ parser.add_argument(
292
+ "--resolution",
293
+ type=int,
294
+ default=512,
295
+ help=(
296
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
297
+ " resolution"
298
+ ),
299
+ )
300
+ parser.add_argument(
301
+ "--center_crop",
302
+ default=False,
303
+ action="store_true",
304
+ help=(
305
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
306
+ " cropped. The images will be resized to the resolution first before cropping."
307
+ ),
308
+ )
309
+ parser.add_argument(
310
+ "--train_text_encoder",
311
+ action="store_true",
312
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
313
+ )
314
+ parser.add_argument(
315
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
316
+ )
317
+ parser.add_argument(
318
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
319
+ )
320
+ parser.add_argument("--num_train_epochs", type=int, default=1)
321
+ parser.add_argument(
322
+ "--max_train_steps",
323
+ type=int,
324
+ default=None,
325
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
326
+ )
327
+ parser.add_argument(
328
+ "--checkpointing_steps",
329
+ type=int,
330
+ default=500,
331
+ help=(
332
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
333
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
334
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
335
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
336
+ "instructions."
337
+ ),
338
+ )
339
+ parser.add_argument(
340
+ "--checkpoints_total_limit",
341
+ type=int,
342
+ default=None,
343
+ help=(
344
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
345
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
346
+ " for more details"
347
+ ),
348
+ )
349
+ parser.add_argument(
350
+ "--resume_from_checkpoint",
351
+ type=str,
352
+ default=None,
353
+ help=(
354
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
355
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
356
+ ),
357
+ )
358
+ parser.add_argument(
359
+ "--gradient_accumulation_steps",
360
+ type=int,
361
+ default=1,
362
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
363
+ )
364
+ parser.add_argument(
365
+ "--gradient_checkpointing",
366
+ action="store_true",
367
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
368
+ )
369
+ parser.add_argument(
370
+ "--learning_rate",
371
+ type=float,
372
+ default=5e-6,
373
+ help="Initial learning rate (after the potential warmup period) to use.",
374
+ )
375
+ parser.add_argument(
376
+ "--scale_lr",
377
+ action="store_true",
378
+ default=False,
379
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
380
+ )
381
+ parser.add_argument(
382
+ "--lr_scheduler",
383
+ type=str,
384
+ default="constant",
385
+ help=(
386
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
387
+ ' "constant", "constant_with_warmup"]'
388
+ ),
389
+ )
390
+ parser.add_argument(
391
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
392
+ )
393
+ parser.add_argument(
394
+ "--lr_num_cycles",
395
+ type=int,
396
+ default=1,
397
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
398
+ )
399
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
400
+ parser.add_argument(
401
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
402
+ )
403
+ parser.add_argument(
404
+ "--dataloader_num_workers",
405
+ type=int,
406
+ default=0,
407
+ help=(
408
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
409
+ ),
410
+ )
411
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
412
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
413
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
414
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
415
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
416
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
417
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
418
+ parser.add_argument(
419
+ "--hub_model_id",
420
+ type=str,
421
+ default=None,
422
+ help="The name of the repository to keep in sync with the local `output_dir`.",
423
+ )
424
+ parser.add_argument(
425
+ "--logging_dir",
426
+ type=str,
427
+ default="logs",
428
+ help=(
429
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
430
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
431
+ ),
432
+ )
433
+ parser.add_argument(
434
+ "--allow_tf32",
435
+ action="store_true",
436
+ help=(
437
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
438
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
439
+ ),
440
+ )
441
+ parser.add_argument(
442
+ "--report_to",
443
+ type=str,
444
+ default="tensorboard",
445
+ help=(
446
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
447
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
448
+ ),
449
+ )
450
+ parser.add_argument(
451
+ "--validation_prompt",
452
+ type=str,
453
+ default=None,
454
+ help="A prompt that is used during validation to verify that the model is learning.",
455
+ )
456
+ parser.add_argument(
457
+ "--num_validation_images",
458
+ type=int,
459
+ default=4,
460
+ help="Number of images that should be generated during validation with `validation_prompt`.",
461
+ )
462
+ parser.add_argument(
463
+ "--validation_steps",
464
+ type=int,
465
+ default=100,
466
+ help=(
467
+ "Run validation every X steps. Validation consists of running the prompt"
468
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
469
+ " and logging the images."
470
+ ),
471
+ )
472
+ parser.add_argument(
473
+ "--mixed_precision",
474
+ type=str,
475
+ default=None,
476
+ choices=["no", "fp16", "bf16"],
477
+ help=(
478
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
479
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
480
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
481
+ ),
482
+ )
483
+ parser.add_argument(
484
+ "--prior_generation_precision",
485
+ type=str,
486
+ default=None,
487
+ choices=["no", "fp32", "fp16", "bf16"],
488
+ help=(
489
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
490
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
491
+ ),
492
+ )
493
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
494
+ parser.add_argument(
495
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
496
+ )
497
+ parser.add_argument(
498
+ "--set_grads_to_none",
499
+ action="store_true",
500
+ help=(
501
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
502
+ " behaviors, so disable this argument if it causes any problems. More info:"
503
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
504
+ ),
505
+ )
506
+
507
+ parser.add_argument(
508
+ "--offset_noise",
509
+ action="store_true",
510
+ default=False,
511
+ help=(
512
+ "Fine-tuning against a modified noise"
513
+ " See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information."
514
+ ),
515
+ )
516
+ parser.add_argument(
517
+ "--pre_compute_text_embeddings",
518
+ action="store_true",
519
+ help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.",
520
+ )
521
+ parser.add_argument(
522
+ "--tokenizer_max_length",
523
+ type=int,
524
+ default=None,
525
+ required=False,
526
+ help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.",
527
+ )
528
+ parser.add_argument(
529
+ "--text_encoder_use_attention_mask",
530
+ action="store_true",
531
+ required=False,
532
+ help="Whether to use attention mask for the text encoder",
533
+ )
534
+ parser.add_argument(
535
+ "--skip_save_text_encoder", action="store_true", required=False, help="Set to not save text encoder"
536
+ )
537
+ parser.add_argument(
538
+ "--validation_images",
539
+ required=False,
540
+ default=None,
541
+ nargs="+",
542
+ help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.",
543
+ )
544
+ parser.add_argument(
545
+ "--class_labels_conditioning",
546
+ required=False,
547
+ default=None,
548
+ help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.",
549
+ )
550
+
551
+ if input_args is not None:
552
+ args = parser.parse_args(input_args)
553
+ else:
554
+ args = parser.parse_args()
555
+
556
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
557
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
558
+ args.local_rank = env_local_rank
559
+
560
+ if args.with_prior_preservation:
561
+ if args.class_data_dir is None:
562
+ raise ValueError("You must specify a data directory for class images.")
563
+ if args.class_prompt is None:
564
+ raise ValueError("You must specify prompt for class images.")
565
+ else:
566
+ # logger is not available yet
567
+ if args.class_data_dir is not None:
568
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
569
+ if args.class_prompt is not None:
570
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
571
+
572
+ if args.train_text_encoder and args.pre_compute_text_embeddings:
573
+ raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`")
574
+
575
+ return args
576
+
577
+
578
+ class DreamBoothDataset(Dataset):
579
+ """
580
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
581
+ It pre-processes the images and the tokenizes prompts.
582
+ """
583
+
584
+ def __init__(
585
+ self,
586
+ instance_data_root,
587
+ instance_prompt,
588
+ tokenizer,
589
+ class_data_root=None,
590
+ class_prompt=None,
591
+ class_num=None,
592
+ size=512,
593
+ center_crop=False,
594
+ encoder_hidden_states=None,
595
+ class_prompt_encoder_hidden_states=None,
596
+ tokenizer_max_length=None,
597
+ ):
598
+ self.size = size
599
+ self.center_crop = center_crop
600
+ self.tokenizer = tokenizer
601
+ self.encoder_hidden_states = encoder_hidden_states
602
+ self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states
603
+ self.tokenizer_max_length = tokenizer_max_length
604
+
605
+ self.instance_data_root = Path(instance_data_root)
606
+ if not self.instance_data_root.exists():
607
+ raise ValueError(f"Instance {self.instance_data_root} images root doesn't exists.")
608
+
609
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
610
+ self.num_instance_images = len(self.instance_images_path)
611
+ self.instance_prompt = instance_prompt
612
+ self._length = self.num_instance_images
613
+
614
+ if class_data_root is not None:
615
+ self.class_data_root = Path(class_data_root)
616
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
617
+ self.class_images_path = list(self.class_data_root.iterdir())
618
+ if class_num is not None:
619
+ self.num_class_images = min(len(self.class_images_path), class_num)
620
+ else:
621
+ self.num_class_images = len(self.class_images_path)
622
+ self._length = max(self.num_class_images, self.num_instance_images)
623
+ self.class_prompt = class_prompt
624
+ else:
625
+ self.class_data_root = None
626
+
627
+ self.image_transforms = transforms.Compose(
628
+ [
629
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
630
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
631
+ transforms.ToTensor(),
632
+ transforms.Normalize([0.5], [0.5]),
633
+ ]
634
+ )
635
+
636
+ def __len__(self):
637
+ return self._length
638
+
639
+ def __getitem__(self, index):
640
+ example = {}
641
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
642
+ instance_image = exif_transpose(instance_image)
643
+
644
+ if not instance_image.mode == "RGB":
645
+ instance_image = instance_image.convert("RGB")
646
+ example["instance_images"] = self.image_transforms(instance_image)
647
+
648
+ if self.encoder_hidden_states is not None:
649
+ example["instance_prompt_ids"] = self.encoder_hidden_states
650
+ else:
651
+ text_inputs = tokenize_prompt(
652
+ self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length
653
+ )
654
+ example["instance_prompt_ids"] = text_inputs.input_ids
655
+ example["instance_attention_mask"] = text_inputs.attention_mask
656
+
657
+ if self.class_data_root:
658
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
659
+ class_image = exif_transpose(class_image)
660
+
661
+ if not class_image.mode == "RGB":
662
+ class_image = class_image.convert("RGB")
663
+ example["class_images"] = self.image_transforms(class_image)
664
+
665
+ if self.class_prompt_encoder_hidden_states is not None:
666
+ example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states
667
+ else:
668
+ class_text_inputs = tokenize_prompt(
669
+ self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length
670
+ )
671
+ example["class_prompt_ids"] = class_text_inputs.input_ids
672
+ example["class_attention_mask"] = class_text_inputs.attention_mask
673
+
674
+ return example
675
+
676
+
677
+ def collate_fn(examples, with_prior_preservation=False):
678
+ has_attention_mask = "instance_attention_mask" in examples[0]
679
+
680
+ input_ids = [example["instance_prompt_ids"] for example in examples]
681
+ pixel_values = [example["instance_images"] for example in examples]
682
+
683
+ if has_attention_mask:
684
+ attention_mask = [example["instance_attention_mask"] for example in examples]
685
+
686
+ # Concat class and instance examples for prior preservation.
687
+ # We do this to avoid doing two forward passes.
688
+ if with_prior_preservation:
689
+ input_ids += [example["class_prompt_ids"] for example in examples]
690
+ pixel_values += [example["class_images"] for example in examples]
691
+
692
+ if has_attention_mask:
693
+ attention_mask += [example["class_attention_mask"] for example in examples]
694
+
695
+ pixel_values = torch.stack(pixel_values)
696
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
697
+
698
+ input_ids = torch.cat(input_ids, dim=0)
699
+
700
+ batch = {
701
+ "input_ids": input_ids,
702
+ "pixel_values": pixel_values,
703
+ }
704
+
705
+ if has_attention_mask:
706
+ attention_mask = torch.cat(attention_mask, dim=0)
707
+ batch["attention_mask"] = attention_mask
708
+
709
+ return batch
710
+
711
+
712
+ class PromptDataset(Dataset):
713
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
714
+
715
+ def __init__(self, prompt, num_samples):
716
+ self.prompt = prompt
717
+ self.num_samples = num_samples
718
+
719
+ def __len__(self):
720
+ return self.num_samples
721
+
722
+ def __getitem__(self, index):
723
+ example = {}
724
+ example["prompt"] = self.prompt
725
+ example["index"] = index
726
+ return example
727
+
728
+
729
+ def model_has_vae(args):
730
+ config_file_name = os.path.join("vae", AutoencoderKL.config_name)
731
+ if os.path.isdir(args.pretrained_model_name_or_path):
732
+ config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name)
733
+ return os.path.isfile(config_file_name)
734
+ else:
735
+ files_in_repo = model_info(args.pretrained_model_name_or_path, revision=args.revision).siblings
736
+ return any(file.rfilename == config_file_name for file in files_in_repo)
737
+
738
+
739
+ def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
740
+ if tokenizer_max_length is not None:
741
+ max_length = tokenizer_max_length
742
+ else:
743
+ max_length = tokenizer.model_max_length
744
+
745
+ text_inputs = tokenizer(
746
+ prompt,
747
+ truncation=True,
748
+ padding="max_length",
749
+ max_length=max_length,
750
+ return_tensors="pt",
751
+ )
752
+
753
+ return text_inputs
754
+
755
+
756
+ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None):
757
+ text_input_ids = input_ids.to(text_encoder.device)
758
+
759
+ if text_encoder_use_attention_mask:
760
+ attention_mask = attention_mask.to(text_encoder.device)
761
+ else:
762
+ attention_mask = None
763
+
764
+ prompt_embeds = text_encoder(
765
+ text_input_ids,
766
+ attention_mask=attention_mask,
767
+ )
768
+ prompt_embeds = prompt_embeds[0]
769
+
770
+ return prompt_embeds
771
+
772
+
773
+ def main(args):
774
+ logging_dir = Path(args.output_dir, args.logging_dir)
775
+
776
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
777
+
778
+ accelerator = Accelerator(
779
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
780
+ mixed_precision=args.mixed_precision,
781
+ log_with=args.report_to,
782
+ project_config=accelerator_project_config,
783
+ )
784
+
785
+ if args.report_to == "wandb":
786
+ if not is_wandb_available():
787
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
788
+
789
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
790
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
791
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
792
+ if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
793
+ raise ValueError(
794
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
795
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
796
+ )
797
+
798
+ # Make one log on every process with the configuration for debugging.
799
+ logging.basicConfig(
800
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
801
+ datefmt="%m/%d/%Y %H:%M:%S",
802
+ level=logging.INFO,
803
+ )
804
+ logger.info(accelerator.state, main_process_only=False)
805
+ if accelerator.is_local_main_process:
806
+ transformers.utils.logging.set_verbosity_warning()
807
+ diffusers.utils.logging.set_verbosity_info()
808
+ else:
809
+ transformers.utils.logging.set_verbosity_error()
810
+ diffusers.utils.logging.set_verbosity_error()
811
+
812
+ # If passed along, set the training seed now.
813
+ if args.seed is not None:
814
+ set_seed(args.seed)
815
+
816
+ # Generate class images if prior preservation is enabled.
817
+ if args.with_prior_preservation:
818
+ class_images_dir = Path(args.class_data_dir)
819
+ if not class_images_dir.exists():
820
+ class_images_dir.mkdir(parents=True)
821
+ cur_class_images = len(list(class_images_dir.iterdir()))
822
+
823
+ if cur_class_images < args.num_class_images:
824
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
825
+ if args.prior_generation_precision == "fp32":
826
+ torch_dtype = torch.float32
827
+ elif args.prior_generation_precision == "fp16":
828
+ torch_dtype = torch.float16
829
+ elif args.prior_generation_precision == "bf16":
830
+ torch_dtype = torch.bfloat16
831
+ pipeline = DiffusionPipeline.from_pretrained(
832
+ args.pretrained_model_name_or_path,
833
+ torch_dtype=torch_dtype,
834
+ safety_checker=None,
835
+ revision=args.revision,
836
+ )
837
+ pipeline.set_progress_bar_config(disable=True)
838
+
839
+ num_new_images = args.num_class_images - cur_class_images
840
+ logger.info(f"Number of class images to sample: {num_new_images}.")
841
+
842
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
843
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
844
+
845
+ sample_dataloader = accelerator.prepare(sample_dataloader)
846
+ pipeline.to(accelerator.device)
847
+
848
+ for example in tqdm(
849
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
850
+ ):
851
+ images = pipeline(example["prompt"]).images
852
+
853
+ for i, image in enumerate(images):
854
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
855
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
856
+ image.save(image_filename)
857
+
858
+ del pipeline
859
+ if torch.cuda.is_available():
860
+ torch.cuda.empty_cache()
861
+
862
+ # Handle the repository creation
863
+ if accelerator.is_main_process:
864
+ if args.output_dir is not None:
865
+ os.makedirs(args.output_dir, exist_ok=True)
866
+
867
+ if args.push_to_hub:
868
+ repo_id = create_repo(
869
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
870
+ ).repo_id
871
+
872
+ # Load the tokenizer
873
+ if args.tokenizer_name:
874
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
875
+ elif args.pretrained_model_name_or_path:
876
+ tokenizer = AutoTokenizer.from_pretrained(
877
+ args.pretrained_model_name_or_path,
878
+ subfolder="tokenizer",
879
+ revision=args.revision,
880
+ use_fast=False,
881
+ )
882
+
883
+ # import correct text encoder class
884
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
885
+
886
+ # Load scheduler and models
887
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
888
+ text_encoder = text_encoder_cls.from_pretrained(
889
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
890
+ )
891
+
892
+ if model_has_vae(args):
893
+ vae = AutoencoderKL.from_pretrained(
894
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
895
+ )
896
+ else:
897
+ vae = None
898
+
899
+ unet = UNet2DConditionModel.from_pretrained(
900
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
901
+ )
902
+
903
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
904
+ def save_model_hook(models, weights, output_dir):
905
+ for model in models:
906
+ sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder"
907
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
908
+
909
+ # make sure to pop weight so that corresponding model is not saved again
910
+ weights.pop()
911
+
912
+ def load_model_hook(models, input_dir):
913
+ while len(models) > 0:
914
+ # pop models so that they are not loaded again
915
+ model = models.pop()
916
+
917
+ if isinstance(model, type(accelerator.unwrap_model(text_encoder))):
918
+ # load transformers style into model
919
+ load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder")
920
+ model.config = load_model.config
921
+ else:
922
+ # load diffusers style into model
923
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
924
+ model.register_to_config(**load_model.config)
925
+
926
+ model.load_state_dict(load_model.state_dict())
927
+ del load_model
928
+
929
+ accelerator.register_save_state_pre_hook(save_model_hook)
930
+ accelerator.register_load_state_pre_hook(load_model_hook)
931
+
932
+ if vae is not None:
933
+ vae.requires_grad_(False)
934
+
935
+ if not args.train_text_encoder:
936
+ text_encoder.requires_grad_(False)
937
+
938
+ if args.enable_xformers_memory_efficient_attention:
939
+ if is_xformers_available():
940
+ import xformers
941
+
942
+ xformers_version = version.parse(xformers.__version__)
943
+ if xformers_version == version.parse("0.0.16"):
944
+ logger.warn(
945
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
946
+ )
947
+ unet.enable_xformers_memory_efficient_attention()
948
+ else:
949
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
950
+
951
+ if args.gradient_checkpointing:
952
+ unet.enable_gradient_checkpointing()
953
+ if args.train_text_encoder:
954
+ text_encoder.gradient_checkpointing_enable()
955
+
956
+ # Check that all trainable models are in full precision
957
+ low_precision_error_string = (
958
+ "Please make sure to always have all model weights in full float32 precision when starting training - even if"
959
+ " doing mixed precision training. copy of the weights should still be float32."
960
+ )
961
+
962
+ if accelerator.unwrap_model(unet).dtype != torch.float32:
963
+ raise ValueError(
964
+ f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
965
+ )
966
+
967
+ if args.train_text_encoder and accelerator.unwrap_model(text_encoder).dtype != torch.float32:
968
+ raise ValueError(
969
+ f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}."
970
+ f" {low_precision_error_string}"
971
+ )
972
+
973
+ # Enable TF32 for faster training on Ampere GPUs,
974
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
975
+ if args.allow_tf32:
976
+ torch.backends.cuda.matmul.allow_tf32 = True
977
+
978
+ if args.scale_lr:
979
+ args.learning_rate = (
980
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
981
+ )
982
+
983
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
984
+ if args.use_8bit_adam:
985
+ try:
986
+ import bitsandbytes as bnb
987
+ except ImportError:
988
+ raise ImportError(
989
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
990
+ )
991
+
992
+ optimizer_class = bnb.optim.AdamW8bit
993
+ else:
994
+ optimizer_class = torch.optim.AdamW
995
+
996
+ # Optimizer creation
997
+ params_to_optimize = (
998
+ itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
999
+ )
1000
+ optimizer = optimizer_class(
1001
+ params_to_optimize,
1002
+ lr=args.learning_rate,
1003
+ betas=(args.adam_beta1, args.adam_beta2),
1004
+ weight_decay=args.adam_weight_decay,
1005
+ eps=args.adam_epsilon,
1006
+ )
1007
+
1008
+ if args.pre_compute_text_embeddings:
1009
+
1010
+ def compute_text_embeddings(prompt):
1011
+ with torch.no_grad():
1012
+ text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length)
1013
+ prompt_embeds = encode_prompt(
1014
+ text_encoder,
1015
+ text_inputs.input_ids,
1016
+ text_inputs.attention_mask,
1017
+ text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
1018
+ )
1019
+
1020
+ return prompt_embeds
1021
+
1022
+ pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
1023
+ validation_prompt_negative_prompt_embeds = compute_text_embeddings("")
1024
+
1025
+ if args.validation_prompt is not None:
1026
+ validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt)
1027
+ else:
1028
+ validation_prompt_encoder_hidden_states = None
1029
+
1030
+ if args.class_prompt is not None:
1031
+ pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt)
1032
+ else:
1033
+ pre_computed_class_prompt_encoder_hidden_states = None
1034
+
1035
+ text_encoder = None
1036
+ tokenizer = None
1037
+
1038
+ gc.collect()
1039
+ torch.cuda.empty_cache()
1040
+ else:
1041
+ pre_computed_encoder_hidden_states = None
1042
+ validation_prompt_encoder_hidden_states = None
1043
+ validation_prompt_negative_prompt_embeds = None
1044
+ pre_computed_class_prompt_encoder_hidden_states = None
1045
+
1046
+ # Dataset and DataLoaders creation:
1047
+ train_dataset = DreamBoothDataset(
1048
+ instance_data_root=args.instance_data_dir,
1049
+ instance_prompt=args.instance_prompt,
1050
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
1051
+ class_prompt=args.class_prompt,
1052
+ class_num=args.num_class_images,
1053
+ tokenizer=tokenizer,
1054
+ size=args.resolution,
1055
+ center_crop=args.center_crop,
1056
+ encoder_hidden_states=pre_computed_encoder_hidden_states,
1057
+ class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,
1058
+ tokenizer_max_length=args.tokenizer_max_length,
1059
+ )
1060
+
1061
+ train_dataloader = torch.utils.data.DataLoader(
1062
+ train_dataset,
1063
+ batch_size=args.train_batch_size,
1064
+ shuffle=True,
1065
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
1066
+ num_workers=args.dataloader_num_workers,
1067
+ )
1068
+
1069
+ # Scheduler and math around the number of training steps.
1070
+ overrode_max_train_steps = False
1071
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1072
+ if args.max_train_steps is None:
1073
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1074
+ overrode_max_train_steps = True
1075
+
1076
+ lr_scheduler = get_scheduler(
1077
+ args.lr_scheduler,
1078
+ optimizer=optimizer,
1079
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1080
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
1081
+ num_cycles=args.lr_num_cycles,
1082
+ power=args.lr_power,
1083
+ )
1084
+
1085
+ # Prepare everything with our `accelerator`.
1086
+ if args.train_text_encoder:
1087
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1088
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
1089
+ )
1090
+ else:
1091
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1092
+ unet, optimizer, train_dataloader, lr_scheduler
1093
+ )
1094
+
1095
+ # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
1096
+ # as these weights are only used for inference, keeping weights in full precision is not required.
1097
+ weight_dtype = torch.float32
1098
+ if accelerator.mixed_precision == "fp16":
1099
+ weight_dtype = torch.float16
1100
+ elif accelerator.mixed_precision == "bf16":
1101
+ weight_dtype = torch.bfloat16
1102
+
1103
+ # Move vae and text_encoder to device and cast to weight_dtype
1104
+ if vae is not None:
1105
+ vae.to(accelerator.device, dtype=weight_dtype)
1106
+
1107
+ if not args.train_text_encoder and text_encoder is not None:
1108
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
1109
+
1110
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1111
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1112
+ if overrode_max_train_steps:
1113
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1114
+ # Afterwards we recalculate our number of training epochs
1115
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1116
+
1117
+ # We need to initialize the trackers we use, and also store our configuration.
1118
+ # The trackers initializes automatically on the main process.
1119
+ if accelerator.is_main_process:
1120
+ tracker_config = vars(copy.deepcopy(args))
1121
+ tracker_config.pop("validation_images")
1122
+ accelerator.init_trackers("dreambooth", config=tracker_config)
1123
+
1124
+ # Train!
1125
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1126
+
1127
+ logger.info("***** Running training *****")
1128
+ logger.info(f" Num examples = {len(train_dataset)}")
1129
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1130
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1131
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1132
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1133
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1134
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1135
+ global_step = 0
1136
+ first_epoch = 0
1137
+
1138
+ # Potentially load in the weights and states from a previous save
1139
+ if args.resume_from_checkpoint:
1140
+ if args.resume_from_checkpoint != "latest":
1141
+ path = os.path.basename(args.resume_from_checkpoint)
1142
+ else:
1143
+ # Get the mos recent checkpoint
1144
+ dirs = os.listdir(args.output_dir)
1145
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1146
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1147
+ path = dirs[-1] if len(dirs) > 0 else None
1148
+
1149
+ if path is None:
1150
+ accelerator.print(
1151
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1152
+ )
1153
+ args.resume_from_checkpoint = None
1154
+ else:
1155
+ accelerator.print(f"Resuming from checkpoint {path}")
1156
+ accelerator.load_state(os.path.join(args.output_dir, path))
1157
+ global_step = int(path.split("-")[1])
1158
+
1159
+ resume_global_step = global_step * args.gradient_accumulation_steps
1160
+ first_epoch = global_step // num_update_steps_per_epoch
1161
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
1162
+
1163
+ # Only show the progress bar once on each machine.
1164
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
1165
+ progress_bar.set_description("Steps")
1166
+
1167
+ for epoch in range(first_epoch, args.num_train_epochs):
1168
+ unet.train()
1169
+ if args.train_text_encoder:
1170
+ text_encoder.train()
1171
+ for step, batch in enumerate(train_dataloader):
1172
+ # Skip steps until we reach the resumed step
1173
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
1174
+ if step % args.gradient_accumulation_steps == 0:
1175
+ progress_bar.update(1)
1176
+ continue
1177
+
1178
+ with accelerator.accumulate(unet):
1179
+ pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
1180
+
1181
+ if vae is not None:
1182
+ # Convert images to latent space
1183
+ model_input = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
1184
+ model_input = model_input * vae.config.scaling_factor
1185
+ else:
1186
+ model_input = pixel_values
1187
+
1188
+ # Sample noise that we'll add to the model input
1189
+ if args.offset_noise:
1190
+ noise = torch.randn_like(model_input) + 0.1 * torch.randn(
1191
+ model_input.shape[0], model_input.shape[1], 1, 1, device=model_input.device
1192
+ )
1193
+ else:
1194
+ noise = torch.randn_like(model_input)
1195
+ bsz, channels, height, width = model_input.shape
1196
+ # Sample a random timestep for each image
1197
+ timesteps = torch.randint(
1198
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
1199
+ )
1200
+ timesteps = timesteps.long()
1201
+
1202
+ # Add noise to the model input according to the noise magnitude at each timestep
1203
+ # (this is the forward diffusion process)
1204
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
1205
+
1206
+ # Get the text embedding for conditioning
1207
+ if args.pre_compute_text_embeddings:
1208
+ encoder_hidden_states = batch["input_ids"]
1209
+ else:
1210
+ encoder_hidden_states = encode_prompt(
1211
+ text_encoder,
1212
+ batch["input_ids"],
1213
+ batch["attention_mask"],
1214
+ text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
1215
+ )
1216
+
1217
+ if accelerator.unwrap_model(unet).config.in_channels == channels * 2:
1218
+ noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
1219
+
1220
+ if args.class_labels_conditioning == "timesteps":
1221
+ class_labels = timesteps
1222
+ else:
1223
+ class_labels = None
1224
+
1225
+ # Predict the noise residual
1226
+ model_pred = unet(
1227
+ noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels
1228
+ ).sample
1229
+
1230
+ if model_pred.shape[1] == 6:
1231
+ model_pred, _ = torch.chunk(model_pred, 2, dim=1)
1232
+
1233
+ # Get the target for loss depending on the prediction type
1234
+ if noise_scheduler.config.prediction_type == "epsilon":
1235
+ target = noise
1236
+ elif noise_scheduler.config.prediction_type == "v_prediction":
1237
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
1238
+ else:
1239
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1240
+
1241
+ if args.with_prior_preservation:
1242
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
1243
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
1244
+ target, target_prior = torch.chunk(target, 2, dim=0)
1245
+
1246
+ # Compute instance loss
1247
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1248
+
1249
+ # Compute prior loss
1250
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
1251
+
1252
+ # Add the prior loss to the instance loss.
1253
+ loss = loss + args.prior_loss_weight * prior_loss
1254
+ else:
1255
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1256
+
1257
+ accelerator.backward(loss)
1258
+ if accelerator.sync_gradients:
1259
+ params_to_clip = (
1260
+ itertools.chain(unet.parameters(), text_encoder.parameters())
1261
+ if args.train_text_encoder
1262
+ else unet.parameters()
1263
+ )
1264
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1265
+ optimizer.step()
1266
+ lr_scheduler.step()
1267
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
1268
+
1269
+ # Checks if the accelerator has performed an optimization step behind the scenes
1270
+ if accelerator.sync_gradients:
1271
+ progress_bar.update(1)
1272
+ global_step += 1
1273
+
1274
+ if accelerator.is_main_process:
1275
+ if global_step % args.checkpointing_steps == 0:
1276
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1277
+ if args.checkpoints_total_limit is not None:
1278
+ checkpoints = os.listdir(args.output_dir)
1279
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1280
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1281
+
1282
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1283
+ if len(checkpoints) >= args.checkpoints_total_limit:
1284
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1285
+ removing_checkpoints = checkpoints[0:num_to_remove]
1286
+
1287
+ logger.info(
1288
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1289
+ )
1290
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1291
+
1292
+ for removing_checkpoint in removing_checkpoints:
1293
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1294
+ shutil.rmtree(removing_checkpoint)
1295
+
1296
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1297
+ accelerator.save_state(save_path)
1298
+ logger.info(f"Saved state to {save_path}")
1299
+
1300
+ images = []
1301
+
1302
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
1303
+ images = log_validation(
1304
+ text_encoder,
1305
+ tokenizer,
1306
+ unet,
1307
+ vae,
1308
+ args,
1309
+ accelerator,
1310
+ weight_dtype,
1311
+ epoch,
1312
+ validation_prompt_encoder_hidden_states,
1313
+ validation_prompt_negative_prompt_embeds,
1314
+ )
1315
+
1316
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1317
+ progress_bar.set_postfix(**logs)
1318
+ accelerator.log(logs, step=global_step)
1319
+
1320
+ if global_step >= args.max_train_steps:
1321
+ break
1322
+
1323
+ # Create the pipeline using using the trained modules and save it.
1324
+ accelerator.wait_for_everyone()
1325
+ if accelerator.is_main_process:
1326
+ pipeline_args = {}
1327
+
1328
+ if text_encoder is not None:
1329
+ pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder)
1330
+
1331
+ if args.skip_save_text_encoder:
1332
+ pipeline_args["text_encoder"] = None
1333
+
1334
+ pipeline = DiffusionPipeline.from_pretrained(
1335
+ args.pretrained_model_name_or_path,
1336
+ unet=accelerator.unwrap_model(unet),
1337
+ revision=args.revision,
1338
+ **pipeline_args,
1339
+ )
1340
+
1341
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
1342
+ scheduler_args = {}
1343
+
1344
+ if "variance_type" in pipeline.scheduler.config:
1345
+ variance_type = pipeline.scheduler.config.variance_type
1346
+
1347
+ if variance_type in ["learned", "learned_range"]:
1348
+ variance_type = "fixed_small"
1349
+
1350
+ scheduler_args["variance_type"] = variance_type
1351
+
1352
+ pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
1353
+
1354
+ pipeline.save_pretrained(args.output_dir)
1355
+
1356
+ if args.push_to_hub:
1357
+ save_model_card(
1358
+ repo_id,
1359
+ images=images,
1360
+ base_model=args.pretrained_model_name_or_path,
1361
+ train_text_encoder=args.train_text_encoder,
1362
+ prompt=args.instance_prompt,
1363
+ repo_folder=args.output_dir,
1364
+ pipeline=pipeline,
1365
+ )
1366
+ upload_folder(
1367
+ repo_id=repo_id,
1368
+ folder_path=args.output_dir,
1369
+ commit_message="End of training",
1370
+ ignore_patterns=["step_*", "epoch_*"],
1371
+ )
1372
+
1373
+ accelerator.end_training()
1374
+
1375
+
1376
+ if __name__ == "__main__":
1377
+ args = parse_args()
1378
+ main(args)
train_dreambooth_flax.py ADDED
@@ -0,0 +1,709 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import hashlib
3
+ import logging
4
+ import math
5
+ import os
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+ import jax
10
+ import jax.numpy as jnp
11
+ import numpy as np
12
+ import optax
13
+ import torch
14
+ import torch.utils.checkpoint
15
+ import transformers
16
+ from flax import jax_utils
17
+ from flax.training import train_state
18
+ from flax.training.common_utils import shard
19
+ from huggingface_hub import HfFolder, Repository, create_repo, whoami
20
+ from jax.experimental.compilation_cache import compilation_cache as cc
21
+ from PIL import Image
22
+ from torch.utils.data import Dataset
23
+ from torchvision import transforms
24
+ from tqdm.auto import tqdm
25
+ from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
26
+
27
+ from diffusers import (
28
+ FlaxAutoencoderKL,
29
+ FlaxDDPMScheduler,
30
+ FlaxPNDMScheduler,
31
+ FlaxStableDiffusionPipeline,
32
+ FlaxUNet2DConditionModel,
33
+ )
34
+ from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
35
+ from diffusers.utils import check_min_version
36
+
37
+
38
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
39
+ check_min_version("0.20.0.dev0")
40
+
41
+ # Cache compiled models across invocations of this script.
42
+ cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
43
+
44
+ logger = logging.getLogger(__name__)
45
+
46
+
47
+ def parse_args():
48
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
49
+ parser.add_argument(
50
+ "--pretrained_model_name_or_path",
51
+ type=str,
52
+ default=None,
53
+ required=True,
54
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
55
+ )
56
+ parser.add_argument(
57
+ "--pretrained_vae_name_or_path",
58
+ type=str,
59
+ default=None,
60
+ help="Path to pretrained vae or vae identifier from huggingface.co/models.",
61
+ )
62
+ parser.add_argument(
63
+ "--revision",
64
+ type=str,
65
+ default=None,
66
+ required=False,
67
+ help="Revision of pretrained model identifier from huggingface.co/models.",
68
+ )
69
+ parser.add_argument(
70
+ "--tokenizer_name",
71
+ type=str,
72
+ default=None,
73
+ help="Pretrained tokenizer name or path if not the same as model_name",
74
+ )
75
+ parser.add_argument(
76
+ "--instance_data_dir",
77
+ type=str,
78
+ default=None,
79
+ required=True,
80
+ help="A folder containing the training data of instance images.",
81
+ )
82
+ parser.add_argument(
83
+ "--class_data_dir",
84
+ type=str,
85
+ default=None,
86
+ required=False,
87
+ help="A folder containing the training data of class images.",
88
+ )
89
+ parser.add_argument(
90
+ "--instance_prompt",
91
+ type=str,
92
+ default=None,
93
+ help="The prompt with identifier specifying the instance",
94
+ )
95
+ parser.add_argument(
96
+ "--class_prompt",
97
+ type=str,
98
+ default=None,
99
+ help="The prompt to specify images in the same class as provided instance images.",
100
+ )
101
+ parser.add_argument(
102
+ "--with_prior_preservation",
103
+ default=False,
104
+ action="store_true",
105
+ help="Flag to add prior preservation loss.",
106
+ )
107
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
108
+ parser.add_argument(
109
+ "--num_class_images",
110
+ type=int,
111
+ default=100,
112
+ help=(
113
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
114
+ " class_data_dir, additional images will be sampled with class_prompt."
115
+ ),
116
+ )
117
+ parser.add_argument(
118
+ "--output_dir",
119
+ type=str,
120
+ default="text-inversion-model",
121
+ help="The output directory where the model predictions and checkpoints will be written.",
122
+ )
123
+ parser.add_argument("--save_steps", type=int, default=None, help="Save a checkpoint every X steps.")
124
+ parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.")
125
+ parser.add_argument(
126
+ "--resolution",
127
+ type=int,
128
+ default=512,
129
+ help=(
130
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
131
+ " resolution"
132
+ ),
133
+ )
134
+ parser.add_argument(
135
+ "--center_crop",
136
+ default=False,
137
+ action="store_true",
138
+ help=(
139
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
140
+ " cropped. The images will be resized to the resolution first before cropping."
141
+ ),
142
+ )
143
+ parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
144
+ parser.add_argument(
145
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
146
+ )
147
+ parser.add_argument(
148
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
149
+ )
150
+ parser.add_argument("--num_train_epochs", type=int, default=1)
151
+ parser.add_argument(
152
+ "--max_train_steps",
153
+ type=int,
154
+ default=None,
155
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
156
+ )
157
+ parser.add_argument(
158
+ "--learning_rate",
159
+ type=float,
160
+ default=5e-6,
161
+ help="Initial learning rate (after the potential warmup period) to use.",
162
+ )
163
+ parser.add_argument(
164
+ "--scale_lr",
165
+ action="store_true",
166
+ default=False,
167
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
168
+ )
169
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
170
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
171
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
172
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
173
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
174
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
175
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
176
+ parser.add_argument(
177
+ "--hub_model_id",
178
+ type=str,
179
+ default=None,
180
+ help="The name of the repository to keep in sync with the local `output_dir`.",
181
+ )
182
+ parser.add_argument(
183
+ "--logging_dir",
184
+ type=str,
185
+ default="logs",
186
+ help=(
187
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
188
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
189
+ ),
190
+ )
191
+ parser.add_argument(
192
+ "--mixed_precision",
193
+ type=str,
194
+ default="no",
195
+ choices=["no", "fp16", "bf16"],
196
+ help=(
197
+ "Whether to use mixed precision. Choose"
198
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
199
+ "and an Nvidia Ampere GPU."
200
+ ),
201
+ )
202
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
203
+
204
+ args = parser.parse_args()
205
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
206
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
207
+ args.local_rank = env_local_rank
208
+
209
+ if args.instance_data_dir is None:
210
+ raise ValueError("You must specify a train data directory.")
211
+
212
+ if args.with_prior_preservation:
213
+ if args.class_data_dir is None:
214
+ raise ValueError("You must specify a data directory for class images.")
215
+ if args.class_prompt is None:
216
+ raise ValueError("You must specify prompt for class images.")
217
+
218
+ return args
219
+
220
+
221
+ class DreamBoothDataset(Dataset):
222
+ """
223
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
224
+ It pre-processes the images and the tokenizes prompts.
225
+ """
226
+
227
+ def __init__(
228
+ self,
229
+ instance_data_root,
230
+ instance_prompt,
231
+ tokenizer,
232
+ class_data_root=None,
233
+ class_prompt=None,
234
+ class_num=None,
235
+ size=512,
236
+ center_crop=False,
237
+ ):
238
+ self.size = size
239
+ self.center_crop = center_crop
240
+ self.tokenizer = tokenizer
241
+
242
+ self.instance_data_root = Path(instance_data_root)
243
+ if not self.instance_data_root.exists():
244
+ raise ValueError("Instance images root doesn't exists.")
245
+
246
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
247
+ self.num_instance_images = len(self.instance_images_path)
248
+ self.instance_prompt = instance_prompt
249
+ self._length = self.num_instance_images
250
+
251
+ if class_data_root is not None:
252
+ self.class_data_root = Path(class_data_root)
253
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
254
+ self.class_images_path = list(self.class_data_root.iterdir())
255
+ if class_num is not None:
256
+ self.num_class_images = min(len(self.class_images_path), class_num)
257
+ else:
258
+ self.num_class_images = len(self.class_images_path)
259
+ self._length = max(self.num_class_images, self.num_instance_images)
260
+ self.class_prompt = class_prompt
261
+ else:
262
+ self.class_data_root = None
263
+
264
+ self.image_transforms = transforms.Compose(
265
+ [
266
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
267
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
268
+ transforms.ToTensor(),
269
+ transforms.Normalize([0.5], [0.5]),
270
+ ]
271
+ )
272
+
273
+ def __len__(self):
274
+ return self._length
275
+
276
+ def __getitem__(self, index):
277
+ example = {}
278
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
279
+ if not instance_image.mode == "RGB":
280
+ instance_image = instance_image.convert("RGB")
281
+ example["instance_images"] = self.image_transforms(instance_image)
282
+ example["instance_prompt_ids"] = self.tokenizer(
283
+ self.instance_prompt,
284
+ padding="do_not_pad",
285
+ truncation=True,
286
+ max_length=self.tokenizer.model_max_length,
287
+ ).input_ids
288
+
289
+ if self.class_data_root:
290
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
291
+ if not class_image.mode == "RGB":
292
+ class_image = class_image.convert("RGB")
293
+ example["class_images"] = self.image_transforms(class_image)
294
+ example["class_prompt_ids"] = self.tokenizer(
295
+ self.class_prompt,
296
+ padding="do_not_pad",
297
+ truncation=True,
298
+ max_length=self.tokenizer.model_max_length,
299
+ ).input_ids
300
+
301
+ return example
302
+
303
+
304
+ class PromptDataset(Dataset):
305
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
306
+
307
+ def __init__(self, prompt, num_samples):
308
+ self.prompt = prompt
309
+ self.num_samples = num_samples
310
+
311
+ def __len__(self):
312
+ return self.num_samples
313
+
314
+ def __getitem__(self, index):
315
+ example = {}
316
+ example["prompt"] = self.prompt
317
+ example["index"] = index
318
+ return example
319
+
320
+
321
+ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
322
+ if token is None:
323
+ token = HfFolder.get_token()
324
+ if organization is None:
325
+ username = whoami(token)["name"]
326
+ return f"{username}/{model_id}"
327
+ else:
328
+ return f"{organization}/{model_id}"
329
+
330
+
331
+ def get_params_to_save(params):
332
+ return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))
333
+
334
+
335
+ def main():
336
+ args = parse_args()
337
+
338
+ logging.basicConfig(
339
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
340
+ datefmt="%m/%d/%Y %H:%M:%S",
341
+ level=logging.INFO,
342
+ )
343
+ # Setup logging, we only want one process per machine to log things on the screen.
344
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
345
+ if jax.process_index() == 0:
346
+ transformers.utils.logging.set_verbosity_info()
347
+ else:
348
+ transformers.utils.logging.set_verbosity_error()
349
+
350
+ if args.seed is not None:
351
+ set_seed(args.seed)
352
+
353
+ rng = jax.random.PRNGKey(args.seed)
354
+
355
+ if args.with_prior_preservation:
356
+ class_images_dir = Path(args.class_data_dir)
357
+ if not class_images_dir.exists():
358
+ class_images_dir.mkdir(parents=True)
359
+ cur_class_images = len(list(class_images_dir.iterdir()))
360
+
361
+ if cur_class_images < args.num_class_images:
362
+ pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
363
+ args.pretrained_model_name_or_path, safety_checker=None, revision=args.revision
364
+ )
365
+ pipeline.set_progress_bar_config(disable=True)
366
+
367
+ num_new_images = args.num_class_images - cur_class_images
368
+ logger.info(f"Number of class images to sample: {num_new_images}.")
369
+
370
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
371
+ total_sample_batch_size = args.sample_batch_size * jax.local_device_count()
372
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=total_sample_batch_size)
373
+
374
+ for example in tqdm(
375
+ sample_dataloader, desc="Generating class images", disable=not jax.process_index() == 0
376
+ ):
377
+ prompt_ids = pipeline.prepare_inputs(example["prompt"])
378
+ prompt_ids = shard(prompt_ids)
379
+ p_params = jax_utils.replicate(params)
380
+ rng = jax.random.split(rng)[0]
381
+ sample_rng = jax.random.split(rng, jax.device_count())
382
+ images = pipeline(prompt_ids, p_params, sample_rng, jit=True).images
383
+ images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
384
+ images = pipeline.numpy_to_pil(np.array(images))
385
+
386
+ for i, image in enumerate(images):
387
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
388
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
389
+ image.save(image_filename)
390
+
391
+ del pipeline
392
+
393
+ # Handle the repository creation
394
+ if jax.process_index() == 0:
395
+ if args.push_to_hub:
396
+ if args.hub_model_id is None:
397
+ repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
398
+ else:
399
+ repo_name = args.hub_model_id
400
+ create_repo(repo_name, exist_ok=True, token=args.hub_token)
401
+ repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
402
+
403
+ with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
404
+ if "step_*" not in gitignore:
405
+ gitignore.write("step_*\n")
406
+ if "epoch_*" not in gitignore:
407
+ gitignore.write("epoch_*\n")
408
+ elif args.output_dir is not None:
409
+ os.makedirs(args.output_dir, exist_ok=True)
410
+
411
+ # Load the tokenizer and add the placeholder token as a additional special token
412
+ if args.tokenizer_name:
413
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
414
+ elif args.pretrained_model_name_or_path:
415
+ tokenizer = CLIPTokenizer.from_pretrained(
416
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
417
+ )
418
+ else:
419
+ raise NotImplementedError("No tokenizer specified!")
420
+
421
+ train_dataset = DreamBoothDataset(
422
+ instance_data_root=args.instance_data_dir,
423
+ instance_prompt=args.instance_prompt,
424
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
425
+ class_prompt=args.class_prompt,
426
+ class_num=args.num_class_images,
427
+ tokenizer=tokenizer,
428
+ size=args.resolution,
429
+ center_crop=args.center_crop,
430
+ )
431
+
432
+ def collate_fn(examples):
433
+ input_ids = [example["instance_prompt_ids"] for example in examples]
434
+ pixel_values = [example["instance_images"] for example in examples]
435
+
436
+ # Concat class and instance examples for prior preservation.
437
+ # We do this to avoid doing two forward passes.
438
+ if args.with_prior_preservation:
439
+ input_ids += [example["class_prompt_ids"] for example in examples]
440
+ pixel_values += [example["class_images"] for example in examples]
441
+
442
+ pixel_values = torch.stack(pixel_values)
443
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
444
+
445
+ input_ids = tokenizer.pad(
446
+ {"input_ids": input_ids}, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
447
+ ).input_ids
448
+
449
+ batch = {
450
+ "input_ids": input_ids,
451
+ "pixel_values": pixel_values,
452
+ }
453
+ batch = {k: v.numpy() for k, v in batch.items()}
454
+ return batch
455
+
456
+ total_train_batch_size = args.train_batch_size * jax.local_device_count()
457
+ if len(train_dataset) < total_train_batch_size:
458
+ raise ValueError(
459
+ f"Training batch size is {total_train_batch_size}, but your dataset only contains"
460
+ f" {len(train_dataset)} images. Please, use a larger dataset or reduce the effective batch size. Note that"
461
+ f" there are {jax.local_device_count()} parallel devices, so your batch size can't be smaller than that."
462
+ )
463
+
464
+ train_dataloader = torch.utils.data.DataLoader(
465
+ train_dataset, batch_size=total_train_batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True
466
+ )
467
+
468
+ weight_dtype = jnp.float32
469
+ if args.mixed_precision == "fp16":
470
+ weight_dtype = jnp.float16
471
+ elif args.mixed_precision == "bf16":
472
+ weight_dtype = jnp.bfloat16
473
+
474
+ if args.pretrained_vae_name_or_path:
475
+ # TODO(patil-suraj): Upload flax weights for the VAE
476
+ vae_arg, vae_kwargs = (args.pretrained_vae_name_or_path, {"from_pt": True})
477
+ else:
478
+ vae_arg, vae_kwargs = (args.pretrained_model_name_or_path, {"subfolder": "vae", "revision": args.revision})
479
+
480
+ # Load models and create wrapper for stable diffusion
481
+ text_encoder = FlaxCLIPTextModel.from_pretrained(
482
+ args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype, revision=args.revision
483
+ )
484
+ vae, vae_params = FlaxAutoencoderKL.from_pretrained(
485
+ vae_arg,
486
+ dtype=weight_dtype,
487
+ **vae_kwargs,
488
+ )
489
+ unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
490
+ args.pretrained_model_name_or_path, subfolder="unet", dtype=weight_dtype, revision=args.revision
491
+ )
492
+
493
+ # Optimization
494
+ if args.scale_lr:
495
+ args.learning_rate = args.learning_rate * total_train_batch_size
496
+
497
+ constant_scheduler = optax.constant_schedule(args.learning_rate)
498
+
499
+ adamw = optax.adamw(
500
+ learning_rate=constant_scheduler,
501
+ b1=args.adam_beta1,
502
+ b2=args.adam_beta2,
503
+ eps=args.adam_epsilon,
504
+ weight_decay=args.adam_weight_decay,
505
+ )
506
+
507
+ optimizer = optax.chain(
508
+ optax.clip_by_global_norm(args.max_grad_norm),
509
+ adamw,
510
+ )
511
+
512
+ unet_state = train_state.TrainState.create(apply_fn=unet.__call__, params=unet_params, tx=optimizer)
513
+ text_encoder_state = train_state.TrainState.create(
514
+ apply_fn=text_encoder.__call__, params=text_encoder.params, tx=optimizer
515
+ )
516
+
517
+ noise_scheduler = FlaxDDPMScheduler(
518
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
519
+ )
520
+ noise_scheduler_state = noise_scheduler.create_state()
521
+
522
+ # Initialize our training
523
+ train_rngs = jax.random.split(rng, jax.local_device_count())
524
+
525
+ def train_step(unet_state, text_encoder_state, vae_params, batch, train_rng):
526
+ dropout_rng, sample_rng, new_train_rng = jax.random.split(train_rng, 3)
527
+
528
+ if args.train_text_encoder:
529
+ params = {"text_encoder": text_encoder_state.params, "unet": unet_state.params}
530
+ else:
531
+ params = {"unet": unet_state.params}
532
+
533
+ def compute_loss(params):
534
+ # Convert images to latent space
535
+ vae_outputs = vae.apply(
536
+ {"params": vae_params}, batch["pixel_values"], deterministic=True, method=vae.encode
537
+ )
538
+ latents = vae_outputs.latent_dist.sample(sample_rng)
539
+ # (NHWC) -> (NCHW)
540
+ latents = jnp.transpose(latents, (0, 3, 1, 2))
541
+ latents = latents * vae.config.scaling_factor
542
+
543
+ # Sample noise that we'll add to the latents
544
+ noise_rng, timestep_rng = jax.random.split(sample_rng)
545
+ noise = jax.random.normal(noise_rng, latents.shape)
546
+ # Sample a random timestep for each image
547
+ bsz = latents.shape[0]
548
+ timesteps = jax.random.randint(
549
+ timestep_rng,
550
+ (bsz,),
551
+ 0,
552
+ noise_scheduler.config.num_train_timesteps,
553
+ )
554
+
555
+ # Add noise to the latents according to the noise magnitude at each timestep
556
+ # (this is the forward diffusion process)
557
+ noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
558
+
559
+ # Get the text embedding for conditioning
560
+ if args.train_text_encoder:
561
+ encoder_hidden_states = text_encoder_state.apply_fn(
562
+ batch["input_ids"], params=params["text_encoder"], dropout_rng=dropout_rng, train=True
563
+ )[0]
564
+ else:
565
+ encoder_hidden_states = text_encoder(
566
+ batch["input_ids"], params=text_encoder_state.params, train=False
567
+ )[0]
568
+
569
+ # Predict the noise residual
570
+ model_pred = unet.apply(
571
+ {"params": params["unet"]}, noisy_latents, timesteps, encoder_hidden_states, train=True
572
+ ).sample
573
+
574
+ # Get the target for loss depending on the prediction type
575
+ if noise_scheduler.config.prediction_type == "epsilon":
576
+ target = noise
577
+ elif noise_scheduler.config.prediction_type == "v_prediction":
578
+ target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
579
+ else:
580
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
581
+
582
+ if args.with_prior_preservation:
583
+ # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
584
+ model_pred, model_pred_prior = jnp.split(model_pred, 2, axis=0)
585
+ target, target_prior = jnp.split(target, 2, axis=0)
586
+
587
+ # Compute instance loss
588
+ loss = (target - model_pred) ** 2
589
+ loss = loss.mean()
590
+
591
+ # Compute prior loss
592
+ prior_loss = (target_prior - model_pred_prior) ** 2
593
+ prior_loss = prior_loss.mean()
594
+
595
+ # Add the prior loss to the instance loss.
596
+ loss = loss + args.prior_loss_weight * prior_loss
597
+ else:
598
+ loss = (target - model_pred) ** 2
599
+ loss = loss.mean()
600
+
601
+ return loss
602
+
603
+ grad_fn = jax.value_and_grad(compute_loss)
604
+ loss, grad = grad_fn(params)
605
+ grad = jax.lax.pmean(grad, "batch")
606
+
607
+ new_unet_state = unet_state.apply_gradients(grads=grad["unet"])
608
+ if args.train_text_encoder:
609
+ new_text_encoder_state = text_encoder_state.apply_gradients(grads=grad["text_encoder"])
610
+ else:
611
+ new_text_encoder_state = text_encoder_state
612
+
613
+ metrics = {"loss": loss}
614
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
615
+
616
+ return new_unet_state, new_text_encoder_state, metrics, new_train_rng
617
+
618
+ # Create parallel version of the train step
619
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, 1))
620
+
621
+ # Replicate the train state on each device
622
+ unet_state = jax_utils.replicate(unet_state)
623
+ text_encoder_state = jax_utils.replicate(text_encoder_state)
624
+ vae_params = jax_utils.replicate(vae_params)
625
+
626
+ # Train!
627
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader))
628
+
629
+ # Scheduler and math around the number of training steps.
630
+ if args.max_train_steps is None:
631
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
632
+
633
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
634
+
635
+ logger.info("***** Running training *****")
636
+ logger.info(f" Num examples = {len(train_dataset)}")
637
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
638
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
639
+ logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
640
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
641
+
642
+ def checkpoint(step=None):
643
+ # Create the pipeline using the trained modules and save it.
644
+ scheduler, _ = FlaxPNDMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
645
+ safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained(
646
+ "CompVis/stable-diffusion-safety-checker", from_pt=True
647
+ )
648
+ pipeline = FlaxStableDiffusionPipeline(
649
+ text_encoder=text_encoder,
650
+ vae=vae,
651
+ unet=unet,
652
+ tokenizer=tokenizer,
653
+ scheduler=scheduler,
654
+ safety_checker=safety_checker,
655
+ feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"),
656
+ )
657
+
658
+ outdir = os.path.join(args.output_dir, str(step)) if step else args.output_dir
659
+ pipeline.save_pretrained(
660
+ outdir,
661
+ params={
662
+ "text_encoder": get_params_to_save(text_encoder_state.params),
663
+ "vae": get_params_to_save(vae_params),
664
+ "unet": get_params_to_save(unet_state.params),
665
+ "safety_checker": safety_checker.params,
666
+ },
667
+ )
668
+
669
+ if args.push_to_hub:
670
+ message = f"checkpoint-{step}" if step is not None else "End of training"
671
+ repo.push_to_hub(commit_message=message, blocking=False, auto_lfs_prune=True)
672
+
673
+ global_step = 0
674
+
675
+ epochs = tqdm(range(args.num_train_epochs), desc="Epoch ... ", position=0)
676
+ for epoch in epochs:
677
+ # ======================== Training ================================
678
+
679
+ train_metrics = []
680
+
681
+ steps_per_epoch = len(train_dataset) // total_train_batch_size
682
+ train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False)
683
+ # train
684
+ for batch in train_dataloader:
685
+ batch = shard(batch)
686
+ unet_state, text_encoder_state, train_metric, train_rngs = p_train_step(
687
+ unet_state, text_encoder_state, vae_params, batch, train_rngs
688
+ )
689
+ train_metrics.append(train_metric)
690
+
691
+ train_step_progress_bar.update(jax.local_device_count())
692
+
693
+ global_step += 1
694
+ if jax.process_index() == 0 and args.save_steps and global_step % args.save_steps == 0:
695
+ checkpoint(global_step)
696
+ if global_step >= args.max_train_steps:
697
+ break
698
+
699
+ train_metric = jax_utils.unreplicate(train_metric)
700
+
701
+ train_step_progress_bar.close()
702
+ epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")
703
+
704
+ if jax.process_index() == 0:
705
+ checkpoint()
706
+
707
+
708
+ if __name__ == "__main__":
709
+ main()
train_dreambooth_lora.py ADDED
@@ -0,0 +1,1424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import copy
18
+ import gc
19
+ import hashlib
20
+ import itertools
21
+ import logging
22
+ import math
23
+ import os
24
+ import shutil
25
+ import warnings
26
+ from pathlib import Path
27
+ from typing import Dict
28
+
29
+ import numpy as np
30
+ import torch
31
+ import torch.nn.functional as F
32
+ import torch.utils.checkpoint
33
+ import transformers
34
+ from accelerate import Accelerator
35
+ from accelerate.logging import get_logger
36
+ from accelerate.utils import ProjectConfiguration, set_seed
37
+ from huggingface_hub import create_repo, upload_folder
38
+ from packaging import version
39
+ from PIL import Image
40
+ from PIL.ImageOps import exif_transpose
41
+ from torch.utils.data import Dataset
42
+ from torchvision import transforms
43
+ from tqdm.auto import tqdm
44
+ from transformers import AutoTokenizer, PretrainedConfig
45
+
46
+ import diffusers
47
+ from diffusers import (
48
+ AutoencoderKL,
49
+ DDPMScheduler,
50
+ DiffusionPipeline,
51
+ DPMSolverMultistepScheduler,
52
+ StableDiffusionPipeline,
53
+ UNet2DConditionModel,
54
+ )
55
+ from diffusers.loaders import (
56
+ LoraLoaderMixin,
57
+ text_encoder_lora_state_dict,
58
+ )
59
+ from diffusers.models.attention_processor import (
60
+ AttnAddedKVProcessor,
61
+ AttnAddedKVProcessor2_0,
62
+ LoRAAttnAddedKVProcessor,
63
+ LoRAAttnProcessor,
64
+ LoRAAttnProcessor2_0,
65
+ SlicedAttnAddedKVProcessor,
66
+ )
67
+ from diffusers.optimization import get_scheduler
68
+ from diffusers.utils import check_min_version, is_wandb_available
69
+ from diffusers.utils.import_utils import is_xformers_available
70
+
71
+
72
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
73
+ check_min_version("0.20.0.dev0")
74
+
75
+ logger = get_logger(__name__)
76
+
77
+
78
+ def save_model_card(
79
+ repo_id: str,
80
+ images=None,
81
+ base_model=str,
82
+ train_text_encoder=False,
83
+ prompt=str,
84
+ repo_folder=None,
85
+ pipeline: DiffusionPipeline = None,
86
+ ):
87
+ img_str = ""
88
+ for i, image in enumerate(images):
89
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
90
+ img_str += f"![img_{i}](./image_{i}.png)\n"
91
+
92
+ yaml = f"""
93
+ ---
94
+ license: creativeml-openrail-m
95
+ base_model: {base_model}
96
+ instance_prompt: {prompt}
97
+ tags:
98
+ - {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'}
99
+ - {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'}
100
+ - text-to-image
101
+ - diffusers
102
+ - lora
103
+ inference: true
104
+ ---
105
+ """
106
+ model_card = f"""
107
+ # LoRA DreamBooth - {repo_id}
108
+
109
+ These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
110
+ {img_str}
111
+
112
+ LoRA for the text encoder was enabled: {train_text_encoder}.
113
+ """
114
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
115
+ f.write(yaml + model_card)
116
+
117
+
118
+ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
119
+ text_encoder_config = PretrainedConfig.from_pretrained(
120
+ pretrained_model_name_or_path,
121
+ subfolder="text_encoder",
122
+ revision=revision,
123
+ )
124
+ model_class = text_encoder_config.architectures[0]
125
+
126
+ if model_class == "CLIPTextModel":
127
+ from transformers import CLIPTextModel
128
+
129
+ return CLIPTextModel
130
+ elif model_class == "RobertaSeriesModelWithTransformation":
131
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
132
+
133
+ return RobertaSeriesModelWithTransformation
134
+ elif model_class == "T5EncoderModel":
135
+ from transformers import T5EncoderModel
136
+
137
+ return T5EncoderModel
138
+ else:
139
+ raise ValueError(f"{model_class} is not supported.")
140
+
141
+
142
+ def parse_args(input_args=None):
143
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
144
+ parser.add_argument(
145
+ "--pretrained_model_name_or_path",
146
+ type=str,
147
+ default=None,
148
+ required=True,
149
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
150
+ )
151
+ parser.add_argument(
152
+ "--revision",
153
+ type=str,
154
+ default=None,
155
+ required=False,
156
+ help="Revision of pretrained model identifier from huggingface.co/models.",
157
+ )
158
+ parser.add_argument(
159
+ "--tokenizer_name",
160
+ type=str,
161
+ default=None,
162
+ help="Pretrained tokenizer name or path if not the same as model_name",
163
+ )
164
+ parser.add_argument(
165
+ "--instance_data_dir",
166
+ type=str,
167
+ default=None,
168
+ required=True,
169
+ help="A folder containing the training data of instance images.",
170
+ )
171
+ parser.add_argument(
172
+ "--class_data_dir",
173
+ type=str,
174
+ default=None,
175
+ required=False,
176
+ help="A folder containing the training data of class images.",
177
+ )
178
+ parser.add_argument(
179
+ "--instance_prompt",
180
+ type=str,
181
+ default=None,
182
+ required=True,
183
+ help="The prompt with identifier specifying the instance",
184
+ )
185
+ parser.add_argument(
186
+ "--class_prompt",
187
+ type=str,
188
+ default=None,
189
+ help="The prompt to specify images in the same class as provided instance images.",
190
+ )
191
+ parser.add_argument(
192
+ "--validation_prompt",
193
+ type=str,
194
+ default=None,
195
+ help="A prompt that is used during validation to verify that the model is learning.",
196
+ )
197
+ parser.add_argument(
198
+ "--num_validation_images",
199
+ type=int,
200
+ default=4,
201
+ help="Number of images that should be generated during validation with `validation_prompt`.",
202
+ )
203
+ parser.add_argument(
204
+ "--validation_epochs",
205
+ type=int,
206
+ default=50,
207
+ help=(
208
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
209
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
210
+ ),
211
+ )
212
+ parser.add_argument(
213
+ "--with_prior_preservation",
214
+ default=False,
215
+ action="store_true",
216
+ help="Flag to add prior preservation loss.",
217
+ )
218
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
219
+ parser.add_argument(
220
+ "--num_class_images",
221
+ type=int,
222
+ default=100,
223
+ help=(
224
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
225
+ " class_data_dir, additional images will be sampled with class_prompt."
226
+ ),
227
+ )
228
+ parser.add_argument(
229
+ "--output_dir",
230
+ type=str,
231
+ default="lora-dreambooth-model",
232
+ help="The output directory where the model predictions and checkpoints will be written.",
233
+ )
234
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
235
+ parser.add_argument(
236
+ "--resolution",
237
+ type=int,
238
+ default=512,
239
+ help=(
240
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
241
+ " resolution"
242
+ ),
243
+ )
244
+ parser.add_argument(
245
+ "--center_crop",
246
+ default=False,
247
+ action="store_true",
248
+ help=(
249
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
250
+ " cropped. The images will be resized to the resolution first before cropping."
251
+ ),
252
+ )
253
+ parser.add_argument(
254
+ "--train_text_encoder",
255
+ action="store_true",
256
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
257
+ )
258
+ parser.add_argument(
259
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
260
+ )
261
+ parser.add_argument(
262
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
263
+ )
264
+ parser.add_argument("--num_train_epochs", type=int, default=1)
265
+ parser.add_argument(
266
+ "--max_train_steps",
267
+ type=int,
268
+ default=None,
269
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
270
+ )
271
+ parser.add_argument(
272
+ "--checkpointing_steps",
273
+ type=int,
274
+ default=500,
275
+ help=(
276
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
277
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
278
+ " training using `--resume_from_checkpoint`."
279
+ ),
280
+ )
281
+ parser.add_argument(
282
+ "--checkpoints_total_limit",
283
+ type=int,
284
+ default=None,
285
+ help=("Max number of checkpoints to store."),
286
+ )
287
+ parser.add_argument(
288
+ "--resume_from_checkpoint",
289
+ type=str,
290
+ default=None,
291
+ help=(
292
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
293
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
294
+ ),
295
+ )
296
+ parser.add_argument(
297
+ "--gradient_accumulation_steps",
298
+ type=int,
299
+ default=1,
300
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
301
+ )
302
+ parser.add_argument(
303
+ "--gradient_checkpointing",
304
+ action="store_true",
305
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
306
+ )
307
+ parser.add_argument(
308
+ "--learning_rate",
309
+ type=float,
310
+ default=5e-4,
311
+ help="Initial learning rate (after the potential warmup period) to use.",
312
+ )
313
+ parser.add_argument(
314
+ "--scale_lr",
315
+ action="store_true",
316
+ default=False,
317
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
318
+ )
319
+ parser.add_argument(
320
+ "--lr_scheduler",
321
+ type=str,
322
+ default="constant",
323
+ help=(
324
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
325
+ ' "constant", "constant_with_warmup"]'
326
+ ),
327
+ )
328
+ parser.add_argument(
329
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
330
+ )
331
+ parser.add_argument(
332
+ "--lr_num_cycles",
333
+ type=int,
334
+ default=1,
335
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
336
+ )
337
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
338
+ parser.add_argument(
339
+ "--dataloader_num_workers",
340
+ type=int,
341
+ default=0,
342
+ help=(
343
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
344
+ ),
345
+ )
346
+ parser.add_argument(
347
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
348
+ )
349
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
350
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
351
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
352
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
353
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
354
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
355
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
356
+ parser.add_argument(
357
+ "--hub_model_id",
358
+ type=str,
359
+ default=None,
360
+ help="The name of the repository to keep in sync with the local `output_dir`.",
361
+ )
362
+ parser.add_argument(
363
+ "--logging_dir",
364
+ type=str,
365
+ default="logs",
366
+ help=(
367
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
368
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
369
+ ),
370
+ )
371
+ parser.add_argument(
372
+ "--allow_tf32",
373
+ action="store_true",
374
+ help=(
375
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
376
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
377
+ ),
378
+ )
379
+ parser.add_argument(
380
+ "--report_to",
381
+ type=str,
382
+ default="tensorboard",
383
+ help=(
384
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
385
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
386
+ ),
387
+ )
388
+ parser.add_argument(
389
+ "--mixed_precision",
390
+ type=str,
391
+ default=None,
392
+ choices=["no", "fp16", "bf16"],
393
+ help=(
394
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
395
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
396
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
397
+ ),
398
+ )
399
+ parser.add_argument(
400
+ "--prior_generation_precision",
401
+ type=str,
402
+ default=None,
403
+ choices=["no", "fp32", "fp16", "bf16"],
404
+ help=(
405
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
406
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
407
+ ),
408
+ )
409
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
410
+ parser.add_argument(
411
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
412
+ )
413
+ parser.add_argument(
414
+ "--pre_compute_text_embeddings",
415
+ action="store_true",
416
+ help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.",
417
+ )
418
+ parser.add_argument(
419
+ "--tokenizer_max_length",
420
+ type=int,
421
+ default=None,
422
+ required=False,
423
+ help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.",
424
+ )
425
+ parser.add_argument(
426
+ "--text_encoder_use_attention_mask",
427
+ action="store_true",
428
+ required=False,
429
+ help="Whether to use attention mask for the text encoder",
430
+ )
431
+ parser.add_argument(
432
+ "--validation_images",
433
+ required=False,
434
+ default=None,
435
+ nargs="+",
436
+ help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.",
437
+ )
438
+ parser.add_argument(
439
+ "--class_labels_conditioning",
440
+ required=False,
441
+ default=None,
442
+ help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.",
443
+ )
444
+ parser.add_argument(
445
+ "--rank",
446
+ type=int,
447
+ default=4,
448
+ help=("The dimension of the LoRA update matrices."),
449
+ )
450
+
451
+ if input_args is not None:
452
+ args = parser.parse_args(input_args)
453
+ else:
454
+ args = parser.parse_args()
455
+
456
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
457
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
458
+ args.local_rank = env_local_rank
459
+
460
+ if args.with_prior_preservation:
461
+ if args.class_data_dir is None:
462
+ raise ValueError("You must specify a data directory for class images.")
463
+ if args.class_prompt is None:
464
+ raise ValueError("You must specify prompt for class images.")
465
+ else:
466
+ # logger is not available yet
467
+ if args.class_data_dir is not None:
468
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
469
+ if args.class_prompt is not None:
470
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
471
+
472
+ if args.train_text_encoder and args.pre_compute_text_embeddings:
473
+ raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`")
474
+
475
+ return args
476
+
477
+
478
+ class DreamBoothDataset(Dataset):
479
+ """
480
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
481
+ It pre-processes the images and the tokenizes prompts.
482
+ """
483
+
484
+ def __init__(
485
+ self,
486
+ instance_data_root,
487
+ instance_prompt,
488
+ tokenizer,
489
+ class_data_root=None,
490
+ class_prompt=None,
491
+ class_num=None,
492
+ size=512,
493
+ center_crop=False,
494
+ encoder_hidden_states=None,
495
+ class_prompt_encoder_hidden_states=None,
496
+ tokenizer_max_length=None,
497
+ ):
498
+ self.size = size
499
+ self.center_crop = center_crop
500
+ self.tokenizer = tokenizer
501
+ self.encoder_hidden_states = encoder_hidden_states
502
+ self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states
503
+ self.tokenizer_max_length = tokenizer_max_length
504
+
505
+ self.instance_data_root = Path(instance_data_root)
506
+ if not self.instance_data_root.exists():
507
+ raise ValueError("Instance images root doesn't exists.")
508
+
509
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
510
+ self.num_instance_images = len(self.instance_images_path)
511
+ self.instance_prompt = instance_prompt
512
+ self._length = self.num_instance_images
513
+
514
+ if class_data_root is not None:
515
+ self.class_data_root = Path(class_data_root)
516
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
517
+ self.class_images_path = list(self.class_data_root.iterdir())
518
+ if class_num is not None:
519
+ self.num_class_images = min(len(self.class_images_path), class_num)
520
+ else:
521
+ self.num_class_images = len(self.class_images_path)
522
+ self._length = max(self.num_class_images, self.num_instance_images)
523
+ self.class_prompt = class_prompt
524
+ else:
525
+ self.class_data_root = None
526
+
527
+ self.image_transforms = transforms.Compose(
528
+ [
529
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
530
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
531
+ transforms.ToTensor(),
532
+ transforms.Normalize([0.5], [0.5]),
533
+ ]
534
+ )
535
+
536
+ def __len__(self):
537
+ return self._length
538
+
539
+ def __getitem__(self, index):
540
+ example = {}
541
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
542
+ instance_image = exif_transpose(instance_image)
543
+
544
+ if not instance_image.mode == "RGB":
545
+ instance_image = instance_image.convert("RGB")
546
+ example["instance_images"] = self.image_transforms(instance_image)
547
+
548
+ if self.encoder_hidden_states is not None:
549
+ example["instance_prompt_ids"] = self.encoder_hidden_states
550
+ else:
551
+ text_inputs = tokenize_prompt(
552
+ self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length
553
+ )
554
+ example["instance_prompt_ids"] = text_inputs.input_ids
555
+ example["instance_attention_mask"] = text_inputs.attention_mask
556
+
557
+ if self.class_data_root:
558
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
559
+ class_image = exif_transpose(class_image)
560
+
561
+ if not class_image.mode == "RGB":
562
+ class_image = class_image.convert("RGB")
563
+ example["class_images"] = self.image_transforms(class_image)
564
+
565
+ if self.class_prompt_encoder_hidden_states is not None:
566
+ example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states
567
+ else:
568
+ class_text_inputs = tokenize_prompt(
569
+ self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length
570
+ )
571
+ example["class_prompt_ids"] = class_text_inputs.input_ids
572
+ example["class_attention_mask"] = class_text_inputs.attention_mask
573
+
574
+ return example
575
+
576
+
577
+ def collate_fn(examples, with_prior_preservation=False):
578
+ has_attention_mask = "instance_attention_mask" in examples[0]
579
+
580
+ input_ids = [example["instance_prompt_ids"] for example in examples]
581
+ pixel_values = [example["instance_images"] for example in examples]
582
+
583
+ if has_attention_mask:
584
+ attention_mask = [example["instance_attention_mask"] for example in examples]
585
+
586
+ # Concat class and instance examples for prior preservation.
587
+ # We do this to avoid doing two forward passes.
588
+ if with_prior_preservation:
589
+ input_ids += [example["class_prompt_ids"] for example in examples]
590
+ pixel_values += [example["class_images"] for example in examples]
591
+ if has_attention_mask:
592
+ attention_mask += [example["class_attention_mask"] for example in examples]
593
+
594
+ pixel_values = torch.stack(pixel_values)
595
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
596
+
597
+ input_ids = torch.cat(input_ids, dim=0)
598
+
599
+ batch = {
600
+ "input_ids": input_ids,
601
+ "pixel_values": pixel_values,
602
+ }
603
+
604
+ if has_attention_mask:
605
+ batch["attention_mask"] = attention_mask
606
+
607
+ return batch
608
+
609
+
610
+ class PromptDataset(Dataset):
611
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
612
+
613
+ def __init__(self, prompt, num_samples):
614
+ self.prompt = prompt
615
+ self.num_samples = num_samples
616
+
617
+ def __len__(self):
618
+ return self.num_samples
619
+
620
+ def __getitem__(self, index):
621
+ example = {}
622
+ example["prompt"] = self.prompt
623
+ example["index"] = index
624
+ return example
625
+
626
+
627
+ def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
628
+ if tokenizer_max_length is not None:
629
+ max_length = tokenizer_max_length
630
+ else:
631
+ max_length = tokenizer.model_max_length
632
+
633
+ text_inputs = tokenizer(
634
+ prompt,
635
+ truncation=True,
636
+ padding="max_length",
637
+ max_length=max_length,
638
+ return_tensors="pt",
639
+ )
640
+
641
+ return text_inputs
642
+
643
+
644
+ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None):
645
+ text_input_ids = input_ids.to(text_encoder.device)
646
+
647
+ if text_encoder_use_attention_mask:
648
+ attention_mask = attention_mask.to(text_encoder.device)
649
+ else:
650
+ attention_mask = None
651
+
652
+ prompt_embeds = text_encoder(
653
+ text_input_ids,
654
+ attention_mask=attention_mask,
655
+ )
656
+ prompt_embeds = prompt_embeds[0]
657
+
658
+ return prompt_embeds
659
+
660
+
661
+ def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
662
+ r"""
663
+ Returns:
664
+ a state dict containing just the attention processor parameters.
665
+ """
666
+ attn_processors = unet.attn_processors
667
+
668
+ attn_processors_state_dict = {}
669
+
670
+ for attn_processor_key, attn_processor in attn_processors.items():
671
+ for parameter_key, parameter in attn_processor.state_dict().items():
672
+ attn_processors_state_dict[f"{attn_processor_key}.{parameter_key}"] = parameter
673
+
674
+ return attn_processors_state_dict
675
+
676
+
677
+ def main(args):
678
+ logging_dir = Path(args.output_dir, args.logging_dir)
679
+
680
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
681
+
682
+ accelerator = Accelerator(
683
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
684
+ mixed_precision=args.mixed_precision,
685
+ log_with=args.report_to,
686
+ project_config=accelerator_project_config,
687
+ )
688
+
689
+ if args.report_to == "wandb":
690
+ if not is_wandb_available():
691
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
692
+ import wandb
693
+
694
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
695
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
696
+ # TODO (sayakpaul): Remove this check when gradient accumulation with two models is enabled in accelerate.
697
+ if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
698
+ raise ValueError(
699
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
700
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
701
+ )
702
+
703
+ # Make one log on every process with the configuration for debugging.
704
+ logging.basicConfig(
705
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
706
+ datefmt="%m/%d/%Y %H:%M:%S",
707
+ level=logging.INFO,
708
+ )
709
+ logger.info(accelerator.state, main_process_only=False)
710
+ if accelerator.is_local_main_process:
711
+ transformers.utils.logging.set_verbosity_warning()
712
+ diffusers.utils.logging.set_verbosity_info()
713
+ else:
714
+ transformers.utils.logging.set_verbosity_error()
715
+ diffusers.utils.logging.set_verbosity_error()
716
+
717
+ # If passed along, set the training seed now.
718
+ if args.seed is not None:
719
+ set_seed(args.seed)
720
+
721
+ # Generate class images if prior preservation is enabled.
722
+ if args.with_prior_preservation:
723
+ class_images_dir = Path(args.class_data_dir)
724
+ if not class_images_dir.exists():
725
+ class_images_dir.mkdir(parents=True)
726
+ cur_class_images = len(list(class_images_dir.iterdir()))
727
+
728
+ if cur_class_images < args.num_class_images:
729
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
730
+ if args.prior_generation_precision == "fp32":
731
+ torch_dtype = torch.float32
732
+ elif args.prior_generation_precision == "fp16":
733
+ torch_dtype = torch.float16
734
+ elif args.prior_generation_precision == "bf16":
735
+ torch_dtype = torch.bfloat16
736
+ pipeline = DiffusionPipeline.from_pretrained(
737
+ args.pretrained_model_name_or_path,
738
+ torch_dtype=torch_dtype,
739
+ safety_checker=None,
740
+ revision=args.revision,
741
+ )
742
+ pipeline.set_progress_bar_config(disable=True)
743
+
744
+ num_new_images = args.num_class_images - cur_class_images
745
+ logger.info(f"Number of class images to sample: {num_new_images}.")
746
+
747
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
748
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
749
+
750
+ sample_dataloader = accelerator.prepare(sample_dataloader)
751
+ pipeline.to(accelerator.device)
752
+
753
+ for example in tqdm(
754
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
755
+ ):
756
+ images = pipeline(example["prompt"]).images
757
+
758
+ for i, image in enumerate(images):
759
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
760
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
761
+ image.save(image_filename)
762
+
763
+ del pipeline
764
+ if torch.cuda.is_available():
765
+ torch.cuda.empty_cache()
766
+
767
+ # Handle the repository creation
768
+ if accelerator.is_main_process:
769
+ if args.output_dir is not None:
770
+ os.makedirs(args.output_dir, exist_ok=True)
771
+
772
+ if args.push_to_hub:
773
+ repo_id = create_repo(
774
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
775
+ ).repo_id
776
+
777
+ # Load the tokenizer
778
+ if args.tokenizer_name:
779
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
780
+ elif args.pretrained_model_name_or_path:
781
+ tokenizer = AutoTokenizer.from_pretrained(
782
+ args.pretrained_model_name_or_path,
783
+ subfolder="tokenizer",
784
+ revision=args.revision,
785
+ use_fast=False,
786
+ )
787
+
788
+ # import correct text encoder class
789
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
790
+
791
+ # Load scheduler and models
792
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
793
+ text_encoder = text_encoder_cls.from_pretrained(
794
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
795
+ )
796
+ try:
797
+ vae = AutoencoderKL.from_pretrained(
798
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
799
+ )
800
+ except OSError:
801
+ # IF does not have a VAE so let's just set it to None
802
+ # We don't have to error out here
803
+ vae = None
804
+
805
+ unet = UNet2DConditionModel.from_pretrained(
806
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
807
+ )
808
+
809
+ # We only train the additional adapter LoRA layers
810
+ if vae is not None:
811
+ vae.requires_grad_(False)
812
+ text_encoder.requires_grad_(False)
813
+ unet.requires_grad_(False)
814
+
815
+ # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
816
+ # as these weights are only used for inference, keeping weights in full precision is not required.
817
+ weight_dtype = torch.float32
818
+ if accelerator.mixed_precision == "fp16":
819
+ weight_dtype = torch.float16
820
+ elif accelerator.mixed_precision == "bf16":
821
+ weight_dtype = torch.bfloat16
822
+
823
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
824
+ unet.to(accelerator.device, dtype=weight_dtype)
825
+ if vae is not None:
826
+ vae.to(accelerator.device, dtype=weight_dtype)
827
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
828
+
829
+ if args.enable_xformers_memory_efficient_attention:
830
+ if is_xformers_available():
831
+ import xformers
832
+
833
+ xformers_version = version.parse(xformers.__version__)
834
+ if xformers_version == version.parse("0.0.16"):
835
+ logger.warn(
836
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
837
+ )
838
+ unet.enable_xformers_memory_efficient_attention()
839
+ else:
840
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
841
+
842
+ if args.gradient_checkpointing:
843
+ unet.enable_gradient_checkpointing()
844
+ if args.train_text_encoder:
845
+ text_encoder.gradient_checkpointing_enable()
846
+
847
+ # now we will add new LoRA weights to the attention layers
848
+ # It's important to realize here how many attention weights will be added and of which sizes
849
+ # The sizes of the attention layers consist only of two different variables:
850
+ # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
851
+ # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
852
+
853
+ # Let's first see how many attention processors we will have to set.
854
+ # For Stable Diffusion, it should be equal to:
855
+ # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
856
+ # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
857
+ # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
858
+ # => 32 layers
859
+
860
+ # Set correct lora layers
861
+ unet_lora_attn_procs = {}
862
+ unet_lora_parameters = []
863
+ for name, attn_processor in unet.attn_processors.items():
864
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
865
+ if name.startswith("mid_block"):
866
+ hidden_size = unet.config.block_out_channels[-1]
867
+ elif name.startswith("up_blocks"):
868
+ block_id = int(name[len("up_blocks.")])
869
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
870
+ elif name.startswith("down_blocks"):
871
+ block_id = int(name[len("down_blocks.")])
872
+ hidden_size = unet.config.block_out_channels[block_id]
873
+
874
+ if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
875
+ lora_attn_processor_class = LoRAAttnAddedKVProcessor
876
+ else:
877
+ lora_attn_processor_class = (
878
+ LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
879
+ )
880
+
881
+ module = lora_attn_processor_class(
882
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=args.rank
883
+ )
884
+ unet_lora_attn_procs[name] = module
885
+ unet_lora_parameters.extend(module.parameters())
886
+
887
+ unet.set_attn_processor(unet_lora_attn_procs)
888
+
889
+ # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
890
+ # So, instead, we monkey-patch the forward calls of its attention-blocks.
891
+ if args.train_text_encoder:
892
+ # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
893
+ text_lora_parameters = LoraLoaderMixin._modify_text_encoder(text_encoder, dtype=torch.float32, rank=args.rank)
894
+
895
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
896
+ def save_model_hook(models, weights, output_dir):
897
+ # there are only two options here. Either are just the unet attn processor layers
898
+ # or there are the unet and text encoder atten layers
899
+ unet_lora_layers_to_save = None
900
+ text_encoder_lora_layers_to_save = None
901
+
902
+ for model in models:
903
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
904
+ unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
905
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
906
+ text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model)
907
+ else:
908
+ raise ValueError(f"unexpected save model: {model.__class__}")
909
+
910
+ # make sure to pop weight so that corresponding model is not saved again
911
+ weights.pop()
912
+
913
+ LoraLoaderMixin.save_lora_weights(
914
+ output_dir,
915
+ unet_lora_layers=unet_lora_layers_to_save,
916
+ text_encoder_lora_layers=text_encoder_lora_layers_to_save,
917
+ )
918
+
919
+ def load_model_hook(models, input_dir):
920
+ unet_ = None
921
+ text_encoder_ = None
922
+
923
+ while len(models) > 0:
924
+ model = models.pop()
925
+
926
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
927
+ unet_ = model
928
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
929
+ text_encoder_ = model
930
+ else:
931
+ raise ValueError(f"unexpected save model: {model.__class__}")
932
+
933
+ lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
934
+ LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
935
+ LoraLoaderMixin.load_lora_into_text_encoder(
936
+ lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
937
+ )
938
+
939
+ accelerator.register_save_state_pre_hook(save_model_hook)
940
+ accelerator.register_load_state_pre_hook(load_model_hook)
941
+
942
+ # Enable TF32 for faster training on Ampere GPUs,
943
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
944
+ if args.allow_tf32:
945
+ torch.backends.cuda.matmul.allow_tf32 = True
946
+
947
+ if args.scale_lr:
948
+ args.learning_rate = (
949
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
950
+ )
951
+
952
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
953
+ if args.use_8bit_adam:
954
+ try:
955
+ import bitsandbytes as bnb
956
+ except ImportError:
957
+ raise ImportError(
958
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
959
+ )
960
+
961
+ optimizer_class = bnb.optim.AdamW8bit
962
+ else:
963
+ optimizer_class = torch.optim.AdamW
964
+
965
+ # Optimizer creation
966
+ params_to_optimize = (
967
+ itertools.chain(unet_lora_parameters, text_lora_parameters)
968
+ if args.train_text_encoder
969
+ else unet_lora_parameters
970
+ )
971
+ optimizer = optimizer_class(
972
+ params_to_optimize,
973
+ lr=args.learning_rate,
974
+ betas=(args.adam_beta1, args.adam_beta2),
975
+ weight_decay=args.adam_weight_decay,
976
+ eps=args.adam_epsilon,
977
+ )
978
+
979
+ if args.pre_compute_text_embeddings:
980
+
981
+ def compute_text_embeddings(prompt):
982
+ with torch.no_grad():
983
+ text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length)
984
+ prompt_embeds = encode_prompt(
985
+ text_encoder,
986
+ text_inputs.input_ids,
987
+ text_inputs.attention_mask,
988
+ text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
989
+ )
990
+
991
+ return prompt_embeds
992
+
993
+ pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
994
+ validation_prompt_negative_prompt_embeds = compute_text_embeddings("")
995
+
996
+ if args.validation_prompt is not None:
997
+ validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt)
998
+ else:
999
+ validation_prompt_encoder_hidden_states = None
1000
+
1001
+ if args.class_prompt is not None:
1002
+ pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
1003
+ else:
1004
+ pre_computed_class_prompt_encoder_hidden_states = None
1005
+
1006
+ text_encoder = None
1007
+ tokenizer = None
1008
+
1009
+ gc.collect()
1010
+ torch.cuda.empty_cache()
1011
+ else:
1012
+ pre_computed_encoder_hidden_states = None
1013
+ validation_prompt_encoder_hidden_states = None
1014
+ validation_prompt_negative_prompt_embeds = None
1015
+ pre_computed_class_prompt_encoder_hidden_states = None
1016
+
1017
+ # Dataset and DataLoaders creation:
1018
+ train_dataset = DreamBoothDataset(
1019
+ instance_data_root=args.instance_data_dir,
1020
+ instance_prompt=args.instance_prompt,
1021
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
1022
+ class_prompt=args.class_prompt,
1023
+ class_num=args.num_class_images,
1024
+ tokenizer=tokenizer,
1025
+ size=args.resolution,
1026
+ center_crop=args.center_crop,
1027
+ encoder_hidden_states=pre_computed_encoder_hidden_states,
1028
+ class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,
1029
+ tokenizer_max_length=args.tokenizer_max_length,
1030
+ )
1031
+
1032
+ train_dataloader = torch.utils.data.DataLoader(
1033
+ train_dataset,
1034
+ batch_size=args.train_batch_size,
1035
+ shuffle=True,
1036
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
1037
+ num_workers=args.dataloader_num_workers,
1038
+ )
1039
+
1040
+ # Scheduler and math around the number of training steps.
1041
+ overrode_max_train_steps = False
1042
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1043
+ if args.max_train_steps is None:
1044
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1045
+ overrode_max_train_steps = True
1046
+
1047
+ lr_scheduler = get_scheduler(
1048
+ args.lr_scheduler,
1049
+ optimizer=optimizer,
1050
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1051
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
1052
+ num_cycles=args.lr_num_cycles,
1053
+ power=args.lr_power,
1054
+ )
1055
+
1056
+ # Prepare everything with our `accelerator`.
1057
+ if args.train_text_encoder:
1058
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1059
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
1060
+ )
1061
+ else:
1062
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1063
+ unet, optimizer, train_dataloader, lr_scheduler
1064
+ )
1065
+
1066
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1067
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1068
+ if overrode_max_train_steps:
1069
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1070
+ # Afterwards we recalculate our number of training epochs
1071
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1072
+
1073
+ # We need to initialize the trackers we use, and also store our configuration.
1074
+ # The trackers initializes automatically on the main process.
1075
+ if accelerator.is_main_process:
1076
+ tracker_config = vars(copy.deepcopy(args))
1077
+ tracker_config.pop("validation_images")
1078
+ accelerator.init_trackers("dreambooth-lora", config=tracker_config)
1079
+
1080
+ # Train!
1081
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1082
+
1083
+ logger.info("***** Running training *****")
1084
+ logger.info(f" Num examples = {len(train_dataset)}")
1085
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1086
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1087
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1088
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1089
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1090
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1091
+ global_step = 0
1092
+ first_epoch = 0
1093
+
1094
+ # Potentially load in the weights and states from a previous save
1095
+ if args.resume_from_checkpoint:
1096
+ if args.resume_from_checkpoint != "latest":
1097
+ path = os.path.basename(args.resume_from_checkpoint)
1098
+ else:
1099
+ # Get the mos recent checkpoint
1100
+ dirs = os.listdir(args.output_dir)
1101
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1102
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1103
+ path = dirs[-1] if len(dirs) > 0 else None
1104
+
1105
+ if path is None:
1106
+ accelerator.print(
1107
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1108
+ )
1109
+ args.resume_from_checkpoint = None
1110
+ else:
1111
+ accelerator.print(f"Resuming from checkpoint {path}")
1112
+ accelerator.load_state(os.path.join(args.output_dir, path))
1113
+ global_step = int(path.split("-")[1])
1114
+
1115
+ resume_global_step = global_step * args.gradient_accumulation_steps
1116
+ first_epoch = global_step // num_update_steps_per_epoch
1117
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
1118
+
1119
+ # Only show the progress bar once on each machine.
1120
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
1121
+ progress_bar.set_description("Steps")
1122
+
1123
+ for epoch in range(first_epoch, args.num_train_epochs):
1124
+ unet.train()
1125
+ if args.train_text_encoder:
1126
+ text_encoder.train()
1127
+ for step, batch in enumerate(train_dataloader):
1128
+ # Skip steps until we reach the resumed step
1129
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
1130
+ if step % args.gradient_accumulation_steps == 0:
1131
+ progress_bar.update(1)
1132
+ continue
1133
+
1134
+ with accelerator.accumulate(unet):
1135
+ pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
1136
+
1137
+ if vae is not None:
1138
+ # Convert images to latent space
1139
+ model_input = vae.encode(pixel_values).latent_dist.sample()
1140
+ model_input = model_input * vae.config.scaling_factor
1141
+ else:
1142
+ model_input = pixel_values
1143
+
1144
+ # Sample noise that we'll add to the latents
1145
+ noise = torch.randn_like(model_input)
1146
+ bsz, channels, height, width = model_input.shape
1147
+ # Sample a random timestep for each image
1148
+ timesteps = torch.randint(
1149
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
1150
+ )
1151
+ timesteps = timesteps.long()
1152
+
1153
+ # Add noise to the model input according to the noise magnitude at each timestep
1154
+ # (this is the forward diffusion process)
1155
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
1156
+
1157
+ # Get the text embedding for conditioning
1158
+ if args.pre_compute_text_embeddings:
1159
+ encoder_hidden_states = batch["input_ids"]
1160
+ else:
1161
+ encoder_hidden_states = encode_prompt(
1162
+ text_encoder,
1163
+ batch["input_ids"],
1164
+ batch["attention_mask"],
1165
+ text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
1166
+ )
1167
+
1168
+ if accelerator.unwrap_model(unet).config.in_channels == channels * 2:
1169
+ noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
1170
+
1171
+ if args.class_labels_conditioning == "timesteps":
1172
+ class_labels = timesteps
1173
+ else:
1174
+ class_labels = None
1175
+
1176
+ # Predict the noise residual
1177
+ model_pred = unet(
1178
+ noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels
1179
+ ).sample
1180
+
1181
+ # if model predicts variance, throw away the prediction. we will only train on the
1182
+ # simplified training objective. This means that all schedulers using the fine tuned
1183
+ # model must be configured to use one of the fixed variance variance types.
1184
+ if model_pred.shape[1] == 6:
1185
+ model_pred, _ = torch.chunk(model_pred, 2, dim=1)
1186
+
1187
+ # Get the target for loss depending on the prediction type
1188
+ if noise_scheduler.config.prediction_type == "epsilon":
1189
+ target = noise
1190
+ elif noise_scheduler.config.prediction_type == "v_prediction":
1191
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
1192
+ else:
1193
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1194
+
1195
+ if args.with_prior_preservation:
1196
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
1197
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
1198
+ target, target_prior = torch.chunk(target, 2, dim=0)
1199
+
1200
+ # Compute instance loss
1201
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1202
+
1203
+ # Compute prior loss
1204
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
1205
+
1206
+ # Add the prior loss to the instance loss.
1207
+ loss = loss + args.prior_loss_weight * prior_loss
1208
+ else:
1209
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1210
+
1211
+ accelerator.backward(loss)
1212
+ if accelerator.sync_gradients:
1213
+ params_to_clip = (
1214
+ itertools.chain(unet_lora_parameters, text_lora_parameters)
1215
+ if args.train_text_encoder
1216
+ else unet_lora_parameters
1217
+ )
1218
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1219
+ optimizer.step()
1220
+ lr_scheduler.step()
1221
+ optimizer.zero_grad()
1222
+
1223
+ # Checks if the accelerator has performed an optimization step behind the scenes
1224
+ if accelerator.sync_gradients:
1225
+ progress_bar.update(1)
1226
+ global_step += 1
1227
+
1228
+ if accelerator.is_main_process:
1229
+ if global_step % args.checkpointing_steps == 0:
1230
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1231
+ if args.checkpoints_total_limit is not None:
1232
+ checkpoints = os.listdir(args.output_dir)
1233
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1234
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1235
+
1236
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1237
+ if len(checkpoints) >= args.checkpoints_total_limit:
1238
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1239
+ removing_checkpoints = checkpoints[0:num_to_remove]
1240
+
1241
+ logger.info(
1242
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1243
+ )
1244
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1245
+
1246
+ for removing_checkpoint in removing_checkpoints:
1247
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1248
+ shutil.rmtree(removing_checkpoint)
1249
+
1250
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1251
+ accelerator.save_state(save_path)
1252
+ logger.info(f"Saved state to {save_path}")
1253
+
1254
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1255
+ progress_bar.set_postfix(**logs)
1256
+ accelerator.log(logs, step=global_step)
1257
+
1258
+ if global_step >= args.max_train_steps:
1259
+ break
1260
+
1261
+ if accelerator.is_main_process:
1262
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
1263
+ logger.info(
1264
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
1265
+ f" {args.validation_prompt}."
1266
+ )
1267
+ # create pipeline
1268
+ pipeline = DiffusionPipeline.from_pretrained(
1269
+ args.pretrained_model_name_or_path,
1270
+ unet=accelerator.unwrap_model(unet),
1271
+ text_encoder=None if args.pre_compute_text_embeddings else accelerator.unwrap_model(text_encoder),
1272
+ revision=args.revision,
1273
+ torch_dtype=weight_dtype,
1274
+ )
1275
+
1276
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
1277
+ scheduler_args = {}
1278
+
1279
+ if "variance_type" in pipeline.scheduler.config:
1280
+ variance_type = pipeline.scheduler.config.variance_type
1281
+
1282
+ if variance_type in ["learned", "learned_range"]:
1283
+ variance_type = "fixed_small"
1284
+
1285
+ scheduler_args["variance_type"] = variance_type
1286
+
1287
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
1288
+ pipeline.scheduler.config, **scheduler_args
1289
+ )
1290
+
1291
+ pipeline = pipeline.to(accelerator.device)
1292
+ pipeline.set_progress_bar_config(disable=True)
1293
+
1294
+ # run inference
1295
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
1296
+ if args.pre_compute_text_embeddings:
1297
+ pipeline_args = {
1298
+ "prompt_embeds": validation_prompt_encoder_hidden_states,
1299
+ "negative_prompt_embeds": validation_prompt_negative_prompt_embeds,
1300
+ }
1301
+ else:
1302
+ pipeline_args = {"prompt": args.validation_prompt}
1303
+
1304
+ if args.validation_images is None:
1305
+ images = []
1306
+ for _ in range(args.num_validation_images):
1307
+ with torch.cuda.amp.autocast():
1308
+ image = pipeline(**pipeline_args, generator=generator).images[0]
1309
+ images.append(image)
1310
+ else:
1311
+ images = []
1312
+ for image in args.validation_images:
1313
+ image = Image.open(image)
1314
+ with torch.cuda.amp.autocast():
1315
+ image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
1316
+ images.append(image)
1317
+
1318
+ for tracker in accelerator.trackers:
1319
+ if tracker.name == "tensorboard":
1320
+ np_images = np.stack([np.asarray(img) for img in images])
1321
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
1322
+ if tracker.name == "wandb":
1323
+ tracker.log(
1324
+ {
1325
+ "validation": [
1326
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1327
+ for i, image in enumerate(images)
1328
+ ]
1329
+ }
1330
+ )
1331
+
1332
+ del pipeline
1333
+ torch.cuda.empty_cache()
1334
+
1335
+ # Save the lora layers
1336
+ accelerator.wait_for_everyone()
1337
+ if accelerator.is_main_process:
1338
+ unet = accelerator.unwrap_model(unet)
1339
+ unet = unet.to(torch.float32)
1340
+ unet_lora_layers = unet_attn_processors_state_dict(unet)
1341
+
1342
+ if text_encoder is not None and args.train_text_encoder:
1343
+ text_encoder = accelerator.unwrap_model(text_encoder)
1344
+ text_encoder = text_encoder.to(torch.float32)
1345
+ text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder)
1346
+ else:
1347
+ text_encoder_lora_layers = None
1348
+
1349
+ LoraLoaderMixin.save_lora_weights(
1350
+ save_directory=args.output_dir,
1351
+ unet_lora_layers=unet_lora_layers,
1352
+ text_encoder_lora_layers=text_encoder_lora_layers,
1353
+ )
1354
+
1355
+ # Final inference
1356
+ # Load previous pipeline
1357
+ pipeline = DiffusionPipeline.from_pretrained(
1358
+ args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
1359
+ )
1360
+
1361
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
1362
+ scheduler_args = {}
1363
+
1364
+ if "variance_type" in pipeline.scheduler.config:
1365
+ variance_type = pipeline.scheduler.config.variance_type
1366
+
1367
+ if variance_type in ["learned", "learned_range"]:
1368
+ variance_type = "fixed_small"
1369
+
1370
+ scheduler_args["variance_type"] = variance_type
1371
+
1372
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
1373
+
1374
+ pipeline = pipeline.to(accelerator.device)
1375
+
1376
+ # load attention processors
1377
+ pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.bin")
1378
+
1379
+ # run inference
1380
+ images = []
1381
+ if args.validation_prompt and args.num_validation_images > 0:
1382
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
1383
+ images = [
1384
+ pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
1385
+ for _ in range(args.num_validation_images)
1386
+ ]
1387
+
1388
+ for tracker in accelerator.trackers:
1389
+ if tracker.name == "tensorboard":
1390
+ np_images = np.stack([np.asarray(img) for img in images])
1391
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
1392
+ if tracker.name == "wandb":
1393
+ tracker.log(
1394
+ {
1395
+ "test": [
1396
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1397
+ for i, image in enumerate(images)
1398
+ ]
1399
+ }
1400
+ )
1401
+
1402
+ if args.push_to_hub:
1403
+ save_model_card(
1404
+ repo_id,
1405
+ images=images,
1406
+ base_model=args.pretrained_model_name_or_path,
1407
+ train_text_encoder=args.train_text_encoder,
1408
+ prompt=args.instance_prompt,
1409
+ repo_folder=args.output_dir,
1410
+ pipeline=pipeline,
1411
+ )
1412
+ upload_folder(
1413
+ repo_id=repo_id,
1414
+ folder_path=args.output_dir,
1415
+ commit_message="End of training",
1416
+ ignore_patterns=["step_*", "epoch_*"],
1417
+ )
1418
+
1419
+ accelerator.end_training()
1420
+
1421
+
1422
+ if __name__ == "__main__":
1423
+ args = parse_args()
1424
+ main(args)
train_dreambooth_lora_sdxl.py ADDED
@@ -0,0 +1,1368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import gc
18
+ import hashlib
19
+ import itertools
20
+ import logging
21
+ import math
22
+ import os
23
+ import shutil
24
+ import warnings
25
+ from pathlib import Path
26
+ from typing import Dict
27
+
28
+ import numpy as np
29
+ import torch
30
+ import torch.nn.functional as F
31
+ import torch.utils.checkpoint
32
+ import transformers
33
+ from accelerate import Accelerator
34
+ from accelerate.logging import get_logger
35
+ from accelerate.utils import ProjectConfiguration, set_seed
36
+ from huggingface_hub import create_repo, upload_folder
37
+ from packaging import version
38
+ from PIL import Image
39
+ from PIL.ImageOps import exif_transpose
40
+ from torch.utils.data import Dataset
41
+ from torchvision import transforms
42
+ from tqdm.auto import tqdm
43
+ from transformers import AutoTokenizer, PretrainedConfig
44
+
45
+ import diffusers
46
+ from diffusers import (
47
+ AutoencoderKL,
48
+ DDPMScheduler,
49
+ DPMSolverMultistepScheduler,
50
+ StableDiffusionXLPipeline,
51
+ UNet2DConditionModel,
52
+ )
53
+ from diffusers.loaders import LoraLoaderMixin, text_encoder_lora_state_dict
54
+ from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0
55
+ from diffusers.optimization import get_scheduler
56
+ from diffusers.utils import check_min_version, is_wandb_available
57
+ from diffusers.utils.import_utils import is_xformers_available
58
+
59
+
60
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
61
+ check_min_version("0.20.0.dev0")
62
+
63
+ logger = get_logger(__name__)
64
+
65
+
66
+ def save_model_card(
67
+ repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None
68
+ ):
69
+ img_str = ""
70
+ for i, image in enumerate(images):
71
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
72
+ img_str += f"![img_{i}](./image_{i}.png)\n"
73
+
74
+ yaml = f"""
75
+ ---
76
+ license: openrail++
77
+ base_model: {base_model}
78
+ instance_prompt: {prompt}
79
+ tags:
80
+ - stable-diffusion-xl
81
+ - stable-diffusion-xl-diffusers
82
+ - text-to-image
83
+ - diffusers
84
+ - lora
85
+ inference: true
86
+ ---
87
+ """
88
+ model_card = f"""
89
+ # LoRA DreamBooth - {repo_id}
90
+
91
+ These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
92
+ {img_str}
93
+
94
+ LoRA for the text encoder was enabled: {train_text_encoder}.
95
+
96
+ Special VAE used for training: {vae_path}.
97
+ """
98
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
99
+ f.write(yaml + model_card)
100
+
101
+
102
+ def import_model_class_from_model_name_or_path(
103
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
104
+ ):
105
+ text_encoder_config = PretrainedConfig.from_pretrained(
106
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
107
+ )
108
+ model_class = text_encoder_config.architectures[0]
109
+
110
+ if model_class == "CLIPTextModel":
111
+ from transformers import CLIPTextModel
112
+
113
+ return CLIPTextModel
114
+ elif model_class == "CLIPTextModelWithProjection":
115
+ from transformers import CLIPTextModelWithProjection
116
+
117
+ return CLIPTextModelWithProjection
118
+ else:
119
+ raise ValueError(f"{model_class} is not supported.")
120
+
121
+
122
+ def parse_args(input_args=None):
123
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
124
+ parser.add_argument(
125
+ "--pretrained_model_name_or_path",
126
+ type=str,
127
+ default=None,
128
+ required=True,
129
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
130
+ )
131
+ parser.add_argument(
132
+ "--pretrained_vae_model_name_or_path",
133
+ type=str,
134
+ default=None,
135
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
136
+ )
137
+ parser.add_argument(
138
+ "--revision",
139
+ type=str,
140
+ default=None,
141
+ required=False,
142
+ help="Revision of pretrained model identifier from huggingface.co/models.",
143
+ )
144
+ parser.add_argument(
145
+ "--instance_data_dir",
146
+ type=str,
147
+ default=None,
148
+ required=True,
149
+ help="A folder containing the training data of instance images.",
150
+ )
151
+ parser.add_argument(
152
+ "--class_data_dir",
153
+ type=str,
154
+ default=None,
155
+ required=False,
156
+ help="A folder containing the training data of class images.",
157
+ )
158
+ parser.add_argument(
159
+ "--instance_prompt",
160
+ type=str,
161
+ default=None,
162
+ required=True,
163
+ help="The prompt with identifier specifying the instance",
164
+ )
165
+ parser.add_argument(
166
+ "--class_prompt",
167
+ type=str,
168
+ default=None,
169
+ help="The prompt to specify images in the same class as provided instance images.",
170
+ )
171
+ parser.add_argument(
172
+ "--validation_prompt",
173
+ type=str,
174
+ default=None,
175
+ help="A prompt that is used during validation to verify that the model is learning.",
176
+ )
177
+ parser.add_argument(
178
+ "--num_validation_images",
179
+ type=int,
180
+ default=4,
181
+ help="Number of images that should be generated during validation with `validation_prompt`.",
182
+ )
183
+ parser.add_argument(
184
+ "--validation_epochs",
185
+ type=int,
186
+ default=50,
187
+ help=(
188
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
189
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
190
+ ),
191
+ )
192
+ parser.add_argument(
193
+ "--with_prior_preservation",
194
+ default=False,
195
+ action="store_true",
196
+ help="Flag to add prior preservation loss.",
197
+ )
198
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
199
+ parser.add_argument(
200
+ "--num_class_images",
201
+ type=int,
202
+ default=100,
203
+ help=(
204
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
205
+ " class_data_dir, additional images will be sampled with class_prompt."
206
+ ),
207
+ )
208
+ parser.add_argument(
209
+ "--output_dir",
210
+ type=str,
211
+ default="lora-dreambooth-model",
212
+ help="The output directory where the model predictions and checkpoints will be written.",
213
+ )
214
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
215
+ parser.add_argument(
216
+ "--resolution",
217
+ type=int,
218
+ default=1024,
219
+ help=(
220
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
221
+ " resolution"
222
+ ),
223
+ )
224
+ parser.add_argument(
225
+ "--crops_coords_top_left_h",
226
+ type=int,
227
+ default=0,
228
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
229
+ )
230
+ parser.add_argument(
231
+ "--crops_coords_top_left_w",
232
+ type=int,
233
+ default=0,
234
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
235
+ )
236
+ parser.add_argument(
237
+ "--center_crop",
238
+ default=False,
239
+ action="store_true",
240
+ help=(
241
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
242
+ " cropped. The images will be resized to the resolution first before cropping."
243
+ ),
244
+ )
245
+ parser.add_argument(
246
+ "--train_text_encoder",
247
+ action="store_true",
248
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
249
+ )
250
+ parser.add_argument(
251
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
252
+ )
253
+ parser.add_argument(
254
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
255
+ )
256
+ parser.add_argument("--num_train_epochs", type=int, default=1)
257
+ parser.add_argument(
258
+ "--max_train_steps",
259
+ type=int,
260
+ default=None,
261
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
262
+ )
263
+ parser.add_argument(
264
+ "--checkpointing_steps",
265
+ type=int,
266
+ default=500,
267
+ help=(
268
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
269
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
270
+ " training using `--resume_from_checkpoint`."
271
+ ),
272
+ )
273
+ parser.add_argument(
274
+ "--checkpoints_total_limit",
275
+ type=int,
276
+ default=None,
277
+ help=("Max number of checkpoints to store."),
278
+ )
279
+ parser.add_argument(
280
+ "--resume_from_checkpoint",
281
+ type=str,
282
+ default=None,
283
+ help=(
284
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
285
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
286
+ ),
287
+ )
288
+ parser.add_argument(
289
+ "--gradient_accumulation_steps",
290
+ type=int,
291
+ default=1,
292
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
293
+ )
294
+ parser.add_argument(
295
+ "--gradient_checkpointing",
296
+ action="store_true",
297
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
298
+ )
299
+ parser.add_argument(
300
+ "--learning_rate",
301
+ type=float,
302
+ default=5e-4,
303
+ help="Initial learning rate (after the potential warmup period) to use.",
304
+ )
305
+ parser.add_argument(
306
+ "--scale_lr",
307
+ action="store_true",
308
+ default=False,
309
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
310
+ )
311
+ parser.add_argument(
312
+ "--lr_scheduler",
313
+ type=str,
314
+ default="constant",
315
+ help=(
316
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
317
+ ' "constant", "constant_with_warmup"]'
318
+ ),
319
+ )
320
+ parser.add_argument(
321
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
322
+ )
323
+ parser.add_argument(
324
+ "--lr_num_cycles",
325
+ type=int,
326
+ default=1,
327
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
328
+ )
329
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
330
+ parser.add_argument(
331
+ "--dataloader_num_workers",
332
+ type=int,
333
+ default=0,
334
+ help=(
335
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
336
+ ),
337
+ )
338
+ parser.add_argument(
339
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
340
+ )
341
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
342
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
343
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
344
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
345
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
346
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
347
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
348
+ parser.add_argument(
349
+ "--hub_model_id",
350
+ type=str,
351
+ default=None,
352
+ help="The name of the repository to keep in sync with the local `output_dir`.",
353
+ )
354
+ parser.add_argument(
355
+ "--logging_dir",
356
+ type=str,
357
+ default="logs",
358
+ help=(
359
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
360
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
361
+ ),
362
+ )
363
+ parser.add_argument(
364
+ "--allow_tf32",
365
+ action="store_true",
366
+ help=(
367
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
368
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
369
+ ),
370
+ )
371
+ parser.add_argument(
372
+ "--report_to",
373
+ type=str,
374
+ default="tensorboard",
375
+ help=(
376
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
377
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
378
+ ),
379
+ )
380
+ parser.add_argument(
381
+ "--mixed_precision",
382
+ type=str,
383
+ default=None,
384
+ choices=["no", "fp16", "bf16"],
385
+ help=(
386
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
387
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
388
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
389
+ ),
390
+ )
391
+ parser.add_argument(
392
+ "--prior_generation_precision",
393
+ type=str,
394
+ default=None,
395
+ choices=["no", "fp32", "fp16", "bf16"],
396
+ help=(
397
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
398
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
399
+ ),
400
+ )
401
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
402
+ parser.add_argument(
403
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
404
+ )
405
+ parser.add_argument(
406
+ "--rank",
407
+ type=int,
408
+ default=4,
409
+ help=("The dimension of the LoRA update matrices."),
410
+ )
411
+
412
+ if input_args is not None:
413
+ args = parser.parse_args(input_args)
414
+ else:
415
+ args = parser.parse_args()
416
+
417
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
418
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
419
+ args.local_rank = env_local_rank
420
+
421
+ if args.with_prior_preservation:
422
+ if args.class_data_dir is None:
423
+ raise ValueError("You must specify a data directory for class images.")
424
+ if args.class_prompt is None:
425
+ raise ValueError("You must specify prompt for class images.")
426
+ else:
427
+ # logger is not available yet
428
+ if args.class_data_dir is not None:
429
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
430
+ if args.class_prompt is not None:
431
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
432
+
433
+ return args
434
+
435
+
436
+ class DreamBoothDataset(Dataset):
437
+ """
438
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
439
+ It pre-processes the images.
440
+ """
441
+
442
+ def __init__(
443
+ self,
444
+ instance_data_root,
445
+ class_data_root=None,
446
+ class_num=None,
447
+ size=1024,
448
+ center_crop=False,
449
+ ):
450
+ self.size = size
451
+ self.center_crop = center_crop
452
+
453
+ self.instance_data_root = Path(instance_data_root)
454
+ if not self.instance_data_root.exists():
455
+ raise ValueError("Instance images root doesn't exists.")
456
+
457
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
458
+ self.num_instance_images = len(self.instance_images_path)
459
+ self._length = self.num_instance_images
460
+
461
+ if class_data_root is not None:
462
+ self.class_data_root = Path(class_data_root)
463
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
464
+ self.class_images_path = list(self.class_data_root.iterdir())
465
+ if class_num is not None:
466
+ self.num_class_images = min(len(self.class_images_path), class_num)
467
+ else:
468
+ self.num_class_images = len(self.class_images_path)
469
+ self._length = max(self.num_class_images, self.num_instance_images)
470
+ else:
471
+ self.class_data_root = None
472
+
473
+ self.image_transforms = transforms.Compose(
474
+ [
475
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
476
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
477
+ transforms.ToTensor(),
478
+ transforms.Normalize([0.5], [0.5]),
479
+ ]
480
+ )
481
+
482
+ def __len__(self):
483
+ return self._length
484
+
485
+ def __getitem__(self, index):
486
+ example = {}
487
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
488
+ instance_image = exif_transpose(instance_image)
489
+
490
+ if not instance_image.mode == "RGB":
491
+ instance_image = instance_image.convert("RGB")
492
+ example["instance_images"] = self.image_transforms(instance_image)
493
+
494
+ if self.class_data_root:
495
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
496
+ class_image = exif_transpose(class_image)
497
+
498
+ if not class_image.mode == "RGB":
499
+ class_image = class_image.convert("RGB")
500
+ example["class_images"] = self.image_transforms(class_image)
501
+
502
+ return example
503
+
504
+
505
+ def collate_fn(examples, with_prior_preservation=False):
506
+ pixel_values = [example["instance_images"] for example in examples]
507
+
508
+ # Concat class and instance examples for prior preservation.
509
+ # We do this to avoid doing two forward passes.
510
+ if with_prior_preservation:
511
+ pixel_values += [example["class_images"] for example in examples]
512
+
513
+ pixel_values = torch.stack(pixel_values)
514
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
515
+
516
+ batch = {"pixel_values": pixel_values}
517
+ return batch
518
+
519
+
520
+ class PromptDataset(Dataset):
521
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
522
+
523
+ def __init__(self, prompt, num_samples):
524
+ self.prompt = prompt
525
+ self.num_samples = num_samples
526
+
527
+ def __len__(self):
528
+ return self.num_samples
529
+
530
+ def __getitem__(self, index):
531
+ example = {}
532
+ example["prompt"] = self.prompt
533
+ example["index"] = index
534
+ return example
535
+
536
+
537
+ def tokenize_prompt(tokenizer, prompt):
538
+ text_inputs = tokenizer(
539
+ prompt,
540
+ padding="max_length",
541
+ max_length=tokenizer.model_max_length,
542
+ truncation=True,
543
+ return_tensors="pt",
544
+ )
545
+ text_input_ids = text_inputs.input_ids
546
+ return text_input_ids
547
+
548
+
549
+ # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
550
+ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
551
+ prompt_embeds_list = []
552
+
553
+ for i, text_encoder in enumerate(text_encoders):
554
+ if tokenizers is not None:
555
+ tokenizer = tokenizers[i]
556
+ text_input_ids = tokenize_prompt(tokenizer, prompt)
557
+ else:
558
+ assert text_input_ids_list is not None
559
+ text_input_ids = text_input_ids_list[i]
560
+
561
+ prompt_embeds = text_encoder(
562
+ text_input_ids.to(text_encoder.device),
563
+ output_hidden_states=True,
564
+ )
565
+
566
+ # We are only ALWAYS interested in the pooled output of the final text encoder
567
+ pooled_prompt_embeds = prompt_embeds[0]
568
+ prompt_embeds = prompt_embeds.hidden_states[-2]
569
+ bs_embed, seq_len, _ = prompt_embeds.shape
570
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
571
+ prompt_embeds_list.append(prompt_embeds)
572
+
573
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
574
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
575
+ return prompt_embeds, pooled_prompt_embeds
576
+
577
+
578
+ def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
579
+ """
580
+ Returns:
581
+ a state dict containing just the attention processor parameters.
582
+ """
583
+ attn_processors = unet.attn_processors
584
+
585
+ attn_processors_state_dict = {}
586
+
587
+ for attn_processor_key, attn_processor in attn_processors.items():
588
+ for parameter_key, parameter in attn_processor.state_dict().items():
589
+ attn_processors_state_dict[f"{attn_processor_key}.{parameter_key}"] = parameter
590
+
591
+ return attn_processors_state_dict
592
+
593
+
594
+ def main(args):
595
+ logging_dir = Path(args.output_dir, args.logging_dir)
596
+
597
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
598
+
599
+ accelerator = Accelerator(
600
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
601
+ mixed_precision=args.mixed_precision,
602
+ log_with=args.report_to,
603
+ project_config=accelerator_project_config,
604
+ )
605
+
606
+ if args.report_to == "wandb":
607
+ if not is_wandb_available():
608
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
609
+ import wandb
610
+
611
+ # Make one log on every process with the configuration for debugging.
612
+ logging.basicConfig(
613
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
614
+ datefmt="%m/%d/%Y %H:%M:%S",
615
+ level=logging.INFO,
616
+ )
617
+ logger.info(accelerator.state, main_process_only=False)
618
+ if accelerator.is_local_main_process:
619
+ transformers.utils.logging.set_verbosity_warning()
620
+ diffusers.utils.logging.set_verbosity_info()
621
+ else:
622
+ transformers.utils.logging.set_verbosity_error()
623
+ diffusers.utils.logging.set_verbosity_error()
624
+
625
+ # If passed along, set the training seed now.
626
+ if args.seed is not None:
627
+ set_seed(args.seed)
628
+
629
+ # Generate class images if prior preservation is enabled.
630
+ if args.with_prior_preservation:
631
+ class_images_dir = Path(args.class_data_dir)
632
+ if not class_images_dir.exists():
633
+ class_images_dir.mkdir(parents=True)
634
+ cur_class_images = len(list(class_images_dir.iterdir()))
635
+
636
+ if cur_class_images < args.num_class_images:
637
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
638
+ if args.prior_generation_precision == "fp32":
639
+ torch_dtype = torch.float32
640
+ elif args.prior_generation_precision == "fp16":
641
+ torch_dtype = torch.float16
642
+ elif args.prior_generation_precision == "bf16":
643
+ torch_dtype = torch.bfloat16
644
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
645
+ args.pretrained_model_name_or_path,
646
+ torch_dtype=torch_dtype,
647
+ revision=args.revision,
648
+ )
649
+ pipeline.set_progress_bar_config(disable=True)
650
+
651
+ num_new_images = args.num_class_images - cur_class_images
652
+ logger.info(f"Number of class images to sample: {num_new_images}.")
653
+
654
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
655
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
656
+
657
+ sample_dataloader = accelerator.prepare(sample_dataloader)
658
+ pipeline.to(accelerator.device)
659
+
660
+ for example in tqdm(
661
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
662
+ ):
663
+ images = pipeline(example["prompt"]).images
664
+
665
+ for i, image in enumerate(images):
666
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
667
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
668
+ image.save(image_filename)
669
+
670
+ del pipeline
671
+ if torch.cuda.is_available():
672
+ torch.cuda.empty_cache()
673
+
674
+ # Handle the repository creation
675
+ if accelerator.is_main_process:
676
+ if args.output_dir is not None:
677
+ os.makedirs(args.output_dir, exist_ok=True)
678
+
679
+ if args.push_to_hub:
680
+ repo_id = create_repo(
681
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
682
+ ).repo_id
683
+
684
+ # Load the tokenizers
685
+ tokenizer_one = AutoTokenizer.from_pretrained(
686
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
687
+ )
688
+ tokenizer_two = AutoTokenizer.from_pretrained(
689
+ args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
690
+ )
691
+
692
+ # import correct text encoder classes
693
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
694
+ args.pretrained_model_name_or_path, args.revision
695
+ )
696
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
697
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
698
+ )
699
+
700
+ # Load scheduler and models
701
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
702
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
703
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
704
+ )
705
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
706
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
707
+ )
708
+ vae_path = (
709
+ args.pretrained_model_name_or_path
710
+ if args.pretrained_vae_model_name_or_path is None
711
+ else args.pretrained_vae_model_name_or_path
712
+ )
713
+ vae = AutoencoderKL.from_pretrained(
714
+ vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision
715
+ )
716
+ unet = UNet2DConditionModel.from_pretrained(
717
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
718
+ )
719
+
720
+ # We only train the additional adapter LoRA layers
721
+ vae.requires_grad_(False)
722
+ text_encoder_one.requires_grad_(False)
723
+ text_encoder_two.requires_grad_(False)
724
+ unet.requires_grad_(False)
725
+
726
+ # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
727
+ # as these weights are only used for inference, keeping weights in full precision is not required.
728
+ weight_dtype = torch.float32
729
+ if accelerator.mixed_precision == "fp16":
730
+ weight_dtype = torch.float16
731
+ elif accelerator.mixed_precision == "bf16":
732
+ weight_dtype = torch.bfloat16
733
+
734
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
735
+ # The VAE is in float32 to avoid NaN losses.
736
+ unet.to(accelerator.device, dtype=weight_dtype)
737
+ if args.pretrained_vae_model_name_or_path is None:
738
+ vae.to(accelerator.device, dtype=torch.float32)
739
+ else:
740
+ vae.to(accelerator.device, dtype=weight_dtype)
741
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
742
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
743
+
744
+ if args.enable_xformers_memory_efficient_attention:
745
+ if is_xformers_available():
746
+ import xformers
747
+
748
+ xformers_version = version.parse(xformers.__version__)
749
+ if xformers_version == version.parse("0.0.16"):
750
+ logger.warn(
751
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
752
+ )
753
+ unet.enable_xformers_memory_efficient_attention()
754
+ else:
755
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
756
+
757
+ if args.gradient_checkpointing:
758
+ unet.enable_gradient_checkpointing()
759
+ if args.train_text_encoder:
760
+ text_encoder_one.gradient_checkpointing_enable()
761
+ text_encoder_two.gradient_checkpointing_enable()
762
+
763
+ # now we will add new LoRA weights to the attention layers
764
+ # Set correct lora layers
765
+ unet_lora_attn_procs = {}
766
+ unet_lora_parameters = []
767
+ for name, attn_processor in unet.attn_processors.items():
768
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
769
+ if name.startswith("mid_block"):
770
+ hidden_size = unet.config.block_out_channels[-1]
771
+ elif name.startswith("up_blocks"):
772
+ block_id = int(name[len("up_blocks.")])
773
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
774
+ elif name.startswith("down_blocks"):
775
+ block_id = int(name[len("down_blocks.")])
776
+ hidden_size = unet.config.block_out_channels[block_id]
777
+
778
+ lora_attn_processor_class = (
779
+ LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
780
+ )
781
+ module = lora_attn_processor_class(
782
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=args.rank
783
+ )
784
+ unet_lora_attn_procs[name] = module
785
+ unet_lora_parameters.extend(module.parameters())
786
+
787
+ unet.set_attn_processor(unet_lora_attn_procs)
788
+
789
+ # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
790
+ # So, instead, we monkey-patch the forward calls of its attention-blocks.
791
+ if args.train_text_encoder:
792
+ # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
793
+ text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(
794
+ text_encoder_one, dtype=torch.float32, rank=args.rank
795
+ )
796
+ text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(
797
+ text_encoder_two, dtype=torch.float32, rank=args.rank
798
+ )
799
+
800
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
801
+ def save_model_hook(models, weights, output_dir):
802
+ # there are only two options here. Either are just the unet attn processor layers
803
+ # or there are the unet and text encoder atten layers
804
+ unet_lora_layers_to_save = None
805
+ text_encoder_one_lora_layers_to_save = None
806
+ text_encoder_two_lora_layers_to_save = None
807
+
808
+ for model in models:
809
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
810
+ unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
811
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
812
+ text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
813
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
814
+ text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
815
+ else:
816
+ raise ValueError(f"unexpected save model: {model.__class__}")
817
+
818
+ # make sure to pop weight so that corresponding model is not saved again
819
+ weights.pop()
820
+
821
+ StableDiffusionXLPipeline.save_lora_weights(
822
+ output_dir,
823
+ unet_lora_layers=unet_lora_layers_to_save,
824
+ text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
825
+ text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
826
+ )
827
+
828
+ def load_model_hook(models, input_dir):
829
+ unet_ = None
830
+ text_encoder_one_ = None
831
+ text_encoder_two_ = None
832
+
833
+ while len(models) > 0:
834
+ model = models.pop()
835
+
836
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
837
+ unet_ = model
838
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
839
+ text_encoder_one_ = model
840
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
841
+ text_encoder_two_ = model
842
+ else:
843
+ raise ValueError(f"unexpected save model: {model.__class__}")
844
+
845
+ lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
846
+ LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
847
+ LoraLoaderMixin.load_lora_into_text_encoder(
848
+ lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
849
+ )
850
+ LoraLoaderMixin.load_lora_into_text_encoder(
851
+ lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
852
+ )
853
+
854
+ accelerator.register_save_state_pre_hook(save_model_hook)
855
+ accelerator.register_load_state_pre_hook(load_model_hook)
856
+
857
+ # Enable TF32 for faster training on Ampere GPUs,
858
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
859
+ if args.allow_tf32:
860
+ torch.backends.cuda.matmul.allow_tf32 = True
861
+
862
+ if args.scale_lr:
863
+ args.learning_rate = (
864
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
865
+ )
866
+
867
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
868
+ if args.use_8bit_adam:
869
+ try:
870
+ import bitsandbytes as bnb
871
+ except ImportError:
872
+ raise ImportError(
873
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
874
+ )
875
+
876
+ optimizer_class = bnb.optim.AdamW8bit
877
+ else:
878
+ optimizer_class = torch.optim.AdamW
879
+
880
+ # Optimizer creation
881
+ params_to_optimize = (
882
+ itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)
883
+ if args.train_text_encoder
884
+ else unet_lora_parameters
885
+ )
886
+ optimizer = optimizer_class(
887
+ params_to_optimize,
888
+ lr=args.learning_rate,
889
+ betas=(args.adam_beta1, args.adam_beta2),
890
+ weight_decay=args.adam_weight_decay,
891
+ eps=args.adam_epsilon,
892
+ )
893
+
894
+ # Computes additional embeddings/ids required by the SDXL UNet.
895
+ # regular text emebddings (when `train_text_encoder` is not True)
896
+ # pooled text embeddings
897
+ # time ids
898
+
899
+ def compute_time_ids():
900
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
901
+ original_size = (args.resolution, args.resolution)
902
+ target_size = (args.resolution, args.resolution)
903
+ crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
904
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
905
+ add_time_ids = torch.tensor([add_time_ids])
906
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
907
+ return add_time_ids
908
+
909
+ if not args.train_text_encoder:
910
+ tokenizers = [tokenizer_one, tokenizer_two]
911
+ text_encoders = [text_encoder_one, text_encoder_two]
912
+
913
+ def compute_text_embeddings(prompt, text_encoders, tokenizers):
914
+ with torch.no_grad():
915
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
916
+ prompt_embeds = prompt_embeds.to(accelerator.device)
917
+ pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
918
+ return prompt_embeds, pooled_prompt_embeds
919
+
920
+ # Handle instance prompt.
921
+ instance_time_ids = compute_time_ids()
922
+ if not args.train_text_encoder:
923
+ instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
924
+ args.instance_prompt, text_encoders, tokenizers
925
+ )
926
+
927
+ # Handle class prompt for prior-preservation.
928
+ if args.with_prior_preservation:
929
+ class_time_ids = compute_time_ids()
930
+ if not args.train_text_encoder:
931
+ class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(
932
+ args.class_prompt, text_encoders, tokenizers
933
+ )
934
+
935
+ # Clear the memory here.
936
+ if not args.train_text_encoder:
937
+ del tokenizers, text_encoders
938
+ gc.collect()
939
+ torch.cuda.empty_cache()
940
+
941
+ # Pack the statically computed variables appropriately. This is so that we don't
942
+ # have to pass them to the dataloader.
943
+ add_time_ids = instance_time_ids
944
+ if args.with_prior_preservation:
945
+ add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0)
946
+
947
+ if not args.train_text_encoder:
948
+ prompt_embeds = instance_prompt_hidden_states
949
+ unet_add_text_embeds = instance_pooled_prompt_embeds
950
+ if args.with_prior_preservation:
951
+ prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
952
+ unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
953
+ else:
954
+ tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
955
+ tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt)
956
+ if args.with_prior_preservation:
957
+ class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt)
958
+ class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt)
959
+ tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
960
+ tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
961
+
962
+ # Dataset and DataLoaders creation:
963
+ train_dataset = DreamBoothDataset(
964
+ instance_data_root=args.instance_data_dir,
965
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
966
+ class_num=args.num_class_images,
967
+ size=args.resolution,
968
+ center_crop=args.center_crop,
969
+ )
970
+
971
+ train_dataloader = torch.utils.data.DataLoader(
972
+ train_dataset,
973
+ batch_size=args.train_batch_size,
974
+ shuffle=True,
975
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
976
+ num_workers=args.dataloader_num_workers,
977
+ )
978
+
979
+ # Scheduler and math around the number of training steps.
980
+ overrode_max_train_steps = False
981
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
982
+ if args.max_train_steps is None:
983
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
984
+ overrode_max_train_steps = True
985
+
986
+ lr_scheduler = get_scheduler(
987
+ args.lr_scheduler,
988
+ optimizer=optimizer,
989
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
990
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
991
+ num_cycles=args.lr_num_cycles,
992
+ power=args.lr_power,
993
+ )
994
+
995
+ # Prepare everything with our `accelerator`.
996
+ if args.train_text_encoder:
997
+ unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
998
+ unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
999
+ )
1000
+ else:
1001
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1002
+ unet, optimizer, train_dataloader, lr_scheduler
1003
+ )
1004
+
1005
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1006
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1007
+ if overrode_max_train_steps:
1008
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1009
+ # Afterwards we recalculate our number of training epochs
1010
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1011
+
1012
+ # We need to initialize the trackers we use, and also store our configuration.
1013
+ # The trackers initializes automatically on the main process.
1014
+ if accelerator.is_main_process:
1015
+ accelerator.init_trackers("dreambooth-lora-sd-xl", config=vars(args))
1016
+
1017
+ # Train!
1018
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1019
+
1020
+ logger.info("***** Running training *****")
1021
+ logger.info(f" Num examples = {len(train_dataset)}")
1022
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1023
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1024
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1025
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1026
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1027
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1028
+ global_step = 0
1029
+ first_epoch = 0
1030
+
1031
+ # Potentially load in the weights and states from a previous save
1032
+ if args.resume_from_checkpoint:
1033
+ if args.resume_from_checkpoint != "latest":
1034
+ path = os.path.basename(args.resume_from_checkpoint)
1035
+ else:
1036
+ # Get the mos recent checkpoint
1037
+ dirs = os.listdir(args.output_dir)
1038
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1039
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1040
+ path = dirs[-1] if len(dirs) > 0 else None
1041
+
1042
+ if path is None:
1043
+ accelerator.print(
1044
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1045
+ )
1046
+ args.resume_from_checkpoint = None
1047
+ else:
1048
+ accelerator.print(f"Resuming from checkpoint {path}")
1049
+ accelerator.load_state(os.path.join(args.output_dir, path))
1050
+ global_step = int(path.split("-")[1])
1051
+
1052
+ resume_global_step = global_step * args.gradient_accumulation_steps
1053
+ first_epoch = global_step // num_update_steps_per_epoch
1054
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
1055
+
1056
+ # Only show the progress bar once on each machine.
1057
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
1058
+ progress_bar.set_description("Steps")
1059
+
1060
+ for epoch in range(first_epoch, args.num_train_epochs):
1061
+ unet.train()
1062
+ if args.train_text_encoder:
1063
+ text_encoder_one.train()
1064
+ text_encoder_two.train()
1065
+ for step, batch in enumerate(train_dataloader):
1066
+ # Skip steps until we reach the resumed step
1067
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
1068
+ if step % args.gradient_accumulation_steps == 0:
1069
+ progress_bar.update(1)
1070
+ continue
1071
+
1072
+ with accelerator.accumulate(unet):
1073
+ if args.pretrained_vae_model_name_or_path is None:
1074
+ pixel_values = batch["pixel_values"]
1075
+ else:
1076
+ pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
1077
+
1078
+ # Convert images to latent space
1079
+ model_input = vae.encode(pixel_values).latent_dist.sample()
1080
+ model_input = model_input * vae.config.scaling_factor
1081
+ if args.pretrained_vae_model_name_or_path is None:
1082
+ model_input = model_input.to(weight_dtype)
1083
+
1084
+ # Sample noise that we'll add to the latents
1085
+ noise = torch.randn_like(model_input)
1086
+ bsz = model_input.shape[0]
1087
+ # Sample a random timestep for each image
1088
+ timesteps = torch.randint(
1089
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
1090
+ )
1091
+ timesteps = timesteps.long()
1092
+
1093
+ # Add noise to the model input according to the noise magnitude at each timestep
1094
+ # (this is the forward diffusion process)
1095
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
1096
+
1097
+ # Calculate the elements to repeat depending on the use of prior-preservation.
1098
+ elems_to_repeat = bsz // 2 if args.with_prior_preservation else bsz
1099
+
1100
+ # Predict the noise residual
1101
+ if not args.train_text_encoder:
1102
+ unet_added_conditions = {
1103
+ "time_ids": add_time_ids.repeat(elems_to_repeat, 1),
1104
+ "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat, 1),
1105
+ }
1106
+ prompt_embeds = prompt_embeds.repeat(elems_to_repeat, 1, 1)
1107
+ model_pred = unet(
1108
+ noisy_model_input,
1109
+ timesteps,
1110
+ prompt_embeds,
1111
+ added_cond_kwargs=unet_added_conditions,
1112
+ ).sample
1113
+ else:
1114
+ unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat, 1)}
1115
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
1116
+ text_encoders=[text_encoder_one, text_encoder_two],
1117
+ tokenizers=None,
1118
+ prompt=None,
1119
+ text_input_ids_list=[tokens_one, tokens_two],
1120
+ )
1121
+ unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat, 1)})
1122
+ prompt_embeds = prompt_embeds.repeat(elems_to_repeat, 1, 1)
1123
+ model_pred = unet(
1124
+ noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions
1125
+ ).sample
1126
+
1127
+ # Get the target for loss depending on the prediction type
1128
+ if noise_scheduler.config.prediction_type == "epsilon":
1129
+ target = noise
1130
+ elif noise_scheduler.config.prediction_type == "v_prediction":
1131
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
1132
+ else:
1133
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1134
+
1135
+ if args.with_prior_preservation:
1136
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
1137
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
1138
+ target, target_prior = torch.chunk(target, 2, dim=0)
1139
+
1140
+ # Compute instance loss
1141
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1142
+
1143
+ # Compute prior loss
1144
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
1145
+
1146
+ # Add the prior loss to the instance loss.
1147
+ loss = loss + args.prior_loss_weight * prior_loss
1148
+ else:
1149
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1150
+
1151
+ accelerator.backward(loss)
1152
+ if accelerator.sync_gradients:
1153
+ params_to_clip = (
1154
+ itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)
1155
+ if args.train_text_encoder
1156
+ else unet_lora_parameters
1157
+ )
1158
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1159
+ optimizer.step()
1160
+ lr_scheduler.step()
1161
+ optimizer.zero_grad()
1162
+
1163
+ # Checks if the accelerator has performed an optimization step behind the scenes
1164
+ if accelerator.sync_gradients:
1165
+ progress_bar.update(1)
1166
+ global_step += 1
1167
+
1168
+ if accelerator.is_main_process:
1169
+ if global_step % args.checkpointing_steps == 0:
1170
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1171
+ if args.checkpoints_total_limit is not None:
1172
+ checkpoints = os.listdir(args.output_dir)
1173
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1174
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1175
+
1176
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1177
+ if len(checkpoints) >= args.checkpoints_total_limit:
1178
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1179
+ removing_checkpoints = checkpoints[0:num_to_remove]
1180
+
1181
+ logger.info(
1182
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1183
+ )
1184
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1185
+
1186
+ for removing_checkpoint in removing_checkpoints:
1187
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1188
+ shutil.rmtree(removing_checkpoint)
1189
+
1190
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1191
+ accelerator.save_state(save_path)
1192
+ logger.info(f"Saved state to {save_path}")
1193
+
1194
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1195
+ progress_bar.set_postfix(**logs)
1196
+ accelerator.log(logs, step=global_step)
1197
+
1198
+ if global_step >= args.max_train_steps:
1199
+ break
1200
+
1201
+ if accelerator.is_main_process:
1202
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
1203
+ logger.info(
1204
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
1205
+ f" {args.validation_prompt}."
1206
+ )
1207
+ # create pipeline
1208
+ if not args.train_text_encoder:
1209
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
1210
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
1211
+ )
1212
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
1213
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
1214
+ )
1215
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
1216
+ args.pretrained_model_name_or_path,
1217
+ vae=vae,
1218
+ text_encoder=accelerator.unwrap_model(text_encoder_one),
1219
+ text_encoder_2=accelerator.unwrap_model(text_encoder_two),
1220
+ unet=accelerator.unwrap_model(unet),
1221
+ revision=args.revision,
1222
+ torch_dtype=weight_dtype,
1223
+ )
1224
+
1225
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
1226
+ scheduler_args = {}
1227
+
1228
+ if "variance_type" in pipeline.scheduler.config:
1229
+ variance_type = pipeline.scheduler.config.variance_type
1230
+
1231
+ if variance_type in ["learned", "learned_range"]:
1232
+ variance_type = "fixed_small"
1233
+
1234
+ scheduler_args["variance_type"] = variance_type
1235
+
1236
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
1237
+ pipeline.scheduler.config, **scheduler_args
1238
+ )
1239
+
1240
+ pipeline = pipeline.to(accelerator.device)
1241
+ pipeline.set_progress_bar_config(disable=True)
1242
+
1243
+ # run inference
1244
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
1245
+ pipeline_args = {"prompt": args.validation_prompt}
1246
+
1247
+ with torch.cuda.amp.autocast():
1248
+ images = [
1249
+ pipeline(**pipeline_args, generator=generator).images[0]
1250
+ for _ in range(args.num_validation_images)
1251
+ ]
1252
+
1253
+ for tracker in accelerator.trackers:
1254
+ if tracker.name == "tensorboard":
1255
+ np_images = np.stack([np.asarray(img) for img in images])
1256
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
1257
+ if tracker.name == "wandb":
1258
+ tracker.log(
1259
+ {
1260
+ "validation": [
1261
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1262
+ for i, image in enumerate(images)
1263
+ ]
1264
+ }
1265
+ )
1266
+
1267
+ del pipeline
1268
+ torch.cuda.empty_cache()
1269
+
1270
+ # Save the lora layers
1271
+ accelerator.wait_for_everyone()
1272
+ if accelerator.is_main_process:
1273
+ unet = accelerator.unwrap_model(unet)
1274
+ unet = unet.to(torch.float32)
1275
+ unet_lora_layers = unet_attn_processors_state_dict(unet)
1276
+
1277
+ if args.train_text_encoder:
1278
+ text_encoder_one = accelerator.unwrap_model(text_encoder_one)
1279
+ text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32))
1280
+ text_encoder_two = accelerator.unwrap_model(text_encoder_two)
1281
+ text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two.to(torch.float32))
1282
+ else:
1283
+ text_encoder_lora_layers = None
1284
+ text_encoder_2_lora_layers = None
1285
+
1286
+ StableDiffusionXLPipeline.save_lora_weights(
1287
+ save_directory=args.output_dir,
1288
+ unet_lora_layers=unet_lora_layers,
1289
+ text_encoder_lora_layers=text_encoder_lora_layers,
1290
+ text_encoder_2_lora_layers=text_encoder_2_lora_layers,
1291
+ )
1292
+
1293
+ # Final inference
1294
+ # Load previous pipeline
1295
+ vae = AutoencoderKL.from_pretrained(
1296
+ vae_path,
1297
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
1298
+ revision=args.revision,
1299
+ torch_dtype=weight_dtype,
1300
+ )
1301
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
1302
+ args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype
1303
+ )
1304
+
1305
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
1306
+ scheduler_args = {}
1307
+
1308
+ if "variance_type" in pipeline.scheduler.config:
1309
+ variance_type = pipeline.scheduler.config.variance_type
1310
+
1311
+ if variance_type in ["learned", "learned_range"]:
1312
+ variance_type = "fixed_small"
1313
+
1314
+ scheduler_args["variance_type"] = variance_type
1315
+
1316
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
1317
+
1318
+ pipeline = pipeline.to(accelerator.device)
1319
+
1320
+ # load attention processors
1321
+ pipeline.load_lora_weights(args.output_dir)
1322
+
1323
+ # run inference
1324
+ images = []
1325
+ if args.validation_prompt and args.num_validation_images > 0:
1326
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
1327
+ images = [
1328
+ pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
1329
+ for _ in range(args.num_validation_images)
1330
+ ]
1331
+
1332
+ for tracker in accelerator.trackers:
1333
+ if tracker.name == "tensorboard":
1334
+ np_images = np.stack([np.asarray(img) for img in images])
1335
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
1336
+ if tracker.name == "wandb":
1337
+ tracker.log(
1338
+ {
1339
+ "test": [
1340
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1341
+ for i, image in enumerate(images)
1342
+ ]
1343
+ }
1344
+ )
1345
+
1346
+ if args.push_to_hub:
1347
+ save_model_card(
1348
+ repo_id,
1349
+ images=images,
1350
+ base_model=args.pretrained_model_name_or_path,
1351
+ train_text_encoder=args.train_text_encoder,
1352
+ prompt=args.instance_prompt,
1353
+ repo_folder=args.output_dir,
1354
+ vae_path=args.pretrained_vae_model_name_or_path,
1355
+ )
1356
+ upload_folder(
1357
+ repo_id=repo_id,
1358
+ folder_path=args.output_dir,
1359
+ commit_message="End of training",
1360
+ ignore_patterns=["step_*", "epoch_*"],
1361
+ )
1362
+
1363
+ accelerator.end_training()
1364
+
1365
+
1366
+ if __name__ == "__main__":
1367
+ args = parse_args()
1368
+ main(args)
unet/config.json ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNet2DConditionModel",
3
+ "_diffusers_version": "0.20.0.dev0",
4
+ "_name_or_path": "CompVis/stable-diffusion-v1-4",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": null,
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": null,
9
+ "attention_head_dim": 8,
10
+ "block_out_channels": [
11
+ 320,
12
+ 640,
13
+ 1280,
14
+ 1280
15
+ ],
16
+ "center_input_sample": false,
17
+ "class_embed_type": null,
18
+ "class_embeddings_concat": false,
19
+ "conv_in_kernel": 3,
20
+ "conv_out_kernel": 3,
21
+ "cross_attention_dim": 768,
22
+ "cross_attention_norm": null,
23
+ "down_block_types": [
24
+ "CrossAttnDownBlock2D",
25
+ "CrossAttnDownBlock2D",
26
+ "CrossAttnDownBlock2D",
27
+ "DownBlock2D"
28
+ ],
29
+ "downsample_padding": 1,
30
+ "dual_cross_attention": false,
31
+ "encoder_hid_dim": null,
32
+ "encoder_hid_dim_type": null,
33
+ "flip_sin_to_cos": true,
34
+ "freq_shift": 0,
35
+ "in_channels": 4,
36
+ "layers_per_block": 2,
37
+ "mid_block_only_cross_attention": null,
38
+ "mid_block_scale_factor": 1,
39
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
40
+ "norm_eps": 1e-05,
41
+ "norm_num_groups": 32,
42
+ "num_attention_heads": null,
43
+ "num_class_embeds": null,
44
+ "only_cross_attention": false,
45
+ "out_channels": 4,
46
+ "projection_class_embeddings_input_dim": null,
47
+ "resnet_out_scale_factor": 1.0,
48
+ "resnet_skip_time_act": false,
49
+ "resnet_time_scale_shift": "default",
50
+ "sample_size": 64,
51
+ "time_cond_proj_dim": null,
52
+ "time_embedding_act_fn": null,
53
+ "time_embedding_dim": null,
54
+ "time_embedding_type": "positional",
55
+ "timestep_post_act": null,
56
+ "transformer_layers_per_block": 1,
57
+ "up_block_types": [
58
+ "UpBlock2D",
59
+ "CrossAttnUpBlock2D",
60
+ "CrossAttnUpBlock2D",
61
+ "CrossAttnUpBlock2D"
62
+ ],
63
+ "upcast_attention": false,
64
+ "use_linear_projection": false
65
+ }
unet/diffusion_pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d36ad3e9ee24f61f68ef2b7cd8bc49bededf60d0a5471f62d01738048725fa7b
3
+ size 3438375973
vae/config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.20.0.dev0",
4
+ "_name_or_path": "/home/ubuntu/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/b95be7d6f134c3a9e62ee616f310733567f069ce/vae",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "force_upcast": true,
19
+ "in_channels": 3,
20
+ "latent_channels": 4,
21
+ "layers_per_block": 2,
22
+ "norm_num_groups": 32,
23
+ "out_channels": 3,
24
+ "sample_size": 512,
25
+ "scaling_factor": 0.18215,
26
+ "up_block_types": [
27
+ "UpDecoderBlock2D",
28
+ "UpDecoderBlock2D",
29
+ "UpDecoderBlock2D",
30
+ "UpDecoderBlock2D"
31
+ ]
32
+ }
vae/diffusion_pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c02e0c7c263a4c0630ca3a72380ff55b9e38e0ab41d64dff7bf620a58342bc75
3
+ size 334712113