ysharma HF staff commited on
Commit
5d28775
1 Parent(s): 47e5f02

Upload files

Browse files
.gitattributes CHANGED
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ contents/alpha_scale.gif filter=lfs diff=lfs merge=lfs -text
36
+ contents/alpha_scale.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
1
+ data_*
2
+ output_*
3
+ __pycache__
4
+ *.pyc
5
+ __test*
6
+ merged_lora*
README.md CHANGED
@@ -1,13 +1,139 @@
1
- ---
2
- title: Low Rank Adaptation
3
- emoji: 🐨
4
- colorFrom: red
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 3.12.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Low-rank Adaptation for Fast Text-to-Image Diffusion Fine-tuning
2
+
3
+ <!-- #region -->
4
+ <p align="center">
5
+ <img src="contents/alpha_scale.gif">
6
+ </p>
7
+ <!-- #endregion -->
8
+
9
+ > Using LORA to fine tune on illustration dataset : $W = W_0 + \alpha \Delta W$, where $\alpha$ is the merging ratio. Above gif is scaling alpha from 0 to 1. Setting alpha to 0 is same as using the original model, and setting alpha to 1 is same as using the fully fine-tuned model.
10
+
11
+ <!-- #region -->
12
+ <p align="center">
13
+ <img src="contents/disney_lora.jpg">
14
+ </p>
15
+ <!-- #endregion -->
16
+
17
+ > "style of sks, baby lion", with disney-style LORA model.
18
+
19
+ <!-- #region -->
20
+ <p align="center">
21
+ <img src="contents/pop_art.jpg">
22
+ </p>
23
+ <!-- #endregion -->
24
+
25
+ > "style of sks, superman", with pop-art style LORA model.
26
+
27
+ ## Main Features
28
+
29
+ - Fine-tune Stable diffusion models twice as faster than dreambooth method, by Low-rank Adaptation
30
+ - Get insanely small end result, easy to share and download.
31
+ - Easy to use, compatible with diffusers
32
+ - Sometimes even better performance than full fine-tuning (but left as future work for extensive comparisons)
33
+ - Merge checkpoints by merging LORA
34
+
35
+ # Lengthy Introduction
36
+
37
+ Thanks to the generous work of Stability AI and Huggingface, so many people have enjoyed fine-tuning stable diffusion models to fit their needs and generate higher fidelity images. **However, the fine-tuning process is very slow, and it is not easy to find a good balance between the number of steps and the quality of the results.**
38
+
39
+ Also, the final results (fully fined-tuned model) is very large. Some people instead works with textual-inversion as an alternative for this. But clearly this is suboptimal: textual inversion only creates a small word-embedding, and the final image is not as good as a fully fine-tuned model.
40
+
41
+ Well, what's the alternative? In the domain of LLM, researchers have developed Efficient fine-tuning methods. LORA, especially, tackles the very problem the community currently has: end users with Open-sourced stable-diffusion model want to try various other fine-tuned model that is created by the community, but the model is too large to download and use. LORA instead attempts to fine-tune the "residual" of the model instead of the entire model: i.e., train the $\Delta W$ instead of $W$.
42
+
43
+ $$
44
+ W' = W + \Delta W
45
+ $$
46
+
47
+ Where we can further decompose $\Delta W$ into low-rank matrices : $\Delta W = A B^T $, where $A, \in \mathbb{R}^{n \times d}, B \in \mathbb{R}^{m \times d}, d << n$.
48
+ This is the key idea of LORA. We can then fine-tune $A$ and $B$ instead of $W$. In the end, you get an insanely small model as $A$ and $B$ are much smaller than $W$.
49
+
50
+ Also, not all of the parameters need tuning: they found that often, $Q, K, V, O$ (i.e., attention layer) of the transformer model is enough to tune. (This is also the reason why the end result is so small). This repo will follow the same idea.
51
+
52
+ Enough of the lengthy introduction, let's get to the code.
53
+
54
+ # Installation
55
+
56
+ ```bash
57
+ pip install git+https://github.com/cloneofsimo/lora.git
58
+ ```
59
+
60
+ # Getting Started
61
+
62
+ ## Fine-tuning Stable diffusion with LORA.
63
+
64
+ Basic usage is as follows: prepare sets of $A, B$ matrices in an unet model, and fine-tune them.
65
+
66
+ ```python
67
+ from lora_diffusion import inject_trainable_lora, extract_lora_up_downs
68
+
69
+ ...
70
+
71
+ unet = UNet2DConditionModel.from_pretrained(
72
+ pretrained_model_name_or_path,
73
+ subfolder="unet",
74
+ )
75
+ unet.requires_grad_(False)
76
+ unet_lora_params, train_names = inject_trainable_lora(unet) # This will
77
+ # turn off all of the gradients of unet, except for the trainable LORA params.
78
+ optimizer = optim.Adam(
79
+ itertools.chain(*unet_lora_params, text_encoder.parameters()), lr=1e-4
80
+ )
81
+ ```
82
+
83
+ An example of this can be found in `train_lora_dreambooth.py`. Run this example with
84
+
85
+ ```bash
86
+ run_lora_db.sh
87
+ ```
88
+
89
+ ## Loading, merging, and interpolating trained LORAs.
90
+
91
+ We've seen that people have been merging different checkpoints with different ratios, and this seems to be very useful to the community. LORA is extremely easy to merge.
92
+
93
+ By the nature of LORA, one can interpolate between different fine-tuned models by adding different $A, B$ matrices.
94
+
95
+ Currently, LORA cli has two options : merge unet with LORA, or merge LORA with LORA.
96
+
97
+ ### Merging unet with LORA
98
+
99
+ ```bash
100
+ $ lora_add --path_1 PATH_TO_DIFFUSER_FORMAT_MODEL --path_2 PATH_TO_LORA.PT --mode upl --alpha 1.0 --output_path OUTPUT_PATH
101
+ ```
102
+
103
+ `path_1` can be both local path or huggingface model name. When adding LORA to unet, alpha is the constant as below:
104
+
105
+ $$
106
+ W' = W + \alpha \Delta W
107
+ $$
108
+
109
+ So, set alpha to 1.0 to fully add LORA. If the LORA seems to have too much effect (i.e., overfitted), set alpha to lower value. If the LORA seems to have too little effect, set alpha to higher than 1.0. You can tune these values to your needs.
110
+
111
+ **Example**
112
+
113
+ ```bash
114
+ $ lora_add --path_1 stabilityai/stable-diffusion-2-base --path_2 lora_illust.pt --mode upl --alpha 1.0 --output_path merged_model
115
+ ```
116
+
117
+ ### Merging LORA with LORA
118
+
119
+ ```bash
120
+ $ lora_add --path_1 PATH_TO_LORA.PT --path_2 PATH_TO_LORA.PT --mode lpl --alpha 0.5 --output_path OUTPUT_PATH.PT
121
+ ```
122
+
123
+ alpha is the ratio of the first model to the second model. i.e.,
124
+
125
+ $$
126
+ \Delta W = (\alpha A_1 + (1 - \alpha) A_2) (B_1 + (1 - \alpha) B_2)^T
127
+ $$
128
+
129
+ Set alpha to 0.5 to get the average of the two models. Set alpha close to 1.0 to get more effect of the first model, and set alpha close to 0.0 to get more effect of the second model.
130
+
131
+ **Example**
132
+
133
+ ```bash
134
+ $ lora_add --path_1 lora_illust.pt --path_2 lora_pop.pt --alpha 0.3 --output_path lora_merged.pt
135
+ ```
136
+
137
+ ### Making Inference with trained LORA
138
+
139
+ Checkout `scripts/run_inference.ipynb` for an example of how to make inference with LORA.
contents/alpha_scale.gif ADDED

Git LFS Details

  • SHA256: 43e9966f27a2b9823956545970d3b2ed5b2f376a1dab5d653f21a977a919e164
  • Pointer size: 132 Bytes
  • Size of remote file: 5.23 MB
contents/alpha_scale.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ad74f5f69d99bfcbeee1d4d2b3900ac1ca7ff83fba5ddf8269ffed8a56c9c6e
3
+ size 5247140
contents/disney_lora.jpg ADDED
contents/pop_art.jpg ADDED
lora_diffusion/__init__.py ADDED
@@ -0,0 +1 @@
 
1
+ from .lora import *
lora_diffusion/cli_lora_add.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Union, Dict
2
+
3
+ import fire
4
+ from diffusers import StableDiffusionPipeline
5
+
6
+ import torch
7
+ from .lora import tune_lora_scale, weight_apply_lora
8
+
9
+
10
+ def add(
11
+ path_1: str,
12
+ path_2: str,
13
+ output_path: str = "./merged_lora.pt",
14
+ alpha: float = 0.5,
15
+ mode: Literal["lpl", "upl"] = "lpl",
16
+ ):
17
+ if mode == "lpl":
18
+ out_list = []
19
+ l1 = torch.load(path_1)
20
+ l2 = torch.load(path_2)
21
+
22
+ l1pairs = zip(l1[::2], l1[1::2])
23
+ l2pairs = zip(l2[::2], l2[1::2])
24
+
25
+ for (x1, y1), (x2, y2) in zip(l1pairs, l2pairs):
26
+ x1.data = alpha * x1.data + (1 - alpha) * x2.data
27
+ y1.data = alpha * y1.data + (1 - alpha) * y2.data
28
+
29
+ out_list.append(x1)
30
+ out_list.append(y1)
31
+
32
+ torch.save(out_list, output_path)
33
+
34
+ elif mode == "upl":
35
+
36
+ loaded_pipeline = StableDiffusionPipeline.from_pretrained(
37
+ path_1,
38
+ ).to("cpu")
39
+
40
+ weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), alpha=alpha)
41
+
42
+ if output_path.endswith(".pt"):
43
+ output_path = output_path[:-3]
44
+
45
+ loaded_pipeline.save_pretrained(output_path)
46
+
47
+
48
+ def main():
49
+ fire.Fire(add)
lora_diffusion/lora.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Callable, Dict, List, Optional, Tuple
3
+
4
+ import numpy as np
5
+ import PIL
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ import torch.nn as nn
10
+
11
+
12
+ class LoraInjectedLinear(nn.Module):
13
+ def __init__(self, in_features, out_features, bias=False):
14
+ super().__init__()
15
+ self.linear = nn.Linear(in_features, out_features, bias)
16
+ self.lora_down = nn.Linear(in_features, 4, bias=False)
17
+ self.lora_up = nn.Linear(4, out_features, bias=False)
18
+ self.scale = 1.0
19
+
20
+ nn.init.normal_(self.lora_down.weight, std=1 / 16)
21
+ nn.init.zeros_(self.lora_up.weight)
22
+
23
+ def forward(self, input):
24
+ return self.linear(input) + self.lora_up(self.lora_down(input)) * self.scale
25
+
26
+
27
+ def inject_trainable_lora(
28
+ model: nn.Module, target_replace_module: List[str] = ["CrossAttention", "Attention"]
29
+ ):
30
+ """
31
+ inject lora into model, and returns lora parameter groups.
32
+ """
33
+
34
+ require_grad_params = []
35
+ names = []
36
+
37
+ for _module in model.modules():
38
+ if _module.__class__.__name__ in target_replace_module:
39
+
40
+ for name, _child_module in _module.named_modules():
41
+ if _child_module.__class__.__name__ == "Linear":
42
+
43
+ weight = _child_module.weight
44
+ bias = _child_module.bias
45
+ _tmp = LoraInjectedLinear(
46
+ _child_module.in_features,
47
+ _child_module.out_features,
48
+ _child_module.bias is not None,
49
+ )
50
+ _tmp.linear.weight = weight
51
+ if bias is not None:
52
+ _tmp.linear.bias = bias
53
+
54
+ # switch the module
55
+ _module._modules[name] = _tmp
56
+
57
+ require_grad_params.append(
58
+ _module._modules[name].lora_up.parameters()
59
+ )
60
+ require_grad_params.append(
61
+ _module._modules[name].lora_down.parameters()
62
+ )
63
+
64
+ _module._modules[name].lora_up.weight.requires_grad = True
65
+ _module._modules[name].lora_down.weight.requires_grad = True
66
+ names.append(name)
67
+
68
+ return require_grad_params, names
69
+
70
+
71
+ def extract_lora_ups_down(model, target_replace_module=["CrossAttention", "Attention"]):
72
+
73
+ loras = []
74
+
75
+ for _module in model.modules():
76
+ if _module.__class__.__name__ in target_replace_module:
77
+ for _child_module in _module.modules():
78
+ if _child_module.__class__.__name__ == "LoraInjectedLinear":
79
+ loras.append((_child_module.lora_up, _child_module.lora_down))
80
+ if len(loras) == 0:
81
+ raise ValueError("No lora injected.")
82
+ return loras
83
+
84
+
85
+ def save_lora_weight(model, path="./lora.pt"):
86
+ weights = []
87
+ for _up, _down in extract_lora_ups_down(model):
88
+ weights.append(_up.weight)
89
+ weights.append(_down.weight)
90
+
91
+ torch.save(weights, path)
92
+
93
+
94
+ def save_lora_as_json(model, path="./lora.json"):
95
+ weights = []
96
+ for _up, _down in extract_lora_ups_down(model):
97
+ weights.append(_up.weight.detach().cpu().numpy().tolist())
98
+ weights.append(_down.weight.detach().cpu().numpy().tolist())
99
+
100
+ import json
101
+
102
+ with open(path, "w") as f:
103
+ json.dump(weights, f)
104
+
105
+
106
+ def weight_apply_lora(
107
+ model, loras, target_replace_module=["CrossAttention", "Attention"], alpha=1.0
108
+ ):
109
+
110
+ for _module in model.modules():
111
+ if _module.__class__.__name__ in target_replace_module:
112
+ for _child_module in _module.modules():
113
+ if _child_module.__class__.__name__ == "Linear":
114
+
115
+ weight = _child_module.weight
116
+
117
+ up_weight = loras.pop(0).detach().to(weight.device)
118
+ down_weight = loras.pop(0).detach().to(weight.device)
119
+
120
+ # W <- W + U * D
121
+ weight = weight + alpha * (up_weight @ down_weight).type(
122
+ weight.dtype
123
+ )
124
+ _child_module.weight = nn.Parameter(weight)
125
+
126
+
127
+ def monkeypatch_lora(
128
+ model, loras, target_replace_module=["CrossAttention", "Attention"]
129
+ ):
130
+ for _module in model.modules():
131
+ if _module.__class__.__name__ in target_replace_module:
132
+ for name, _child_module in _module.named_modules():
133
+ if _child_module.__class__.__name__ == "Linear":
134
+
135
+ weight = _child_module.weight
136
+ bias = _child_module.bias
137
+ _tmp = LoraInjectedLinear(
138
+ _child_module.in_features,
139
+ _child_module.out_features,
140
+ _child_module.bias is not None,
141
+ )
142
+ _tmp.linear.weight = weight
143
+
144
+ if bias is not None:
145
+ _tmp.linear.bias = bias
146
+
147
+ # switch the module
148
+ _module._modules[name] = _tmp
149
+
150
+ up_weight = loras.pop(0)
151
+ down_weight = loras.pop(0)
152
+
153
+ _module._modules[name].lora_up.weight = nn.Parameter(
154
+ up_weight.type(weight.dtype)
155
+ )
156
+ _module._modules[name].lora_down.weight = nn.Parameter(
157
+ down_weight.type(weight.dtype)
158
+ )
159
+
160
+ _module._modules[name].to(weight.device)
161
+
162
+
163
+ def tune_lora_scale(model, alpha: float = 1.0):
164
+ for _module in model.modules():
165
+ if _module.__class__.__name__ == "LoraInjectedLinear":
166
+ _module.scale = alpha
lora_disney.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:72f687f810b86bb8cc64d2ece59886e2e96d29e3f57f97340ee147d168b8a5fe
3
+ size 3397249
lora_illust.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f6acb0bc0cd5f96299be7839f89f58727e2666e58861e55866ea02125c97aba
3
+ size 3397249
lora_pop.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18a1565852a08cfcff63e90670286c9427e3958f57de9b84e3f8b2c9a3a14b6c
3
+ size 3397249
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- diffusers["torch"]
2
  transformers
3
- git+https://github.com/huggingface/accelerate
4
- git+https://github.com/cloneofsimo/lora.git
1
+ diffusers>=0.9.0
2
  transformers
3
+ scipy
4
+ ftfy
run_lora_db.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #https://github.com/huggingface/diffusers/tree/main/examples/dreambooth
2
+ export MODEL_NAME="stabilityai/stable-diffusion-2-1-base"
3
+ export INSTANCE_DIR="./data_example"
4
+ export OUTPUT_DIR="./output_example"
5
+
6
+ accelerate launch train_lora_dreambooth.py \
7
+ --pretrained_model_name_or_path=$MODEL_NAME \
8
+ --instance_data_dir=$INSTANCE_DIR \
9
+ --output_dir=$OUTPUT_DIR \
10
+ --instance_prompt="style of sks" \
11
+ --resolution=512 \
12
+ --train_batch_size=1 \
13
+ --gradient_accumulation_steps=1 \
14
+ --learning_rate=1e-4 \
15
+ --lr_scheduler="constant" \
16
+ --lr_warmup_steps=0 \
17
+ --max_train_steps=30000
scripts/make_alpha_gifs.ipynb ADDED
The diff for this file is too large to render. See raw diff
scripts/run_inference.ipynb ADDED
The diff for this file is too large to render. See raw diff
setup.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import pkg_resources
4
+ from setuptools import find_packages, setup
5
+
6
+ setup(
7
+ name="lora_diffusion",
8
+ py_modules=["lora_diffusion"],
9
+ version="0.0.1",
10
+ description="Low Rank Adaptation for Diffusion Models. Works with Stable Diffusion out-of-the-box.",
11
+ author="Simo Ryu",
12
+ packages=find_packages(),
13
+ entry_points={
14
+ "console_scripts": [
15
+ "lora_add = lora_diffusion.cli_lora_add:main",
16
+ ],
17
+ },
18
+ install_requires=[
19
+ str(r)
20
+ for r in pkg_resources.parse_requirements(
21
+ open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
22
+ )
23
+ ],
24
+ include_package_data=True,
25
+ )
train_lora_dreambooth.py ADDED
@@ -0,0 +1,964 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Bootstrapped from:
2
+ # https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py
3
+
4
+ import argparse
5
+ import hashlib
6
+ import itertools
7
+ import math
8
+ import os
9
+ from pathlib import Path
10
+ from typing import Optional
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ import torch.utils.checkpoint
15
+
16
+
17
+ from accelerate import Accelerator
18
+ from accelerate.logging import get_logger
19
+ from accelerate.utils import set_seed
20
+ from diffusers import (
21
+ AutoencoderKL,
22
+ DDPMScheduler,
23
+ StableDiffusionPipeline,
24
+ UNet2DConditionModel,
25
+ )
26
+ from diffusers.optimization import get_scheduler
27
+ from huggingface_hub import HfFolder, Repository, whoami
28
+
29
+ from tqdm.auto import tqdm
30
+ from transformers import CLIPTextModel, CLIPTokenizer
31
+
32
+ from lora_diffusion import (
33
+ inject_trainable_lora,
34
+ save_lora_weight,
35
+ extract_lora_ups_down,
36
+ )
37
+
38
+ from torch.utils.data import Dataset
39
+ from PIL import Image
40
+ from torchvision import transforms
41
+
42
+ from pathlib import Path
43
+
44
+ import random
45
+ import re
46
+
47
+
48
+ class DreamBoothDataset(Dataset):
49
+ """
50
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
51
+ It pre-processes the images and the tokenizes prompts.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ instance_data_root,
57
+ instance_prompt,
58
+ tokenizer,
59
+ class_data_root=None,
60
+ class_prompt=None,
61
+ size=512,
62
+ center_crop=False,
63
+ ):
64
+ self.size = size
65
+ self.center_crop = center_crop
66
+ self.tokenizer = tokenizer
67
+
68
+ self.instance_data_root = Path(instance_data_root)
69
+ if not self.instance_data_root.exists():
70
+ raise ValueError("Instance images root doesn't exists.")
71
+
72
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
73
+ self.num_instance_images = len(self.instance_images_path)
74
+ self.instance_prompt = instance_prompt
75
+ self._length = self.num_instance_images
76
+
77
+ if class_data_root is not None:
78
+ self.class_data_root = Path(class_data_root)
79
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
80
+ self.class_images_path = list(self.class_data_root.iterdir())
81
+ self.num_class_images = len(self.class_images_path)
82
+ self._length = max(self.num_class_images, self.num_instance_images)
83
+ self.class_prompt = class_prompt
84
+ else:
85
+ self.class_data_root = None
86
+
87
+ self.image_transforms = transforms.Compose(
88
+ [
89
+ transforms.Resize(
90
+ size, interpolation=transforms.InterpolationMode.BILINEAR
91
+ ),
92
+ transforms.CenterCrop(size)
93
+ if center_crop
94
+ else transforms.RandomCrop(size),
95
+ transforms.ToTensor(),
96
+ transforms.Normalize([0.5], [0.5]),
97
+ ]
98
+ )
99
+
100
+ def __len__(self):
101
+ return self._length
102
+
103
+ def __getitem__(self, index):
104
+ example = {}
105
+ instance_image = Image.open(
106
+ self.instance_images_path[index % self.num_instance_images]
107
+ )
108
+ if not instance_image.mode == "RGB":
109
+ instance_image = instance_image.convert("RGB")
110
+ example["instance_images"] = self.image_transforms(instance_image)
111
+ example["instance_prompt_ids"] = self.tokenizer(
112
+ self.instance_prompt,
113
+ padding="do_not_pad",
114
+ truncation=True,
115
+ max_length=self.tokenizer.model_max_length,
116
+ ).input_ids
117
+
118
+ if self.class_data_root:
119
+ class_image = Image.open(
120
+ self.class_images_path[index % self.num_class_images]
121
+ )
122
+ if not class_image.mode == "RGB":
123
+ class_image = class_image.convert("RGB")
124
+ example["class_images"] = self.image_transforms(class_image)
125
+ example["class_prompt_ids"] = self.tokenizer(
126
+ self.class_prompt,
127
+ padding="do_not_pad",
128
+ truncation=True,
129
+ max_length=self.tokenizer.model_max_length,
130
+ ).input_ids
131
+
132
+ return example
133
+
134
+
135
+ class DreamBoothLabled(Dataset):
136
+ """
137
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
138
+ It pre-processes the images and the tokenizes prompts.
139
+ """
140
+
141
+ def __init__(
142
+ self,
143
+ instance_data_root,
144
+ instance_prompt,
145
+ tokenizer,
146
+ class_data_root=None,
147
+ class_prompt=None,
148
+ size=512,
149
+ center_crop=False,
150
+ ):
151
+ self.size = size
152
+ self.center_crop = center_crop
153
+ self.tokenizer = tokenizer
154
+
155
+ self.instance_data_root = Path(instance_data_root)
156
+ if not self.instance_data_root.exists():
157
+ raise ValueError("Instance images root doesn't exists.")
158
+
159
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
160
+ self.num_instance_images = len(self.instance_images_path)
161
+ self.instance_prompt = instance_prompt
162
+ self._length = self.num_instance_images
163
+
164
+ if class_data_root is not None:
165
+ self.class_data_root = Path(class_data_root)
166
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
167
+ self.class_images_path = list(self.class_data_root.iterdir())
168
+ self.num_class_images = len(self.class_images_path)
169
+ self._length = max(self.num_class_images, self.num_instance_images)
170
+ self.class_prompt = class_prompt
171
+ else:
172
+ self.class_data_root = None
173
+
174
+ self.image_transforms = transforms.Compose(
175
+ [
176
+ transforms.Resize(
177
+ size, interpolation=transforms.InterpolationMode.BILINEAR
178
+ ),
179
+ transforms.CenterCrop(size)
180
+ if center_crop
181
+ else transforms.RandomCrop(size),
182
+ transforms.ToTensor(),
183
+ transforms.Normalize([0.5], [0.5]),
184
+ ]
185
+ )
186
+
187
+ def __len__(self):
188
+ return self._length
189
+
190
+ def __getitem__(self, index):
191
+ example = {}
192
+ instance_image = Image.open(
193
+ self.instance_images_path[index % self.num_instance_images]
194
+ )
195
+
196
+ instance_prompt = (
197
+ str(self.instance_images_path[index % self.num_instance_images])
198
+ .split("/")[-1]
199
+ .split(".")[0]
200
+ .replace("-", " ")
201
+ )
202
+ # remove numbers in prompt
203
+ instance_prompt = re.sub(r"\d+", "", instance_prompt)
204
+ # print(instance_prompt)
205
+
206
+ _svg = random.choice(["svg", "flat color", "vector illustration", "sks"])
207
+ instance_prompt = f"{instance_prompt}, style of {_svg}"
208
+
209
+ if not instance_image.mode == "RGB":
210
+ instance_image = instance_image.convert("RGB")
211
+ example["instance_images"] = self.image_transforms(instance_image)
212
+ example["instance_prompt_ids"] = self.tokenizer(
213
+ instance_prompt,
214
+ padding="do_not_pad",
215
+ truncation=True,
216
+ max_length=self.tokenizer.model_max_length,
217
+ ).input_ids
218
+
219
+ if self.class_data_root:
220
+ class_image = Image.open(
221
+ self.class_images_path[index % self.num_class_images]
222
+ )
223
+ if not class_image.mode == "RGB":
224
+ class_image = class_image.convert("RGB")
225
+ example["class_images"] = self.image_transforms(class_image)
226
+ example["class_prompt_ids"] = self.tokenizer(
227
+ self.class_prompt,
228
+ padding="do_not_pad",
229
+ truncation=True,
230
+ max_length=self.tokenizer.model_max_length,
231
+ ).input_ids
232
+
233
+ return example
234
+
235
+
236
+ class PromptDataset(Dataset):
237
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
238
+
239
+ def __init__(self, prompt, num_samples):
240
+ self.prompt = prompt
241
+ self.num_samples = num_samples
242
+
243
+ def __len__(self):
244
+ return self.num_samples
245
+
246
+ def __getitem__(self, index):
247
+ example = {}
248
+ example["prompt"] = self.prompt
249
+ example["index"] = index
250
+ return example
251
+
252
+
253
+ logger = get_logger(__name__)
254
+
255
+
256
+ def parse_args(input_args=None):
257
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
258
+ parser.add_argument(
259
+ "--pretrained_model_name_or_path",
260
+ type=str,
261
+ default=None,
262
+ required=True,
263
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
264
+ )
265
+ parser.add_argument(
266
+ "--revision",
267
+ type=str,
268
+ default=None,
269
+ required=False,
270
+ help="Revision of pretrained model identifier from huggingface.co/models.",
271
+ )
272
+ parser.add_argument(
273
+ "--tokenizer_name",
274
+ type=str,
275
+ default=None,
276
+ help="Pretrained tokenizer name or path if not the same as model_name",
277
+ )
278
+ parser.add_argument(
279
+ "--instance_data_dir",
280
+ type=str,
281
+ default=None,
282
+ required=True,
283
+ help="A folder containing the training data of instance images.",
284
+ )
285
+ parser.add_argument(
286
+ "--class_data_dir",
287
+ type=str,
288
+ default=None,
289
+ required=False,
290
+ help="A folder containing the training data of class images.",
291
+ )
292
+ parser.add_argument(
293
+ "--instance_prompt",
294
+ type=str,
295
+ default=None,
296
+ required=True,
297
+ help="The prompt with identifier specifying the instance",
298
+ )
299
+ parser.add_argument(
300
+ "--class_prompt",
301
+ type=str,
302
+ default=None,
303
+ help="The prompt to specify images in the same class as provided instance images.",
304
+ )
305
+ parser.add_argument(
306
+ "--with_prior_preservation",
307
+ default=False,
308
+ action="store_true",
309
+ help="Flag to add prior preservation loss.",
310
+ )
311
+ parser.add_argument(
312
+ "--prior_loss_weight",
313
+ type=float,
314
+ default=1.0,
315
+ help="The weight of prior preservation loss.",
316
+ )
317
+ parser.add_argument(
318
+ "--num_class_images",
319
+ type=int,
320
+ default=100,
321
+ help=(
322
+ "Minimal class images for prior preservation loss. If not have enough images, additional images will be"
323
+ " sampled with class_prompt."
324
+ ),
325
+ )
326
+ parser.add_argument(
327
+ "--output_dir",
328
+ type=str,
329
+ default="text-inversion-model",
330
+ help="The output directory where the model predictions and checkpoints will be written.",
331
+ )
332
+ parser.add_argument(
333
+ "--seed", type=int, default=None, help="A seed for reproducible training."
334
+ )
335
+ parser.add_argument(
336
+ "--resolution",
337
+ type=int,
338
+ default=512,
339
+ help=(
340
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
341
+ " resolution"
342
+ ),
343
+ )
344
+ parser.add_argument(
345
+ "--center_crop",
346
+ action="store_true",
347
+ help="Whether to center crop images before resizing to resolution",
348
+ )
349
+ parser.add_argument(
350
+ "--train_text_encoder",
351
+ action="store_true",
352
+ help="Whether to train the text encoder",
353
+ )
354
+ parser.add_argument(
355
+ "--train_batch_size",
356
+ type=int,
357
+ default=4,
358
+ help="Batch size (per device) for the training dataloader.",
359
+ )
360
+ parser.add_argument(
361
+ "--sample_batch_size",
362
+ type=int,
363
+ default=4,
364
+ help="Batch size (per device) for sampling images.",
365
+ )
366
+ parser.add_argument("--num_train_epochs", type=int, default=1)
367
+ parser.add_argument(
368
+ "--max_train_steps",
369
+ type=int,
370
+ default=None,
371
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
372
+ )
373
+ parser.add_argument(
374
+ "--save_steps",
375
+ type=int,
376
+ default=500,
377
+ help="Save checkpoint every X updates steps.",
378
+ )
379
+ parser.add_argument(
380
+ "--gradient_accumulation_steps",
381
+ type=int,
382
+ default=1,
383
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
384
+ )
385
+ parser.add_argument(
386
+ "--gradient_checkpointing",
387
+ action="store_true",
388
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
389
+ )
390
+ parser.add_argument(
391
+ "--learning_rate",
392
+ type=float,
393
+ default=5e-6,
394
+ help="Initial learning rate (after the potential warmup period) to use.",
395
+ )
396
+ parser.add_argument(
397
+ "--scale_lr",
398
+ action="store_true",
399
+ default=False,
400
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
401
+ )
402
+ parser.add_argument(
403
+ "--lr_scheduler",
404
+ type=str,
405
+ default="constant",
406
+ help=(
407
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
408
+ ' "constant", "constant_with_warmup"]'
409
+ ),
410
+ )
411
+ parser.add_argument(
412
+ "--lr_warmup_steps",
413
+ type=int,
414
+ default=500,
415
+ help="Number of steps for the warmup in the lr scheduler.",
416
+ )
417
+ parser.add_argument(
418
+ "--use_8bit_adam",
419
+ action="store_true",
420
+ help="Whether or not to use 8-bit Adam from bitsandbytes.",
421
+ )
422
+ parser.add_argument(
423
+ "--adam_beta1",
424
+ type=float,
425
+ default=0.9,
426
+ help="The beta1 parameter for the Adam optimizer.",
427
+ )
428
+ parser.add_argument(
429
+ "--adam_beta2",
430
+ type=float,
431
+ default=0.999,
432
+ help="The beta2 parameter for the Adam optimizer.",
433
+ )
434
+ parser.add_argument(
435
+ "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use."
436
+ )
437
+ parser.add_argument(
438
+ "--adam_epsilon",
439
+ type=float,
440
+ default=1e-08,
441
+ help="Epsilon value for the Adam optimizer",
442
+ )
443
+ parser.add_argument(
444
+ "--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
445
+ )
446
+ parser.add_argument(
447
+ "--push_to_hub",
448
+ action="store_true",
449
+ help="Whether or not to push the model to the Hub.",
450
+ )
451
+ parser.add_argument(
452
+ "--hub_token",
453
+ type=str,
454
+ default=None,
455
+ help="The token to use to push to the Model Hub.",
456
+ )
457
+ parser.add_argument(
458
+ "--hub_model_id",
459
+ type=str,
460
+ default=None,
461
+ help="The name of the repository to keep in sync with the local `output_dir`.",
462
+ )
463
+ parser.add_argument(
464
+ "--logging_dir",
465
+ type=str,
466
+ default="logs",
467
+ help=(
468
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
469
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
470
+ ),
471
+ )
472
+ parser.add_argument(
473
+ "--mixed_precision",
474
+ type=str,
475
+ default=None,
476
+ choices=["no", "fp16", "bf16"],
477
+ help=(
478
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
479
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
480
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
481
+ ),
482
+ )
483
+ parser.add_argument(
484
+ "--local_rank",
485
+ type=int,
486
+ default=-1,
487
+ help="For distributed training: local_rank",
488
+ )
489
+
490
+ if input_args is not None:
491
+ args = parser.parse_args(input_args)
492
+ else:
493
+ args = parser.parse_args()
494
+
495
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
496
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
497
+ args.local_rank = env_local_rank
498
+
499
+ if args.with_prior_preservation:
500
+ if args.class_data_dir is None:
501
+ raise ValueError("You must specify a data directory for class images.")
502
+ if args.class_prompt is None:
503
+ raise ValueError("You must specify prompt for class images.")
504
+ else:
505
+ if args.class_data_dir is not None:
506
+ logger.warning(
507
+ "You need not use --class_data_dir without --with_prior_preservation."
508
+ )
509
+ if args.class_prompt is not None:
510
+ logger.warning(
511
+ "You need not use --class_prompt without --with_prior_preservation."
512
+ )
513
+
514
+ return args
515
+
516
+
517
+ def get_full_repo_name(
518
+ model_id: str, organization: Optional[str] = None, token: Optional[str] = None
519
+ ):
520
+ if token is None:
521
+ token = HfFolder.get_token()
522
+ if organization is None:
523
+ username = whoami(token)["name"]
524
+ return f"{username}/{model_id}"
525
+ else:
526
+ return f"{organization}/{model_id}"
527
+
528
+
529
+ def main(args):
530
+ logging_dir = Path(args.output_dir, args.logging_dir)
531
+
532
+ accelerator = Accelerator(
533
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
534
+ mixed_precision=args.mixed_precision,
535
+ log_with="tensorboard",
536
+ logging_dir=logging_dir,
537
+ )
538
+
539
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
540
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
541
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
542
+ if (
543
+ args.train_text_encoder
544
+ and args.gradient_accumulation_steps > 1
545
+ and accelerator.num_processes > 1
546
+ ):
547
+ raise ValueError(
548
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
549
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
550
+ )
551
+
552
+ if args.seed is not None:
553
+ set_seed(args.seed)
554
+
555
+ if args.with_prior_preservation:
556
+ class_images_dir = Path(args.class_data_dir)
557
+ if not class_images_dir.exists():
558
+ class_images_dir.mkdir(parents=True)
559
+ cur_class_images = len(list(class_images_dir.iterdir()))
560
+
561
+ if cur_class_images < args.num_class_images:
562
+ torch_dtype = (
563
+ torch.float16 if accelerator.device.type == "cuda" else torch.float32
564
+ )
565
+ pipeline = StableDiffusionPipeline.from_pretrained(
566
+ args.pretrained_model_name_or_path,
567
+ torch_dtype=torch_dtype,
568
+ safety_checker=None,
569
+ revision=args.revision,
570
+ )
571
+ pipeline.set_progress_bar_config(disable=True)
572
+
573
+ num_new_images = args.num_class_images - cur_class_images
574
+ logger.info(f"Number of class images to sample: {num_new_images}.")
575
+
576
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
577
+ sample_dataloader = torch.utils.data.DataLoader(
578
+ sample_dataset, batch_size=args.sample_batch_size
579
+ )
580
+
581
+ sample_dataloader = accelerator.prepare(sample_dataloader)
582
+ pipeline.to(accelerator.device)
583
+
584
+ for example in tqdm(
585
+ sample_dataloader,
586
+ desc="Generating class images",
587
+ disable=not accelerator.is_local_main_process,
588
+ ):
589
+ images = pipeline(example["prompt"]).images
590
+
591
+ for i, image in enumerate(images):
592
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
593
+ image_filename = (
594
+ class_images_dir
595
+ / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
596
+ )
597
+ image.save(image_filename)
598
+
599
+ del pipeline
600
+ if torch.cuda.is_available():
601
+ torch.cuda.empty_cache()
602
+
603
+ # Handle the repository creation
604
+ if accelerator.is_main_process:
605
+ if args.push_to_hub:
606
+ if args.hub_model_id is None:
607
+ repo_name = get_full_repo_name(
608
+ Path(args.output_dir).name, token=args.hub_token
609
+ )
610
+ else:
611
+ repo_name = args.hub_model_id
612
+ repo = Repository(args.output_dir, clone_from=repo_name)
613
+
614
+ with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
615
+ if "step_*" not in gitignore:
616
+ gitignore.write("step_*\n")
617
+ if "epoch_*" not in gitignore:
618
+ gitignore.write("epoch_*\n")
619
+ elif args.output_dir is not None:
620
+ os.makedirs(args.output_dir, exist_ok=True)
621
+
622
+ # Load the tokenizer
623
+ if args.tokenizer_name:
624
+ tokenizer = CLIPTokenizer.from_pretrained(
625
+ args.tokenizer_name,
626
+ revision=args.revision,
627
+ )
628
+ elif args.pretrained_model_name_or_path:
629
+ tokenizer = CLIPTokenizer.from_pretrained(
630
+ args.pretrained_model_name_or_path,
631
+ subfolder="tokenizer",
632
+ revision=args.revision,
633
+ )
634
+
635
+ # Load models and create wrapper for stable diffusion
636
+ text_encoder = CLIPTextModel.from_pretrained(
637
+ args.pretrained_model_name_or_path,
638
+ subfolder="text_encoder",
639
+ revision=args.revision,
640
+ )
641
+ vae = AutoencoderKL.from_pretrained(
642
+ args.pretrained_model_name_or_path,
643
+ subfolder="vae",
644
+ revision=args.revision,
645
+ )
646
+ unet = UNet2DConditionModel.from_pretrained(
647
+ args.pretrained_model_name_or_path,
648
+ subfolder="unet",
649
+ revision=args.revision,
650
+ )
651
+ unet.requires_grad_(False)
652
+ unet_lora_params, train_names = inject_trainable_lora(unet)
653
+
654
+ for _up, _down in extract_lora_ups_down(unet):
655
+ print(_up.weight)
656
+ print(_down.weight)
657
+ break
658
+
659
+ vae.requires_grad_(False)
660
+ if not args.train_text_encoder:
661
+ text_encoder.requires_grad_(False)
662
+
663
+ if args.gradient_checkpointing:
664
+ unet.enable_gradient_checkpointing()
665
+ if args.train_text_encoder:
666
+ text_encoder.gradient_checkpointing_enable()
667
+
668
+ if args.scale_lr:
669
+ args.learning_rate = (
670
+ args.learning_rate
671
+ * args.gradient_accumulation_steps
672
+ * args.train_batch_size
673
+ * accelerator.num_processes
674
+ )
675
+
676
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
677
+ if args.use_8bit_adam:
678
+ try:
679
+ import bitsandbytes as bnb
680
+ except ImportError:
681
+ raise ImportError(
682
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
683
+ )
684
+
685
+ optimizer_class = bnb.optim.AdamW8bit
686
+ else:
687
+ optimizer_class = torch.optim.AdamW
688
+
689
+ params_to_optimize = (
690
+ itertools.chain(*unet_lora_params, text_encoder.parameters())
691
+ if args.train_text_encoder
692
+ else itertools.chain(*unet_lora_params)
693
+ )
694
+ optimizer = optimizer_class(
695
+ params_to_optimize,
696
+ lr=args.learning_rate,
697
+ betas=(args.adam_beta1, args.adam_beta2),
698
+ weight_decay=args.adam_weight_decay,
699
+ eps=args.adam_epsilon,
700
+ )
701
+
702
+ noise_scheduler = DDPMScheduler.from_config(
703
+ args.pretrained_model_name_or_path, subfolder="scheduler"
704
+ )
705
+
706
+ train_dataset = DreamBoothDataset(
707
+ instance_data_root=args.instance_data_dir,
708
+ instance_prompt=args.instance_prompt,
709
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
710
+ class_prompt=args.class_prompt,
711
+ tokenizer=tokenizer,
712
+ size=args.resolution,
713
+ center_crop=args.center_crop,
714
+ )
715
+
716
+ def collate_fn(examples):
717
+ input_ids = [example["instance_prompt_ids"] for example in examples]
718
+ pixel_values = [example["instance_images"] for example in examples]
719
+
720
+ # Concat class and instance examples for prior preservation.
721
+ # We do this to avoid doing two forward passes.
722
+ if args.with_prior_preservation:
723
+ input_ids += [example["class_prompt_ids"] for example in examples]
724
+ pixel_values += [example["class_images"] for example in examples]
725
+
726
+ pixel_values = torch.stack(pixel_values)
727
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
728
+
729
+ input_ids = tokenizer.pad(
730
+ {"input_ids": input_ids},
731
+ padding="max_length",
732
+ max_length=tokenizer.model_max_length,
733
+ return_tensors="pt",
734
+ ).input_ids
735
+
736
+ batch = {
737
+ "input_ids": input_ids,
738
+ "pixel_values": pixel_values,
739
+ }
740
+ return batch
741
+
742
+ train_dataloader = torch.utils.data.DataLoader(
743
+ train_dataset,
744
+ batch_size=args.train_batch_size,
745
+ shuffle=True,
746
+ collate_fn=collate_fn,
747
+ num_workers=1,
748
+ )
749
+
750
+ # Scheduler and math around the number of training steps.
751
+ overrode_max_train_steps = False
752
+ num_update_steps_per_epoch = math.ceil(
753
+ len(train_dataloader) / args.gradient_accumulation_steps
754
+ )
755
+ if args.max_train_steps is None:
756
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
757
+ overrode_max_train_steps = True
758
+
759
+ lr_scheduler = get_scheduler(
760
+ args.lr_scheduler,
761
+ optimizer=optimizer,
762
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
763
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
764
+ )
765
+
766
+ if args.train_text_encoder:
767
+ (
768
+ unet,
769
+ text_encoder,
770
+ optimizer,
771
+ train_dataloader,
772
+ lr_scheduler,
773
+ ) = accelerator.prepare(
774
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
775
+ )
776
+ else:
777
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
778
+ unet, optimizer, train_dataloader, lr_scheduler
779
+ )
780
+
781
+ weight_dtype = torch.float32
782
+ if accelerator.mixed_precision == "fp16":
783
+ weight_dtype = torch.float16
784
+ elif accelerator.mixed_precision == "bf16":
785
+ weight_dtype = torch.bfloat16
786
+
787
+ # Move text_encode and vae to gpu.
788
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
789
+ # as these models are only used for inference, keeping weights in full precision is not required.
790
+ vae.to(accelerator.device, dtype=weight_dtype)
791
+ if not args.train_text_encoder:
792
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
793
+
794
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
795
+ num_update_steps_per_epoch = math.ceil(
796
+ len(train_dataloader) / args.gradient_accumulation_steps
797
+ )
798
+ if overrode_max_train_steps:
799
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
800
+ # Afterwards we recalculate our number of training epochs
801
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
802
+
803
+ # We need to initialize the trackers we use, and also store our configuration.
804
+ # The trackers initializes automatically on the main process.
805
+ if accelerator.is_main_process:
806
+ accelerator.init_trackers("dreambooth", config=vars(args))
807
+
808
+ # Train!
809
+ total_batch_size = (
810
+ args.train_batch_size
811
+ * accelerator.num_processes
812
+ * args.gradient_accumulation_steps
813
+ )
814
+
815
+ print("***** Running training *****")
816
+ print(f" Num examples = {len(train_dataset)}")
817
+ print(f" Num batches each epoch = {len(train_dataloader)}")
818
+ print(f" Num Epochs = {args.num_train_epochs}")
819
+ print(f" Instantaneous batch size per device = {args.train_batch_size}")
820
+ print(
821
+ f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
822
+ )
823
+ print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
824
+ print(f" Total optimization steps = {args.max_train_steps}")
825
+ # Only show the progress bar once on each machine.
826
+ progress_bar = tqdm(
827
+ range(args.max_train_steps), disable=not accelerator.is_local_main_process
828
+ )
829
+ progress_bar.set_description("Steps")
830
+ global_step = 0
831
+
832
+ for epoch in range(args.num_train_epochs):
833
+ unet.train()
834
+ if args.train_text_encoder:
835
+ text_encoder.train()
836
+ for step, batch in enumerate(train_dataloader):
837
+
838
+ # Convert images to latent space
839
+ latents = vae.encode(
840
+ batch["pixel_values"].to(dtype=weight_dtype)
841
+ ).latent_dist.sample()
842
+ latents = latents * 0.18215
843
+
844
+ # Sample noise that we'll add to the latents
845
+ noise = torch.randn_like(latents)
846
+ bsz = latents.shape[0]
847
+ # Sample a random timestep for each image
848
+ timesteps = torch.randint(
849
+ 0,
850
+ noise_scheduler.config.num_train_timesteps,
851
+ (bsz,),
852
+ device=latents.device,
853
+ )
854
+ timesteps = timesteps.long()
855
+
856
+ # Add noise to the latents according to the noise magnitude at each timestep
857
+ # (this is the forward diffusion process)
858
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
859
+
860
+ # Get the text embedding for conditioning
861
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
862
+
863
+ # Predict the noise residual
864
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
865
+
866
+ # Get the target for loss depending on the prediction type
867
+ if noise_scheduler.config.prediction_type == "epsilon":
868
+ target = noise
869
+ elif noise_scheduler.config.prediction_type == "v_prediction":
870
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
871
+ else:
872
+ raise ValueError(
873
+ f"Unknown prediction type {noise_scheduler.config.prediction_type}"
874
+ )
875
+
876
+ if args.with_prior_preservation:
877
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
878
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
879
+ target, target_prior = torch.chunk(target, 2, dim=0)
880
+
881
+ # Compute instance loss
882
+ loss = (
883
+ F.mse_loss(model_pred.float(), target.float(), reduction="none")
884
+ .mean([1, 2, 3])
885
+ .mean()
886
+ )
887
+
888
+ # Compute prior loss
889
+ prior_loss = F.mse_loss(
890
+ model_pred_prior.float(), target_prior.float(), reduction="mean"
891
+ )
892
+
893
+ # Add the prior loss to the instance loss.
894
+ loss = loss + args.prior_loss_weight * prior_loss
895
+ else:
896
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
897
+
898
+ accelerator.backward(loss)
899
+ if accelerator.sync_gradients:
900
+ params_to_clip = (
901
+ itertools.chain(unet.parameters(), text_encoder.parameters())
902
+ if args.train_text_encoder
903
+ else unet.parameters()
904
+ )
905
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
906
+ optimizer.step()
907
+ lr_scheduler.step()
908
+ progress_bar.update(1)
909
+ optimizer.zero_grad()
910
+
911
+ # Checks if the accelerator has performed an optimization step behind the scenes
912
+ if accelerator.sync_gradients:
913
+
914
+ global_step += 1
915
+
916
+ if global_step % args.save_steps == 0:
917
+ if accelerator.is_main_process:
918
+ pipeline = StableDiffusionPipeline.from_pretrained(
919
+ args.pretrained_model_name_or_path,
920
+ unet=accelerator.unwrap_model(unet),
921
+ text_encoder=accelerator.unwrap_model(text_encoder),
922
+ revision=args.revision,
923
+ )
924
+
925
+ save_lora_weight(pipeline.unet, args.output_dir + "/lora_weight.pt")
926
+
927
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
928
+ progress_bar.set_postfix(**logs)
929
+ accelerator.log(logs, step=global_step)
930
+
931
+ if global_step >= args.max_train_steps:
932
+ break
933
+
934
+ accelerator.wait_for_everyone()
935
+
936
+ # Create the pipeline using using the trained modules and save it.
937
+ if accelerator.is_main_process:
938
+ pipeline = StableDiffusionPipeline.from_pretrained(
939
+ args.pretrained_model_name_or_path,
940
+ unet=accelerator.unwrap_model(unet),
941
+ text_encoder=accelerator.unwrap_model(text_encoder),
942
+ revision=args.revision,
943
+ )
944
+
945
+ print("\n\nLora TRAINING DONE!\n\n")
946
+
947
+ save_lora_weight(pipeline.unet, args.output_dir + "/lora_weight.pt")
948
+
949
+ for _up, _down in extract_lora_ups_down(pipeline.unet):
950
+ print("First Layer's Up Weight is now : ", _up.weight)
951
+ print("First Layer's Down Weight is now : ", _down.weight)
952
+ break
953
+
954
+ if args.push_to_hub:
955
+ repo.push_to_hub(
956
+ commit_message="End of training", blocking=False, auto_lfs_prune=True
957
+ )
958
+
959
+ accelerator.end_training()
960
+
961
+
962
+ if __name__ == "__main__":
963
+ args = parse_args()
964
+ main(args)