andzhang01
commited on
Commit
•
a4be04d
1
Parent(s):
35df028
Upload 42 files
Browse files- .gitattributes +4 -34
- .gitignore +11 -0
- LICENSE.md +201 -0
- README-ja.md +147 -0
- README.md +338 -0
- XTI_hijack.py +209 -0
- _typos.toml +15 -0
- activate.bat +5 -0
- activate.ps1 +1 -0
- config_README-ja.md +279 -0
- dreambooth_gui.py +954 -0
- fine_tune.py +440 -0
- fine_tune_README.md +465 -0
- fine_tune_README_ja.md +140 -0
- finetune_gui.py +900 -0
- gen_img_diffusers.py +0 -0
- gui.bat +16 -0
- gui.ps1 +21 -0
- gui.sh +15 -0
- kohya_gui.py +114 -0
- lora_gui.py +1294 -0
- requirements.txt +34 -0
- setup.bat +48 -0
- setup.py +10 -0
- setup.sh +648 -0
- style.css +21 -0
- textual_inversion_gui.py +1014 -0
- train_README-ja.md +945 -0
- train_db.py +437 -0
- train_db_README-ja.md +167 -0
- train_db_README.md +295 -0
- train_network.py +773 -0
- train_network_README-ja.md +479 -0
- train_network_README.md +189 -0
- train_textual_inversion.py +598 -0
- train_textual_inversion_XTI.py +650 -0
- train_ti_README-ja.md +105 -0
- train_ti_README.md +62 -0
- upgrade.bat +16 -0
- upgrade.ps1 +14 -0
- upgrade.sh +16 -0
- utilities.cmd +1 -0
.gitattributes
CHANGED
@@ -1,34 +1,4 @@
|
|
1 |
-
*.
|
2 |
-
*.
|
3 |
-
*.
|
4 |
-
*.
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.sh text eol=lf
|
2 |
+
*.ps1 text eol=crlf
|
3 |
+
*.bat text eol=crlf
|
4 |
+
*.cmd text eol=crlf
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
venv
|
2 |
+
__pycache__
|
3 |
+
cudnn_windows
|
4 |
+
.vscode
|
5 |
+
*.egg-info
|
6 |
+
build
|
7 |
+
wd14_tagger_model
|
8 |
+
.DS_Store
|
9 |
+
locon
|
10 |
+
gui-user.bat
|
11 |
+
gui-user.ps1
|
LICENSE.md
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [2022] [kohya-ss]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README-ja.md
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## リポジトリについて
|
2 |
+
Stable Diffusionの学習、画像生成、その他のスクリプトを入れたリポジトリです。
|
3 |
+
|
4 |
+
[README in English](./README.md) ←更新情報はこちらにあります
|
5 |
+
|
6 |
+
GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています(英語です)のであわせてご覧ください。bmaltais氏に感謝します。
|
7 |
+
|
8 |
+
以下のスクリプトがあります。
|
9 |
+
|
10 |
+
* DreamBooth、U-NetおよびText Encoderの学習をサポート
|
11 |
+
* fine-tuning、同上
|
12 |
+
* 画像生成
|
13 |
+
* モデル変換(Stable Diffision ckpt/safetensorsとDiffusersの相互変換)
|
14 |
+
|
15 |
+
## 使用法について
|
16 |
+
|
17 |
+
当リポジトリ内およびnote.comに記事がありますのでそちらをご覧ください(将来的にはすべてこちらへ移すかもしれません)。
|
18 |
+
|
19 |
+
* [学習について、共通編](./train_README-ja.md) : データ整備やオプションなど
|
20 |
+
* [データセット設定](./config_README-ja.md)
|
21 |
+
* [DreamBoothの学習について](./train_db_README-ja.md)
|
22 |
+
* [fine-tuningのガイド](./fine_tune_README_ja.md):
|
23 |
+
* [LoRAの学習について](./train_network_README-ja.md)
|
24 |
+
* [Textual Inversionの学習について](./train_ti_README-ja.md)
|
25 |
+
* note.com [画像生成スクリプト](https://note.com/kohya_ss/n/n2693183a798e)
|
26 |
+
* note.com [モデル変換スクリプト](https://note.com/kohya_ss/n/n374f316fe4ad)
|
27 |
+
|
28 |
+
## Windowsでの動作に必要なプログラム
|
29 |
+
|
30 |
+
Python 3.10.6およびGitが必要です。
|
31 |
+
|
32 |
+
- Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe
|
33 |
+
- git: https://git-scm.com/download/win
|
34 |
+
|
35 |
+
PowerShellを使う場合、venvを使えるようにするためには以下の手順でセキュリティ設定を変更してください。
|
36 |
+
(venvに限らずスクリプトの実行が可能になりますので注意してください。)
|
37 |
+
|
38 |
+
- PowerShellを管理者として開きます。
|
39 |
+
- 「Set-ExecutionPolicy Unrestricted」と入力し、Yと答えます。
|
40 |
+
- 管理者のPowerShellを閉じます。
|
41 |
+
|
42 |
+
## Windows環境でのインストール
|
43 |
+
|
44 |
+
以下の例ではPyTorchは1.12.1/CUDA 11.6版をインストールします。CUDA 11.3版やPyTorch 1.13を使う場合は適宜書き換えください。
|
45 |
+
|
46 |
+
(なお、python -m venv~の行で「python」とだけ表示された場合、py -m venv~のようにpythonをpyに変更してください。)
|
47 |
+
|
48 |
+
通常の(管理者ではない)PowerShellを開き以下を順に実行します。
|
49 |
+
|
50 |
+
```powershell
|
51 |
+
git clone https://github.com/kohya-ss/sd-scripts.git
|
52 |
+
cd sd-scripts
|
53 |
+
|
54 |
+
python -m venv venv
|
55 |
+
.\venv\Scripts\activate
|
56 |
+
|
57 |
+
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
|
58 |
+
pip install --upgrade -r requirements.txt
|
59 |
+
pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
|
60 |
+
|
61 |
+
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
|
62 |
+
cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
|
63 |
+
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
|
64 |
+
|
65 |
+
accelerate config
|
66 |
+
```
|
67 |
+
|
68 |
+
<!--
|
69 |
+
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117
|
70 |
+
pip install --use-pep517 --upgrade -r requirements.txt
|
71 |
+
pip install -U -I --no-deps xformers==0.0.16
|
72 |
+
-->
|
73 |
+
|
74 |
+
コマンドプロンプトでは以下になります。
|
75 |
+
|
76 |
+
|
77 |
+
```bat
|
78 |
+
git clone https://github.com/kohya-ss/sd-scripts.git
|
79 |
+
cd sd-scripts
|
80 |
+
|
81 |
+
python -m venv venv
|
82 |
+
.\venv\Scripts\activate
|
83 |
+
|
84 |
+
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
|
85 |
+
pip install --upgrade -r requirements.txt
|
86 |
+
pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
|
87 |
+
|
88 |
+
copy /y .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
|
89 |
+
copy /y .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
|
90 |
+
copy /y .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
|
91 |
+
|
92 |
+
accelerate config
|
93 |
+
```
|
94 |
+
|
95 |
+
(注:``python -m venv venv`` のほうが ``python -m venv --system-site-packages venv`` より安全そうなため書き換えました。globalなpythonにパッケージがインストールしてあると、後者だといろいろと問題が起きます。)
|
96 |
+
|
97 |
+
accelerate configの質問には以下のように答えてください。(bf16で学習する場合、最後の質問にはbf16と答えてください。)
|
98 |
+
|
99 |
+
※0.15.0から日本語環境では選択のためにカーソルキーを押すと落ちます(……)。数字キーの0、1、2……で選択できますので、そちらを使ってください。
|
100 |
+
|
101 |
+
```txt
|
102 |
+
- This machine
|
103 |
+
- No distributed training
|
104 |
+
- NO
|
105 |
+
- NO
|
106 |
+
- NO
|
107 |
+
- all
|
108 |
+
- fp16
|
109 |
+
```
|
110 |
+
|
111 |
+
※場合によって ``ValueError: fp16 mixed precision requires a GPU`` というエラーが出ることがあるようです。この場合、6番目の質問(
|
112 |
+
``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``)に「0」と答えてください。(id `0`のGPUが使われます。)
|
113 |
+
|
114 |
+
### PyTorchとxformersのバージョンについて
|
115 |
+
|
116 |
+
他のバージョンでは学習がうまくいかない場合があるようです。特に他の理由がなければ指定のバージョンをお使いください。
|
117 |
+
|
118 |
+
## アップグレード
|
119 |
+
|
120 |
+
新しいリリースがあった場合、以下のコマンドで更新できます。
|
121 |
+
|
122 |
+
```powershell
|
123 |
+
cd sd-scripts
|
124 |
+
git pull
|
125 |
+
.\venv\Scripts\activate
|
126 |
+
pip install --use-pep517 --upgrade -r requirements.txt
|
127 |
+
```
|
128 |
+
|
129 |
+
コマンドが成功すれば新しいバージョンが使用できます。
|
130 |
+
|
131 |
+
## 謝意
|
132 |
+
|
133 |
+
LoRAの実装は[cloneofsimo氏のリポジトリ](https://github.com/cloneofsimo/lora)を基にしたものです。感謝申し上げます。
|
134 |
+
|
135 |
+
Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora) が最初にリリースし、KohakuBlueleaf氏が [LoCon](https://github.com/KohakuBlueleaf/LoCon) でその有効性を明らかにしたものです。KohakuBlueleaf氏に深く感謝します。
|
136 |
+
|
137 |
+
## ライセンス
|
138 |
+
|
139 |
+
スクリプトのライセンスはASL 2.0ですが(Diffusersおよびcloneofsimo氏のリポジトリ由来のものも同様)、一部他のライセンスのコードを含みます。
|
140 |
+
|
141 |
+
[Memory Efficient Attention Pytorch](https://github.com/lucidrains/memory-efficient-attention-pytorch): MIT
|
142 |
+
|
143 |
+
[bitsandbytes](https://github.com/TimDettmers/bitsandbytes): MIT
|
144 |
+
|
145 |
+
[BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause
|
146 |
+
|
147 |
+
|
README.md
ADDED
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Kohya's GUI
|
2 |
+
|
3 |
+
This repository provides a Windows-focused Gradio GUI for [Kohya's Stable Diffusion trainers](https://github.com/kohya-ss/sd-scripts). The GUI allows you to set the training parameters and generate and run the required CLI commands to train the model.
|
4 |
+
|
5 |
+
If you run on Linux and would like to use the GUI, there is now a port of it as a docker container. You can find the project [here](https://github.com/P2Enjoy/kohya_ss-docker).
|
6 |
+
|
7 |
+
### Table of Contents
|
8 |
+
|
9 |
+
- [Tutorials](#tutorials)
|
10 |
+
- [Required Dependencies](#required-dependencies)
|
11 |
+
- [Linux/macOS](#linux-and-macos-dependencies)
|
12 |
+
- [Installation](#installation)
|
13 |
+
- [Linux/macOS](#linux-and-macos)
|
14 |
+
- [Default Install Locations](#install-location)
|
15 |
+
- [Windows](#windows)
|
16 |
+
- [CUDNN 8.6](#optional--cudnn-86)
|
17 |
+
- [Upgrading](#upgrading)
|
18 |
+
- [Windows](#windows-upgrade)
|
19 |
+
- [Linux/macOS](#linux-and-macos-upgrade)
|
20 |
+
- [Launching the GUI](#starting-gui-service)
|
21 |
+
- [Windows](#launching-the-gui-on-windows)
|
22 |
+
- [Linux/macOS](#launching-the-gui-on-linux-and-macos)
|
23 |
+
- [Direct Launch via Python Script](#launching-the-gui-directly-using-kohyaguipy)
|
24 |
+
- [Dreambooth](#dreambooth)
|
25 |
+
- [Finetune](#finetune)
|
26 |
+
- [Train Network](#train-network)
|
27 |
+
- [LoRA](#lora)
|
28 |
+
- [Troubleshooting](#troubleshooting)
|
29 |
+
- [Page File Limit](#page-file-limit)
|
30 |
+
- [No module called tkinter](#no-module-called-tkinter)
|
31 |
+
- [FileNotFoundError](#filenotfounderror)
|
32 |
+
- [Change History](#change-history)
|
33 |
+
|
34 |
+
## Tutorials
|
35 |
+
|
36 |
+
[How to Create a LoRA Part 1: Dataset Preparation](https://www.youtube.com/watch?v=N4_-fB62Hwk):
|
37 |
+
|
38 |
+
[![LoRA Part 1 Tutorial](https://img.youtube.com/vi/N4_-fB62Hwk/0.jpg)](https://www.youtube.com/watch?v=N4_-fB62Hwk)
|
39 |
+
|
40 |
+
[How to Create a LoRA Part 2: Training the Model](https://www.youtube.com/watch?v=k5imq01uvUY):
|
41 |
+
|
42 |
+
[![LoRA Part 2 Tutorial](https://img.youtube.com/vi/k5imq01uvUY/0.jpg)](https://www.youtube.com/watch?v=k5imq01uvUY)
|
43 |
+
|
44 |
+
## Required Dependencies
|
45 |
+
|
46 |
+
- Install [Python 3.10](https://www.python.org/ftp/python/3.10.9/python-3.10.9-amd64.exe)
|
47 |
+
- make sure to tick the box to add Python to the 'PATH' environment variable
|
48 |
+
- Install [Git](https://git-scm.com/download/win)
|
49 |
+
- Install [Visual Studio 2015, 2017, 2019, and 2022 redistributable](https://aka.ms/vs/17/release/vc_redist.x64.exe)
|
50 |
+
|
51 |
+
### Linux and macOS dependencies
|
52 |
+
|
53 |
+
These dependencies are taken care of via `setup.sh` in the installation section. No additional steps should be needed unless the scripts inform you otherwise.
|
54 |
+
|
55 |
+
## Installation
|
56 |
+
|
57 |
+
### Runpod
|
58 |
+
Follow the instructions found in this discussion: https://github.com/bmaltais/kohya_ss/discussions/379
|
59 |
+
|
60 |
+
### Linux and macOS
|
61 |
+
In the terminal, run
|
62 |
+
|
63 |
+
```
|
64 |
+
git clone https://github.com/bmaltais/kohya_ss.git
|
65 |
+
cd kohya_ss
|
66 |
+
# May need to chmod +x ./setup.sh if you're on a machine with stricter security.
|
67 |
+
# There are additional options if needed for a runpod environment.
|
68 |
+
# Call 'setup.sh -h' or 'setup.sh --help' for more information.
|
69 |
+
./setup.sh
|
70 |
+
```
|
71 |
+
|
72 |
+
Setup.sh help included here:
|
73 |
+
|
74 |
+
```bash
|
75 |
+
Kohya_SS Installation Script for POSIX operating systems.
|
76 |
+
|
77 |
+
The following options are useful in a runpod environment,
|
78 |
+
but will not affect a local machine install.
|
79 |
+
|
80 |
+
Usage:
|
81 |
+
setup.sh -b dev -d /workspace/kohya_ss -g https://mycustom.repo.tld/custom_fork.git
|
82 |
+
setup.sh --branch=dev --dir=/workspace/kohya_ss --git-repo=https://mycustom.repo.tld/custom_fork.git
|
83 |
+
|
84 |
+
Options:
|
85 |
+
-b BRANCH, --branch=BRANCH Select which branch of kohya to check out on new installs.
|
86 |
+
-d DIR, --dir=DIR The full path you want kohya_ss installed to.
|
87 |
+
-g REPO, --git_repo=REPO You can optionally provide a git repo to check out for runpod installation. Useful for custom forks.
|
88 |
+
-h, --help Show this screen.
|
89 |
+
-i, --interactive Interactively configure accelerate instead of using default config file.
|
90 |
+
-n, --no-update Do not update kohya_ss repo. No git pull or clone operations.
|
91 |
+
-p, --public Expose public URL in runpod mode. Won't have an effect in other modes.
|
92 |
+
-r, --runpod Forces a runpod installation. Useful if detection fails for any reason.
|
93 |
+
-s, --skip-space-check Skip the 10Gb minimum storage space check.
|
94 |
+
-u, --no-gui Skips launching the GUI.
|
95 |
+
-v, --verbose Increase verbosity levels up to 3.
|
96 |
+
```
|
97 |
+
|
98 |
+
#### Install location
|
99 |
+
|
100 |
+
The default install location for Linux is where the script is located if a previous installation is detected that location.
|
101 |
+
Otherwise, it will fall to `/opt/kohya_ss`. If /opt is not writeable, the fallback is `$HOME/kohya_ss`. Lastly, if all else fails it will simply install to the current folder you are in (PWD).
|
102 |
+
|
103 |
+
On macOS and other non-Linux machines, it will first try to detect an install where the script is run from and then run setup there if that's detected.
|
104 |
+
If a previous install isn't found at that location, then it will default install to `$HOME/kohya_ss` followed by where you're currently at if there's no access to $HOME.
|
105 |
+
You can override this behavior by specifying an install directory with the -d option.
|
106 |
+
|
107 |
+
If you are using the interactive mode, our default values for the accelerate config screen after running the script answer "This machine", "None", "No" for the remaining questions.
|
108 |
+
These are the same answers as the Windows install.
|
109 |
+
|
110 |
+
### Windows
|
111 |
+
|
112 |
+
- Install [Python 3.10](https://www.python.org/ftp/python/3.10.9/python-3.10.9-amd64.exe)
|
113 |
+
- make sure to tick the box to add Python to the 'PATH' environment variable
|
114 |
+
- Install [Git](https://git-scm.com/download/win)
|
115 |
+
- Install [Visual Studio 2015, 2017, 2019, and 2022 redistributable](https://aka.ms/vs/17/release/vc_redist.x64.exe)
|
116 |
+
|
117 |
+
In the terminal, run:
|
118 |
+
|
119 |
+
```
|
120 |
+
git clone https://github.com/bmaltais/kohya_ss.git
|
121 |
+
cd kohya_ss
|
122 |
+
.\setup.bat
|
123 |
+
```
|
124 |
+
|
125 |
+
If this is a 1st install answer No when asked `Do you want to uninstall previous versions of torch and associated files before installing`.
|
126 |
+
|
127 |
+
|
128 |
+
Then configure accelerate with the same answers as in the MacOS instructions when prompted.
|
129 |
+
|
130 |
+
### Optional: CUDNN 8.6
|
131 |
+
|
132 |
+
This step is optional but can improve the learning speed for NVIDIA 30X0/40X0 owners. It allows for larger training batch size and faster training speed.
|
133 |
+
|
134 |
+
Due to the file size, I can't host the DLLs needed for CUDNN 8.6 on Github. I strongly advise you download them for a speed boost in sample generation (almost 50% on 4090 GPU) you can download them [here](https://b1.thefileditch.ch/mwxKTEtelILoIbMbruuM.zip).
|
135 |
+
|
136 |
+
To install, simply unzip the directory and place the `cudnn_windows` folder in the root of the this repo.
|
137 |
+
|
138 |
+
Run the following commands to install:
|
139 |
+
|
140 |
+
```
|
141 |
+
.\venv\Scripts\activate
|
142 |
+
|
143 |
+
python .\tools\cudann_1.8_install.py
|
144 |
+
```
|
145 |
+
|
146 |
+
Once the commands have completed successfully you should be ready to use the new version. MacOS support is not tested and has been mostly taken from https://gist.github.com/jstayco/9f5733f05b9dc29de95c4056a023d645
|
147 |
+
|
148 |
+
## Upgrading
|
149 |
+
|
150 |
+
The following commands will work from the root directory of the project if you'd prefer to not run scripts.
|
151 |
+
These commands will work on any OS.
|
152 |
+
```bash
|
153 |
+
git pull
|
154 |
+
|
155 |
+
.\venv\Scripts\activate
|
156 |
+
|
157 |
+
pip install --use-pep517 --upgrade -r requirements.txt
|
158 |
+
```
|
159 |
+
|
160 |
+
### Windows Upgrade
|
161 |
+
When a new release comes out, you can upgrade your repo with the following commands in the root directory:
|
162 |
+
|
163 |
+
```powershell
|
164 |
+
upgrade.bat
|
165 |
+
```
|
166 |
+
|
167 |
+
### Linux and macOS Upgrade
|
168 |
+
You can cd into the root directory and simply run
|
169 |
+
|
170 |
+
```bash
|
171 |
+
# Refresh and update everything
|
172 |
+
./setup.sh
|
173 |
+
|
174 |
+
# This will refresh everything, but NOT clone or pull the git repo.
|
175 |
+
./setup.sh --no-git-update
|
176 |
+
```
|
177 |
+
|
178 |
+
Once the commands have completed successfully you should be ready to use the new version.
|
179 |
+
|
180 |
+
# Starting GUI Service
|
181 |
+
|
182 |
+
The following command line arguments can be passed to the scripts on any OS to configure the underlying service.
|
183 |
+
```
|
184 |
+
--listen: the IP address to listen on for connections to Gradio.
|
185 |
+
--username: a username for authentication.
|
186 |
+
--password: a password for authentication.
|
187 |
+
--server_port: the port to run the server listener on.
|
188 |
+
--inbrowser: opens the Gradio UI in a web browser.
|
189 |
+
--share: shares the Gradio UI.
|
190 |
+
```
|
191 |
+
|
192 |
+
### Launching the GUI on Windows
|
193 |
+
|
194 |
+
The two scripts to launch the GUI on Windows are gui.ps1 and gui.bat in the root directory.
|
195 |
+
You can use whichever script you prefer.
|
196 |
+
|
197 |
+
To launch the Gradio UI, run the script in a terminal with the desired command line arguments, for example:
|
198 |
+
|
199 |
+
`gui.ps1 --listen 127.0.0.1 --server_port 7860 --inbrowser --share`
|
200 |
+
|
201 |
+
or
|
202 |
+
|
203 |
+
`gui.bat --listen 127.0.0.1 --server_port 7860 --inbrowser --share`
|
204 |
+
|
205 |
+
## Launching the GUI on Linux and macOS
|
206 |
+
|
207 |
+
Run the launcher script with the desired command line arguments similar to Windows.
|
208 |
+
`gui.sh --listen 127.0.0.1 --server_port 7860 --inbrowser --share`
|
209 |
+
|
210 |
+
## Launching the GUI directly using kohya_gui.py
|
211 |
+
|
212 |
+
To run the GUI directly bypassing the wrapper scripts, simply use this command from the root project directory:
|
213 |
+
|
214 |
+
```
|
215 |
+
.\venv\Scripts\activate
|
216 |
+
|
217 |
+
python .\kohya_gui.py
|
218 |
+
```
|
219 |
+
|
220 |
+
## Dreambooth
|
221 |
+
|
222 |
+
You can find the dreambooth solution specific here: [Dreambooth README](train_db_README.md)
|
223 |
+
|
224 |
+
## Finetune
|
225 |
+
|
226 |
+
You can find the finetune solution specific here: [Finetune README](fine_tune_README.md)
|
227 |
+
|
228 |
+
## Train Network
|
229 |
+
|
230 |
+
You can find the train network solution specific here: [Train network README](train_network_README.md)
|
231 |
+
|
232 |
+
## LoRA
|
233 |
+
|
234 |
+
Training a LoRA currently uses the `train_network.py` code. You can create a LoRA network by using the all-in-one `gui.cmd` or by running the dedicated LoRA training GUI with:
|
235 |
+
|
236 |
+
```
|
237 |
+
.\venv\Scripts\activate
|
238 |
+
|
239 |
+
python lora_gui.py
|
240 |
+
```
|
241 |
+
|
242 |
+
Once you have created the LoRA network, you can generate images via auto1111 by installing [this extension](https://github.com/kohya-ss/sd-webui-additional-networks).
|
243 |
+
|
244 |
+
### Naming of LoRA
|
245 |
+
|
246 |
+
The LoRA supported by `train_network.py` has been named to avoid confusion. The documentation has been updated. The following are the names of LoRA types in this repository.
|
247 |
+
|
248 |
+
1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers)
|
249 |
+
|
250 |
+
LoRA for Linear layers and Conv2d layers with 1x1 kernel
|
251 |
+
|
252 |
+
2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers)
|
253 |
+
|
254 |
+
In addition to 1., LoRA for Conv2d layers with 3x3 kernel
|
255 |
+
|
256 |
+
LoRA-LierLa is the default LoRA type for `train_network.py` (without `conv_dim` network arg). LoRA-LierLa can be used with [our extension](https://github.com/kohya-ss/sd-webui-additional-networks) for AUTOMATIC1111's Web UI, or with the built-in LoRA feature of the Web UI.
|
257 |
+
|
258 |
+
To use LoRA-C3Liar with Web UI, please use our extension.
|
259 |
+
|
260 |
+
## Sample image generation during training
|
261 |
+
A prompt file might look like this, for example
|
262 |
+
|
263 |
+
```
|
264 |
+
# prompt 1
|
265 |
+
masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
|
266 |
+
|
267 |
+
# prompt 2
|
268 |
+
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
|
269 |
+
```
|
270 |
+
|
271 |
+
Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used.
|
272 |
+
|
273 |
+
* `--n` Negative prompt up to the next option.
|
274 |
+
* `--w` Specifies the width of the generated image.
|
275 |
+
* `--h` Specifies the height of the generated image.
|
276 |
+
* `--d` Specifies the seed of the generated image.
|
277 |
+
* `--l` Specifies the CFG scale of the generated image.
|
278 |
+
* `--s` Specifies the number of steps in the generation.
|
279 |
+
|
280 |
+
The prompt weighting such as `( )` and `[ ]` are working.
|
281 |
+
|
282 |
+
## Troubleshooting
|
283 |
+
|
284 |
+
### Page File Limit
|
285 |
+
|
286 |
+
- X error relating to `page file`: Increase the page file size limit in Windows.
|
287 |
+
|
288 |
+
### No module called tkinter
|
289 |
+
|
290 |
+
- Re-install [Python 3.10](https://www.python.org/ftp/python/3.10.9/python-3.10.9-amd64.exe) on your system.
|
291 |
+
|
292 |
+
### FileNotFoundError
|
293 |
+
|
294 |
+
This is usually related to an installation issue. Make sure you do not have any python modules installed locally that could conflict with the ones installed in the venv:
|
295 |
+
|
296 |
+
1. Open a new powershell terminal and make sure no venv is active.
|
297 |
+
2. Run the following commands:
|
298 |
+
|
299 |
+
```
|
300 |
+
pip freeze > uninstall.txt
|
301 |
+
pip uninstall -r uninstall.txt
|
302 |
+
```
|
303 |
+
|
304 |
+
This will store a backup file with your current locally installed pip packages and then uninstall them. Then, redo the installation instructions within the kohya_ss venv.
|
305 |
+
|
306 |
+
## Change History
|
307 |
+
|
308 |
+
* 2023/04/22 (v21.5.5)
|
309 |
+
- Update LoRA merge GUI to support SD checkpoint merge and up to 4 LoRA merging
|
310 |
+
- Fixed `lora_interrogator.py` not working. Please refer to [PR #392](https://github.com/kohya-ss/sd-scripts/pull/392) for details. Thank you A2va and heyalexchoi!
|
311 |
+
- Fixed the handling of tags containing `_` in `tag_images_by_wd14_tagger.py`.
|
312 |
+
- Add new Extract DyLoRA gui to the Utilities tab.
|
313 |
+
- Add new Merge LyCORIS models into checkpoint gui to the Utilities tab.
|
314 |
+
- Add new info on startup to help debug things
|
315 |
+
* 2023/04/17 (v21.5.4)
|
316 |
+
- Fixed a bug that caused an error when loading DyLoRA with the `--network_weight` option in `train_network.py`.
|
317 |
+
- Added the `--recursive` option to each script in the `finetune` folder to process folders recursively. Please refer to [PR #400](https://github.com/kohya-ss/sd-scripts/pull/400/) for details. Thanks to Linaqruf!
|
318 |
+
- Upgrade Gradio to latest release
|
319 |
+
- Fix issue when Adafactor is used as optimizer and LR Warmup is not 0: https://github.com/bmaltais/kohya_ss/issues/617
|
320 |
+
- Added support for DyLoRA in `train_network.py`. Please refer to [here](./train_network_README-ja.md#dylora) for details (currently only in Japanese).
|
321 |
+
- Added support for caching latents to disk in each training script. Please specify __both__ `--cache_latents` and `--cache_latents_to_disk` options.
|
322 |
+
- The files are saved in the same folder as the images with the extension `.npz`. If you specify the `--flip_aug` option, the files with `_flip.npz` will also be saved.
|
323 |
+
- Multi-GPU training has not been tested.
|
324 |
+
- This feature is not tested with all combinations of datasets and training scripts, so there may be bugs.
|
325 |
+
- Added workaround for an error that occurs when training with `fp16` or `bf16` in `fine_tune.py`.
|
326 |
+
- Implemented DyLoRA GUI support. There will now be a new 'DyLoRA Unit` slider when the LoRA type is selected as `kohya DyLoRA` to specify the desired Unit value for DyLoRA training.
|
327 |
+
- Update gui.bat and gui.ps1 based on: https://github.com/bmaltais/kohya_ss/issues/188
|
328 |
+
- Update `setup.bat` to install torch 2.0.0 instead of 1.2.1. If you want to upgrade from 1.2.1 to 2.0.0 run setup.bat again, select 1 to uninstall the previous torch modules, then select 2 for torch 2.0.0
|
329 |
+
|
330 |
+
* 2023/04/09 (v21.5.2)
|
331 |
+
|
332 |
+
- Added support for training with weighted captions. Thanks to AI-Casanova for the great contribution!
|
333 |
+
- Please refer to the PR for details: [PR #336](https://github.com/kohya-ss/sd-scripts/pull/336)
|
334 |
+
- Specify the `--weighted_captions` option. It is available for all training scripts except Textual Inversion and XTI.
|
335 |
+
- This option is also applicable to token strings of the DreamBooth method.
|
336 |
+
- The syntax for weighted captions is almost the same as the Web UI, and you can use things like `(abc)`, `[abc]`, and `(abc:1.23)`. Nesting is also possible.
|
337 |
+
- If you include a comma in the parentheses, the parentheses will not be properly matched in the prompt shuffle/dropout, so do not include a comma in the parentheses.
|
338 |
+
- Run gui.sh from any place
|
XTI_hijack.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Union, List, Optional, Dict, Any, Tuple
|
3 |
+
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
4 |
+
|
5 |
+
def unet_forward_XTI(self,
|
6 |
+
sample: torch.FloatTensor,
|
7 |
+
timestep: Union[torch.Tensor, float, int],
|
8 |
+
encoder_hidden_states: torch.Tensor,
|
9 |
+
class_labels: Optional[torch.Tensor] = None,
|
10 |
+
return_dict: bool = True,
|
11 |
+
) -> Union[UNet2DConditionOutput, Tuple]:
|
12 |
+
r"""
|
13 |
+
Args:
|
14 |
+
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
15 |
+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
16 |
+
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
17 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
18 |
+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
22 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
23 |
+
returning a tuple, the first element is the sample tensor.
|
24 |
+
"""
|
25 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
26 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
27 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
28 |
+
# on the fly if necessary.
|
29 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
30 |
+
|
31 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
32 |
+
forward_upsample_size = False
|
33 |
+
upsample_size = None
|
34 |
+
|
35 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
36 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
37 |
+
forward_upsample_size = True
|
38 |
+
|
39 |
+
# 0. center input if necessary
|
40 |
+
if self.config.center_input_sample:
|
41 |
+
sample = 2 * sample - 1.0
|
42 |
+
|
43 |
+
# 1. time
|
44 |
+
timesteps = timestep
|
45 |
+
if not torch.is_tensor(timesteps):
|
46 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
47 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
48 |
+
is_mps = sample.device.type == "mps"
|
49 |
+
if isinstance(timestep, float):
|
50 |
+
dtype = torch.float32 if is_mps else torch.float64
|
51 |
+
else:
|
52 |
+
dtype = torch.int32 if is_mps else torch.int64
|
53 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
54 |
+
elif len(timesteps.shape) == 0:
|
55 |
+
timesteps = timesteps[None].to(sample.device)
|
56 |
+
|
57 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
58 |
+
timesteps = timesteps.expand(sample.shape[0])
|
59 |
+
|
60 |
+
t_emb = self.time_proj(timesteps)
|
61 |
+
|
62 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
63 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
64 |
+
# there might be better ways to encapsulate this.
|
65 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
66 |
+
emb = self.time_embedding(t_emb)
|
67 |
+
|
68 |
+
if self.config.num_class_embeds is not None:
|
69 |
+
if class_labels is None:
|
70 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
71 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
72 |
+
emb = emb + class_emb
|
73 |
+
|
74 |
+
# 2. pre-process
|
75 |
+
sample = self.conv_in(sample)
|
76 |
+
|
77 |
+
# 3. down
|
78 |
+
down_block_res_samples = (sample,)
|
79 |
+
down_i = 0
|
80 |
+
for downsample_block in self.down_blocks:
|
81 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
82 |
+
sample, res_samples = downsample_block(
|
83 |
+
hidden_states=sample,
|
84 |
+
temb=emb,
|
85 |
+
encoder_hidden_states=encoder_hidden_states[down_i:down_i+2],
|
86 |
+
)
|
87 |
+
down_i += 2
|
88 |
+
else:
|
89 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
90 |
+
|
91 |
+
down_block_res_samples += res_samples
|
92 |
+
|
93 |
+
# 4. mid
|
94 |
+
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6])
|
95 |
+
|
96 |
+
# 5. up
|
97 |
+
up_i = 7
|
98 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
99 |
+
is_final_block = i == len(self.up_blocks) - 1
|
100 |
+
|
101 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
102 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
103 |
+
|
104 |
+
# if we have not reached the final block and need to forward the
|
105 |
+
# upsample size, we do it here
|
106 |
+
if not is_final_block and forward_upsample_size:
|
107 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
108 |
+
|
109 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
110 |
+
sample = upsample_block(
|
111 |
+
hidden_states=sample,
|
112 |
+
temb=emb,
|
113 |
+
res_hidden_states_tuple=res_samples,
|
114 |
+
encoder_hidden_states=encoder_hidden_states[up_i:up_i+3],
|
115 |
+
upsample_size=upsample_size,
|
116 |
+
)
|
117 |
+
up_i += 3
|
118 |
+
else:
|
119 |
+
sample = upsample_block(
|
120 |
+
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
121 |
+
)
|
122 |
+
# 6. post-process
|
123 |
+
sample = self.conv_norm_out(sample)
|
124 |
+
sample = self.conv_act(sample)
|
125 |
+
sample = self.conv_out(sample)
|
126 |
+
|
127 |
+
if not return_dict:
|
128 |
+
return (sample,)
|
129 |
+
|
130 |
+
return UNet2DConditionOutput(sample=sample)
|
131 |
+
|
132 |
+
def downblock_forward_XTI(
|
133 |
+
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
|
134 |
+
):
|
135 |
+
output_states = ()
|
136 |
+
i = 0
|
137 |
+
|
138 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
139 |
+
if self.training and self.gradient_checkpointing:
|
140 |
+
|
141 |
+
def create_custom_forward(module, return_dict=None):
|
142 |
+
def custom_forward(*inputs):
|
143 |
+
if return_dict is not None:
|
144 |
+
return module(*inputs, return_dict=return_dict)
|
145 |
+
else:
|
146 |
+
return module(*inputs)
|
147 |
+
|
148 |
+
return custom_forward
|
149 |
+
|
150 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
151 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
152 |
+
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i]
|
153 |
+
)[0]
|
154 |
+
else:
|
155 |
+
hidden_states = resnet(hidden_states, temb)
|
156 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample
|
157 |
+
|
158 |
+
output_states += (hidden_states,)
|
159 |
+
i += 1
|
160 |
+
|
161 |
+
if self.downsamplers is not None:
|
162 |
+
for downsampler in self.downsamplers:
|
163 |
+
hidden_states = downsampler(hidden_states)
|
164 |
+
|
165 |
+
output_states += (hidden_states,)
|
166 |
+
|
167 |
+
return hidden_states, output_states
|
168 |
+
|
169 |
+
def upblock_forward_XTI(
|
170 |
+
self,
|
171 |
+
hidden_states,
|
172 |
+
res_hidden_states_tuple,
|
173 |
+
temb=None,
|
174 |
+
encoder_hidden_states=None,
|
175 |
+
upsample_size=None,
|
176 |
+
):
|
177 |
+
i = 0
|
178 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
179 |
+
# pop res hidden states
|
180 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
181 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
182 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
183 |
+
|
184 |
+
if self.training and self.gradient_checkpointing:
|
185 |
+
|
186 |
+
def create_custom_forward(module, return_dict=None):
|
187 |
+
def custom_forward(*inputs):
|
188 |
+
if return_dict is not None:
|
189 |
+
return module(*inputs, return_dict=return_dict)
|
190 |
+
else:
|
191 |
+
return module(*inputs)
|
192 |
+
|
193 |
+
return custom_forward
|
194 |
+
|
195 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
196 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
197 |
+
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i]
|
198 |
+
)[0]
|
199 |
+
else:
|
200 |
+
hidden_states = resnet(hidden_states, temb)
|
201 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample
|
202 |
+
|
203 |
+
i += 1
|
204 |
+
|
205 |
+
if self.upsamplers is not None:
|
206 |
+
for upsampler in self.upsamplers:
|
207 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
208 |
+
|
209 |
+
return hidden_states
|
_typos.toml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Files for typos
|
2 |
+
# Instruction: https://github.com/marketplace/actions/typos-action#getting-started
|
3 |
+
|
4 |
+
[default.extend-identifiers]
|
5 |
+
|
6 |
+
[default.extend-words]
|
7 |
+
NIN="NIN"
|
8 |
+
parms="parms"
|
9 |
+
nin="nin"
|
10 |
+
extention="extention" # Intentionally left
|
11 |
+
nd="nd"
|
12 |
+
|
13 |
+
|
14 |
+
[files]
|
15 |
+
extend-exclude = ["_typos.toml"]
|
activate.bat
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@echo off
|
2 |
+
|
3 |
+
call venv\Scripts\activate.bat
|
4 |
+
|
5 |
+
pause
|
activate.ps1
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
.\venv\Scripts\activate
|
config_README-ja.md
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
For non-Japanese speakers: this README is provided only in Japanese in the current state. Sorry for inconvenience. We will provide English version in the near future.
|
2 |
+
|
3 |
+
`--dataset_config` で渡すことができる設定ファイルに関する説明です。
|
4 |
+
|
5 |
+
## 概要
|
6 |
+
|
7 |
+
設定ファイルを渡すことにより、ユーザが細かい設定を行えるようにします。
|
8 |
+
|
9 |
+
* 複数のデータセットが設定可能になります
|
10 |
+
* 例えば `resolution` をデータセットごとに設定して、それらを混合して学習できます。
|
11 |
+
* DreamBooth の手法と fine tuning の手法の両方に対応している学習方法では、DreamBooth 方式と fine tuning 方式のデータセットを混合することが可能です。
|
12 |
+
* サブセットごとに設定を変更することが可能になります
|
13 |
+
* データセットを画像ディレクトリ別またはメタデータ別に分割したものがサブセットです。いくつかのサブセットが集まってデータセットを構成します。
|
14 |
+
* `keep_tokens` や `flip_aug` 等のオプションはサブセットごとに設定可能です。一方、`resolution` や `batch_size` といったオプションはデータセットごとに設定可能で、同じデータセットに属するサブセットでは値が共通になります。詳しくは後述します。
|
15 |
+
|
16 |
+
設定ファイルの形式は JSON か TOML を利用できます。記述のしやすさを考えると [TOML](https://toml.io/ja/v1.0.0-rc.2) を利用するのがオススメです。以下、TOML の利用を前提に説明します。
|
17 |
+
|
18 |
+
TOML で記述した設定ファイルの例です。
|
19 |
+
|
20 |
+
```toml
|
21 |
+
[general]
|
22 |
+
shuffle_caption = true
|
23 |
+
caption_extension = '.txt'
|
24 |
+
keep_tokens = 1
|
25 |
+
|
26 |
+
# これは DreamBooth 方式のデータセット
|
27 |
+
[[datasets]]
|
28 |
+
resolution = 512
|
29 |
+
batch_size = 4
|
30 |
+
keep_tokens = 2
|
31 |
+
|
32 |
+
[[datasets.subsets]]
|
33 |
+
image_dir = 'C:\hoge'
|
34 |
+
class_tokens = 'hoge girl'
|
35 |
+
# このサブセットは keep_tokens = 2 (所属する datasets の値が使われる)
|
36 |
+
|
37 |
+
[[datasets.subsets]]
|
38 |
+
image_dir = 'C:\fuga'
|
39 |
+
class_tokens = 'fuga boy'
|
40 |
+
keep_tokens = 3
|
41 |
+
|
42 |
+
[[datasets.subsets]]
|
43 |
+
is_reg = true
|
44 |
+
image_dir = 'C:\reg'
|
45 |
+
class_tokens = 'human'
|
46 |
+
keep_tokens = 1
|
47 |
+
|
48 |
+
# これは fine tuning 方式のデータセット
|
49 |
+
[[datasets]]
|
50 |
+
resolution = [768, 768]
|
51 |
+
batch_size = 2
|
52 |
+
|
53 |
+
[[datasets.subsets]]
|
54 |
+
image_dir = 'C:\piyo'
|
55 |
+
metadata_file = 'C:\piyo\piyo_md.json'
|
56 |
+
# このサブセットは keep_tokens = 1 (general の値が使われる)
|
57 |
+
```
|
58 |
+
|
59 |
+
この例では、3 つのディレクトリを DreamBooth 方式のデータセットとして 512x512 (batch size 4) で学習させ、1 つのディレクトリを fine tuning 方式のデータセットとして 768x768 (batch size 2) で学習させることになります。
|
60 |
+
|
61 |
+
## データセット・サブセットに関する設定
|
62 |
+
|
63 |
+
データセット・サブセットに関する設定は、登録可能な箇所がいくつかに分かれています。
|
64 |
+
|
65 |
+
* `[general]`
|
66 |
+
* 全データセットまたは全サブセットに適用されるオプションを指定する箇所です。
|
67 |
+
* データセットごとの設定及びサブセットごとの設定に同名のオプションが存在していた場合には、データセット・サブセットごとの設定が優先されます。
|
68 |
+
* `[[datasets]]`
|
69 |
+
* `datasets` はデータセットに関する設定の登録箇所になります。各データセットに個別に適用されるオプションを指定する箇所です。
|
70 |
+
* サブセットごとの設定が存在していた場合には、サブセットごとの設定が優先されます。
|
71 |
+
* `[[datasets.subsets]]`
|
72 |
+
* `datasets.subsets` はサブセットに関する設定の登録箇所になります。各サブセットに個別に適用されるオプションを指定する箇所です。
|
73 |
+
|
74 |
+
先程の例における、画像ディレクトリと登録箇所の対応に関するイメージ図です。
|
75 |
+
|
76 |
+
```
|
77 |
+
C:\
|
78 |
+
├─ hoge -> [[datasets.subsets]] No.1 ┐ ┐
|
79 |
+
├─ fuga -> [[datasets.subsets]] No.2 |-> [[datasets]] No.1 |-> [general]
|
80 |
+
├─ reg -> [[datasets.subsets]] No.3 ┘ |
|
81 |
+
└─ piyo -> [[datasets.subsets]] No.4 --> [[datasets]] No.2 ┘
|
82 |
+
```
|
83 |
+
|
84 |
+
画像ディレクトリがそれぞれ1つの `[[datasets.subsets]]` に対応しています。そして `[[datasets.subsets]]` が1つ以上組み合わさって1つの `[[datasets]]` を構成します。`[general]` には全ての `[[datasets]]`, `[[datasets.subsets]]` が属します。
|
85 |
+
|
86 |
+
登録箇所ごとに指定可能なオプションは異なりますが、同名のオプションが指定された場合は下位の登録箇所にある値が優先されます。先程の例の `keep_tokens` オプションの扱われ方を確認してもらうと理解しやすいかと思います。
|
87 |
+
|
88 |
+
加えて、学習方法が対応している手法によっても指定可能なオプションが変化します。
|
89 |
+
|
90 |
+
* DreamBooth 方式専用のオプション
|
91 |
+
* fine tuning 方式専用のオプション
|
92 |
+
* caption dropout の手法が使える場合のオプション
|
93 |
+
|
94 |
+
DreamBooth の手法と fine tuning の手法の両方とも利用可能な学習方法では、両者を併用することができます。
|
95 |
+
併用する際の注意点として、DreamBooth 方式なのか fine tuning 方式なのかはデータセット単位で判別を行っているため、同じデータセット中に DreamBooth 方式のサブセットと fine tuning 方式のサブセットを混在させることはできません。
|
96 |
+
つまり、これらを併用したい場合には異なる方式のサブセットが異なるデータセットに所属するように設定する必要があります。
|
97 |
+
|
98 |
+
プログラムの挙動としては、後述する `metadata_file` オプションが存在していたら fine tuning 方式のサブセットだと判断します。
|
99 |
+
そのため、同一のデータセットに所属するサブセットについて言うと、「全てが `metadata_file` オプションを持つ」か「全てが `metadata_file` オプションを持たない」かのどちらかになっていれば問題ありません。
|
100 |
+
|
101 |
+
以下、利用可能なオプションを説明します。コマンドライン引数と名称が同一のオプションについては、基本的に説明を割愛します。他の README を参照してください。
|
102 |
+
|
103 |
+
### 全学習方法で共通のオプション
|
104 |
+
|
105 |
+
学習方法によらずに指定可能なオプションです。
|
106 |
+
|
107 |
+
#### データセット向けオプション
|
108 |
+
|
109 |
+
データセットの設定に関わるオプションです。`datasets.subsets` には記述できません。
|
110 |
+
|
111 |
+
| オプション名 | 設定例 | `[general]` | `[[datasets]]` |
|
112 |
+
| ---- | ---- | ---- | ---- |
|
113 |
+
| `batch_size` | `1` | o | o |
|
114 |
+
| `bucket_no_upscale` | `true` | o | o |
|
115 |
+
| `bucket_reso_steps` | `64` | o | o |
|
116 |
+
| `enable_bucket` | `true` | o | o |
|
117 |
+
| `max_bucket_reso` | `1024` | o | o |
|
118 |
+
| `min_bucket_reso` | `128` | o | o |
|
119 |
+
| `resolution` | `256`, `[512, 512]` | o | o |
|
120 |
+
|
121 |
+
* `batch_size`
|
122 |
+
* コマンドライン引数の `--train_batch_size` と同等です。
|
123 |
+
|
124 |
+
これらの設定はデータセットごとに固定です。
|
125 |
+
つまり、データセットに所属するサブセットはこれらの設定を共有することになります。
|
126 |
+
例えば解像度が異なるデータセットを用意したい場合は、上に挙げた例のように別々のデータセットとして定義すれば別々の解像度を設定可能です。
|
127 |
+
|
128 |
+
#### サブセット向けオプション
|
129 |
+
|
130 |
+
サブセットの設定に関わるオプションです。
|
131 |
+
|
132 |
+
| オプション名 | 設定例 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
|
133 |
+
| ---- | ---- | ---- | ---- | ---- |
|
134 |
+
| `color_aug` | `false` | o | o | o |
|
135 |
+
| `face_crop_aug_range` | `[1.0, 3.0]` | o | o | o |
|
136 |
+
| `flip_aug` | `true` | o | o | o |
|
137 |
+
| `keep_tokens` | `2` | o | o | o |
|
138 |
+
| `num_repeats` | `10` | o | o | o |
|
139 |
+
| `random_crop` | `false` | o | o | o |
|
140 |
+
| `shuffle_caption` | `true` | o | o | o |
|
141 |
+
|
142 |
+
* `num_repeats`
|
143 |
+
* サブセットの画像の繰り返し回数を指定します。fine tuning における `--dataset_repeats` に相当しますが、`num_repeats` はどの学習方法でも指定可能です。
|
144 |
+
|
145 |
+
### DreamBooth 方式専用のオプション
|
146 |
+
|
147 |
+
DreamBooth 方式のオプションは、サブセット向けオプションのみ存在します。
|
148 |
+
|
149 |
+
#### サブセット向けオプション
|
150 |
+
|
151 |
+
DreamBooth 方式のサブセットの設定に関わるオプションです。
|
152 |
+
|
153 |
+
| オプション名 | 設定例 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
|
154 |
+
| ---- | ---- | ---- | ---- | ---- |
|
155 |
+
| `image_dir` | `‘C:\hoge’` | - | - | o(必須) |
|
156 |
+
| `caption_extension` | `".txt"` | o | o | o |
|
157 |
+
| `class_tokens` | `“sks girl”` | - | - | o |
|
158 |
+
| `is_reg` | `false` | - | - | o |
|
159 |
+
|
160 |
+
まず注意点として、 `image_dir` には画像ファイルが直下に置かれているパスを指定する必要があります。従来の DreamBooth の手法ではサブディレクトリに画像を置く必要がありましたが、そちらとは仕様に互換性がありません。また、`5_cat` のようなフォルダ名にしても、画像の繰り返し回数とクラス名は反映されません。これらを個別に設定したい場合、`num_repeats` と `class_tokens` で明示的に指定する必要があることに注意してください。
|
161 |
+
|
162 |
+
* `image_dir`
|
163 |
+
* 画像ディレクトリのパスを指定します。指定必須オプションです。
|
164 |
+
* 画像はディレクトリ直下に置かれている必要があります。
|
165 |
+
* `class_tokens`
|
166 |
+
* クラストークンを設定します。
|
167 |
+
* 画像に対応する caption ファイルが存在しない場合にのみ学習時に利用されます。利用するかどうかの判定は画像ごとに行います。`class_tokens` を指定しなかった場合に caption ファイル���見つからなかった場合にはエラーになります。
|
168 |
+
* `is_reg`
|
169 |
+
* サブセットの画像が正規化用かどうかを指定します。指定しなかった場合は `false` として、つまり正規化画像ではないとして扱います。
|
170 |
+
|
171 |
+
### fine tuning 方式専用のオプション
|
172 |
+
|
173 |
+
fine tuning 方式のオプションは、サブセット向けオプションのみ存在します。
|
174 |
+
|
175 |
+
#### サブセット向けオプション
|
176 |
+
|
177 |
+
fine tuning 方式のサブセットの設定に関わるオプションです。
|
178 |
+
|
179 |
+
| オプション名 | 設定例 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
|
180 |
+
| ---- | ---- | ---- | ---- | ---- |
|
181 |
+
| `image_dir` | `‘C:\hoge’` | - | - | o |
|
182 |
+
| `metadata_file` | `'C:\piyo\piyo_md.json'` | - | - | o(必須) |
|
183 |
+
|
184 |
+
* `image_dir`
|
185 |
+
* 画像ディレクトリのパスを指定します。DreamBooth の手法の方とは異なり指定は必須ではありませんが、設定することを推奨します。
|
186 |
+
* 指定する必要がない状況としては、メタデータファイルの生成時に `--full_path` を付与して実行していた場合です。
|
187 |
+
* 画像はディレクトリ直下に置かれている必要があります。
|
188 |
+
* `metadata_file`
|
189 |
+
* サブセットで利用されるメタデータファイルのパスを指定します。指定必須オプションです。
|
190 |
+
* コマンドライン引数の `--in_json` と同等です。
|
191 |
+
* サブセットごとにメタデータファイルを指定する必要がある仕様上、ディレクトリを跨いだメタデータを1つのメタデータファイルとして作成することは避けた方が良いでしょう。画像ディレクトリごとにメタデータファイルを用意し、それらを別々のサブセットとして登録することを強く推奨します。
|
192 |
+
|
193 |
+
### caption dropout の手法が使える場合に指定可能なオプション
|
194 |
+
|
195 |
+
caption dropout の手法が使える場合のオプションは、サブセット向けオプションのみ存在します。
|
196 |
+
DreamBooth 方式か fine tuning 方式かに関わらず、caption dropout に対応している学習方法であれば指定可能です。
|
197 |
+
|
198 |
+
#### サブセット向けオプション
|
199 |
+
|
200 |
+
caption dropout が使えるサブセットの設定に関わるオプションです。
|
201 |
+
|
202 |
+
| オプション名 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
|
203 |
+
| ---- | ---- | ---- | ---- |
|
204 |
+
| `caption_dropout_every_n_epochs` | o | o | o |
|
205 |
+
| `caption_dropout_rate` | o | o | o |
|
206 |
+
| `caption_tag_dropout_rate` | o | o | o |
|
207 |
+
|
208 |
+
## 重複したサブセットが存在する時の挙動
|
209 |
+
|
210 |
+
DreamBooth 方式のデータセットの場合、その中にある `image_dir` が同一のサブセットは重複していると見なされます。
|
211 |
+
fine tuning 方式のデータセットの場合は、その中にある `metadata_file` が同一のサブセットは重複していると見なされます。
|
212 |
+
データセット中に重複したサブセットが存在する場合、2個目以降は無視されます。
|
213 |
+
|
214 |
+
一方、異なるデータセットに所属している場合は、重複しているとは見なされません。
|
215 |
+
例えば、以下のように同一の `image_dir` を持つサブセットを別々のデータセットに入れた場合には、重複していないと見なします。
|
216 |
+
これは、同じ画像でも異なる解像度で学習したい場合に役立ちます。
|
217 |
+
|
218 |
+
```toml
|
219 |
+
# 別々のデータセットに存在している場合は重複とは見なされず、両方とも学習に使われる
|
220 |
+
|
221 |
+
[[datasets]]
|
222 |
+
resolution = 512
|
223 |
+
|
224 |
+
[[datasets.subsets]]
|
225 |
+
image_dir = 'C:\hoge'
|
226 |
+
|
227 |
+
[[datasets]]
|
228 |
+
resolution = 768
|
229 |
+
|
230 |
+
[[datasets.subsets]]
|
231 |
+
image_dir = 'C:\hoge'
|
232 |
+
```
|
233 |
+
|
234 |
+
## コマンドライン引数との併用
|
235 |
+
|
236 |
+
設定ファイルのオプションの中には、コマンドライン引数のオプションと役割が重複しているものがあります。
|
237 |
+
|
238 |
+
以下に挙げるコマンドライン引数のオプションは、設定ファイルを渡した場合には無視されます。
|
239 |
+
|
240 |
+
* `--train_data_dir`
|
241 |
+
* `--reg_data_dir`
|
242 |
+
* `--in_json`
|
243 |
+
|
244 |
+
以下に挙げるコマンドライン引数のオプションは、コマンドライン引数と設定ファイルで同時に指定された場合、コマンドライン引数の値よりも設定ファイルの値が優先されます。特に断りがなければ同名のオプションとなります。
|
245 |
+
|
246 |
+
| コマンドライン引数のオプション | 優先される設定ファイルのオプション |
|
247 |
+
| ---------------------------------- | ---------------------------------- |
|
248 |
+
| `--bucket_no_upscale` | |
|
249 |
+
| `--bucket_reso_steps` | |
|
250 |
+
| `--caption_dropout_every_n_epochs` | |
|
251 |
+
| `--caption_dropout_rate` | |
|
252 |
+
| `--caption_extension` | |
|
253 |
+
| `--caption_tag_dropout_rate` | |
|
254 |
+
| `--color_aug` | |
|
255 |
+
| `--dataset_repeats` | `num_repeats` |
|
256 |
+
| `--enable_bucket` | |
|
257 |
+
| `--face_crop_aug_range` | |
|
258 |
+
| `--flip_aug` | |
|
259 |
+
| `--keep_tokens` | |
|
260 |
+
| `--min_bucket_reso` | |
|
261 |
+
| `--random_crop` | |
|
262 |
+
| `--resolution` | |
|
263 |
+
| `--shuffle_caption` | |
|
264 |
+
| `--train_batch_size` | `batch_size` |
|
265 |
+
|
266 |
+
## エラーの手引き
|
267 |
+
|
268 |
+
現在、外部ライブラリを利用して設定ファイルの記述が正しいかどうかをチェックしているのですが、整備が行き届いておらずエラーメッセージがわかりづらいという問題があります。
|
269 |
+
将来的にはこの問題の改善に取り組む予定です。
|
270 |
+
|
271 |
+
次善策として、頻出のエラーとその対処法について載せておきます。
|
272 |
+
正しいはずなのにエラーが出る場合、エラー内容がどうしても分からない場合は、バグかもしれないのでご連絡ください。
|
273 |
+
|
274 |
+
* `voluptuous.error.MultipleInvalid: required key not provided @ ...`: 指定必須のオプションが指定されていないというエラーです。指定を忘れているか、オプション名を間違って記述している可能性が高いです。
|
275 |
+
* `...` の箇所にはエラーが発生した場所が載っています。例えば `voluptuous.error.MultipleInvalid: required key not provided @ data['datasets'][0]['subsets'][0]['image_dir']` のようなエラーが出たら、0 番目の `datasets` 中の 0 番目の `subsets` の設定に `image_dir` が存在しないということになります。
|
276 |
+
* `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`: 指定する値の形式が不正というエラーです。値の形式が間違っている可能性が高いです。`int` の部分は対象となるオプションによって変わります。この README に載っているオプションの「設定例」が役立つかもしれません。
|
277 |
+
* `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`: 対応していないオプション名が存在している場合に発生するエラーです。オプション名を間違って記述しているか、誤って紛れ込んでいる可能性が高いです。
|
278 |
+
|
279 |
+
|
dreambooth_gui.py
ADDED
@@ -0,0 +1,954 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# v1: initial release
|
2 |
+
# v2: add open and save folder icons
|
3 |
+
# v3: Add new Utilities tab for Dreambooth folder preparation
|
4 |
+
# v3.1: Adding captionning of images to utilities
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
import json
|
8 |
+
import math
|
9 |
+
import os
|
10 |
+
import subprocess
|
11 |
+
import pathlib
|
12 |
+
import argparse
|
13 |
+
from library.common_gui import (
|
14 |
+
get_folder_path,
|
15 |
+
remove_doublequote,
|
16 |
+
get_file_path,
|
17 |
+
get_any_file_path,
|
18 |
+
get_saveasfile_path,
|
19 |
+
color_aug_changed,
|
20 |
+
save_inference_file,
|
21 |
+
gradio_advanced_training,
|
22 |
+
run_cmd_advanced_training,
|
23 |
+
run_cmd_training,
|
24 |
+
gradio_training,
|
25 |
+
gradio_config,
|
26 |
+
gradio_source_model,
|
27 |
+
# set_legacy_8bitadam,
|
28 |
+
update_my_data,
|
29 |
+
check_if_model_exist,
|
30 |
+
)
|
31 |
+
from library.tensorboard_gui import (
|
32 |
+
gradio_tensorboard,
|
33 |
+
start_tensorboard,
|
34 |
+
stop_tensorboard,
|
35 |
+
)
|
36 |
+
from library.dreambooth_folder_creation_gui import (
|
37 |
+
gradio_dreambooth_folder_creation_tab,
|
38 |
+
)
|
39 |
+
from library.utilities import utilities_tab
|
40 |
+
from library.sampler_gui import sample_gradio_config, run_cmd_sample
|
41 |
+
from easygui import msgbox
|
42 |
+
|
43 |
+
folder_symbol = '\U0001f4c2' # 📂
|
44 |
+
refresh_symbol = '\U0001f504' # 🔄
|
45 |
+
save_style_symbol = '\U0001f4be' # 💾
|
46 |
+
document_symbol = '\U0001F4C4' # 📄
|
47 |
+
|
48 |
+
|
49 |
+
def save_configuration(
|
50 |
+
save_as,
|
51 |
+
file_path,
|
52 |
+
pretrained_model_name_or_path,
|
53 |
+
v2,
|
54 |
+
v_parameterization,
|
55 |
+
logging_dir,
|
56 |
+
train_data_dir,
|
57 |
+
reg_data_dir,
|
58 |
+
output_dir,
|
59 |
+
max_resolution,
|
60 |
+
learning_rate,
|
61 |
+
lr_scheduler,
|
62 |
+
lr_warmup,
|
63 |
+
train_batch_size,
|
64 |
+
epoch,
|
65 |
+
save_every_n_epochs,
|
66 |
+
mixed_precision,
|
67 |
+
save_precision,
|
68 |
+
seed,
|
69 |
+
num_cpu_threads_per_process,
|
70 |
+
cache_latents,
|
71 |
+
caption_extension,
|
72 |
+
enable_bucket,
|
73 |
+
gradient_checkpointing,
|
74 |
+
full_fp16,
|
75 |
+
no_token_padding,
|
76 |
+
stop_text_encoder_training,
|
77 |
+
# use_8bit_adam,
|
78 |
+
xformers,
|
79 |
+
save_model_as,
|
80 |
+
shuffle_caption,
|
81 |
+
save_state,
|
82 |
+
resume,
|
83 |
+
prior_loss_weight,
|
84 |
+
color_aug,
|
85 |
+
flip_aug,
|
86 |
+
clip_skip,
|
87 |
+
vae,
|
88 |
+
output_name,
|
89 |
+
max_token_length,
|
90 |
+
max_train_epochs,
|
91 |
+
max_data_loader_n_workers,
|
92 |
+
mem_eff_attn,
|
93 |
+
gradient_accumulation_steps,
|
94 |
+
model_list,
|
95 |
+
keep_tokens,
|
96 |
+
persistent_data_loader_workers,
|
97 |
+
bucket_no_upscale,
|
98 |
+
random_crop,
|
99 |
+
bucket_reso_steps,
|
100 |
+
caption_dropout_every_n_epochs,
|
101 |
+
caption_dropout_rate,
|
102 |
+
optimizer,
|
103 |
+
optimizer_args,
|
104 |
+
noise_offset,
|
105 |
+
sample_every_n_steps,
|
106 |
+
sample_every_n_epochs,
|
107 |
+
sample_sampler,
|
108 |
+
sample_prompts,
|
109 |
+
additional_parameters,
|
110 |
+
vae_batch_size,
|
111 |
+
min_snr_gamma,weighted_captions,
|
112 |
+
):
|
113 |
+
# Get list of function parameters and values
|
114 |
+
parameters = list(locals().items())
|
115 |
+
|
116 |
+
original_file_path = file_path
|
117 |
+
|
118 |
+
save_as_bool = True if save_as.get('label') == 'True' else False
|
119 |
+
|
120 |
+
if save_as_bool:
|
121 |
+
print('Save as...')
|
122 |
+
file_path = get_saveasfile_path(file_path)
|
123 |
+
else:
|
124 |
+
print('Save...')
|
125 |
+
if file_path == None or file_path == '':
|
126 |
+
file_path = get_saveasfile_path(file_path)
|
127 |
+
|
128 |
+
# print(file_path)
|
129 |
+
|
130 |
+
if file_path == None or file_path == '':
|
131 |
+
return original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
132 |
+
|
133 |
+
# Return the values of the variables as a dictionary
|
134 |
+
variables = {
|
135 |
+
name: value
|
136 |
+
for name, value in parameters # locals().items()
|
137 |
+
if name
|
138 |
+
not in [
|
139 |
+
'file_path',
|
140 |
+
'save_as',
|
141 |
+
]
|
142 |
+
}
|
143 |
+
|
144 |
+
# Extract the destination directory from the file path
|
145 |
+
destination_directory = os.path.dirname(file_path)
|
146 |
+
|
147 |
+
# Create the destination directory if it doesn't exist
|
148 |
+
if not os.path.exists(destination_directory):
|
149 |
+
os.makedirs(destination_directory)
|
150 |
+
|
151 |
+
# Save the data to the selected file
|
152 |
+
with open(file_path, 'w') as file:
|
153 |
+
json.dump(variables, file, indent=2)
|
154 |
+
|
155 |
+
return file_path
|
156 |
+
|
157 |
+
|
158 |
+
def open_configuration(
|
159 |
+
ask_for_file,
|
160 |
+
file_path,
|
161 |
+
pretrained_model_name_or_path,
|
162 |
+
v2,
|
163 |
+
v_parameterization,
|
164 |
+
logging_dir,
|
165 |
+
train_data_dir,
|
166 |
+
reg_data_dir,
|
167 |
+
output_dir,
|
168 |
+
max_resolution,
|
169 |
+
learning_rate,
|
170 |
+
lr_scheduler,
|
171 |
+
lr_warmup,
|
172 |
+
train_batch_size,
|
173 |
+
epoch,
|
174 |
+
save_every_n_epochs,
|
175 |
+
mixed_precision,
|
176 |
+
save_precision,
|
177 |
+
seed,
|
178 |
+
num_cpu_threads_per_process,
|
179 |
+
cache_latents,
|
180 |
+
caption_extension,
|
181 |
+
enable_bucket,
|
182 |
+
gradient_checkpointing,
|
183 |
+
full_fp16,
|
184 |
+
no_token_padding,
|
185 |
+
stop_text_encoder_training,
|
186 |
+
# use_8bit_adam,
|
187 |
+
xformers,
|
188 |
+
save_model_as,
|
189 |
+
shuffle_caption,
|
190 |
+
save_state,
|
191 |
+
resume,
|
192 |
+
prior_loss_weight,
|
193 |
+
color_aug,
|
194 |
+
flip_aug,
|
195 |
+
clip_skip,
|
196 |
+
vae,
|
197 |
+
output_name,
|
198 |
+
max_token_length,
|
199 |
+
max_train_epochs,
|
200 |
+
max_data_loader_n_workers,
|
201 |
+
mem_eff_attn,
|
202 |
+
gradient_accumulation_steps,
|
203 |
+
model_list,
|
204 |
+
keep_tokens,
|
205 |
+
persistent_data_loader_workers,
|
206 |
+
bucket_no_upscale,
|
207 |
+
random_crop,
|
208 |
+
bucket_reso_steps,
|
209 |
+
caption_dropout_every_n_epochs,
|
210 |
+
caption_dropout_rate,
|
211 |
+
optimizer,
|
212 |
+
optimizer_args,
|
213 |
+
noise_offset,
|
214 |
+
sample_every_n_steps,
|
215 |
+
sample_every_n_epochs,
|
216 |
+
sample_sampler,
|
217 |
+
sample_prompts,
|
218 |
+
additional_parameters,
|
219 |
+
vae_batch_size,
|
220 |
+
min_snr_gamma,weighted_captions,
|
221 |
+
):
|
222 |
+
# Get list of function parameters and values
|
223 |
+
parameters = list(locals().items())
|
224 |
+
|
225 |
+
ask_for_file = True if ask_for_file.get('label') == 'True' else False
|
226 |
+
|
227 |
+
original_file_path = file_path
|
228 |
+
|
229 |
+
if ask_for_file:
|
230 |
+
file_path = get_file_path(file_path)
|
231 |
+
|
232 |
+
if not file_path == '' and not file_path == None:
|
233 |
+
# load variables from JSON file
|
234 |
+
with open(file_path, 'r') as f:
|
235 |
+
my_data = json.load(f)
|
236 |
+
print('Loading config...')
|
237 |
+
# Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True
|
238 |
+
my_data = update_my_data(my_data)
|
239 |
+
else:
|
240 |
+
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
241 |
+
my_data = {}
|
242 |
+
|
243 |
+
values = [file_path]
|
244 |
+
for key, value in parameters:
|
245 |
+
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
|
246 |
+
if not key in ['ask_for_file', 'file_path']:
|
247 |
+
values.append(my_data.get(key, value))
|
248 |
+
return tuple(values)
|
249 |
+
|
250 |
+
|
251 |
+
def train_model(
|
252 |
+
pretrained_model_name_or_path,
|
253 |
+
v2,
|
254 |
+
v_parameterization,
|
255 |
+
logging_dir,
|
256 |
+
train_data_dir,
|
257 |
+
reg_data_dir,
|
258 |
+
output_dir,
|
259 |
+
max_resolution,
|
260 |
+
learning_rate,
|
261 |
+
lr_scheduler,
|
262 |
+
lr_warmup,
|
263 |
+
train_batch_size,
|
264 |
+
epoch,
|
265 |
+
save_every_n_epochs,
|
266 |
+
mixed_precision,
|
267 |
+
save_precision,
|
268 |
+
seed,
|
269 |
+
num_cpu_threads_per_process,
|
270 |
+
cache_latents,
|
271 |
+
caption_extension,
|
272 |
+
enable_bucket,
|
273 |
+
gradient_checkpointing,
|
274 |
+
full_fp16,
|
275 |
+
no_token_padding,
|
276 |
+
stop_text_encoder_training_pct,
|
277 |
+
# use_8bit_adam,
|
278 |
+
xformers,
|
279 |
+
save_model_as,
|
280 |
+
shuffle_caption,
|
281 |
+
save_state,
|
282 |
+
resume,
|
283 |
+
prior_loss_weight,
|
284 |
+
color_aug,
|
285 |
+
flip_aug,
|
286 |
+
clip_skip,
|
287 |
+
vae,
|
288 |
+
output_name,
|
289 |
+
max_token_length,
|
290 |
+
max_train_epochs,
|
291 |
+
max_data_loader_n_workers,
|
292 |
+
mem_eff_attn,
|
293 |
+
gradient_accumulation_steps,
|
294 |
+
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
295 |
+
keep_tokens,
|
296 |
+
persistent_data_loader_workers,
|
297 |
+
bucket_no_upscale,
|
298 |
+
random_crop,
|
299 |
+
bucket_reso_steps,
|
300 |
+
caption_dropout_every_n_epochs,
|
301 |
+
caption_dropout_rate,
|
302 |
+
optimizer,
|
303 |
+
optimizer_args,
|
304 |
+
noise_offset,
|
305 |
+
sample_every_n_steps,
|
306 |
+
sample_every_n_epochs,
|
307 |
+
sample_sampler,
|
308 |
+
sample_prompts,
|
309 |
+
additional_parameters,
|
310 |
+
vae_batch_size,
|
311 |
+
min_snr_gamma,weighted_captions,
|
312 |
+
):
|
313 |
+
if pretrained_model_name_or_path == '':
|
314 |
+
msgbox('Source model information is missing')
|
315 |
+
return
|
316 |
+
|
317 |
+
if train_data_dir == '':
|
318 |
+
msgbox('Image folder path is missing')
|
319 |
+
return
|
320 |
+
|
321 |
+
if not os.path.exists(train_data_dir):
|
322 |
+
msgbox('Image folder does not exist')
|
323 |
+
return
|
324 |
+
|
325 |
+
if reg_data_dir != '':
|
326 |
+
if not os.path.exists(reg_data_dir):
|
327 |
+
msgbox('Regularisation folder does not exist')
|
328 |
+
return
|
329 |
+
|
330 |
+
if output_dir == '':
|
331 |
+
msgbox('Output folder path is missing')
|
332 |
+
return
|
333 |
+
|
334 |
+
if check_if_model_exist(output_name, output_dir, save_model_as):
|
335 |
+
return
|
336 |
+
|
337 |
+
if optimizer == 'Adafactor' and lr_warmup != '0':
|
338 |
+
msgbox("Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.", title="Warning")
|
339 |
+
lr_warmup = '0'
|
340 |
+
|
341 |
+
# Get a list of all subfolders in train_data_dir, excluding hidden folders
|
342 |
+
subfolders = [
|
343 |
+
f
|
344 |
+
for f in os.listdir(train_data_dir)
|
345 |
+
if os.path.isdir(os.path.join(train_data_dir, f))
|
346 |
+
and not f.startswith('.')
|
347 |
+
]
|
348 |
+
|
349 |
+
# Check if subfolders are present. If not let the user know and return
|
350 |
+
if not subfolders:
|
351 |
+
print(
|
352 |
+
'\033[33mNo subfolders were found in',
|
353 |
+
train_data_dir,
|
354 |
+
" can't train\...033[0m",
|
355 |
+
)
|
356 |
+
return
|
357 |
+
|
358 |
+
total_steps = 0
|
359 |
+
|
360 |
+
# Loop through each subfolder and extract the number of repeats
|
361 |
+
for folder in subfolders:
|
362 |
+
# Extract the number of repeats from the folder name
|
363 |
+
try:
|
364 |
+
repeats = int(folder.split('_')[0])
|
365 |
+
except ValueError:
|
366 |
+
print(
|
367 |
+
'\033[33mSubfolder',
|
368 |
+
folder,
|
369 |
+
"does not have a proper repeat value, please correct the name or remove it... can't train...\033[0m",
|
370 |
+
)
|
371 |
+
continue
|
372 |
+
|
373 |
+
# Count the number of images in the folder
|
374 |
+
num_images = len(
|
375 |
+
[
|
376 |
+
f
|
377 |
+
for f, lower_f in (
|
378 |
+
(file, file.lower())
|
379 |
+
for file in os.listdir(
|
380 |
+
os.path.join(train_data_dir, folder)
|
381 |
+
)
|
382 |
+
)
|
383 |
+
if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp'))
|
384 |
+
]
|
385 |
+
)
|
386 |
+
|
387 |
+
if num_images == 0:
|
388 |
+
print(f'{folder} folder contain no images, skipping...')
|
389 |
+
else:
|
390 |
+
# Calculate the total number of steps for this folder
|
391 |
+
steps = repeats * num_images
|
392 |
+
total_steps += steps
|
393 |
+
|
394 |
+
# Print the result
|
395 |
+
print('\033[33mFolder', folder, ':', steps, 'steps\033[0m')
|
396 |
+
|
397 |
+
if total_steps == 0:
|
398 |
+
print(
|
399 |
+
'\033[33mNo images were found in folder',
|
400 |
+
train_data_dir,
|
401 |
+
'... please rectify!\033[0m',
|
402 |
+
)
|
403 |
+
return
|
404 |
+
|
405 |
+
# Print the result
|
406 |
+
# print(f"{total_steps} total steps")
|
407 |
+
|
408 |
+
if reg_data_dir == '':
|
409 |
+
reg_factor = 1
|
410 |
+
else:
|
411 |
+
print(
|
412 |
+
'\033[94mRegularisation images are used... Will double the number of steps required...\033[0m'
|
413 |
+
)
|
414 |
+
reg_factor = 2
|
415 |
+
|
416 |
+
# calculate max_train_steps
|
417 |
+
max_train_steps = int(
|
418 |
+
math.ceil(
|
419 |
+
float(total_steps)
|
420 |
+
/ int(train_batch_size)
|
421 |
+
* int(epoch)
|
422 |
+
* int(reg_factor)
|
423 |
+
)
|
424 |
+
)
|
425 |
+
print(f'max_train_steps = {max_train_steps}')
|
426 |
+
|
427 |
+
# calculate stop encoder training
|
428 |
+
if int(stop_text_encoder_training_pct) == -1:
|
429 |
+
stop_text_encoder_training = -1
|
430 |
+
elif stop_text_encoder_training_pct == None:
|
431 |
+
stop_text_encoder_training = 0
|
432 |
+
else:
|
433 |
+
stop_text_encoder_training = math.ceil(
|
434 |
+
float(max_train_steps) / 100 * int(stop_text_encoder_training_pct)
|
435 |
+
)
|
436 |
+
print(f'stop_text_encoder_training = {stop_text_encoder_training}')
|
437 |
+
|
438 |
+
lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
|
439 |
+
print(f'lr_warmup_steps = {lr_warmup_steps}')
|
440 |
+
|
441 |
+
run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_db.py"'
|
442 |
+
if v2:
|
443 |
+
run_cmd += ' --v2'
|
444 |
+
if v_parameterization:
|
445 |
+
run_cmd += ' --v_parameterization'
|
446 |
+
if enable_bucket:
|
447 |
+
run_cmd += ' --enable_bucket'
|
448 |
+
if no_token_padding:
|
449 |
+
run_cmd += ' --no_token_padding'
|
450 |
+
if weighted_captions:
|
451 |
+
run_cmd += ' --weighted_captions'
|
452 |
+
run_cmd += (
|
453 |
+
f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"'
|
454 |
+
)
|
455 |
+
run_cmd += f' --train_data_dir="{train_data_dir}"'
|
456 |
+
if len(reg_data_dir):
|
457 |
+
run_cmd += f' --reg_data_dir="{reg_data_dir}"'
|
458 |
+
run_cmd += f' --resolution={max_resolution}'
|
459 |
+
run_cmd += f' --output_dir="{output_dir}"'
|
460 |
+
run_cmd += f' --logging_dir="{logging_dir}"'
|
461 |
+
if not stop_text_encoder_training == 0:
|
462 |
+
run_cmd += (
|
463 |
+
f' --stop_text_encoder_training={stop_text_encoder_training}'
|
464 |
+
)
|
465 |
+
if not save_model_as == 'same as source model':
|
466 |
+
run_cmd += f' --save_model_as={save_model_as}'
|
467 |
+
# if not resume == '':
|
468 |
+
# run_cmd += f' --resume={resume}'
|
469 |
+
if not float(prior_loss_weight) == 1.0:
|
470 |
+
run_cmd += f' --prior_loss_weight={prior_loss_weight}'
|
471 |
+
if not vae == '':
|
472 |
+
run_cmd += f' --vae="{vae}"'
|
473 |
+
if not output_name == '':
|
474 |
+
run_cmd += f' --output_name="{output_name}"'
|
475 |
+
if int(max_token_length) > 75:
|
476 |
+
run_cmd += f' --max_token_length={max_token_length}'
|
477 |
+
if not max_train_epochs == '':
|
478 |
+
run_cmd += f' --max_train_epochs="{max_train_epochs}"'
|
479 |
+
if not max_data_loader_n_workers == '':
|
480 |
+
run_cmd += (
|
481 |
+
f' --max_data_loader_n_workers="{max_data_loader_n_workers}"'
|
482 |
+
)
|
483 |
+
if int(gradient_accumulation_steps) > 1:
|
484 |
+
run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}'
|
485 |
+
|
486 |
+
run_cmd += run_cmd_training(
|
487 |
+
learning_rate=learning_rate,
|
488 |
+
lr_scheduler=lr_scheduler,
|
489 |
+
lr_warmup_steps=lr_warmup_steps,
|
490 |
+
train_batch_size=train_batch_size,
|
491 |
+
max_train_steps=max_train_steps,
|
492 |
+
save_every_n_epochs=save_every_n_epochs,
|
493 |
+
mixed_precision=mixed_precision,
|
494 |
+
save_precision=save_precision,
|
495 |
+
seed=seed,
|
496 |
+
caption_extension=caption_extension,
|
497 |
+
cache_latents=cache_latents,
|
498 |
+
optimizer=optimizer,
|
499 |
+
optimizer_args=optimizer_args,
|
500 |
+
)
|
501 |
+
|
502 |
+
run_cmd += run_cmd_advanced_training(
|
503 |
+
max_train_epochs=max_train_epochs,
|
504 |
+
max_data_loader_n_workers=max_data_loader_n_workers,
|
505 |
+
max_token_length=max_token_length,
|
506 |
+
resume=resume,
|
507 |
+
save_state=save_state,
|
508 |
+
mem_eff_attn=mem_eff_attn,
|
509 |
+
clip_skip=clip_skip,
|
510 |
+
flip_aug=flip_aug,
|
511 |
+
color_aug=color_aug,
|
512 |
+
shuffle_caption=shuffle_caption,
|
513 |
+
gradient_checkpointing=gradient_checkpointing,
|
514 |
+
full_fp16=full_fp16,
|
515 |
+
xformers=xformers,
|
516 |
+
# use_8bit_adam=use_8bit_adam,
|
517 |
+
keep_tokens=keep_tokens,
|
518 |
+
persistent_data_loader_workers=persistent_data_loader_workers,
|
519 |
+
bucket_no_upscale=bucket_no_upscale,
|
520 |
+
random_crop=random_crop,
|
521 |
+
bucket_reso_steps=bucket_reso_steps,
|
522 |
+
caption_dropout_every_n_epochs=caption_dropout_every_n_epochs,
|
523 |
+
caption_dropout_rate=caption_dropout_rate,
|
524 |
+
noise_offset=noise_offset,
|
525 |
+
additional_parameters=additional_parameters,
|
526 |
+
vae_batch_size=vae_batch_size,
|
527 |
+
min_snr_gamma=min_snr_gamma,
|
528 |
+
)
|
529 |
+
|
530 |
+
run_cmd += run_cmd_sample(
|
531 |
+
sample_every_n_steps,
|
532 |
+
sample_every_n_epochs,
|
533 |
+
sample_sampler,
|
534 |
+
sample_prompts,
|
535 |
+
output_dir,
|
536 |
+
)
|
537 |
+
|
538 |
+
print(run_cmd)
|
539 |
+
|
540 |
+
# Run the command
|
541 |
+
if os.name == 'posix':
|
542 |
+
os.system(run_cmd)
|
543 |
+
else:
|
544 |
+
subprocess.run(run_cmd)
|
545 |
+
|
546 |
+
# check if output_dir/last is a folder... therefore it is a diffuser model
|
547 |
+
last_dir = pathlib.Path(f'{output_dir}/{output_name}')
|
548 |
+
|
549 |
+
if not last_dir.is_dir():
|
550 |
+
# Copy inference model for v2 if required
|
551 |
+
save_inference_file(output_dir, v2, v_parameterization, output_name)
|
552 |
+
|
553 |
+
|
554 |
+
def dreambooth_tab(
|
555 |
+
train_data_dir=gr.Textbox(),
|
556 |
+
reg_data_dir=gr.Textbox(),
|
557 |
+
output_dir=gr.Textbox(),
|
558 |
+
logging_dir=gr.Textbox(),
|
559 |
+
):
|
560 |
+
dummy_db_true = gr.Label(value=True, visible=False)
|
561 |
+
dummy_db_false = gr.Label(value=False, visible=False)
|
562 |
+
gr.Markdown('Train a custom model using kohya dreambooth python code...')
|
563 |
+
(
|
564 |
+
button_open_config,
|
565 |
+
button_save_config,
|
566 |
+
button_save_as_config,
|
567 |
+
config_file_name,
|
568 |
+
button_load_config,
|
569 |
+
) = gradio_config()
|
570 |
+
|
571 |
+
(
|
572 |
+
pretrained_model_name_or_path,
|
573 |
+
v2,
|
574 |
+
v_parameterization,
|
575 |
+
save_model_as,
|
576 |
+
model_list,
|
577 |
+
) = gradio_source_model()
|
578 |
+
|
579 |
+
with gr.Tab('Folders'):
|
580 |
+
with gr.Row():
|
581 |
+
train_data_dir = gr.Textbox(
|
582 |
+
label='Image folder',
|
583 |
+
placeholder='Folder where the training folders containing the images are located',
|
584 |
+
)
|
585 |
+
train_data_dir_input_folder = gr.Button(
|
586 |
+
'📂', elem_id='open_folder_small'
|
587 |
+
)
|
588 |
+
train_data_dir_input_folder.click(
|
589 |
+
get_folder_path,
|
590 |
+
outputs=train_data_dir,
|
591 |
+
show_progress=False,
|
592 |
+
)
|
593 |
+
reg_data_dir = gr.Textbox(
|
594 |
+
label='Regularisation folder',
|
595 |
+
placeholder='(Optional) Folder where where the regularization folders containing the images are located',
|
596 |
+
)
|
597 |
+
reg_data_dir_input_folder = gr.Button(
|
598 |
+
'📂', elem_id='open_folder_small'
|
599 |
+
)
|
600 |
+
reg_data_dir_input_folder.click(
|
601 |
+
get_folder_path,
|
602 |
+
outputs=reg_data_dir,
|
603 |
+
show_progress=False,
|
604 |
+
)
|
605 |
+
with gr.Row():
|
606 |
+
output_dir = gr.Textbox(
|
607 |
+
label='Model output folder',
|
608 |
+
placeholder='Folder to output trained model',
|
609 |
+
)
|
610 |
+
output_dir_input_folder = gr.Button(
|
611 |
+
'📂', elem_id='open_folder_small'
|
612 |
+
)
|
613 |
+
output_dir_input_folder.click(get_folder_path, outputs=output_dir)
|
614 |
+
logging_dir = gr.Textbox(
|
615 |
+
label='Logging folder',
|
616 |
+
placeholder='Optional: enable logging and output TensorBoard log to this folder',
|
617 |
+
)
|
618 |
+
logging_dir_input_folder = gr.Button(
|
619 |
+
'📂', elem_id='open_folder_small'
|
620 |
+
)
|
621 |
+
logging_dir_input_folder.click(
|
622 |
+
get_folder_path,
|
623 |
+
outputs=logging_dir,
|
624 |
+
show_progress=False,
|
625 |
+
)
|
626 |
+
with gr.Row():
|
627 |
+
output_name = gr.Textbox(
|
628 |
+
label='Model output name',
|
629 |
+
placeholder='Name of the model to output',
|
630 |
+
value='last',
|
631 |
+
interactive=True,
|
632 |
+
)
|
633 |
+
train_data_dir.change(
|
634 |
+
remove_doublequote,
|
635 |
+
inputs=[train_data_dir],
|
636 |
+
outputs=[train_data_dir],
|
637 |
+
)
|
638 |
+
reg_data_dir.change(
|
639 |
+
remove_doublequote,
|
640 |
+
inputs=[reg_data_dir],
|
641 |
+
outputs=[reg_data_dir],
|
642 |
+
)
|
643 |
+
output_dir.change(
|
644 |
+
remove_doublequote,
|
645 |
+
inputs=[output_dir],
|
646 |
+
outputs=[output_dir],
|
647 |
+
)
|
648 |
+
logging_dir.change(
|
649 |
+
remove_doublequote,
|
650 |
+
inputs=[logging_dir],
|
651 |
+
outputs=[logging_dir],
|
652 |
+
)
|
653 |
+
with gr.Tab('Training parameters'):
|
654 |
+
(
|
655 |
+
learning_rate,
|
656 |
+
lr_scheduler,
|
657 |
+
lr_warmup,
|
658 |
+
train_batch_size,
|
659 |
+
epoch,
|
660 |
+
save_every_n_epochs,
|
661 |
+
mixed_precision,
|
662 |
+
save_precision,
|
663 |
+
num_cpu_threads_per_process,
|
664 |
+
seed,
|
665 |
+
caption_extension,
|
666 |
+
cache_latents,
|
667 |
+
optimizer,
|
668 |
+
optimizer_args,
|
669 |
+
) = gradio_training(
|
670 |
+
learning_rate_value='1e-5',
|
671 |
+
lr_scheduler_value='cosine',
|
672 |
+
lr_warmup_value='10',
|
673 |
+
)
|
674 |
+
with gr.Row():
|
675 |
+
max_resolution = gr.Textbox(
|
676 |
+
label='Max resolution',
|
677 |
+
value='512,512',
|
678 |
+
placeholder='512,512',
|
679 |
+
)
|
680 |
+
stop_text_encoder_training = gr.Slider(
|
681 |
+
minimum=-1,
|
682 |
+
maximum=100,
|
683 |
+
value=0,
|
684 |
+
step=1,
|
685 |
+
label='Stop text encoder training',
|
686 |
+
)
|
687 |
+
enable_bucket = gr.Checkbox(label='Enable buckets', value=True)
|
688 |
+
with gr.Accordion('Advanced Configuration', open=False):
|
689 |
+
with gr.Row():
|
690 |
+
no_token_padding = gr.Checkbox(
|
691 |
+
label='No token padding', value=False
|
692 |
+
)
|
693 |
+
gradient_accumulation_steps = gr.Number(
|
694 |
+
label='Gradient accumulate steps', value='1'
|
695 |
+
)
|
696 |
+
weighted_captions = gr.Checkbox(
|
697 |
+
label='Weighted captions', value=False
|
698 |
+
)
|
699 |
+
with gr.Row():
|
700 |
+
prior_loss_weight = gr.Number(
|
701 |
+
label='Prior loss weight', value=1.0
|
702 |
+
)
|
703 |
+
vae = gr.Textbox(
|
704 |
+
label='VAE',
|
705 |
+
placeholder='(Optiona) path to checkpoint of vae to replace for training',
|
706 |
+
)
|
707 |
+
vae_button = gr.Button('📂', elem_id='open_folder_small')
|
708 |
+
vae_button.click(
|
709 |
+
get_any_file_path,
|
710 |
+
outputs=vae,
|
711 |
+
show_progress=False,
|
712 |
+
)
|
713 |
+
(
|
714 |
+
# use_8bit_adam,
|
715 |
+
xformers,
|
716 |
+
full_fp16,
|
717 |
+
gradient_checkpointing,
|
718 |
+
shuffle_caption,
|
719 |
+
color_aug,
|
720 |
+
flip_aug,
|
721 |
+
clip_skip,
|
722 |
+
mem_eff_attn,
|
723 |
+
save_state,
|
724 |
+
resume,
|
725 |
+
max_token_length,
|
726 |
+
max_train_epochs,
|
727 |
+
max_data_loader_n_workers,
|
728 |
+
keep_tokens,
|
729 |
+
persistent_data_loader_workers,
|
730 |
+
bucket_no_upscale,
|
731 |
+
random_crop,
|
732 |
+
bucket_reso_steps,
|
733 |
+
caption_dropout_every_n_epochs,
|
734 |
+
caption_dropout_rate,
|
735 |
+
noise_offset,
|
736 |
+
additional_parameters,
|
737 |
+
vae_batch_size,
|
738 |
+
min_snr_gamma,
|
739 |
+
) = gradio_advanced_training()
|
740 |
+
color_aug.change(
|
741 |
+
color_aug_changed,
|
742 |
+
inputs=[color_aug],
|
743 |
+
outputs=[cache_latents],
|
744 |
+
)
|
745 |
+
|
746 |
+
(
|
747 |
+
sample_every_n_steps,
|
748 |
+
sample_every_n_epochs,
|
749 |
+
sample_sampler,
|
750 |
+
sample_prompts,
|
751 |
+
) = sample_gradio_config()
|
752 |
+
|
753 |
+
with gr.Tab('Tools'):
|
754 |
+
gr.Markdown(
|
755 |
+
'This section provide Dreambooth tools to help setup your dataset...'
|
756 |
+
)
|
757 |
+
gradio_dreambooth_folder_creation_tab(
|
758 |
+
train_data_dir_input=train_data_dir,
|
759 |
+
reg_data_dir_input=reg_data_dir,
|
760 |
+
output_dir_input=output_dir,
|
761 |
+
logging_dir_input=logging_dir,
|
762 |
+
)
|
763 |
+
|
764 |
+
button_run = gr.Button('Train model', variant='primary')
|
765 |
+
|
766 |
+
# Setup gradio tensorboard buttons
|
767 |
+
button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard()
|
768 |
+
|
769 |
+
button_start_tensorboard.click(
|
770 |
+
start_tensorboard,
|
771 |
+
inputs=logging_dir,
|
772 |
+
show_progress=False,
|
773 |
+
)
|
774 |
+
|
775 |
+
button_stop_tensorboard.click(
|
776 |
+
stop_tensorboard,
|
777 |
+
show_progress=False,
|
778 |
+
)
|
779 |
+
|
780 |
+
settings_list = [
|
781 |
+
pretrained_model_name_or_path,
|
782 |
+
v2,
|
783 |
+
v_parameterization,
|
784 |
+
logging_dir,
|
785 |
+
train_data_dir,
|
786 |
+
reg_data_dir,
|
787 |
+
output_dir,
|
788 |
+
max_resolution,
|
789 |
+
learning_rate,
|
790 |
+
lr_scheduler,
|
791 |
+
lr_warmup,
|
792 |
+
train_batch_size,
|
793 |
+
epoch,
|
794 |
+
save_every_n_epochs,
|
795 |
+
mixed_precision,
|
796 |
+
save_precision,
|
797 |
+
seed,
|
798 |
+
num_cpu_threads_per_process,
|
799 |
+
cache_latents,
|
800 |
+
caption_extension,
|
801 |
+
enable_bucket,
|
802 |
+
gradient_checkpointing,
|
803 |
+
full_fp16,
|
804 |
+
no_token_padding,
|
805 |
+
stop_text_encoder_training,
|
806 |
+
# use_8bit_adam,
|
807 |
+
xformers,
|
808 |
+
save_model_as,
|
809 |
+
shuffle_caption,
|
810 |
+
save_state,
|
811 |
+
resume,
|
812 |
+
prior_loss_weight,
|
813 |
+
color_aug,
|
814 |
+
flip_aug,
|
815 |
+
clip_skip,
|
816 |
+
vae,
|
817 |
+
output_name,
|
818 |
+
max_token_length,
|
819 |
+
max_train_epochs,
|
820 |
+
max_data_loader_n_workers,
|
821 |
+
mem_eff_attn,
|
822 |
+
gradient_accumulation_steps,
|
823 |
+
model_list,
|
824 |
+
keep_tokens,
|
825 |
+
persistent_data_loader_workers,
|
826 |
+
bucket_no_upscale,
|
827 |
+
random_crop,
|
828 |
+
bucket_reso_steps,
|
829 |
+
caption_dropout_every_n_epochs,
|
830 |
+
caption_dropout_rate,
|
831 |
+
optimizer,
|
832 |
+
optimizer_args,
|
833 |
+
noise_offset,
|
834 |
+
sample_every_n_steps,
|
835 |
+
sample_every_n_epochs,
|
836 |
+
sample_sampler,
|
837 |
+
sample_prompts,
|
838 |
+
additional_parameters,
|
839 |
+
vae_batch_size,
|
840 |
+
min_snr_gamma,
|
841 |
+
weighted_captions,
|
842 |
+
]
|
843 |
+
|
844 |
+
button_open_config.click(
|
845 |
+
open_configuration,
|
846 |
+
inputs=[dummy_db_true, config_file_name] + settings_list,
|
847 |
+
outputs=[config_file_name] + settings_list,
|
848 |
+
show_progress=False,
|
849 |
+
)
|
850 |
+
|
851 |
+
button_load_config.click(
|
852 |
+
open_configuration,
|
853 |
+
inputs=[dummy_db_false, config_file_name] + settings_list,
|
854 |
+
outputs=[config_file_name] + settings_list,
|
855 |
+
show_progress=False,
|
856 |
+
)
|
857 |
+
|
858 |
+
button_save_config.click(
|
859 |
+
save_configuration,
|
860 |
+
inputs=[dummy_db_false, config_file_name] + settings_list,
|
861 |
+
outputs=[config_file_name],
|
862 |
+
show_progress=False,
|
863 |
+
)
|
864 |
+
|
865 |
+
button_save_as_config.click(
|
866 |
+
save_configuration,
|
867 |
+
inputs=[dummy_db_true, config_file_name] + settings_list,
|
868 |
+
outputs=[config_file_name],
|
869 |
+
show_progress=False,
|
870 |
+
)
|
871 |
+
|
872 |
+
button_run.click(
|
873 |
+
train_model,
|
874 |
+
inputs=settings_list,
|
875 |
+
show_progress=False,
|
876 |
+
)
|
877 |
+
|
878 |
+
return (
|
879 |
+
train_data_dir,
|
880 |
+
reg_data_dir,
|
881 |
+
output_dir,
|
882 |
+
logging_dir,
|
883 |
+
)
|
884 |
+
|
885 |
+
|
886 |
+
def UI(**kwargs):
|
887 |
+
css = ''
|
888 |
+
|
889 |
+
if os.path.exists('./style.css'):
|
890 |
+
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
|
891 |
+
print('Load CSS...')
|
892 |
+
css += file.read() + '\n'
|
893 |
+
|
894 |
+
interface = gr.Blocks(css=css)
|
895 |
+
|
896 |
+
with interface:
|
897 |
+
with gr.Tab('Dreambooth'):
|
898 |
+
(
|
899 |
+
train_data_dir_input,
|
900 |
+
reg_data_dir_input,
|
901 |
+
output_dir_input,
|
902 |
+
logging_dir_input,
|
903 |
+
) = dreambooth_tab()
|
904 |
+
with gr.Tab('Utilities'):
|
905 |
+
utilities_tab(
|
906 |
+
train_data_dir_input=train_data_dir_input,
|
907 |
+
reg_data_dir_input=reg_data_dir_input,
|
908 |
+
output_dir_input=output_dir_input,
|
909 |
+
logging_dir_input=logging_dir_input,
|
910 |
+
enable_copy_info_button=True,
|
911 |
+
)
|
912 |
+
|
913 |
+
# Show the interface
|
914 |
+
launch_kwargs = {}
|
915 |
+
if not kwargs.get('username', None) == '':
|
916 |
+
launch_kwargs['auth'] = (
|
917 |
+
kwargs.get('username', None),
|
918 |
+
kwargs.get('password', None),
|
919 |
+
)
|
920 |
+
if kwargs.get('server_port', 0) > 0:
|
921 |
+
launch_kwargs['server_port'] = kwargs.get('server_port', 0)
|
922 |
+
if kwargs.get('inbrowser', False):
|
923 |
+
launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False)
|
924 |
+
print(launch_kwargs)
|
925 |
+
interface.launch(**launch_kwargs)
|
926 |
+
|
927 |
+
|
928 |
+
if __name__ == '__main__':
|
929 |
+
# torch.cuda.set_per_process_memory_fraction(0.48)
|
930 |
+
parser = argparse.ArgumentParser()
|
931 |
+
parser.add_argument(
|
932 |
+
'--username', type=str, default='', help='Username for authentication'
|
933 |
+
)
|
934 |
+
parser.add_argument(
|
935 |
+
'--password', type=str, default='', help='Password for authentication'
|
936 |
+
)
|
937 |
+
parser.add_argument(
|
938 |
+
'--server_port',
|
939 |
+
type=int,
|
940 |
+
default=0,
|
941 |
+
help='Port to run the server listener on',
|
942 |
+
)
|
943 |
+
parser.add_argument(
|
944 |
+
'--inbrowser', action='store_true', help='Open in browser'
|
945 |
+
)
|
946 |
+
|
947 |
+
args = parser.parse_args()
|
948 |
+
|
949 |
+
UI(
|
950 |
+
username=args.username,
|
951 |
+
password=args.password,
|
952 |
+
inbrowser=args.inbrowser,
|
953 |
+
server_port=args.server_port,
|
954 |
+
)
|
fine_tune.py
ADDED
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# training with captions
|
2 |
+
# XXX dropped option: hypernetwork training
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import gc
|
6 |
+
import math
|
7 |
+
import os
|
8 |
+
import toml
|
9 |
+
from multiprocessing import Value
|
10 |
+
|
11 |
+
from tqdm import tqdm
|
12 |
+
import torch
|
13 |
+
from accelerate.utils import set_seed
|
14 |
+
import diffusers
|
15 |
+
from diffusers import DDPMScheduler
|
16 |
+
|
17 |
+
import library.train_util as train_util
|
18 |
+
import library.config_util as config_util
|
19 |
+
from library.config_util import (
|
20 |
+
ConfigSanitizer,
|
21 |
+
BlueprintGenerator,
|
22 |
+
)
|
23 |
+
import library.custom_train_functions as custom_train_functions
|
24 |
+
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings
|
25 |
+
|
26 |
+
|
27 |
+
def train(args):
|
28 |
+
train_util.verify_training_args(args)
|
29 |
+
train_util.prepare_dataset_args(args, True)
|
30 |
+
|
31 |
+
cache_latents = args.cache_latents
|
32 |
+
|
33 |
+
if args.seed is not None:
|
34 |
+
set_seed(args.seed) # 乱数系列を初期化する
|
35 |
+
|
36 |
+
tokenizer = train_util.load_tokenizer(args)
|
37 |
+
|
38 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True))
|
39 |
+
if args.dataset_config is not None:
|
40 |
+
print(f"Load dataset config from {args.dataset_config}")
|
41 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
42 |
+
ignored = ["train_data_dir", "in_json"]
|
43 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
44 |
+
print(
|
45 |
+
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
46 |
+
", ".join(ignored)
|
47 |
+
)
|
48 |
+
)
|
49 |
+
else:
|
50 |
+
user_config = {
|
51 |
+
"datasets": [
|
52 |
+
{
|
53 |
+
"subsets": [
|
54 |
+
{
|
55 |
+
"image_dir": args.train_data_dir,
|
56 |
+
"metadata_file": args.in_json,
|
57 |
+
}
|
58 |
+
]
|
59 |
+
}
|
60 |
+
]
|
61 |
+
}
|
62 |
+
|
63 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
64 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
65 |
+
|
66 |
+
current_epoch = Value("i", 0)
|
67 |
+
current_step = Value("i", 0)
|
68 |
+
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
69 |
+
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
70 |
+
|
71 |
+
if args.debug_dataset:
|
72 |
+
train_util.debug_dataset(train_dataset_group)
|
73 |
+
return
|
74 |
+
if len(train_dataset_group) == 0:
|
75 |
+
print(
|
76 |
+
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
|
77 |
+
)
|
78 |
+
return
|
79 |
+
|
80 |
+
if cache_latents:
|
81 |
+
assert (
|
82 |
+
train_dataset_group.is_latent_cacheable()
|
83 |
+
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
84 |
+
|
85 |
+
# acceleratorを準備する
|
86 |
+
print("prepare accelerator")
|
87 |
+
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
88 |
+
|
89 |
+
# mixed precisionに対応した型を用意しておき適宜castする
|
90 |
+
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
91 |
+
|
92 |
+
# モデルを読み込む
|
93 |
+
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype)
|
94 |
+
|
95 |
+
# verify load/save model formats
|
96 |
+
if load_stable_diffusion_format:
|
97 |
+
src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
|
98 |
+
src_diffusers_model_path = None
|
99 |
+
else:
|
100 |
+
src_stable_diffusion_ckpt = None
|
101 |
+
src_diffusers_model_path = args.pretrained_model_name_or_path
|
102 |
+
|
103 |
+
if args.save_model_as is None:
|
104 |
+
save_stable_diffusion_format = load_stable_diffusion_format
|
105 |
+
use_safetensors = args.use_safetensors
|
106 |
+
else:
|
107 |
+
save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
|
108 |
+
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
|
109 |
+
|
110 |
+
# Diffusers版のxformers使用フラグを設定する関数
|
111 |
+
def set_diffusers_xformers_flag(model, valid):
|
112 |
+
# model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう
|
113 |
+
# pipeが自動で再帰的にset_use_memory_efficient_attention_xformersを探すんだって(;´Д`)
|
114 |
+
# U-Netだけ使う時にはどうすればいいのか……仕方ないからコピって使うか
|
115 |
+
# 0.10.2でなんか巻き戻って個別に指定するようになった(;^ω^)
|
116 |
+
|
117 |
+
# Recursively walk through all the children.
|
118 |
+
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
119 |
+
# gets the message
|
120 |
+
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
121 |
+
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
122 |
+
module.set_use_memory_efficient_attention_xformers(valid)
|
123 |
+
|
124 |
+
for child in module.children():
|
125 |
+
fn_recursive_set_mem_eff(child)
|
126 |
+
|
127 |
+
fn_recursive_set_mem_eff(model)
|
128 |
+
|
129 |
+
# モデルに xformers とか memory efficient attention を組み込む
|
130 |
+
if args.diffusers_xformers:
|
131 |
+
print("Use xformers by Diffusers")
|
132 |
+
set_diffusers_xformers_flag(unet, True)
|
133 |
+
else:
|
134 |
+
# Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある
|
135 |
+
print("Disable Diffusers' xformers")
|
136 |
+
set_diffusers_xformers_flag(unet, False)
|
137 |
+
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
138 |
+
|
139 |
+
# 学習を準備する
|
140 |
+
if cache_latents:
|
141 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
142 |
+
vae.requires_grad_(False)
|
143 |
+
vae.eval()
|
144 |
+
with torch.no_grad():
|
145 |
+
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
146 |
+
vae.to("cpu")
|
147 |
+
if torch.cuda.is_available():
|
148 |
+
torch.cuda.empty_cache()
|
149 |
+
gc.collect()
|
150 |
+
|
151 |
+
accelerator.wait_for_everyone()
|
152 |
+
|
153 |
+
# 学習を準備する:モデルを適切な状態にする
|
154 |
+
training_models = []
|
155 |
+
if args.gradient_checkpointing:
|
156 |
+
unet.enable_gradient_checkpointing()
|
157 |
+
training_models.append(unet)
|
158 |
+
|
159 |
+
if args.train_text_encoder:
|
160 |
+
print("enable text encoder training")
|
161 |
+
if args.gradient_checkpointing:
|
162 |
+
text_encoder.gradient_checkpointing_enable()
|
163 |
+
training_models.append(text_encoder)
|
164 |
+
else:
|
165 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
166 |
+
text_encoder.requires_grad_(False) # text encoderは学習しない
|
167 |
+
if args.gradient_checkpointing:
|
168 |
+
text_encoder.gradient_checkpointing_enable()
|
169 |
+
text_encoder.train() # required for gradient_checkpointing
|
170 |
+
else:
|
171 |
+
text_encoder.eval()
|
172 |
+
|
173 |
+
if not cache_latents:
|
174 |
+
vae.requires_grad_(False)
|
175 |
+
vae.eval()
|
176 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
177 |
+
|
178 |
+
for m in training_models:
|
179 |
+
m.requires_grad_(True)
|
180 |
+
params = []
|
181 |
+
for m in training_models:
|
182 |
+
params.extend(m.parameters())
|
183 |
+
params_to_optimize = params
|
184 |
+
|
185 |
+
# 学習に必要なクラスを準備する
|
186 |
+
print("prepare optimizer, data loader etc.")
|
187 |
+
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
188 |
+
|
189 |
+
# dataloaderを準備する
|
190 |
+
# DataLoaderのプロセス数:0はメインプロセスになる
|
191 |
+
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
192 |
+
train_dataloader = torch.utils.data.DataLoader(
|
193 |
+
train_dataset_group,
|
194 |
+
batch_size=1,
|
195 |
+
shuffle=True,
|
196 |
+
collate_fn=collater,
|
197 |
+
num_workers=n_workers,
|
198 |
+
persistent_workers=args.persistent_data_loader_workers,
|
199 |
+
)
|
200 |
+
|
201 |
+
# 学習ステップ数を計算する
|
202 |
+
if args.max_train_epochs is not None:
|
203 |
+
args.max_train_steps = args.max_train_epochs * math.ceil(
|
204 |
+
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
205 |
+
)
|
206 |
+
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
207 |
+
|
208 |
+
# データセット側にも学習ステップを送信
|
209 |
+
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
210 |
+
|
211 |
+
# lr schedulerを用意する
|
212 |
+
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
213 |
+
|
214 |
+
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
215 |
+
if args.full_fp16:
|
216 |
+
assert (
|
217 |
+
args.mixed_precision == "fp16"
|
218 |
+
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
219 |
+
print("enable full fp16 training.")
|
220 |
+
unet.to(weight_dtype)
|
221 |
+
text_encoder.to(weight_dtype)
|
222 |
+
|
223 |
+
# acceleratorがなんかよろしくやってくれるらしい
|
224 |
+
if args.train_text_encoder:
|
225 |
+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
226 |
+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
227 |
+
)
|
228 |
+
else:
|
229 |
+
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
230 |
+
|
231 |
+
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
232 |
+
if args.full_fp16:
|
233 |
+
train_util.patch_accelerator_for_fp16_training(accelerator)
|
234 |
+
|
235 |
+
# resumeする
|
236 |
+
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
237 |
+
|
238 |
+
# epoch数を計算する
|
239 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
240 |
+
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
241 |
+
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
242 |
+
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
243 |
+
|
244 |
+
# 学習する
|
245 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
246 |
+
print("running training / 学習開始")
|
247 |
+
print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
|
248 |
+
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
249 |
+
print(f" num epochs / epoch数: {num_train_epochs}")
|
250 |
+
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
251 |
+
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
252 |
+
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
253 |
+
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
254 |
+
|
255 |
+
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
256 |
+
global_step = 0
|
257 |
+
|
258 |
+
noise_scheduler = DDPMScheduler(
|
259 |
+
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
260 |
+
)
|
261 |
+
|
262 |
+
if accelerator.is_main_process:
|
263 |
+
accelerator.init_trackers("finetuning")
|
264 |
+
|
265 |
+
for epoch in range(num_train_epochs):
|
266 |
+
print(f"epoch {epoch+1}/{num_train_epochs}")
|
267 |
+
current_epoch.value = epoch + 1
|
268 |
+
|
269 |
+
for m in training_models:
|
270 |
+
m.train()
|
271 |
+
|
272 |
+
loss_total = 0
|
273 |
+
for step, batch in enumerate(train_dataloader):
|
274 |
+
current_step.value = global_step
|
275 |
+
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
276 |
+
with torch.no_grad():
|
277 |
+
if "latents" in batch and batch["latents"] is not None:
|
278 |
+
latents = batch["latents"].to(accelerator.device) # .to(dtype=weight_dtype)
|
279 |
+
else:
|
280 |
+
# latentに変換
|
281 |
+
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
282 |
+
latents = latents * 0.18215
|
283 |
+
b_size = latents.shape[0]
|
284 |
+
|
285 |
+
with torch.set_grad_enabled(args.train_text_encoder):
|
286 |
+
# Get the text embedding for conditioning
|
287 |
+
if args.weighted_captions:
|
288 |
+
encoder_hidden_states = get_weighted_text_embeddings(tokenizer,
|
289 |
+
text_encoder,
|
290 |
+
batch["captions"],
|
291 |
+
accelerator.device,
|
292 |
+
args.max_token_length // 75 if args.max_token_length else 1,
|
293 |
+
clip_skip=args.clip_skip,
|
294 |
+
)
|
295 |
+
else:
|
296 |
+
input_ids = batch["input_ids"].to(accelerator.device)
|
297 |
+
encoder_hidden_states = train_util.get_hidden_states(
|
298 |
+
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
|
299 |
+
)
|
300 |
+
|
301 |
+
# Sample noise that we'll add to the latents
|
302 |
+
noise = torch.randn_like(latents, device=latents.device)
|
303 |
+
if args.noise_offset:
|
304 |
+
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
305 |
+
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
306 |
+
|
307 |
+
# Sample a random timestep for each image
|
308 |
+
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
309 |
+
timesteps = timesteps.long()
|
310 |
+
|
311 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
312 |
+
# (this is the forward diffusion process)
|
313 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
314 |
+
|
315 |
+
# Predict the noise residual
|
316 |
+
with accelerator.autocast():
|
317 |
+
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
318 |
+
|
319 |
+
if args.v_parameterization:
|
320 |
+
# v-parameterization training
|
321 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
322 |
+
else:
|
323 |
+
target = noise
|
324 |
+
|
325 |
+
if args.min_snr_gamma:
|
326 |
+
# do not mean over batch dimension for snr weight
|
327 |
+
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
328 |
+
loss = loss.mean([1, 2, 3])
|
329 |
+
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
330 |
+
loss = loss.mean() # mean over batch dimension
|
331 |
+
else:
|
332 |
+
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
333 |
+
|
334 |
+
accelerator.backward(loss)
|
335 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
336 |
+
params_to_clip = []
|
337 |
+
for m in training_models:
|
338 |
+
params_to_clip.extend(m.parameters())
|
339 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
340 |
+
|
341 |
+
optimizer.step()
|
342 |
+
lr_scheduler.step()
|
343 |
+
optimizer.zero_grad(set_to_none=True)
|
344 |
+
|
345 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
346 |
+
if accelerator.sync_gradients:
|
347 |
+
progress_bar.update(1)
|
348 |
+
global_step += 1
|
349 |
+
|
350 |
+
train_util.sample_images(
|
351 |
+
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
|
352 |
+
)
|
353 |
+
|
354 |
+
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
355 |
+
if args.logging_dir is not None:
|
356 |
+
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
357 |
+
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
358 |
+
logs["lr/d*lr"] = (
|
359 |
+
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
|
360 |
+
)
|
361 |
+
accelerator.log(logs, step=global_step)
|
362 |
+
|
363 |
+
# TODO moving averageにする
|
364 |
+
loss_total += current_loss
|
365 |
+
avr_loss = loss_total / (step + 1)
|
366 |
+
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
367 |
+
progress_bar.set_postfix(**logs)
|
368 |
+
|
369 |
+
if global_step >= args.max_train_steps:
|
370 |
+
break
|
371 |
+
|
372 |
+
if args.logging_dir is not None:
|
373 |
+
logs = {"loss/epoch": loss_total / len(train_dataloader)}
|
374 |
+
accelerator.log(logs, step=epoch + 1)
|
375 |
+
|
376 |
+
accelerator.wait_for_everyone()
|
377 |
+
|
378 |
+
if args.save_every_n_epochs is not None:
|
379 |
+
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
380 |
+
train_util.save_sd_model_on_epoch_end(
|
381 |
+
args,
|
382 |
+
accelerator,
|
383 |
+
src_path,
|
384 |
+
save_stable_diffusion_format,
|
385 |
+
use_safetensors,
|
386 |
+
save_dtype,
|
387 |
+
epoch,
|
388 |
+
num_train_epochs,
|
389 |
+
global_step,
|
390 |
+
unwrap_model(text_encoder),
|
391 |
+
unwrap_model(unet),
|
392 |
+
vae,
|
393 |
+
)
|
394 |
+
|
395 |
+
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
396 |
+
|
397 |
+
is_main_process = accelerator.is_main_process
|
398 |
+
if is_main_process:
|
399 |
+
unet = unwrap_model(unet)
|
400 |
+
text_encoder = unwrap_model(text_encoder)
|
401 |
+
|
402 |
+
accelerator.end_training()
|
403 |
+
|
404 |
+
if args.save_state:
|
405 |
+
train_util.save_state_on_train_end(args, accelerator)
|
406 |
+
|
407 |
+
del accelerator # この後メモリを使うのでこれは消す
|
408 |
+
|
409 |
+
if is_main_process:
|
410 |
+
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
411 |
+
train_util.save_sd_model_on_train_end(
|
412 |
+
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
|
413 |
+
)
|
414 |
+
print("model saved.")
|
415 |
+
|
416 |
+
|
417 |
+
def setup_parser() -> argparse.ArgumentParser:
|
418 |
+
parser = argparse.ArgumentParser()
|
419 |
+
|
420 |
+
train_util.add_sd_models_arguments(parser)
|
421 |
+
train_util.add_dataset_arguments(parser, False, True, True)
|
422 |
+
train_util.add_training_arguments(parser, False)
|
423 |
+
train_util.add_sd_saving_arguments(parser)
|
424 |
+
train_util.add_optimizer_arguments(parser)
|
425 |
+
config_util.add_config_arguments(parser)
|
426 |
+
custom_train_functions.add_custom_train_arguments(parser)
|
427 |
+
|
428 |
+
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
|
429 |
+
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
430 |
+
|
431 |
+
return parser
|
432 |
+
|
433 |
+
|
434 |
+
if __name__ == "__main__":
|
435 |
+
parser = setup_parser()
|
436 |
+
|
437 |
+
args = parser.parse_args()
|
438 |
+
args = train_util.read_config_from_file(args, parser)
|
439 |
+
|
440 |
+
train(args)
|
fine_tune_README.md
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
It is a fine tuning that corresponds to NovelAI's proposed learning method, automatic captioning, tagging, Windows + VRAM 12GB (for v1.4/1.5) environment, etc.
|
2 |
+
|
3 |
+
## overview
|
4 |
+
Fine tuning of U-Net of Stable Diffusion using Diffusers. It corresponds to the following improvements in NovelAI's article (For Aspect Ratio Bucketing, I referred to NovelAI's code, but the final code is all original).
|
5 |
+
|
6 |
+
* Use the output of the penultimate layer instead of the last layer of CLIP (Text Encoder).
|
7 |
+
* Learning at non-square resolutions (Aspect Ratio Bucketing).
|
8 |
+
* Extend token length from 75 to 225.
|
9 |
+
* Captioning with BLIP (automatic creation of captions), automatic tagging with DeepDanbooru or WD14Tagger.
|
10 |
+
* Also supports Hypernetwork learning.
|
11 |
+
* Supports Stable Diffusion v2.0 (base and 768/v).
|
12 |
+
* By acquiring the output of VAE in advance and saving it to disk, we aim to save memory and speed up learning.
|
13 |
+
|
14 |
+
Text Encoder is not trained by default. For fine tuning of the whole model, it seems common to learn only U-Net (NovelAI seems to be the same). Text Encoder can also be learned as an option.
|
15 |
+
|
16 |
+
## Additional features
|
17 |
+
### Change CLIP output
|
18 |
+
CLIP (Text Encoder) converts the text into features in order to reflect the prompt in the image. Stable diffusion uses the output of the last layer of CLIP, but you can change it to use the output of the penultimate layer. According to NovelAI, this will reflect prompts more accurately.
|
19 |
+
It is also possible to use the output of the last layer as is.
|
20 |
+
*Stable Diffusion 2.0 uses the penultimate layer by default. Do not specify the clip_skip option.
|
21 |
+
|
22 |
+
### Training in non-square resolutions
|
23 |
+
Stable Diffusion is trained at 512\*512, but also at resolutions such as 256\*1024 and 384\*640. It is expected that this will reduce the cropped portion and learn the relationship between prompts and images more correctly.
|
24 |
+
The learning resolution is adjusted vertically and horizontally in units of 64 pixels within a range that does not exceed the resolution area (= memory usage) given as a parameter.
|
25 |
+
|
26 |
+
In machine learning, it is common to unify all input sizes, but there are no particular restrictions, and in fact it is okay as long as they are unified within the same batch. NovelAI's bucketing seems to refer to classifying training data in advance for each learning resolution according to the aspect ratio. And by creating a batch with the images in each bucket, the image size of the batch is unified.
|
27 |
+
|
28 |
+
### Extending token length from 75 to 225
|
29 |
+
Stable diffusion has a maximum of 75 tokens (77 tokens including the start and end), but we will extend it to 225 tokens.
|
30 |
+
However, the maximum length that CLIP accepts is 75 tokens, so in the case of 225 tokens, we simply divide it into thirds, call CLIP, and then concatenate the results.
|
31 |
+
|
32 |
+
*I'm not sure if this is the preferred implementation. It seems to be working for now. Especially in 2.0, there is no implementation that can be used as a reference, so I have implemented it independently.
|
33 |
+
|
34 |
+
*Automatic1111's Web UI seems to divide the text with commas in mind, but in my case, it's a simple division.
|
35 |
+
|
36 |
+
## Environmental arrangement
|
37 |
+
|
38 |
+
See the [README](./README-en.md) in this repository.
|
39 |
+
|
40 |
+
## Preparing teacher data
|
41 |
+
|
42 |
+
Prepare the image data you want to learn and put it in any folder. No prior preparation such as resizing is required.
|
43 |
+
However, for images that are smaller than the training resolution, it is recommended to enlarge them while maintaining the quality using super-resolution.
|
44 |
+
|
45 |
+
It also supports multiple teacher data folders. Preprocessing will be executed for each folder.
|
46 |
+
|
47 |
+
For example, store an image like this:
|
48 |
+
|
49 |
+
![Teacher data folder screenshot](https://user-images.githubusercontent.com/52813779/208907739-8e89d5fa-6ca8-4b60-8927-f484d2a9ae04.png)
|
50 |
+
|
51 |
+
## Automatic captioning
|
52 |
+
Skip if you just want to learn tags without captions.
|
53 |
+
|
54 |
+
Also, when preparing captions manually, prepare them in the same directory as the teacher data image, with the same file name, extension .caption, etc. Each file should be a text file with only one line.
|
55 |
+
|
56 |
+
### Captioning with BLIP
|
57 |
+
|
58 |
+
The latest version no longer requires BLIP downloads, weight downloads, and additional virtual environments. Works as-is.
|
59 |
+
|
60 |
+
Run make_captions.py in the finetune folder.
|
61 |
+
|
62 |
+
```
|
63 |
+
python finetune\make_captions.py --batch_size <batch size> <teacher data folder>
|
64 |
+
```
|
65 |
+
|
66 |
+
If the batch size is 8 and the training data is placed in the parent folder train_data, it will be as follows.
|
67 |
+
|
68 |
+
```
|
69 |
+
python finetune\make_captions.py --batch_size 8 ..\train_data
|
70 |
+
```
|
71 |
+
|
72 |
+
A caption file is created in the same directory as the teacher data image with the same file name and extension .caption.
|
73 |
+
|
74 |
+
Increase or decrease batch_size according to the VRAM capacity of the GPU. Bigger is faster (I think 12GB of VRAM can be a little more).
|
75 |
+
You can specify the maximum length of the caption with the max_length option. Default is 75. It may be longer if the model is trained with a token length of 225.
|
76 |
+
You can change the caption extension with the caption_extension option. Default is .caption (.txt conflicts with DeepDanbooru described later).
|
77 |
+
|
78 |
+
If there are multiple teacher data folders, execute for each folder.
|
79 |
+
|
80 |
+
Note that the inference is random, so the results will change each time you run it. If you want to fix it, specify a random number seed like "--seed 42" with the --seed option.
|
81 |
+
|
82 |
+
For other options, please refer to the help with --help (there seems to be no documentation for the meaning of the parameters, so you have to look at the source).
|
83 |
+
|
84 |
+
A caption file is generated with the extension .caption by default.
|
85 |
+
|
86 |
+
![Folder where caption is generated](https://user-images.githubusercontent.com/52813779/208908845-48a9d36c-f6ee-4dae-af71-9ab462d1459e.png)
|
87 |
+
|
88 |
+
For example, with captions like:
|
89 |
+
|
90 |
+
![captions and images](https://user-images.githubusercontent.com/52813779/208908947-af936957-5d73-4339-b6c8-945a52857373.png)
|
91 |
+
|
92 |
+
## Tagged by DeepDanbooru
|
93 |
+
If you do not want to tag the danbooru tag itself, please proceed to "Preprocessing of caption and tag information".
|
94 |
+
|
95 |
+
Tagging is done with DeepDanbooru or WD14Tagger. WD14Tagger seems to be more accurate. If you want to tag with WD14Tagger, skip to the next chapter.
|
96 |
+
|
97 |
+
### Environmental arrangement
|
98 |
+
Clone DeepDanbooru https://github.com/KichangKim/DeepDanbooru into your working folder, or download the zip and extract it. I unzipped it.
|
99 |
+
Also, download deepdanbooru-v3-20211112-sgd-e28.zip from Assets of "DeepDanbooru Pretrained Model v3-20211112-sgd-e28" on the DeepDanbooru Releases page https://github.com/KichangKim/DeepDanbooru/releases and extract it to the DeepDanbooru folder.
|
100 |
+
|
101 |
+
Download from below. Click to open Assets and download from there.
|
102 |
+
|
103 |
+
![DeepDanbooru download page](https://user-images.githubusercontent.com/52813779/208909417-10e597df-7085-41ee-bd06-3e856a1339df.png)
|
104 |
+
|
105 |
+
Make a directory structure like this
|
106 |
+
|
107 |
+
![DeepDanbooru directory structure](https://user-images.githubusercontent.com/52813779/208909486-38935d8b-8dc6-43f1-84d3-fef99bc471aa.png)
|
108 |
+
|
109 |
+
Install the necessary libraries for the Diffusers environment. Go to the DeepDanbooru folder and install it (I think it's actually just adding tensorflow-io).
|
110 |
+
|
111 |
+
```
|
112 |
+
pip install -r requirements.txt
|
113 |
+
```
|
114 |
+
|
115 |
+
Next, install DeepDanbooru itself.
|
116 |
+
|
117 |
+
```
|
118 |
+
pip install .
|
119 |
+
```
|
120 |
+
|
121 |
+
This completes the preparation of the environment for tagging.
|
122 |
+
|
123 |
+
### Implementing tagging
|
124 |
+
Go to DeepDanbooru's folder and run deepdanbooru to tag.
|
125 |
+
|
126 |
+
```
|
127 |
+
deepdanbooru evaluate <teacher data folder> --project-path deepdanbooru-v3-20211112-sgd-e28 --allow-folder --save-txt
|
128 |
+
```
|
129 |
+
|
130 |
+
If you put the training data in the parent folder train_data, it will be as follows.
|
131 |
+
|
132 |
+
```
|
133 |
+
deepdanbooru evaluate ../train_data --project-path deepdanbooru-v3-20211112-sgd-e28 --allow-folder --save-txt
|
134 |
+
```
|
135 |
+
|
136 |
+
A tag file is created in the same directory as the teacher data image with the same file name and extension .txt. It is slow because it is processed one by one.
|
137 |
+
|
138 |
+
If there are multiple teacher data folders, execute for each folder.
|
139 |
+
|
140 |
+
It is generated as follows.
|
141 |
+
|
142 |
+
![DeepDanbooru generated files](https://user-images.githubusercontent.com/52813779/208909855-d21b9c98-f2d3-4283-8238-5b0e5aad6691.png)
|
143 |
+
|
144 |
+
A tag is attached like this (great amount of information...).
|
145 |
+
|
146 |
+
![Deep Danbooru tag and image](https://user-images.githubusercontent.com/52813779/208909908-a7920174-266e-48d5-aaef-940aba709519.png)
|
147 |
+
|
148 |
+
## Tagging with WD14Tagger
|
149 |
+
This procedure uses WD14Tagger instead of DeepDanbooru.
|
150 |
+
|
151 |
+
Use the tagger used in Mr. Automatic1111's WebUI. I referred to the information on this github page (https://github.com/toriato/stable-diffusion-webui-wd14-tagger#mrsmilingwolfs-model-aka-waifu-diffusion-14-tagger).
|
152 |
+
|
153 |
+
The modules required for the initial environment maintenance have already been installed. Weights are automatically downloaded from Hugging Face.
|
154 |
+
|
155 |
+
### Implementing tagging
|
156 |
+
Run the script to do the tagging.
|
157 |
+
```
|
158 |
+
python tag_images_by_wd14_tagger.py --batch_size <batch size> <teacher data folder>
|
159 |
+
```
|
160 |
+
|
161 |
+
If you put the training data in the parent folder train_data, it will be as follows.
|
162 |
+
```
|
163 |
+
python tag_images_by_wd14_tagger.py --batch_size 4 ..\train_data
|
164 |
+
```
|
165 |
+
|
166 |
+
The model file will be automatically downloaded to the wd14_tagger_model folder on first launch (folder can be changed in options). It will be as follows.
|
167 |
+
|
168 |
+
![downloaded file](https://user-images.githubusercontent.com/52813779/208910447-f7eb0582-90d6-49d3-a666-2b508c7d1842.png)
|
169 |
+
|
170 |
+
A tag file is created in the same directory as the teacher data image with the same file name and extension .txt.
|
171 |
+
|
172 |
+
![generated tag file](https://user-images.githubusercontent.com/52813779/208910534-ea514373-1185-4b7d-9ae3-61eb50bc294e.png)
|
173 |
+
|
174 |
+
![tags and images](https://user-images.githubusercontent.com/52813779/208910599-29070c15-7639-474f-b3e4-06bd5a3df29e.png)
|
175 |
+
|
176 |
+
With the thresh option, you can specify the number of confidences of the determined tag to attach the tag. The default is 0.35, same as the WD14Tagger sample. Lower values give more tags, but less accuracy.
|
177 |
+
Increase or decrease batch_size according to the VRAM capacity of the GPU. Bigger is faster (I think 12GB of VRAM can be a little more). You can change the tag file extension with the caption_extension option. Default is .txt.
|
178 |
+
You can specify the folder where the model is saved with the model_dir option.
|
179 |
+
Also, if you specify the force_download option, the model will be re-downloaded even if there is a save destination folder.
|
180 |
+
|
181 |
+
If there are multiple teacher data folders, execute for each folder.
|
182 |
+
|
183 |
+
## Preprocessing caption and tag information
|
184 |
+
|
185 |
+
Combine captions and tags into a single file as metadata for easy processing from scripts.
|
186 |
+
|
187 |
+
### Caption preprocessing
|
188 |
+
|
189 |
+
To put captions into the metadata, run the following in your working folder (if you don't use captions for learning, you don't need to run this) (it's actually a single line, and so on).
|
190 |
+
|
191 |
+
```
|
192 |
+
python merge_captions_to_metadata.py <teacher data folder>
|
193 |
+
--in_json <metadata file name to read>
|
194 |
+
<metadata file name>
|
195 |
+
```
|
196 |
+
|
197 |
+
The metadata file name is an arbitrary name.
|
198 |
+
If the training data is train_data, there is no metadata file to read, and the metadata file is meta_cap.json, it will be as follows.
|
199 |
+
|
200 |
+
```
|
201 |
+
python merge_captions_to_metadata.py train_data meta_cap.json
|
202 |
+
```
|
203 |
+
|
204 |
+
You can specify the caption extension with the caption_extension option.
|
205 |
+
|
206 |
+
If there are multiple teacher data folders, please specify the full_path argument (metadata will have full path information). Then run it for each folder.
|
207 |
+
|
208 |
+
```
|
209 |
+
python merge_captions_to_metadata.py --full_path
|
210 |
+
train_data1 meta_cap1.json
|
211 |
+
python merge_captions_to_metadata.py --full_path --in_json meta_cap1.json
|
212 |
+
train_data2 meta_cap2.json
|
213 |
+
```
|
214 |
+
|
215 |
+
If in_json is omitted, if there is a write destination metadata file, it will be read from there and overwritten there.
|
216 |
+
|
217 |
+
__*It is safe to rewrite the in_json option and the write destination each time and write to a separate metadata file. __
|
218 |
+
|
219 |
+
### Tag preprocessing
|
220 |
+
|
221 |
+
Similarly, tags are also collected in metadata (no need to do this if tags are not used for learning).
|
222 |
+
```
|
223 |
+
python merge_dd_tags_to_metadata.py <teacher data folder>
|
224 |
+
--in_json <metadata file name to load>
|
225 |
+
<metadata file name to write>
|
226 |
+
```
|
227 |
+
|
228 |
+
With the same directory structure as above, when reading meta_cap.json and writing to meta_cap_dd.json, it will be as follows.
|
229 |
+
```
|
230 |
+
python merge_dd_tags_to_metadata.py train_data --in_json meta_cap.json meta_cap_dd.json
|
231 |
+
```
|
232 |
+
|
233 |
+
If you have multiple teacher data folders, please specify the full_path argument. Then run it for each folder.
|
234 |
+
|
235 |
+
```
|
236 |
+
python merge_dd_tags_to_metadata.py --full_path --in_json meta_cap2.json
|
237 |
+
train_data1 meta_cap_dd1.json
|
238 |
+
python merge_dd_tags_to_metadata.py --full_path --in_json meta_cap_dd1.json
|
239 |
+
train_data2 meta_cap_dd2.json
|
240 |
+
```
|
241 |
+
|
242 |
+
If in_json is omitted, if there is a write destination metadata file, it will be read from there and overwritten there.
|
243 |
+
|
244 |
+
__*It is safe to rewrite the in_json option and the write destination each time and write to a separate metadata file. __
|
245 |
+
|
246 |
+
### Cleaning captions and tags
|
247 |
+
Up to this point, captions and DeepDanbooru tags have been put together in the metadata file. However, captions with automatic captioning are subtle due to spelling variations (*), and tags include underscores and ratings (in the case of DeepDanbooru), so the editor's replacement function etc. You should use it to clean your captions and tags.
|
248 |
+
|
249 |
+
*For example, when learning a girl in an anime picture, there are variations in captions such as girl/girls/woman/women. Also, it may be more appropriate to simply use "girl" for things like "anime girl".
|
250 |
+
|
251 |
+
A script for cleaning is provided, so please edit the contents of the script according to the situation and use it.
|
252 |
+
|
253 |
+
(It is no longer necessary to specify the teacher data folder. All data in the metadata will be cleaned.)
|
254 |
+
|
255 |
+
```
|
256 |
+
python clean_captions_and_tags.py <metadata file name to read> <metadata file name to write>
|
257 |
+
```
|
258 |
+
|
259 |
+
Please note that --in_json is not included. For example:
|
260 |
+
|
261 |
+
```
|
262 |
+
python clean_captions_and_tags.py meta_cap_dd.json meta_clean.json
|
263 |
+
```
|
264 |
+
|
265 |
+
Preprocessing of captions and tags is now complete.
|
266 |
+
|
267 |
+
## Get latents in advance
|
268 |
+
|
269 |
+
In order to speed up the learning, we acquire the latent representation of the image in advance and save it to disk. At the same time, bucketing (classifying the training data according to the aspect ratio) is performed.
|
270 |
+
|
271 |
+
In your working folder, type:
|
272 |
+
```
|
273 |
+
python prepare_buckets_latents.py <teacher data folder>
|
274 |
+
<metadata file name to read> <metadata file name to write>
|
275 |
+
<model name or checkpoint for fine tuning>
|
276 |
+
--batch_size <batch size>
|
277 |
+
--max_resolution <resolution width, height>
|
278 |
+
--mixed_precision <precision>
|
279 |
+
```
|
280 |
+
|
281 |
+
If the model is model.ckpt, batch size 4, training resolution is 512\*512, precision is no (float32), read metadata from meta_clean.json and write to meta_lat.json:
|
282 |
+
|
283 |
+
```
|
284 |
+
python prepare_buckets_latents.py
|
285 |
+
train_data meta_clean.json meta_lat.json model.ckpt
|
286 |
+
--batch_size 4 --max_resolution 512,512 --mixed_precision no
|
287 |
+
```
|
288 |
+
|
289 |
+
Latents are saved in numpy npz format in the teacher data folder.
|
290 |
+
|
291 |
+
Specify the --v2 option when loading a Stable Diffusion 2.0 model (--v_parameterization is not required).
|
292 |
+
|
293 |
+
You can specify the minimum resolution size with the --min_bucket_reso option and the maximum size with the --max_bucket_reso option. The defaults are 256 and 1024 respectively. For example, specifying a minimum size of 384 will not use resolutions such as 256\*1024 or 320\*768.
|
294 |
+
If you increase the resolution to something like 768\*768, you should specify something like 1280 for the maximum size.
|
295 |
+
|
296 |
+
If you specify the --flip_aug option, it will perform horizontal flip augmentation (data augmentation). You can artificially double the amount of data, but if you specify it when the data is not left-right symmetrical (for example, character appearance, hairstyle, etc.), learning will not go well.
|
297 |
+
(This is a simple implementation that acquires the latents for the flipped image and saves the \*\_flip.npz file. No options are required for fline_tune.py. If there is a file with \_flip, Randomly load a file without
|
298 |
+
|
299 |
+
The batch size may be increased a little more even with 12GB of VRAM.
|
300 |
+
The resolution is a number divisible by 64, and is specified by "width, height". The resolution is directly linked to the memory size during fine tuning. 512,512 seems to be the limit with VRAM 12GB (*). 16GB may be raised to 512,704 or 512,768. Even with 256, 256, etc., it seems to be difficult with 8GB of VRAM (because parameters and optimizers require a certain amount of memory regardless of resolution).
|
301 |
+
|
302 |
+
*There was also a report that learning batch size 1 worked with 12GB VRAM and 640,640.
|
303 |
+
|
304 |
+
The result of bucketing is displayed as follows.
|
305 |
+
|
306 |
+
![bucketing result](https://user-images.githubusercontent.com/52813779/208911419-71c00fbb-2ce6-49d5-89b5-b78d7715e441.png)
|
307 |
+
|
308 |
+
If you have multiple teacher data folders, please specify the full_path argument. Then run it for each folder.
|
309 |
+
```
|
310 |
+
python prepare_buckets_latents.py --full_path
|
311 |
+
train_data1 meta_clean.json meta_lat1.json model.ckpt
|
312 |
+
--batch_size 4 --max_resolution 512,512 --mixed_precision no
|
313 |
+
|
314 |
+
python prepare_buckets_latents.py --full_path
|
315 |
+
train_data2 meta_lat1.json meta_lat2.json model.ckpt
|
316 |
+
--batch_size 4 --max_resolution 512,512 --mixed_precision no
|
317 |
+
|
318 |
+
```
|
319 |
+
It is possible to make the read source and write destination the same, but separate is safer.
|
320 |
+
|
321 |
+
__*It is safe to rewrite the argument each time and write it to a separate metadata file. __
|
322 |
+
|
323 |
+
|
324 |
+
## Run training
|
325 |
+
For example: Below are the settings for saving memory.
|
326 |
+
```
|
327 |
+
accelerate launch --num_cpu_threads_per_process 8 fine_tune.py
|
328 |
+
--pretrained_model_name_or_path=model.ckpt
|
329 |
+
--in_json meta_lat.json
|
330 |
+
--train_data_dir=train_data
|
331 |
+
--output_dir=fine_tuned
|
332 |
+
--shuffle_caption
|
333 |
+
--train_batch_size=1 --learning_rate=5e-6 --max_train_steps=10000
|
334 |
+
--use_8bit_adam --xformers --gradient_checkpointing
|
335 |
+
--mixed_precision=bf16
|
336 |
+
--save_every_n_epochs=4
|
337 |
+
```
|
338 |
+
|
339 |
+
It seems to be good to specify the number of CPU cores for num_cpu_threads_per_process of accelerate.
|
340 |
+
|
341 |
+
Specify the model to be trained in pretrained_model_name_or_path (Stable Diffusion checkpoint or Diffusers model). Stable Diffusion checkpoint supports .ckpt and .safetensors (automatically determined by extension).
|
342 |
+
|
343 |
+
Specifies the metadata file when caching latent to in_json.
|
344 |
+
|
345 |
+
Specify the training data folder for train_data_dir and the output destination folder for the trained model for output_dir.
|
346 |
+
|
347 |
+
If shuffle_caption is specified, captions and tags are shuffled and learned in units separated by commas (this is the method used in Waifu Diffusion v1.3).
|
348 |
+
(You can keep some of the leading tokens fixed without shuffling. See keep_tokens for other options.)
|
349 |
+
|
350 |
+
Specify the batch size in train_batch_size. Specify 1 or 2 for VRAM 12GB. The number that can be specified also changes depending on the resolution.
|
351 |
+
The actual amount of data used for training is "batch size x number of steps". When increasing the batch size, the number of steps can be decreased accordingly.
|
352 |
+
|
353 |
+
Specify the learning rate in learning_rate. For example Waifu Diffusion v1.3 seems to be 5e-6.
|
354 |
+
Specify the number of steps in max_train_steps.
|
355 |
+
|
356 |
+
Specify use_8bit_adam to use the 8-bit Adam Optimizer. It saves memory and speeds up, but accuracy may decrease.
|
357 |
+
|
358 |
+
Specifying xformers replaces CrossAttention to save memory and speed up.
|
359 |
+
* As of 11/9, xformers will cause an error in float32 learning, so please use bf16/fp16 or use memory-saving CrossAttention with mem_eff_attn instead (speed is inferior to xformers).
|
360 |
+
|
361 |
+
Enable intermediate saving of gradients in gradient_checkpointing. It's slower, but uses less memory.
|
362 |
+
|
363 |
+
Specifies whether to use mixed precision with mixed_precision. Specifying "fp16" or "bf16" saves memory, but accuracy is inferior.
|
364 |
+
"fp16" and "bf16" use almost the same amount of memory, and it is said that bf16 has better learning results (I didn't feel much difference in the range I tried).
|
365 |
+
If "no" is specified, it will not be used (it will be float32).
|
366 |
+
|
367 |
+
* It seems that an error will occur when reading checkpoints learned with bf16 with Mr. AUTOMATIC1111's Web UI. This seems to be because the data type bfloat16 causes an error in the Web UI model safety checker. Save in fp16 or float32 format with the save_precision option. Or it seems to be good to store it in safetytensors format.
|
368 |
+
|
369 |
+
Specifying save_every_n_epochs will save the model being trained every time that many epochs have passed.
|
370 |
+
|
371 |
+
### Supports Stable Diffusion 2.0
|
372 |
+
Specify the --v2 option when using Hugging Face's stable-diffusion-2-base, and specify both --v2 and --v_parameterization options when using stable-diffusion-2 or 768-v-ema.ckpt please.
|
373 |
+
|
374 |
+
### Increase accuracy and speed when memory is available
|
375 |
+
First, removing gradient_checkpointing will speed it up. However, the batch size that can be set is reduced, so please set while looking at the balance between accuracy and speed.
|
376 |
+
|
377 |
+
Increasing the batch size increases speed and accuracy. Increase the speed while checking the speed per data within the range where the memory is sufficient (the speed may actually decrease when the memory is at the limit).
|
378 |
+
|
379 |
+
### Change CLIP output used
|
380 |
+
Specifying 2 for the clip_skip option uses the output of the next-to-last layer. If 1 or option is omitted, the last layer is used.
|
381 |
+
The learned model should be able to be inferred by Automatic1111's web UI.
|
382 |
+
|
383 |
+
*SD2.0 uses the second layer from the back by default, so please do not specify it when learning SD2.0.
|
384 |
+
|
385 |
+
If the model being trained was originally trained to use the second layer, 2 is a good value.
|
386 |
+
|
387 |
+
If you were using the last layer instead, the entire model would have been trained on that assumption. Therefore, if you train again using the second layer, you may need a certain number of teacher data and longer learning to obtain the desired learning result.
|
388 |
+
|
389 |
+
### Extending Token Length
|
390 |
+
You can learn by extending the token length by specifying 150 or 225 for max_token_length.
|
391 |
+
The learned model should be able to be inferred by Automatic1111's web UI.
|
392 |
+
|
393 |
+
As with clip_skip, learning with a length different from the learning state of the model may require a certain amount of teacher data and a longer learning time.
|
394 |
+
|
395 |
+
### Save learning log
|
396 |
+
Specify the log save destination folder in the logging_dir option. Logs in TensorBoard format are saved.
|
397 |
+
|
398 |
+
For example, if you specify --logging_dir=logs, a logs folder will be created in your working folder, and logs will be saved in the date/time folder.
|
399 |
+
Also, if you specify the --log_prefix option, the specified string will be added before the date and time. Use "--logging_dir=logs --log_prefix=fine_tune_style1" for identification.
|
400 |
+
|
401 |
+
To check the log with TensorBoard, open another command prompt and enter the following in the working folder (I think tensorboard is installed when Diffusers is installed, but if it is not installed, pip install Please put it in tensorboard).
|
402 |
+
```
|
403 |
+
tensorboard --logdir=logs
|
404 |
+
```
|
405 |
+
|
406 |
+
### Learning Hypernetworks
|
407 |
+
It will be explained in another article.
|
408 |
+
|
409 |
+
### Learning with fp16 gradient (experimental feature)
|
410 |
+
The full_fp16 option will change the gradient from normal float32 to float16 (fp16) and learn (it seems to be full fp16 learning instead of mixed precision). As a result, it seems that the SD1.x 512*512 size can be learned with a VRAM usage of less than 8GB, and the SD2.x 512*512 size can be learned with a VRAM usage of less than 12GB.
|
411 |
+
|
412 |
+
Specify fp16 in advance in accelerate config and optionally set mixed_precision="fp16" (does not work with bf16).
|
413 |
+
|
414 |
+
To minimize memory usage, use the xformers, use_8bit_adam, gradient_checkpointing options and set train_batch_size to 1.
|
415 |
+
(If you can afford it, increasing the train_batch_size step by step should improve the accuracy a little.)
|
416 |
+
|
417 |
+
It is realized by patching the PyTorch source (confirmed with PyTorch 1.12.1 and 1.13.0). The accuracy will drop considerably, and the probability of learning failure on the way will also increase. The setting of the learning rate and the number of steps seems to be severe. Please be aware of them and use them at your own risk.
|
418 |
+
|
419 |
+
### Other Options
|
420 |
+
|
421 |
+
#### keep_tokens
|
422 |
+
If a number is specified, the specified number of tokens (comma-separated strings) from the beginning of the caption are fixed without being shuffled.
|
423 |
+
|
424 |
+
If there are both captions and tags, the prompts during learning will be concatenated like "caption, tag 1, tag 2...", so if you set "--keep_tokens=1", the caption will always be at the beginning during learning. will come.
|
425 |
+
|
426 |
+
#### dataset_repeats
|
427 |
+
If the number of data sets is extremely small, the epoch will end soon (it will take some time at the epoch break), so please specify a numerical value and multiply the data by some to make the epoch longer.
|
428 |
+
|
429 |
+
#### train_text_encoder
|
430 |
+
Text Encoder is also a learning target. Slightly increased memory usage.
|
431 |
+
|
432 |
+
In normal fine tuning, the Text Encoder is not targeted for training (probably because U-Net is trained to follow the output of the Text Encoder), but if the number of training data is small, the Text Encoder is trained like DreamBooth. also seems to be valid.
|
433 |
+
|
434 |
+
#### save_precision
|
435 |
+
The data format when saving checkpoints can be specified from float, fp16, and bf16 (if not specified, it is the same as the data format during learning). It saves disk space, but the model produces different results. Also, if you specify float or fp16, you should be able to read it on Mr. 1111's Web UI.
|
436 |
+
|
437 |
+
*For VAE, the data format of the original checkpoint will remain, so the model size may not be reduced to a little over 2GB even with fp16.
|
438 |
+
|
439 |
+
#### save_model_as
|
440 |
+
Specify the save format of the model. Specify one of ckpt, safetensors, diffusers, diffusers_safetensors.
|
441 |
+
|
442 |
+
When reading Stable Diffusion format (ckpt or safetensors) and saving in Diffusers format, missing information is supplemented by dropping v1.5 or v2.1 information from Hugging Face.
|
443 |
+
|
444 |
+
#### use_safetensors
|
445 |
+
This option saves checkpoints in safetyensors format. The save format will be the default (same format as loaded).
|
446 |
+
|
447 |
+
#### save_state and resume
|
448 |
+
The save_state option saves the learning state of the optimizer, etc. in addition to the checkpoint in the folder when saving midway and at the final save. This avoids a decrease in accuracy when learning is resumed after being interrupted (since the optimizer optimizes while having a state, if the state is reset, the optimization must be performed again from the initial state. not). Note that the number of steps is not saved due to Accelerate specifications.
|
449 |
+
|
450 |
+
When starting the script, you can resume by specifying the folder where the state is saved with the resume option.
|
451 |
+
|
452 |
+
Please note that the learning state will be about 5 GB per save, so please be careful of the disk capacity.
|
453 |
+
|
454 |
+
#### gradient_accumulation_steps
|
455 |
+
Updates the gradient in batches for the specified number of steps. Has a similar effect to increasing the batch size, but consumes slightly more memory.
|
456 |
+
|
457 |
+
*The Accelerate specification does not support multiple learning models, so if you set Text Encoder as the learning target and specify a value of 2 or more for this option, an error may occur.
|
458 |
+
|
459 |
+
#### lr_scheduler / lr_warmup_steps
|
460 |
+
You can choose the learning rate scheduler from linear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup with the lr_scheduler option. Default is constant.
|
461 |
+
|
462 |
+
With lr_warmup_steps, you can specify the number of steps to warm up the scheduler (gradually changing the learning rate). Please do your own research for details.
|
463 |
+
|
464 |
+
#### diffusers_xformers
|
465 |
+
Uses Diffusers' xformers feature rather than the script's own xformers replacement feature. Hypernetwork learning is no longer possible.
|
fine_tune_README_ja.md
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
NovelAIの提案した学習手法、自動キャプションニング、タグ付け、Windows+VRAM 12GB(SD v1.xの場合)環境等に対応したfine tuningです。ここでfine tuningとは、モデルを画像とキャプションで学習することを指します(LoRAやTextual Inversion、Hypernetworksは含みません)
|
2 |
+
|
3 |
+
[学習についての共通ドキュメント](./train_README-ja.md) もあわせてご覧ください。
|
4 |
+
|
5 |
+
# 概要
|
6 |
+
|
7 |
+
Diffusersを用いてStable DiffusionのU-Netのfine tuningを行います。NovelAIの記事にある以下の改善に対応しています(Aspect Ratio BucketingについてはNovelAIのコードを参考にしましたが、最終的なコードはすべてオリジナルです)。
|
8 |
+
|
9 |
+
* CLIP(Text Encoder)の最後の層ではなく最後から二番目の層の出力を用いる。
|
10 |
+
* 正方形以外の解像度での学習(Aspect Ratio Bucketing) 。
|
11 |
+
* トークン長を75から225に拡張する。
|
12 |
+
* BLIPによるキャプショニング(キャプションの自動作成)、DeepDanbooruまたはWD14Taggerによる自動タグ付けを行う。
|
13 |
+
* Hypernetworkの学習にも対応する。
|
14 |
+
* Stable Diffusion v2.0(baseおよび768/v)に対応。
|
15 |
+
* VAEの出力をあらかじめ取得しディスクに保存しておくことで、学習の省メモリ化、高速化を図る。
|
16 |
+
|
17 |
+
デフォルトではText Encoderの学習は行いません。モデル全体のfine tuningではU-Netだけを学習するのが一般的なようです(NovelAIもそのようです)。オプション指定でText Encoderも学習対象とできます。
|
18 |
+
|
19 |
+
# 追加機能について
|
20 |
+
|
21 |
+
## CLIPの出力の変更
|
22 |
+
|
23 |
+
プロンプトを画像に反映するため、テキストの特徴量への変換を行うのがCLIP(Text Encoder)です。Stable DiffusionではCLIPの最後の層の出力を用いていますが、それを最後から二番目の層の出力を用いるよう変更できます。NovelAIによると、これによりより正確にプロンプトが反映されるようになるとのことです。
|
24 |
+
元のまま、最後の層の出力を用いることも可能です。
|
25 |
+
|
26 |
+
※Stable Diffusion 2.0では最後から二番目の層をデフォルトで使います。clip_skipオプションを指定しないでください。
|
27 |
+
|
28 |
+
## 正方形以外の解像度での学習
|
29 |
+
|
30 |
+
Stable Diffusionは512\*512で学習されていますが、それに加えて256\*1024や384\*640といった解像度でも学習します。これによりトリミングされる部分が減り、より正しくプロンプトと画像の関係が学習されることが期待されます。
|
31 |
+
学習解像度はパラメータとして与えられた解像度の面積(=メモリ使用量)を超えない範囲で、64ピクセル単位で縦横に調整、作成されます。
|
32 |
+
|
33 |
+
機械学習では入力サイズをすべて統一するのが一般的ですが、特に制約があるわけではなく、実際は同一のバッチ内で統一されていれば大丈夫です。NovelAIの言うbucketingは、あらかじめ教師データを、アスペクト比に応じた学習解像度ごとに分類しておくことを指しているようです。そしてバッチを各bucket内の画像で作成することで、バッチの画像サイズを統一します。
|
34 |
+
|
35 |
+
## トークン長の75から225への拡張
|
36 |
+
|
37 |
+
Stable Diffusionでは最大75トークン(開始・終了を含むと77トークン)ですが、それを225トークンまで拡張します。
|
38 |
+
ただしCLIPが受け付ける最大長は75トークンですので、225トークンの場合、単純に三分割してCLIPを呼び出してから結果を連結しています。
|
39 |
+
|
40 |
+
※これが望ましい実装なのかどうかはいまひとつわかりません。とりあえず動いてはいるようです。特に2.0では何も参考になる実装がないので独自に実装してあります。
|
41 |
+
|
42 |
+
※Automatic1111氏のWeb UIではカンマを意識して分割、といったこともしているようですが、私の場合はそこまでしておらず単純な分割です。
|
43 |
+
|
44 |
+
# 学習の手順
|
45 |
+
|
46 |
+
あらかじめこのリポジトリのREADMEを参照し、環境整備を行ってください。
|
47 |
+
|
48 |
+
## データの準備
|
49 |
+
|
50 |
+
[学習データの準備について](./train_README-ja.md) を参照してください。fine tuningではメタデータを用いるfine tuning方式のみ対応しています。
|
51 |
+
|
52 |
+
## 学習の実行
|
53 |
+
たとえば以下のように実行します。以下は省メモリ化のための設定です。それぞれの行を必要に応じて書き換えてください。
|
54 |
+
|
55 |
+
```
|
56 |
+
accelerate launch --num_cpu_threads_per_process 1 fine_tune.py
|
57 |
+
--pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ>
|
58 |
+
--output_dir=<学習したモデルの出力先フォルダ>
|
59 |
+
--output_name=<学習したモデル出力時のファイル名>
|
60 |
+
--dataset_config=<データ準備で作成した.tomlファイル>
|
61 |
+
--save_model_as=safetensors
|
62 |
+
--learning_rate=5e-6 --max_train_steps=10000
|
63 |
+
--use_8bit_adam --xformers --gradient_checkpointing
|
64 |
+
--mixed_precision=fp16
|
65 |
+
```
|
66 |
+
|
67 |
+
`num_cpu_threads_per_process` には通常は1を指定するとよいようです。
|
68 |
+
|
69 |
+
`pretrained_model_name_or_path` に追加学習を行う元となるモデルを指定します。Stable Diffusionのcheckpointファイル(.ckptまたは.safetensors)、Diffusersのローカルディスクにあるモデルディレクトリ、DiffusersのモデルID("stabilityai/stable-diffusion-2"など)が指定できます。
|
70 |
+
|
71 |
+
`output_dir` に学習後のモデルを保存するフォルダを指定します。`output_name` にモデルのファイル名を拡張子を除いて指定します。`save_model_as` でsafetensors形式での保存を指定しています。
|
72 |
+
|
73 |
+
`dataset_config` に `.toml` ファイルを指定します。ファイル内でのバッチサイズ指定は、当初はメモリ消費を抑えるために `1` としてください。
|
74 |
+
|
75 |
+
学習させるステップ数 `max_train_steps` を10000とします。学習率 `learning_rate` はここでは5e-6を指定しています。
|
76 |
+
|
77 |
+
省メモリ化のため `mixed_precision="fp16"` を指定します(RTX30 シリーズ以降では `bf16` も指定できます。環境整備時にaccelerateに行った設定と合わせてください)。また `gradient_checkpointing` を指定します。
|
78 |
+
|
79 |
+
オプティマイザ(モデルを学習データにあうように最適化=学習させるクラス)にメモリ消費の少ない 8bit AdamW を使うため、 `optimizer_type="AdamW8bit"` を指定します。
|
80 |
+
|
81 |
+
`xformers` オプションを指定し、xformersのCrossAttentionを用います。xformersをインストールしていない場合やエラーとなる場合(環境にもよりますが `mixed_precision="no"` の場合など)、代わりに `mem_eff_attn` オプションを指定すると省メモリ版CrossAttentionを使用します(速度は遅くなります)。
|
82 |
+
|
83 |
+
ある程度メモリがある場合は、`.toml` ファイルを編集してバッチサイズをたとえば `4` くらいに増やしてください(高速化と精度向上の可能性があります)。
|
84 |
+
|
85 |
+
### よく使われるオプションについて
|
86 |
+
|
87 |
+
以下の場合にはオプションに関するドキュメントを参照してください。
|
88 |
+
|
89 |
+
- Stable Diffusion 2.xまたはそこからの派生モデルを学習する
|
90 |
+
- clip skipを2以上を前提としたモデルを学習する
|
91 |
+
- 75トークンを超えたキャプションで学習する
|
92 |
+
|
93 |
+
### バッチサイズについて
|
94 |
+
|
95 |
+
モデル全体を学習するためLoRA等の学習に比べるとメモリ消費量は多くなります(DreamBoothと同じ)。
|
96 |
+
|
97 |
+
### 学習率について
|
98 |
+
|
99 |
+
1e-6から5e-6程度が一般的なようです。他のfine tuningの例なども参照してみてください。
|
100 |
+
|
101 |
+
### 以前の形式のデータセット指定をした場合のコマンドライン
|
102 |
+
|
103 |
+
解像度やバッチサイズをオプションで指定します。コマンドラインの例は以下の通りです。
|
104 |
+
|
105 |
+
```
|
106 |
+
accelerate launch --num_cpu_threads_per_process 1 fine_tune.py
|
107 |
+
--pretrained_model_name_or_path=model.ckpt
|
108 |
+
--in_json meta_lat.json
|
109 |
+
--train_data_dir=train_data
|
110 |
+
--output_dir=fine_tuned
|
111 |
+
--shuffle_caption
|
112 |
+
--train_batch_size=1 --learning_rate=5e-6 --max_train_steps=10000
|
113 |
+
--use_8bit_adam --xformers --gradient_checkpointing
|
114 |
+
--mixed_precision=bf16
|
115 |
+
--save_every_n_epochs=4
|
116 |
+
```
|
117 |
+
|
118 |
+
<!--
|
119 |
+
### 勾配をfp16とした学習(実験的機能)
|
120 |
+
full_fp16オプションを指定すると勾配を通常のfloat32からfloat16(fp16)に変更して学習します(mixed precisionではなく完全なfp16学習になるようです)。これによりSD1.xの512*512サイズでは8GB未満、SD2.xの512*512サイズで12GB未満のVRAM使用量で学習できるようです。
|
121 |
+
|
122 |
+
あらかじめaccelerate configでfp16を指定し、オプションでmixed_precision="fp16"としてください(bf16では動作しません)。
|
123 |
+
|
124 |
+
メモリ使用量を最小化するためには、xformers、use_8bit_adam、gradient_checkpointingの各オプションを指定し、train_batch_sizeを1としてください。
|
125 |
+
(余裕があるようならtrain_batch_sizeを段階的に増やすと若干精度が上がるはずです。)
|
126 |
+
|
127 |
+
PyTorchのソースにパッチを当てて無理やり実現しています(PyTorch 1.12.1と1.13.0で確認)。精度はかなり落ちますし、途中で学習失敗する確率も高くなります。学習率やステップ数の設定もシビアなようです。それらを認識したうえで自己責任でお使いください。
|
128 |
+
-->
|
129 |
+
|
130 |
+
# fine tuning特有のその他の主なオプション
|
131 |
+
|
132 |
+
すべてのオプションについては別文書を参照してください。
|
133 |
+
|
134 |
+
## `train_text_encoder`
|
135 |
+
Text Encoderも学習対象とします。メモリ使用量が若干増加します。
|
136 |
+
|
137 |
+
通常のfine tuningではText Encoderは学習対象としませんが(恐らくText Encoderの出力に従うようにU-Netを学習するため)、学習データ数が少ない場合には、DreamBoothのようにText Encoder側に学習させるのも有効的なようです。
|
138 |
+
|
139 |
+
## `diffusers_xformers`
|
140 |
+
スクリプト独自のxformers置換機能ではなくDiffusersのxformers機能を利用します。Hypernetworkの学習はできなくなります。
|
finetune_gui.py
ADDED
@@ -0,0 +1,900 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import json
|
3 |
+
import math
|
4 |
+
import os
|
5 |
+
import subprocess
|
6 |
+
import pathlib
|
7 |
+
import argparse
|
8 |
+
from library.common_gui import (
|
9 |
+
get_folder_path,
|
10 |
+
get_file_path,
|
11 |
+
get_saveasfile_path,
|
12 |
+
save_inference_file,
|
13 |
+
gradio_advanced_training,
|
14 |
+
run_cmd_advanced_training,
|
15 |
+
gradio_training,
|
16 |
+
run_cmd_advanced_training,
|
17 |
+
gradio_config,
|
18 |
+
gradio_source_model,
|
19 |
+
color_aug_changed,
|
20 |
+
run_cmd_training,
|
21 |
+
# set_legacy_8bitadam,
|
22 |
+
update_my_data,
|
23 |
+
check_if_model_exist,
|
24 |
+
)
|
25 |
+
from library.tensorboard_gui import (
|
26 |
+
gradio_tensorboard,
|
27 |
+
start_tensorboard,
|
28 |
+
stop_tensorboard,
|
29 |
+
)
|
30 |
+
from library.utilities import utilities_tab
|
31 |
+
from library.sampler_gui import sample_gradio_config, run_cmd_sample
|
32 |
+
from easygui import msgbox
|
33 |
+
|
34 |
+
folder_symbol = '\U0001f4c2' # 📂
|
35 |
+
refresh_symbol = '\U0001f504' # 🔄
|
36 |
+
save_style_symbol = '\U0001f4be' # 💾
|
37 |
+
document_symbol = '\U0001F4C4' # 📄
|
38 |
+
|
39 |
+
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
|
40 |
+
|
41 |
+
|
42 |
+
def save_configuration(
|
43 |
+
save_as,
|
44 |
+
file_path,
|
45 |
+
pretrained_model_name_or_path,
|
46 |
+
v2,
|
47 |
+
v_parameterization,
|
48 |
+
train_dir,
|
49 |
+
image_folder,
|
50 |
+
output_dir,
|
51 |
+
logging_dir,
|
52 |
+
max_resolution,
|
53 |
+
min_bucket_reso,
|
54 |
+
max_bucket_reso,
|
55 |
+
batch_size,
|
56 |
+
flip_aug,
|
57 |
+
caption_metadata_filename,
|
58 |
+
latent_metadata_filename,
|
59 |
+
full_path,
|
60 |
+
learning_rate,
|
61 |
+
lr_scheduler,
|
62 |
+
lr_warmup,
|
63 |
+
dataset_repeats,
|
64 |
+
train_batch_size,
|
65 |
+
epoch,
|
66 |
+
save_every_n_epochs,
|
67 |
+
mixed_precision,
|
68 |
+
save_precision,
|
69 |
+
seed,
|
70 |
+
num_cpu_threads_per_process,
|
71 |
+
train_text_encoder,
|
72 |
+
create_caption,
|
73 |
+
create_buckets,
|
74 |
+
save_model_as,
|
75 |
+
caption_extension,
|
76 |
+
# use_8bit_adam,
|
77 |
+
xformers,
|
78 |
+
clip_skip,
|
79 |
+
save_state,
|
80 |
+
resume,
|
81 |
+
gradient_checkpointing,
|
82 |
+
gradient_accumulation_steps,
|
83 |
+
mem_eff_attn,
|
84 |
+
shuffle_caption,
|
85 |
+
output_name,
|
86 |
+
max_token_length,
|
87 |
+
max_train_epochs,
|
88 |
+
max_data_loader_n_workers,
|
89 |
+
full_fp16,
|
90 |
+
color_aug,
|
91 |
+
model_list,
|
92 |
+
cache_latents,
|
93 |
+
use_latent_files,
|
94 |
+
keep_tokens,
|
95 |
+
persistent_data_loader_workers,
|
96 |
+
bucket_no_upscale,
|
97 |
+
random_crop,
|
98 |
+
bucket_reso_steps,
|
99 |
+
caption_dropout_every_n_epochs,
|
100 |
+
caption_dropout_rate,
|
101 |
+
optimizer,
|
102 |
+
optimizer_args,
|
103 |
+
noise_offset,
|
104 |
+
sample_every_n_steps,
|
105 |
+
sample_every_n_epochs,
|
106 |
+
sample_sampler,
|
107 |
+
sample_prompts,
|
108 |
+
additional_parameters,
|
109 |
+
vae_batch_size,
|
110 |
+
min_snr_gamma,weighted_captions,
|
111 |
+
):
|
112 |
+
# Get list of function parameters and values
|
113 |
+
parameters = list(locals().items())
|
114 |
+
|
115 |
+
original_file_path = file_path
|
116 |
+
|
117 |
+
save_as_bool = True if save_as.get('label') == 'True' else False
|
118 |
+
|
119 |
+
if save_as_bool:
|
120 |
+
print('Save as...')
|
121 |
+
file_path = get_saveasfile_path(file_path)
|
122 |
+
else:
|
123 |
+
print('Save...')
|
124 |
+
if file_path == None or file_path == '':
|
125 |
+
file_path = get_saveasfile_path(file_path)
|
126 |
+
|
127 |
+
# print(file_path)
|
128 |
+
|
129 |
+
if file_path == None or file_path == '':
|
130 |
+
return original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
131 |
+
|
132 |
+
# Return the values of the variables as a dictionary
|
133 |
+
variables = {
|
134 |
+
name: value
|
135 |
+
for name, value in parameters # locals().items()
|
136 |
+
if name
|
137 |
+
not in [
|
138 |
+
'file_path',
|
139 |
+
'save_as',
|
140 |
+
]
|
141 |
+
}
|
142 |
+
|
143 |
+
# Extract the destination directory from the file path
|
144 |
+
destination_directory = os.path.dirname(file_path)
|
145 |
+
|
146 |
+
# Create the destination directory if it doesn't exist
|
147 |
+
if not os.path.exists(destination_directory):
|
148 |
+
os.makedirs(destination_directory)
|
149 |
+
|
150 |
+
# Save the data to the selected file
|
151 |
+
with open(file_path, 'w') as file:
|
152 |
+
json.dump(variables, file, indent=2)
|
153 |
+
|
154 |
+
return file_path
|
155 |
+
|
156 |
+
|
157 |
+
def open_configuration(
|
158 |
+
ask_for_file,
|
159 |
+
file_path,
|
160 |
+
pretrained_model_name_or_path,
|
161 |
+
v2,
|
162 |
+
v_parameterization,
|
163 |
+
train_dir,
|
164 |
+
image_folder,
|
165 |
+
output_dir,
|
166 |
+
logging_dir,
|
167 |
+
max_resolution,
|
168 |
+
min_bucket_reso,
|
169 |
+
max_bucket_reso,
|
170 |
+
batch_size,
|
171 |
+
flip_aug,
|
172 |
+
caption_metadata_filename,
|
173 |
+
latent_metadata_filename,
|
174 |
+
full_path,
|
175 |
+
learning_rate,
|
176 |
+
lr_scheduler,
|
177 |
+
lr_warmup,
|
178 |
+
dataset_repeats,
|
179 |
+
train_batch_size,
|
180 |
+
epoch,
|
181 |
+
save_every_n_epochs,
|
182 |
+
mixed_precision,
|
183 |
+
save_precision,
|
184 |
+
seed,
|
185 |
+
num_cpu_threads_per_process,
|
186 |
+
train_text_encoder,
|
187 |
+
create_caption,
|
188 |
+
create_buckets,
|
189 |
+
save_model_as,
|
190 |
+
caption_extension,
|
191 |
+
# use_8bit_adam,
|
192 |
+
xformers,
|
193 |
+
clip_skip,
|
194 |
+
save_state,
|
195 |
+
resume,
|
196 |
+
gradient_checkpointing,
|
197 |
+
gradient_accumulation_steps,
|
198 |
+
mem_eff_attn,
|
199 |
+
shuffle_caption,
|
200 |
+
output_name,
|
201 |
+
max_token_length,
|
202 |
+
max_train_epochs,
|
203 |
+
max_data_loader_n_workers,
|
204 |
+
full_fp16,
|
205 |
+
color_aug,
|
206 |
+
model_list,
|
207 |
+
cache_latents,
|
208 |
+
use_latent_files,
|
209 |
+
keep_tokens,
|
210 |
+
persistent_data_loader_workers,
|
211 |
+
bucket_no_upscale,
|
212 |
+
random_crop,
|
213 |
+
bucket_reso_steps,
|
214 |
+
caption_dropout_every_n_epochs,
|
215 |
+
caption_dropout_rate,
|
216 |
+
optimizer,
|
217 |
+
optimizer_args,
|
218 |
+
noise_offset,
|
219 |
+
sample_every_n_steps,
|
220 |
+
sample_every_n_epochs,
|
221 |
+
sample_sampler,
|
222 |
+
sample_prompts,
|
223 |
+
additional_parameters,
|
224 |
+
vae_batch_size,
|
225 |
+
min_snr_gamma,weighted_captions,
|
226 |
+
):
|
227 |
+
# Get list of function parameters and values
|
228 |
+
parameters = list(locals().items())
|
229 |
+
|
230 |
+
ask_for_file = True if ask_for_file.get('label') == 'True' else False
|
231 |
+
|
232 |
+
original_file_path = file_path
|
233 |
+
|
234 |
+
if ask_for_file:
|
235 |
+
file_path = get_file_path(file_path)
|
236 |
+
|
237 |
+
if not file_path == '' and not file_path == None:
|
238 |
+
# load variables from JSON file
|
239 |
+
with open(file_path, 'r') as f:
|
240 |
+
my_data = json.load(f)
|
241 |
+
print('Loading config...')
|
242 |
+
# Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True
|
243 |
+
my_data = update_my_data(my_data)
|
244 |
+
else:
|
245 |
+
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
246 |
+
my_data = {}
|
247 |
+
|
248 |
+
values = [file_path]
|
249 |
+
for key, value in parameters:
|
250 |
+
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
|
251 |
+
if not key in ['ask_for_file', 'file_path']:
|
252 |
+
values.append(my_data.get(key, value))
|
253 |
+
return tuple(values)
|
254 |
+
|
255 |
+
|
256 |
+
def train_model(
|
257 |
+
pretrained_model_name_or_path,
|
258 |
+
v2,
|
259 |
+
v_parameterization,
|
260 |
+
train_dir,
|
261 |
+
image_folder,
|
262 |
+
output_dir,
|
263 |
+
logging_dir,
|
264 |
+
max_resolution,
|
265 |
+
min_bucket_reso,
|
266 |
+
max_bucket_reso,
|
267 |
+
batch_size,
|
268 |
+
flip_aug,
|
269 |
+
caption_metadata_filename,
|
270 |
+
latent_metadata_filename,
|
271 |
+
full_path,
|
272 |
+
learning_rate,
|
273 |
+
lr_scheduler,
|
274 |
+
lr_warmup,
|
275 |
+
dataset_repeats,
|
276 |
+
train_batch_size,
|
277 |
+
epoch,
|
278 |
+
save_every_n_epochs,
|
279 |
+
mixed_precision,
|
280 |
+
save_precision,
|
281 |
+
seed,
|
282 |
+
num_cpu_threads_per_process,
|
283 |
+
train_text_encoder,
|
284 |
+
generate_caption_database,
|
285 |
+
generate_image_buckets,
|
286 |
+
save_model_as,
|
287 |
+
caption_extension,
|
288 |
+
# use_8bit_adam,
|
289 |
+
xformers,
|
290 |
+
clip_skip,
|
291 |
+
save_state,
|
292 |
+
resume,
|
293 |
+
gradient_checkpointing,
|
294 |
+
gradient_accumulation_steps,
|
295 |
+
mem_eff_attn,
|
296 |
+
shuffle_caption,
|
297 |
+
output_name,
|
298 |
+
max_token_length,
|
299 |
+
max_train_epochs,
|
300 |
+
max_data_loader_n_workers,
|
301 |
+
full_fp16,
|
302 |
+
color_aug,
|
303 |
+
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
304 |
+
cache_latents,
|
305 |
+
use_latent_files,
|
306 |
+
keep_tokens,
|
307 |
+
persistent_data_loader_workers,
|
308 |
+
bucket_no_upscale,
|
309 |
+
random_crop,
|
310 |
+
bucket_reso_steps,
|
311 |
+
caption_dropout_every_n_epochs,
|
312 |
+
caption_dropout_rate,
|
313 |
+
optimizer,
|
314 |
+
optimizer_args,
|
315 |
+
noise_offset,
|
316 |
+
sample_every_n_steps,
|
317 |
+
sample_every_n_epochs,
|
318 |
+
sample_sampler,
|
319 |
+
sample_prompts,
|
320 |
+
additional_parameters,
|
321 |
+
vae_batch_size,
|
322 |
+
min_snr_gamma,weighted_captions,
|
323 |
+
):
|
324 |
+
if check_if_model_exist(output_name, output_dir, save_model_as):
|
325 |
+
return
|
326 |
+
|
327 |
+
if optimizer == 'Adafactor' and lr_warmup != '0':
|
328 |
+
msgbox("Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.", title="Warning")
|
329 |
+
lr_warmup = '0'
|
330 |
+
|
331 |
+
# create caption json file
|
332 |
+
if generate_caption_database:
|
333 |
+
if not os.path.exists(train_dir):
|
334 |
+
os.mkdir(train_dir)
|
335 |
+
|
336 |
+
run_cmd = f'{PYTHON} finetune/merge_captions_to_metadata.py'
|
337 |
+
if caption_extension == '':
|
338 |
+
run_cmd += f' --caption_extension=".caption"'
|
339 |
+
else:
|
340 |
+
run_cmd += f' --caption_extension={caption_extension}'
|
341 |
+
run_cmd += f' "{image_folder}"'
|
342 |
+
run_cmd += f' "{train_dir}/{caption_metadata_filename}"'
|
343 |
+
if full_path:
|
344 |
+
run_cmd += f' --full_path'
|
345 |
+
|
346 |
+
print(run_cmd)
|
347 |
+
|
348 |
+
# Run the command
|
349 |
+
if os.name == 'posix':
|
350 |
+
os.system(run_cmd)
|
351 |
+
else:
|
352 |
+
subprocess.run(run_cmd)
|
353 |
+
|
354 |
+
# create images buckets
|
355 |
+
if generate_image_buckets:
|
356 |
+
run_cmd = f'{PYTHON} finetune/prepare_buckets_latents.py'
|
357 |
+
run_cmd += f' "{image_folder}"'
|
358 |
+
run_cmd += f' "{train_dir}/{caption_metadata_filename}"'
|
359 |
+
run_cmd += f' "{train_dir}/{latent_metadata_filename}"'
|
360 |
+
run_cmd += f' "{pretrained_model_name_or_path}"'
|
361 |
+
run_cmd += f' --batch_size={batch_size}'
|
362 |
+
run_cmd += f' --max_resolution={max_resolution}'
|
363 |
+
run_cmd += f' --min_bucket_reso={min_bucket_reso}'
|
364 |
+
run_cmd += f' --max_bucket_reso={max_bucket_reso}'
|
365 |
+
run_cmd += f' --mixed_precision={mixed_precision}'
|
366 |
+
# if flip_aug:
|
367 |
+
# run_cmd += f' --flip_aug'
|
368 |
+
if full_path:
|
369 |
+
run_cmd += f' --full_path'
|
370 |
+
|
371 |
+
print(run_cmd)
|
372 |
+
|
373 |
+
# Run the command
|
374 |
+
if os.name == 'posix':
|
375 |
+
os.system(run_cmd)
|
376 |
+
else:
|
377 |
+
subprocess.run(run_cmd)
|
378 |
+
|
379 |
+
image_num = len(
|
380 |
+
[
|
381 |
+
f
|
382 |
+
for f, lower_f in (
|
383 |
+
(file, file.lower()) for file in os.listdir(image_folder)
|
384 |
+
)
|
385 |
+
if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp'))
|
386 |
+
]
|
387 |
+
)
|
388 |
+
print(f'image_num = {image_num}')
|
389 |
+
|
390 |
+
repeats = int(image_num) * int(dataset_repeats)
|
391 |
+
print(f'repeats = {str(repeats)}')
|
392 |
+
|
393 |
+
# calculate max_train_steps
|
394 |
+
max_train_steps = int(
|
395 |
+
math.ceil(float(repeats) / int(train_batch_size) * int(epoch))
|
396 |
+
)
|
397 |
+
|
398 |
+
# Divide by two because flip augmentation create two copied of the source images
|
399 |
+
if flip_aug:
|
400 |
+
max_train_steps = int(math.ceil(float(max_train_steps) / 2))
|
401 |
+
|
402 |
+
print(f'max_train_steps = {max_train_steps}')
|
403 |
+
|
404 |
+
lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
|
405 |
+
print(f'lr_warmup_steps = {lr_warmup_steps}')
|
406 |
+
|
407 |
+
run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "./fine_tune.py"'
|
408 |
+
if v2:
|
409 |
+
run_cmd += ' --v2'
|
410 |
+
if v_parameterization:
|
411 |
+
run_cmd += ' --v_parameterization'
|
412 |
+
if train_text_encoder:
|
413 |
+
run_cmd += ' --train_text_encoder'
|
414 |
+
if weighted_captions:
|
415 |
+
run_cmd += ' --weighted_captions'
|
416 |
+
run_cmd += (
|
417 |
+
f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"'
|
418 |
+
)
|
419 |
+
if use_latent_files == 'Yes':
|
420 |
+
run_cmd += f' --in_json="{train_dir}/{latent_metadata_filename}"'
|
421 |
+
else:
|
422 |
+
run_cmd += f' --in_json="{train_dir}/{caption_metadata_filename}"'
|
423 |
+
run_cmd += f' --train_data_dir="{image_folder}"'
|
424 |
+
run_cmd += f' --output_dir="{output_dir}"'
|
425 |
+
if not logging_dir == '':
|
426 |
+
run_cmd += f' --logging_dir="{logging_dir}"'
|
427 |
+
run_cmd += f' --dataset_repeats={dataset_repeats}'
|
428 |
+
run_cmd += f' --learning_rate={learning_rate}'
|
429 |
+
|
430 |
+
run_cmd += ' --enable_bucket'
|
431 |
+
run_cmd += f' --resolution={max_resolution}'
|
432 |
+
run_cmd += f' --min_bucket_reso={min_bucket_reso}'
|
433 |
+
run_cmd += f' --max_bucket_reso={max_bucket_reso}'
|
434 |
+
|
435 |
+
if not save_model_as == 'same as source model':
|
436 |
+
run_cmd += f' --save_model_as={save_model_as}'
|
437 |
+
if int(gradient_accumulation_steps) > 1:
|
438 |
+
run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}'
|
439 |
+
# if save_state:
|
440 |
+
# run_cmd += ' --save_state'
|
441 |
+
# if not resume == '':
|
442 |
+
# run_cmd += f' --resume={resume}'
|
443 |
+
if not output_name == '':
|
444 |
+
run_cmd += f' --output_name="{output_name}"'
|
445 |
+
if int(max_token_length) > 75:
|
446 |
+
run_cmd += f' --max_token_length={max_token_length}'
|
447 |
+
|
448 |
+
run_cmd += run_cmd_training(
|
449 |
+
learning_rate=learning_rate,
|
450 |
+
lr_scheduler=lr_scheduler,
|
451 |
+
lr_warmup_steps=lr_warmup_steps,
|
452 |
+
train_batch_size=train_batch_size,
|
453 |
+
max_train_steps=max_train_steps,
|
454 |
+
save_every_n_epochs=save_every_n_epochs,
|
455 |
+
mixed_precision=mixed_precision,
|
456 |
+
save_precision=save_precision,
|
457 |
+
seed=seed,
|
458 |
+
caption_extension=caption_extension,
|
459 |
+
cache_latents=cache_latents,
|
460 |
+
optimizer=optimizer,
|
461 |
+
optimizer_args=optimizer_args,
|
462 |
+
)
|
463 |
+
|
464 |
+
run_cmd += run_cmd_advanced_training(
|
465 |
+
max_train_epochs=max_train_epochs,
|
466 |
+
max_data_loader_n_workers=max_data_loader_n_workers,
|
467 |
+
max_token_length=max_token_length,
|
468 |
+
resume=resume,
|
469 |
+
save_state=save_state,
|
470 |
+
mem_eff_attn=mem_eff_attn,
|
471 |
+
clip_skip=clip_skip,
|
472 |
+
flip_aug=flip_aug,
|
473 |
+
color_aug=color_aug,
|
474 |
+
shuffle_caption=shuffle_caption,
|
475 |
+
gradient_checkpointing=gradient_checkpointing,
|
476 |
+
full_fp16=full_fp16,
|
477 |
+
xformers=xformers,
|
478 |
+
# use_8bit_adam=use_8bit_adam,
|
479 |
+
keep_tokens=keep_tokens,
|
480 |
+
persistent_data_loader_workers=persistent_data_loader_workers,
|
481 |
+
bucket_no_upscale=bucket_no_upscale,
|
482 |
+
random_crop=random_crop,
|
483 |
+
bucket_reso_steps=bucket_reso_steps,
|
484 |
+
caption_dropout_every_n_epochs=caption_dropout_every_n_epochs,
|
485 |
+
caption_dropout_rate=caption_dropout_rate,
|
486 |
+
noise_offset=noise_offset,
|
487 |
+
additional_parameters=additional_parameters,
|
488 |
+
vae_batch_size=vae_batch_size,
|
489 |
+
min_snr_gamma=min_snr_gamma,
|
490 |
+
)
|
491 |
+
|
492 |
+
run_cmd += run_cmd_sample(
|
493 |
+
sample_every_n_steps,
|
494 |
+
sample_every_n_epochs,
|
495 |
+
sample_sampler,
|
496 |
+
sample_prompts,
|
497 |
+
output_dir,
|
498 |
+
)
|
499 |
+
|
500 |
+
print(run_cmd)
|
501 |
+
|
502 |
+
# Run the command
|
503 |
+
if os.name == 'posix':
|
504 |
+
os.system(run_cmd)
|
505 |
+
else:
|
506 |
+
subprocess.run(run_cmd)
|
507 |
+
|
508 |
+
# check if output_dir/last is a folder... therefore it is a diffuser model
|
509 |
+
last_dir = pathlib.Path(f'{output_dir}/{output_name}')
|
510 |
+
|
511 |
+
if not last_dir.is_dir():
|
512 |
+
# Copy inference model for v2 if required
|
513 |
+
save_inference_file(output_dir, v2, v_parameterization, output_name)
|
514 |
+
|
515 |
+
|
516 |
+
def remove_doublequote(file_path):
|
517 |
+
if file_path != None:
|
518 |
+
file_path = file_path.replace('"', '')
|
519 |
+
|
520 |
+
return file_path
|
521 |
+
|
522 |
+
|
523 |
+
def finetune_tab():
|
524 |
+
dummy_db_true = gr.Label(value=True, visible=False)
|
525 |
+
dummy_db_false = gr.Label(value=False, visible=False)
|
526 |
+
gr.Markdown('Train a custom model using kohya finetune python code...')
|
527 |
+
|
528 |
+
(
|
529 |
+
button_open_config,
|
530 |
+
button_save_config,
|
531 |
+
button_save_as_config,
|
532 |
+
config_file_name,
|
533 |
+
button_load_config,
|
534 |
+
) = gradio_config()
|
535 |
+
|
536 |
+
(
|
537 |
+
pretrained_model_name_or_path,
|
538 |
+
v2,
|
539 |
+
v_parameterization,
|
540 |
+
save_model_as,
|
541 |
+
model_list,
|
542 |
+
) = gradio_source_model()
|
543 |
+
|
544 |
+
with gr.Tab('Folders'):
|
545 |
+
with gr.Row():
|
546 |
+
train_dir = gr.Textbox(
|
547 |
+
label='Training config folder',
|
548 |
+
placeholder='folder where the training configuration files will be saved',
|
549 |
+
)
|
550 |
+
train_dir_folder = gr.Button(
|
551 |
+
folder_symbol, elem_id='open_folder_small'
|
552 |
+
)
|
553 |
+
train_dir_folder.click(
|
554 |
+
get_folder_path,
|
555 |
+
outputs=train_dir,
|
556 |
+
show_progress=False,
|
557 |
+
)
|
558 |
+
|
559 |
+
image_folder = gr.Textbox(
|
560 |
+
label='Training Image folder',
|
561 |
+
placeholder='folder where the training images are located',
|
562 |
+
)
|
563 |
+
image_folder_input_folder = gr.Button(
|
564 |
+
folder_symbol, elem_id='open_folder_small'
|
565 |
+
)
|
566 |
+
image_folder_input_folder.click(
|
567 |
+
get_folder_path,
|
568 |
+
outputs=image_folder,
|
569 |
+
show_progress=False,
|
570 |
+
)
|
571 |
+
with gr.Row():
|
572 |
+
output_dir = gr.Textbox(
|
573 |
+
label='Model output folder',
|
574 |
+
placeholder='folder where the model will be saved',
|
575 |
+
)
|
576 |
+
output_dir_input_folder = gr.Button(
|
577 |
+
folder_symbol, elem_id='open_folder_small'
|
578 |
+
)
|
579 |
+
output_dir_input_folder.click(
|
580 |
+
get_folder_path,
|
581 |
+
outputs=output_dir,
|
582 |
+
show_progress=False,
|
583 |
+
)
|
584 |
+
|
585 |
+
logging_dir = gr.Textbox(
|
586 |
+
label='Logging folder',
|
587 |
+
placeholder='Optional: enable logging and output TensorBoard log to this folder',
|
588 |
+
)
|
589 |
+
logging_dir_input_folder = gr.Button(
|
590 |
+
folder_symbol, elem_id='open_folder_small'
|
591 |
+
)
|
592 |
+
logging_dir_input_folder.click(
|
593 |
+
get_folder_path,
|
594 |
+
outputs=logging_dir,
|
595 |
+
show_progress=False,
|
596 |
+
)
|
597 |
+
with gr.Row():
|
598 |
+
output_name = gr.Textbox(
|
599 |
+
label='Model output name',
|
600 |
+
placeholder='Name of the model to output',
|
601 |
+
value='last',
|
602 |
+
interactive=True,
|
603 |
+
)
|
604 |
+
train_dir.change(
|
605 |
+
remove_doublequote,
|
606 |
+
inputs=[train_dir],
|
607 |
+
outputs=[train_dir],
|
608 |
+
)
|
609 |
+
image_folder.change(
|
610 |
+
remove_doublequote,
|
611 |
+
inputs=[image_folder],
|
612 |
+
outputs=[image_folder],
|
613 |
+
)
|
614 |
+
output_dir.change(
|
615 |
+
remove_doublequote,
|
616 |
+
inputs=[output_dir],
|
617 |
+
outputs=[output_dir],
|
618 |
+
)
|
619 |
+
with gr.Tab('Dataset preparation'):
|
620 |
+
with gr.Row():
|
621 |
+
max_resolution = gr.Textbox(
|
622 |
+
label='Resolution (width,height)', value='512,512'
|
623 |
+
)
|
624 |
+
min_bucket_reso = gr.Textbox(
|
625 |
+
label='Min bucket resolution', value='256'
|
626 |
+
)
|
627 |
+
max_bucket_reso = gr.Textbox(
|
628 |
+
label='Max bucket resolution', value='1024'
|
629 |
+
)
|
630 |
+
batch_size = gr.Textbox(label='Batch size', value='1')
|
631 |
+
with gr.Row():
|
632 |
+
create_caption = gr.Checkbox(
|
633 |
+
label='Generate caption metadata', value=True
|
634 |
+
)
|
635 |
+
create_buckets = gr.Checkbox(
|
636 |
+
label='Generate image buckets metadata', value=True
|
637 |
+
)
|
638 |
+
use_latent_files = gr.Dropdown(
|
639 |
+
label='Use latent files',
|
640 |
+
choices=[
|
641 |
+
'No',
|
642 |
+
'Yes',
|
643 |
+
],
|
644 |
+
value='Yes',
|
645 |
+
)
|
646 |
+
with gr.Accordion('Advanced parameters', open=False):
|
647 |
+
with gr.Row():
|
648 |
+
caption_metadata_filename = gr.Textbox(
|
649 |
+
label='Caption metadata filename', value='meta_cap.json'
|
650 |
+
)
|
651 |
+
latent_metadata_filename = gr.Textbox(
|
652 |
+
label='Latent metadata filename', value='meta_lat.json'
|
653 |
+
)
|
654 |
+
with gr.Row():
|
655 |
+
full_path = gr.Checkbox(label='Use full path', value=True)
|
656 |
+
weighted_captions = gr.Checkbox(
|
657 |
+
label='Weighted captions', value=False
|
658 |
+
)
|
659 |
+
with gr.Tab('Training parameters'):
|
660 |
+
(
|
661 |
+
learning_rate,
|
662 |
+
lr_scheduler,
|
663 |
+
lr_warmup,
|
664 |
+
train_batch_size,
|
665 |
+
epoch,
|
666 |
+
save_every_n_epochs,
|
667 |
+
mixed_precision,
|
668 |
+
save_precision,
|
669 |
+
num_cpu_threads_per_process,
|
670 |
+
seed,
|
671 |
+
caption_extension,
|
672 |
+
cache_latents,
|
673 |
+
optimizer,
|
674 |
+
optimizer_args,
|
675 |
+
) = gradio_training(learning_rate_value='1e-5')
|
676 |
+
with gr.Row():
|
677 |
+
dataset_repeats = gr.Textbox(label='Dataset repeats', value=40)
|
678 |
+
train_text_encoder = gr.Checkbox(
|
679 |
+
label='Train text encoder', value=True
|
680 |
+
)
|
681 |
+
with gr.Accordion('Advanced parameters', open=False):
|
682 |
+
with gr.Row():
|
683 |
+
gradient_accumulation_steps = gr.Number(
|
684 |
+
label='Gradient accumulate steps', value='1'
|
685 |
+
)
|
686 |
+
(
|
687 |
+
# use_8bit_adam,
|
688 |
+
xformers,
|
689 |
+
full_fp16,
|
690 |
+
gradient_checkpointing,
|
691 |
+
shuffle_caption,
|
692 |
+
color_aug,
|
693 |
+
flip_aug,
|
694 |
+
clip_skip,
|
695 |
+
mem_eff_attn,
|
696 |
+
save_state,
|
697 |
+
resume,
|
698 |
+
max_token_length,
|
699 |
+
max_train_epochs,
|
700 |
+
max_data_loader_n_workers,
|
701 |
+
keep_tokens,
|
702 |
+
persistent_data_loader_workers,
|
703 |
+
bucket_no_upscale,
|
704 |
+
random_crop,
|
705 |
+
bucket_reso_steps,
|
706 |
+
caption_dropout_every_n_epochs,
|
707 |
+
caption_dropout_rate,
|
708 |
+
noise_offset,
|
709 |
+
additional_parameters,
|
710 |
+
vae_batch_size,
|
711 |
+
min_snr_gamma,
|
712 |
+
) = gradio_advanced_training()
|
713 |
+
color_aug.change(
|
714 |
+
color_aug_changed,
|
715 |
+
inputs=[color_aug],
|
716 |
+
outputs=[cache_latents], # Not applicable to fine_tune.py
|
717 |
+
)
|
718 |
+
|
719 |
+
(
|
720 |
+
sample_every_n_steps,
|
721 |
+
sample_every_n_epochs,
|
722 |
+
sample_sampler,
|
723 |
+
sample_prompts,
|
724 |
+
) = sample_gradio_config()
|
725 |
+
|
726 |
+
button_run = gr.Button('Train model', variant='primary')
|
727 |
+
|
728 |
+
# Setup gradio tensorboard buttons
|
729 |
+
button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard()
|
730 |
+
|
731 |
+
button_start_tensorboard.click(
|
732 |
+
start_tensorboard,
|
733 |
+
inputs=logging_dir,
|
734 |
+
)
|
735 |
+
|
736 |
+
button_stop_tensorboard.click(
|
737 |
+
stop_tensorboard,
|
738 |
+
show_progress=False,
|
739 |
+
)
|
740 |
+
|
741 |
+
settings_list = [
|
742 |
+
pretrained_model_name_or_path,
|
743 |
+
v2,
|
744 |
+
v_parameterization,
|
745 |
+
train_dir,
|
746 |
+
image_folder,
|
747 |
+
output_dir,
|
748 |
+
logging_dir,
|
749 |
+
max_resolution,
|
750 |
+
min_bucket_reso,
|
751 |
+
max_bucket_reso,
|
752 |
+
batch_size,
|
753 |
+
flip_aug,
|
754 |
+
caption_metadata_filename,
|
755 |
+
latent_metadata_filename,
|
756 |
+
full_path,
|
757 |
+
learning_rate,
|
758 |
+
lr_scheduler,
|
759 |
+
lr_warmup,
|
760 |
+
dataset_repeats,
|
761 |
+
train_batch_size,
|
762 |
+
epoch,
|
763 |
+
save_every_n_epochs,
|
764 |
+
mixed_precision,
|
765 |
+
save_precision,
|
766 |
+
seed,
|
767 |
+
num_cpu_threads_per_process,
|
768 |
+
train_text_encoder,
|
769 |
+
create_caption,
|
770 |
+
create_buckets,
|
771 |
+
save_model_as,
|
772 |
+
caption_extension,
|
773 |
+
# use_8bit_adam,
|
774 |
+
xformers,
|
775 |
+
clip_skip,
|
776 |
+
save_state,
|
777 |
+
resume,
|
778 |
+
gradient_checkpointing,
|
779 |
+
gradient_accumulation_steps,
|
780 |
+
mem_eff_attn,
|
781 |
+
shuffle_caption,
|
782 |
+
output_name,
|
783 |
+
max_token_length,
|
784 |
+
max_train_epochs,
|
785 |
+
max_data_loader_n_workers,
|
786 |
+
full_fp16,
|
787 |
+
color_aug,
|
788 |
+
model_list,
|
789 |
+
cache_latents,
|
790 |
+
use_latent_files,
|
791 |
+
keep_tokens,
|
792 |
+
persistent_data_loader_workers,
|
793 |
+
bucket_no_upscale,
|
794 |
+
random_crop,
|
795 |
+
bucket_reso_steps,
|
796 |
+
caption_dropout_every_n_epochs,
|
797 |
+
caption_dropout_rate,
|
798 |
+
optimizer,
|
799 |
+
optimizer_args,
|
800 |
+
noise_offset,
|
801 |
+
sample_every_n_steps,
|
802 |
+
sample_every_n_epochs,
|
803 |
+
sample_sampler,
|
804 |
+
sample_prompts,
|
805 |
+
additional_parameters,
|
806 |
+
vae_batch_size,
|
807 |
+
min_snr_gamma,
|
808 |
+
weighted_captions,
|
809 |
+
]
|
810 |
+
|
811 |
+
button_run.click(train_model, inputs=settings_list)
|
812 |
+
|
813 |
+
button_open_config.click(
|
814 |
+
open_configuration,
|
815 |
+
inputs=[dummy_db_true, config_file_name] + settings_list,
|
816 |
+
outputs=[config_file_name] + settings_list,
|
817 |
+
show_progress=False,
|
818 |
+
)
|
819 |
+
|
820 |
+
button_load_config.click(
|
821 |
+
open_configuration,
|
822 |
+
inputs=[dummy_db_false, config_file_name] + settings_list,
|
823 |
+
outputs=[config_file_name] + settings_list,
|
824 |
+
show_progress=False,
|
825 |
+
)
|
826 |
+
|
827 |
+
button_save_config.click(
|
828 |
+
save_configuration,
|
829 |
+
inputs=[dummy_db_false, config_file_name] + settings_list,
|
830 |
+
outputs=[config_file_name],
|
831 |
+
show_progress=False,
|
832 |
+
)
|
833 |
+
|
834 |
+
button_save_as_config.click(
|
835 |
+
save_configuration,
|
836 |
+
inputs=[dummy_db_true, config_file_name] + settings_list,
|
837 |
+
outputs=[config_file_name],
|
838 |
+
show_progress=False,
|
839 |
+
)
|
840 |
+
|
841 |
+
|
842 |
+
def UI(**kwargs):
|
843 |
+
|
844 |
+
css = ''
|
845 |
+
|
846 |
+
if os.path.exists('./style.css'):
|
847 |
+
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
|
848 |
+
print('Load CSS...')
|
849 |
+
css += file.read() + '\n'
|
850 |
+
|
851 |
+
interface = gr.Blocks(css=css)
|
852 |
+
|
853 |
+
with interface:
|
854 |
+
with gr.Tab('Finetune'):
|
855 |
+
finetune_tab()
|
856 |
+
with gr.Tab('Utilities'):
|
857 |
+
utilities_tab(enable_dreambooth_tab=False)
|
858 |
+
|
859 |
+
# Show the interface
|
860 |
+
launch_kwargs = {}
|
861 |
+
if not kwargs.get('username', None) == '':
|
862 |
+
launch_kwargs['auth'] = (
|
863 |
+
kwargs.get('username', None),
|
864 |
+
kwargs.get('password', None),
|
865 |
+
)
|
866 |
+
if kwargs.get('server_port', 0) > 0:
|
867 |
+
launch_kwargs['server_port'] = kwargs.get('server_port', 0)
|
868 |
+
if kwargs.get('inbrowser', False):
|
869 |
+
launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False)
|
870 |
+
print(launch_kwargs)
|
871 |
+
interface.launch(**launch_kwargs)
|
872 |
+
|
873 |
+
|
874 |
+
if __name__ == '__main__':
|
875 |
+
# torch.cuda.set_per_process_memory_fraction(0.48)
|
876 |
+
parser = argparse.ArgumentParser()
|
877 |
+
parser.add_argument(
|
878 |
+
'--username', type=str, default='', help='Username for authentication'
|
879 |
+
)
|
880 |
+
parser.add_argument(
|
881 |
+
'--password', type=str, default='', help='Password for authentication'
|
882 |
+
)
|
883 |
+
parser.add_argument(
|
884 |
+
'--server_port',
|
885 |
+
type=int,
|
886 |
+
default=0,
|
887 |
+
help='Port to run the server listener on',
|
888 |
+
)
|
889 |
+
parser.add_argument(
|
890 |
+
'--inbrowser', action='store_true', help='Open in browser'
|
891 |
+
)
|
892 |
+
|
893 |
+
args = parser.parse_args()
|
894 |
+
|
895 |
+
UI(
|
896 |
+
username=args.username,
|
897 |
+
password=args.password,
|
898 |
+
inbrowser=args.inbrowser,
|
899 |
+
server_port=args.server_port,
|
900 |
+
)
|
gen_img_diffusers.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
gui.bat
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@echo off
|
2 |
+
|
3 |
+
:: Activate the virtual environment
|
4 |
+
call .\venv\Scripts\activate.bat
|
5 |
+
set PATH=%PATH%;%~dp0venv\Lib\site-packages\torch\lib
|
6 |
+
|
7 |
+
:: Debug info about system
|
8 |
+
python.exe .\tools\debug_info.py
|
9 |
+
|
10 |
+
:: Validate the requirements and store the exit code
|
11 |
+
python.exe .\tools\validate_requirements.py
|
12 |
+
|
13 |
+
:: If the exit code is 0, run the kohya_gui.py script with the command-line arguments
|
14 |
+
if %errorlevel% equ 0 (
|
15 |
+
python.exe kohya_gui.py %*
|
16 |
+
)
|
gui.ps1
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Activate the virtual environment
|
2 |
+
& .\venv\Scripts\activate
|
3 |
+
$env:PATH += ";$($MyInvocation.MyCommand.Path)\venv\Lib\site-packages\torch\lib"
|
4 |
+
|
5 |
+
# Debug info about system
|
6 |
+
python.exe .\tools\debug_info.py
|
7 |
+
|
8 |
+
# Validate the requirements and store the exit code
|
9 |
+
python.exe .\tools\validate_requirements.py
|
10 |
+
|
11 |
+
# If the exit code is 0, read arguments from gui_parameters.txt (if it exists)
|
12 |
+
# and run the kohya_gui.py script with the command-line arguments
|
13 |
+
if ($LASTEXITCODE -eq 0) {
|
14 |
+
$argsFromFile = @()
|
15 |
+
if (Test-Path .\gui_parameters.txt) {
|
16 |
+
$argsFromFile = Get-Content .\gui_parameters.txt -Encoding UTF8 | Where-Object { $_ -notmatch "^#" } | Foreach-Object { $_ -split " " }
|
17 |
+
}
|
18 |
+
$args_combo = $argsFromFile + $args
|
19 |
+
Write-Host "The arguments passed to this script were: $args_combo"
|
20 |
+
python.exe kohya_gui.py $args_combo
|
21 |
+
}
|
gui.sh
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# This gets the directory the script is run from so pathing can work relative to the script where needed.
|
4 |
+
SCRIPT_DIR=$(cd -- "$(dirname -- "$0")" && pwd)
|
5 |
+
|
6 |
+
# Step into GUI local directory
|
7 |
+
cd "$SCRIPT_DIR"
|
8 |
+
|
9 |
+
# Activate the virtual environment
|
10 |
+
source "$SCRIPT_DIR/venv/bin/activate"
|
11 |
+
|
12 |
+
# If the requirements are validated, run the kohya_gui.py script with the command-line arguments
|
13 |
+
if python "$SCRIPT_DIR"/tools/validate_requirements.py -r "$SCRIPT_DIR"/requirements.txt; then
|
14 |
+
python "$SCRIPT_DIR/kohya_gui.py" "$@"
|
15 |
+
fi
|
kohya_gui.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import argparse
|
4 |
+
from dreambooth_gui import dreambooth_tab
|
5 |
+
from finetune_gui import finetune_tab
|
6 |
+
from textual_inversion_gui import ti_tab
|
7 |
+
from library.utilities import utilities_tab
|
8 |
+
from library.extract_lora_gui import gradio_extract_lora_tab
|
9 |
+
from library.extract_lycoris_locon_gui import gradio_extract_lycoris_locon_tab
|
10 |
+
from library.merge_lora_gui import gradio_merge_lora_tab
|
11 |
+
from library.resize_lora_gui import gradio_resize_lora_tab
|
12 |
+
from library.extract_lora_from_dylora_gui import gradio_extract_dylora_tab
|
13 |
+
from library.merge_lycoris_gui import gradio_merge_lycoris_tab
|
14 |
+
from lora_gui import lora_tab
|
15 |
+
|
16 |
+
|
17 |
+
def UI(**kwargs):
|
18 |
+
css = ''
|
19 |
+
|
20 |
+
if os.path.exists('./style.css'):
|
21 |
+
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
|
22 |
+
print('Load CSS...')
|
23 |
+
css += file.read() + '\n'
|
24 |
+
|
25 |
+
interface = gr.Blocks(css=css, title='Kohya_ss GUI', theme=gr.themes.Default())
|
26 |
+
|
27 |
+
with interface:
|
28 |
+
with gr.Tab('Dreambooth'):
|
29 |
+
(
|
30 |
+
train_data_dir_input,
|
31 |
+
reg_data_dir_input,
|
32 |
+
output_dir_input,
|
33 |
+
logging_dir_input,
|
34 |
+
) = dreambooth_tab()
|
35 |
+
with gr.Tab('Dreambooth LoRA'):
|
36 |
+
lora_tab()
|
37 |
+
with gr.Tab('Dreambooth TI'):
|
38 |
+
ti_tab()
|
39 |
+
with gr.Tab('Finetune'):
|
40 |
+
finetune_tab()
|
41 |
+
with gr.Tab('Utilities'):
|
42 |
+
utilities_tab(
|
43 |
+
train_data_dir_input=train_data_dir_input,
|
44 |
+
reg_data_dir_input=reg_data_dir_input,
|
45 |
+
output_dir_input=output_dir_input,
|
46 |
+
logging_dir_input=logging_dir_input,
|
47 |
+
enable_copy_info_button=True,
|
48 |
+
)
|
49 |
+
gradio_extract_dylora_tab()
|
50 |
+
gradio_extract_lora_tab()
|
51 |
+
gradio_extract_lycoris_locon_tab()
|
52 |
+
gradio_merge_lora_tab()
|
53 |
+
gradio_merge_lycoris_tab()
|
54 |
+
gradio_resize_lora_tab()
|
55 |
+
|
56 |
+
# Show the interface
|
57 |
+
launch_kwargs = {}
|
58 |
+
username = kwargs.get('username')
|
59 |
+
password = kwargs.get('password')
|
60 |
+
server_port = kwargs.get('server_port', 0)
|
61 |
+
inbrowser = kwargs.get('inbrowser', False)
|
62 |
+
share = kwargs.get('share', False)
|
63 |
+
server_name = kwargs.get('listen')
|
64 |
+
|
65 |
+
launch_kwargs['server_name'] = server_name
|
66 |
+
if username and password:
|
67 |
+
launch_kwargs['auth'] = (username, password)
|
68 |
+
if server_port > 0:
|
69 |
+
launch_kwargs['server_port'] = server_port
|
70 |
+
if inbrowser:
|
71 |
+
launch_kwargs['inbrowser'] = inbrowser
|
72 |
+
if share:
|
73 |
+
launch_kwargs['share'] = share
|
74 |
+
interface.launch(**launch_kwargs)
|
75 |
+
|
76 |
+
|
77 |
+
if __name__ == '__main__':
|
78 |
+
# torch.cuda.set_per_process_memory_fraction(0.48)
|
79 |
+
parser = argparse.ArgumentParser()
|
80 |
+
parser.add_argument(
|
81 |
+
'--listen',
|
82 |
+
type=str,
|
83 |
+
default='127.0.0.1',
|
84 |
+
help='IP to listen on for connections to Gradio',
|
85 |
+
)
|
86 |
+
parser.add_argument(
|
87 |
+
'--username', type=str, default='', help='Username for authentication'
|
88 |
+
)
|
89 |
+
parser.add_argument(
|
90 |
+
'--password', type=str, default='', help='Password for authentication'
|
91 |
+
)
|
92 |
+
parser.add_argument(
|
93 |
+
'--server_port',
|
94 |
+
type=int,
|
95 |
+
default=0,
|
96 |
+
help='Port to run the server listener on',
|
97 |
+
)
|
98 |
+
parser.add_argument(
|
99 |
+
'--inbrowser', action='store_true', help='Open in browser'
|
100 |
+
)
|
101 |
+
parser.add_argument(
|
102 |
+
'--share', action='store_true', help='Share the gradio UI'
|
103 |
+
)
|
104 |
+
|
105 |
+
args = parser.parse_args()
|
106 |
+
|
107 |
+
UI(
|
108 |
+
username=args.username,
|
109 |
+
password=args.password,
|
110 |
+
inbrowser=args.inbrowser,
|
111 |
+
server_port=args.server_port,
|
112 |
+
share=args.share,
|
113 |
+
listen=args.listen,
|
114 |
+
)
|
lora_gui.py
ADDED
@@ -0,0 +1,1294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# v1: initial release
|
2 |
+
# v2: add open and save folder icons
|
3 |
+
# v3: Add new Utilities tab for Dreambooth folder preparation
|
4 |
+
# v3.1: Adding captionning of images to utilities
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
import easygui
|
8 |
+
import json
|
9 |
+
import math
|
10 |
+
import os
|
11 |
+
import subprocess
|
12 |
+
import pathlib
|
13 |
+
import argparse
|
14 |
+
from library.common_gui import (
|
15 |
+
get_folder_path,
|
16 |
+
remove_doublequote,
|
17 |
+
get_file_path,
|
18 |
+
get_any_file_path,
|
19 |
+
get_saveasfile_path,
|
20 |
+
color_aug_changed,
|
21 |
+
save_inference_file,
|
22 |
+
gradio_advanced_training,
|
23 |
+
run_cmd_advanced_training,
|
24 |
+
gradio_training,
|
25 |
+
gradio_config,
|
26 |
+
gradio_source_model,
|
27 |
+
run_cmd_training,
|
28 |
+
# set_legacy_8bitadam,
|
29 |
+
update_my_data,
|
30 |
+
check_if_model_exist,
|
31 |
+
)
|
32 |
+
from library.dreambooth_folder_creation_gui import (
|
33 |
+
gradio_dreambooth_folder_creation_tab,
|
34 |
+
)
|
35 |
+
from library.tensorboard_gui import (
|
36 |
+
gradio_tensorboard,
|
37 |
+
start_tensorboard,
|
38 |
+
stop_tensorboard,
|
39 |
+
)
|
40 |
+
from library.dataset_balancing_gui import gradio_dataset_balancing_tab
|
41 |
+
from library.utilities import utilities_tab
|
42 |
+
from library.merge_lora_gui import gradio_merge_lora_tab
|
43 |
+
from library.svd_merge_lora_gui import gradio_svd_merge_lora_tab
|
44 |
+
from library.verify_lora_gui import gradio_verify_lora_tab
|
45 |
+
from library.resize_lora_gui import gradio_resize_lora_tab
|
46 |
+
from library.sampler_gui import sample_gradio_config, run_cmd_sample
|
47 |
+
from easygui import msgbox
|
48 |
+
|
49 |
+
folder_symbol = '\U0001f4c2' # 📂
|
50 |
+
refresh_symbol = '\U0001f504' # 🔄
|
51 |
+
save_style_symbol = '\U0001f4be' # 💾
|
52 |
+
document_symbol = '\U0001F4C4' # 📄
|
53 |
+
path_of_this_folder = os.getcwd()
|
54 |
+
|
55 |
+
|
56 |
+
def save_configuration(
|
57 |
+
save_as,
|
58 |
+
file_path,
|
59 |
+
pretrained_model_name_or_path,
|
60 |
+
v2,
|
61 |
+
v_parameterization,
|
62 |
+
logging_dir,
|
63 |
+
train_data_dir,
|
64 |
+
reg_data_dir,
|
65 |
+
output_dir,
|
66 |
+
max_resolution,
|
67 |
+
learning_rate,
|
68 |
+
lr_scheduler,
|
69 |
+
lr_warmup,
|
70 |
+
train_batch_size,
|
71 |
+
epoch,
|
72 |
+
save_every_n_epochs,
|
73 |
+
mixed_precision,
|
74 |
+
save_precision,
|
75 |
+
seed,
|
76 |
+
num_cpu_threads_per_process,
|
77 |
+
cache_latents,
|
78 |
+
caption_extension,
|
79 |
+
enable_bucket,
|
80 |
+
gradient_checkpointing,
|
81 |
+
full_fp16,
|
82 |
+
no_token_padding,
|
83 |
+
stop_text_encoder_training,
|
84 |
+
# use_8bit_adam,
|
85 |
+
xformers,
|
86 |
+
save_model_as,
|
87 |
+
shuffle_caption,
|
88 |
+
save_state,
|
89 |
+
resume,
|
90 |
+
prior_loss_weight,
|
91 |
+
text_encoder_lr,
|
92 |
+
unet_lr,
|
93 |
+
network_dim,
|
94 |
+
lora_network_weights,
|
95 |
+
color_aug,
|
96 |
+
flip_aug,
|
97 |
+
clip_skip,
|
98 |
+
gradient_accumulation_steps,
|
99 |
+
mem_eff_attn,
|
100 |
+
output_name,
|
101 |
+
model_list,
|
102 |
+
max_token_length,
|
103 |
+
max_train_epochs,
|
104 |
+
max_data_loader_n_workers,
|
105 |
+
network_alpha,
|
106 |
+
training_comment,
|
107 |
+
keep_tokens,
|
108 |
+
lr_scheduler_num_cycles,
|
109 |
+
lr_scheduler_power,
|
110 |
+
persistent_data_loader_workers,
|
111 |
+
bucket_no_upscale,
|
112 |
+
random_crop,
|
113 |
+
bucket_reso_steps,
|
114 |
+
caption_dropout_every_n_epochs,
|
115 |
+
caption_dropout_rate,
|
116 |
+
optimizer,
|
117 |
+
optimizer_args,
|
118 |
+
noise_offset,
|
119 |
+
LoRA_type,
|
120 |
+
conv_dim,
|
121 |
+
conv_alpha,
|
122 |
+
sample_every_n_steps,
|
123 |
+
sample_every_n_epochs,
|
124 |
+
sample_sampler,
|
125 |
+
sample_prompts,
|
126 |
+
additional_parameters,
|
127 |
+
vae_batch_size,
|
128 |
+
min_snr_gamma,
|
129 |
+
down_lr_weight,mid_lr_weight,up_lr_weight,block_lr_zero_threshold,block_dims,block_alphas,conv_dims,conv_alphas,
|
130 |
+
weighted_captions,unit,
|
131 |
+
):
|
132 |
+
# Get list of function parameters and values
|
133 |
+
parameters = list(locals().items())
|
134 |
+
|
135 |
+
original_file_path = file_path
|
136 |
+
|
137 |
+
save_as_bool = True if save_as.get('label') == 'True' else False
|
138 |
+
|
139 |
+
if save_as_bool:
|
140 |
+
print('Save as...')
|
141 |
+
file_path = get_saveasfile_path(file_path)
|
142 |
+
else:
|
143 |
+
print('Save...')
|
144 |
+
if file_path == None or file_path == '':
|
145 |
+
file_path = get_saveasfile_path(file_path)
|
146 |
+
|
147 |
+
# print(file_path)
|
148 |
+
|
149 |
+
if file_path == None or file_path == '':
|
150 |
+
return original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
151 |
+
|
152 |
+
# Return the values of the variables as a dictionary
|
153 |
+
variables = {
|
154 |
+
name: value
|
155 |
+
for name, value in parameters # locals().items()
|
156 |
+
if name
|
157 |
+
not in [
|
158 |
+
'file_path',
|
159 |
+
'save_as',
|
160 |
+
]
|
161 |
+
}
|
162 |
+
|
163 |
+
# Extract the destination directory from the file path
|
164 |
+
destination_directory = os.path.dirname(file_path)
|
165 |
+
|
166 |
+
# Create the destination directory if it doesn't exist
|
167 |
+
if not os.path.exists(destination_directory):
|
168 |
+
os.makedirs(destination_directory)
|
169 |
+
|
170 |
+
# Save the data to the selected file
|
171 |
+
with open(file_path, 'w') as file:
|
172 |
+
json.dump(variables, file, indent=2)
|
173 |
+
|
174 |
+
return file_path
|
175 |
+
|
176 |
+
|
177 |
+
def open_configuration(
|
178 |
+
ask_for_file,
|
179 |
+
file_path,
|
180 |
+
pretrained_model_name_or_path,
|
181 |
+
v2,
|
182 |
+
v_parameterization,
|
183 |
+
logging_dir,
|
184 |
+
train_data_dir,
|
185 |
+
reg_data_dir,
|
186 |
+
output_dir,
|
187 |
+
max_resolution,
|
188 |
+
learning_rate,
|
189 |
+
lr_scheduler,
|
190 |
+
lr_warmup,
|
191 |
+
train_batch_size,
|
192 |
+
epoch,
|
193 |
+
save_every_n_epochs,
|
194 |
+
mixed_precision,
|
195 |
+
save_precision,
|
196 |
+
seed,
|
197 |
+
num_cpu_threads_per_process,
|
198 |
+
cache_latents,
|
199 |
+
caption_extension,
|
200 |
+
enable_bucket,
|
201 |
+
gradient_checkpointing,
|
202 |
+
full_fp16,
|
203 |
+
no_token_padding,
|
204 |
+
stop_text_encoder_training,
|
205 |
+
# use_8bit_adam,
|
206 |
+
xformers,
|
207 |
+
save_model_as,
|
208 |
+
shuffle_caption,
|
209 |
+
save_state,
|
210 |
+
resume,
|
211 |
+
prior_loss_weight,
|
212 |
+
text_encoder_lr,
|
213 |
+
unet_lr,
|
214 |
+
network_dim,
|
215 |
+
lora_network_weights,
|
216 |
+
color_aug,
|
217 |
+
flip_aug,
|
218 |
+
clip_skip,
|
219 |
+
gradient_accumulation_steps,
|
220 |
+
mem_eff_attn,
|
221 |
+
output_name,
|
222 |
+
model_list,
|
223 |
+
max_token_length,
|
224 |
+
max_train_epochs,
|
225 |
+
max_data_loader_n_workers,
|
226 |
+
network_alpha,
|
227 |
+
training_comment,
|
228 |
+
keep_tokens,
|
229 |
+
lr_scheduler_num_cycles,
|
230 |
+
lr_scheduler_power,
|
231 |
+
persistent_data_loader_workers,
|
232 |
+
bucket_no_upscale,
|
233 |
+
random_crop,
|
234 |
+
bucket_reso_steps,
|
235 |
+
caption_dropout_every_n_epochs,
|
236 |
+
caption_dropout_rate,
|
237 |
+
optimizer,
|
238 |
+
optimizer_args,
|
239 |
+
noise_offset,
|
240 |
+
LoRA_type,
|
241 |
+
conv_dim,
|
242 |
+
conv_alpha,
|
243 |
+
sample_every_n_steps,
|
244 |
+
sample_every_n_epochs,
|
245 |
+
sample_sampler,
|
246 |
+
sample_prompts,
|
247 |
+
additional_parameters,
|
248 |
+
vae_batch_size,
|
249 |
+
min_snr_gamma,
|
250 |
+
down_lr_weight,mid_lr_weight,up_lr_weight,block_lr_zero_threshold,block_dims,block_alphas,conv_dims,conv_alphas,
|
251 |
+
weighted_captions,unit,
|
252 |
+
):
|
253 |
+
# Get list of function parameters and values
|
254 |
+
parameters = list(locals().items())
|
255 |
+
|
256 |
+
ask_for_file = True if ask_for_file.get('label') == 'True' else False
|
257 |
+
|
258 |
+
original_file_path = file_path
|
259 |
+
|
260 |
+
if ask_for_file:
|
261 |
+
file_path = get_file_path(file_path)
|
262 |
+
|
263 |
+
if not file_path == '' and not file_path == None:
|
264 |
+
# load variables from JSON file
|
265 |
+
with open(file_path, 'r') as f:
|
266 |
+
my_data = json.load(f)
|
267 |
+
print('Loading config...')
|
268 |
+
|
269 |
+
# Update values to fix deprecated use_8bit_adam checkbox, set appropriate optimizer if it is set to True, etc.
|
270 |
+
my_data = update_my_data(my_data)
|
271 |
+
else:
|
272 |
+
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
273 |
+
my_data = {}
|
274 |
+
|
275 |
+
values = [file_path]
|
276 |
+
for key, value in parameters:
|
277 |
+
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
|
278 |
+
if not key in ['ask_for_file', 'file_path']:
|
279 |
+
values.append(my_data.get(key, value))
|
280 |
+
|
281 |
+
# This next section is about making the LoCon parameters visible if LoRA_type = 'Standard'
|
282 |
+
if my_data.get('LoRA_type', 'Standard') == 'LoCon':
|
283 |
+
values.append(gr.Row.update(visible=True))
|
284 |
+
else:
|
285 |
+
values.append(gr.Row.update(visible=False))
|
286 |
+
|
287 |
+
return tuple(values)
|
288 |
+
|
289 |
+
|
290 |
+
def train_model(
|
291 |
+
print_only,
|
292 |
+
pretrained_model_name_or_path,
|
293 |
+
v2,
|
294 |
+
v_parameterization,
|
295 |
+
logging_dir,
|
296 |
+
train_data_dir,
|
297 |
+
reg_data_dir,
|
298 |
+
output_dir,
|
299 |
+
max_resolution,
|
300 |
+
learning_rate,
|
301 |
+
lr_scheduler,
|
302 |
+
lr_warmup,
|
303 |
+
train_batch_size,
|
304 |
+
epoch,
|
305 |
+
save_every_n_epochs,
|
306 |
+
mixed_precision,
|
307 |
+
save_precision,
|
308 |
+
seed,
|
309 |
+
num_cpu_threads_per_process,
|
310 |
+
cache_latents,
|
311 |
+
caption_extension,
|
312 |
+
enable_bucket,
|
313 |
+
gradient_checkpointing,
|
314 |
+
full_fp16,
|
315 |
+
no_token_padding,
|
316 |
+
stop_text_encoder_training_pct,
|
317 |
+
# use_8bit_adam,
|
318 |
+
xformers,
|
319 |
+
save_model_as,
|
320 |
+
shuffle_caption,
|
321 |
+
save_state,
|
322 |
+
resume,
|
323 |
+
prior_loss_weight,
|
324 |
+
text_encoder_lr,
|
325 |
+
unet_lr,
|
326 |
+
network_dim,
|
327 |
+
lora_network_weights,
|
328 |
+
color_aug,
|
329 |
+
flip_aug,
|
330 |
+
clip_skip,
|
331 |
+
gradient_accumulation_steps,
|
332 |
+
mem_eff_attn,
|
333 |
+
output_name,
|
334 |
+
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
335 |
+
max_token_length,
|
336 |
+
max_train_epochs,
|
337 |
+
max_data_loader_n_workers,
|
338 |
+
network_alpha,
|
339 |
+
training_comment,
|
340 |
+
keep_tokens,
|
341 |
+
lr_scheduler_num_cycles,
|
342 |
+
lr_scheduler_power,
|
343 |
+
persistent_data_loader_workers,
|
344 |
+
bucket_no_upscale,
|
345 |
+
random_crop,
|
346 |
+
bucket_reso_steps,
|
347 |
+
caption_dropout_every_n_epochs,
|
348 |
+
caption_dropout_rate,
|
349 |
+
optimizer,
|
350 |
+
optimizer_args,
|
351 |
+
noise_offset,
|
352 |
+
LoRA_type,
|
353 |
+
conv_dim,
|
354 |
+
conv_alpha,
|
355 |
+
sample_every_n_steps,
|
356 |
+
sample_every_n_epochs,
|
357 |
+
sample_sampler,
|
358 |
+
sample_prompts,
|
359 |
+
additional_parameters,
|
360 |
+
vae_batch_size,
|
361 |
+
min_snr_gamma,
|
362 |
+
down_lr_weight,mid_lr_weight,up_lr_weight,block_lr_zero_threshold,block_dims,block_alphas,conv_dims,conv_alphas,
|
363 |
+
weighted_captions,unit,
|
364 |
+
):
|
365 |
+
print_only_bool = True if print_only.get('label') == 'True' else False
|
366 |
+
|
367 |
+
if pretrained_model_name_or_path == '':
|
368 |
+
msgbox('Source model information is missing')
|
369 |
+
return
|
370 |
+
|
371 |
+
if train_data_dir == '':
|
372 |
+
msgbox('Image folder path is missing')
|
373 |
+
return
|
374 |
+
|
375 |
+
if not os.path.exists(train_data_dir):
|
376 |
+
msgbox('Image folder does not exist')
|
377 |
+
return
|
378 |
+
|
379 |
+
if reg_data_dir != '':
|
380 |
+
if not os.path.exists(reg_data_dir):
|
381 |
+
msgbox('Regularisation folder does not exist')
|
382 |
+
return
|
383 |
+
|
384 |
+
if output_dir == '':
|
385 |
+
msgbox('Output folder path is missing')
|
386 |
+
return
|
387 |
+
|
388 |
+
if int(bucket_reso_steps) < 1:
|
389 |
+
msgbox('Bucket resolution steps need to be greater than 0')
|
390 |
+
return
|
391 |
+
|
392 |
+
if not os.path.exists(output_dir):
|
393 |
+
os.makedirs(output_dir)
|
394 |
+
|
395 |
+
if stop_text_encoder_training_pct > 0:
|
396 |
+
msgbox(
|
397 |
+
'Output "stop text encoder training" is not yet supported. Ignoring'
|
398 |
+
)
|
399 |
+
stop_text_encoder_training_pct = 0
|
400 |
+
|
401 |
+
if check_if_model_exist(output_name, output_dir, save_model_as):
|
402 |
+
return
|
403 |
+
|
404 |
+
if optimizer == 'Adafactor' and lr_warmup != '0':
|
405 |
+
msgbox("Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.", title="Warning")
|
406 |
+
lr_warmup = '0'
|
407 |
+
|
408 |
+
# If string is empty set string to 0.
|
409 |
+
if text_encoder_lr == '':
|
410 |
+
text_encoder_lr = 0
|
411 |
+
if unet_lr == '':
|
412 |
+
unet_lr = 0
|
413 |
+
|
414 |
+
# if (float(text_encoder_lr) == 0) and (float(unet_lr) == 0):
|
415 |
+
# msgbox(
|
416 |
+
# 'At least one Learning Rate value for "Text encoder" or "Unet" need to be provided'
|
417 |
+
# )
|
418 |
+
# return
|
419 |
+
|
420 |
+
# Get a list of all subfolders in train_data_dir
|
421 |
+
subfolders = [
|
422 |
+
f
|
423 |
+
for f in os.listdir(train_data_dir)
|
424 |
+
if os.path.isdir(os.path.join(train_data_dir, f))
|
425 |
+
]
|
426 |
+
|
427 |
+
total_steps = 0
|
428 |
+
|
429 |
+
# Loop through each subfolder and extract the number of repeats
|
430 |
+
for folder in subfolders:
|
431 |
+
try:
|
432 |
+
# Extract the number of repeats from the folder name
|
433 |
+
repeats = int(folder.split('_')[0])
|
434 |
+
|
435 |
+
# Count the number of images in the folder
|
436 |
+
num_images = len(
|
437 |
+
[
|
438 |
+
f
|
439 |
+
for f, lower_f in (
|
440 |
+
(file, file.lower())
|
441 |
+
for file in os.listdir(
|
442 |
+
os.path.join(train_data_dir, folder)
|
443 |
+
)
|
444 |
+
)
|
445 |
+
if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp'))
|
446 |
+
]
|
447 |
+
)
|
448 |
+
|
449 |
+
print(f'Folder {folder}: {num_images} images found')
|
450 |
+
|
451 |
+
# Calculate the total number of steps for this folder
|
452 |
+
steps = repeats * num_images
|
453 |
+
|
454 |
+
# Print the result
|
455 |
+
print(f'Folder {folder}: {steps} steps')
|
456 |
+
|
457 |
+
total_steps += steps
|
458 |
+
|
459 |
+
except ValueError:
|
460 |
+
# Handle the case where the folder name does not contain an underscore
|
461 |
+
print(f"Error: '{folder}' does not contain an underscore, skipping...")
|
462 |
+
|
463 |
+
# calculate max_train_steps
|
464 |
+
max_train_steps = int(
|
465 |
+
math.ceil(
|
466 |
+
float(total_steps)
|
467 |
+
/ int(train_batch_size)
|
468 |
+
* int(epoch)
|
469 |
+
# * int(reg_factor)
|
470 |
+
)
|
471 |
+
)
|
472 |
+
print(f'max_train_steps = {max_train_steps}')
|
473 |
+
|
474 |
+
# calculate stop encoder training
|
475 |
+
if stop_text_encoder_training_pct == None:
|
476 |
+
stop_text_encoder_training = 0
|
477 |
+
else:
|
478 |
+
stop_text_encoder_training = math.ceil(
|
479 |
+
float(max_train_steps) / 100 * int(stop_text_encoder_training_pct)
|
480 |
+
)
|
481 |
+
print(f'stop_text_encoder_training = {stop_text_encoder_training}')
|
482 |
+
|
483 |
+
lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
|
484 |
+
print(f'lr_warmup_steps = {lr_warmup_steps}')
|
485 |
+
|
486 |
+
run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_network.py"'
|
487 |
+
|
488 |
+
if v2:
|
489 |
+
run_cmd += ' --v2'
|
490 |
+
if v_parameterization:
|
491 |
+
run_cmd += ' --v_parameterization'
|
492 |
+
if enable_bucket:
|
493 |
+
run_cmd += ' --enable_bucket'
|
494 |
+
if no_token_padding:
|
495 |
+
run_cmd += ' --no_token_padding'
|
496 |
+
if weighted_captions:
|
497 |
+
run_cmd += ' --weighted_captions'
|
498 |
+
run_cmd += (
|
499 |
+
f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"'
|
500 |
+
)
|
501 |
+
run_cmd += f' --train_data_dir="{train_data_dir}"'
|
502 |
+
if len(reg_data_dir):
|
503 |
+
run_cmd += f' --reg_data_dir="{reg_data_dir}"'
|
504 |
+
run_cmd += f' --resolution={max_resolution}'
|
505 |
+
run_cmd += f' --output_dir="{output_dir}"'
|
506 |
+
run_cmd += f' --logging_dir="{logging_dir}"'
|
507 |
+
run_cmd += f' --network_alpha="{network_alpha}"'
|
508 |
+
if not training_comment == '':
|
509 |
+
run_cmd += f' --training_comment="{training_comment}"'
|
510 |
+
if not stop_text_encoder_training == 0:
|
511 |
+
run_cmd += (
|
512 |
+
f' --stop_text_encoder_training={stop_text_encoder_training}'
|
513 |
+
)
|
514 |
+
if not save_model_as == 'same as source model':
|
515 |
+
run_cmd += f' --save_model_as={save_model_as}'
|
516 |
+
if not float(prior_loss_weight) == 1.0:
|
517 |
+
run_cmd += f' --prior_loss_weight={prior_loss_weight}'
|
518 |
+
if LoRA_type == 'LoCon' or LoRA_type == 'LyCORIS/LoCon':
|
519 |
+
try:
|
520 |
+
import lycoris
|
521 |
+
except ModuleNotFoundError:
|
522 |
+
print(
|
523 |
+
"\033[1;31mError:\033[0m The required module 'lycoris_lora' is not installed. Please install by running \033[33mupgrade.ps1\033[0m before running this program."
|
524 |
+
)
|
525 |
+
return
|
526 |
+
run_cmd += f' --network_module=lycoris.kohya'
|
527 |
+
run_cmd += f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "algo=lora"'
|
528 |
+
if LoRA_type == 'LyCORIS/LoHa':
|
529 |
+
try:
|
530 |
+
import lycoris
|
531 |
+
except ModuleNotFoundError:
|
532 |
+
print(
|
533 |
+
"\033[1;31mError:\033[0m The required module 'lycoris_lora' is not installed. Please install by running \033[33mupgrade.ps1\033[0m before running this program."
|
534 |
+
)
|
535 |
+
return
|
536 |
+
run_cmd += f' --network_module=lycoris.kohya'
|
537 |
+
run_cmd += f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "algo=loha"'
|
538 |
+
|
539 |
+
|
540 |
+
if LoRA_type in ['Kohya LoCon', 'Standard']:
|
541 |
+
kohya_lora_var_list = ['down_lr_weight', 'mid_lr_weight', 'up_lr_weight', 'block_lr_zero_threshold', 'block_dims', 'block_alphas', 'conv_dims', 'conv_alphas']
|
542 |
+
|
543 |
+
run_cmd += f' --network_module=networks.lora'
|
544 |
+
kohya_lora_vars = {key: value for key, value in vars().items() if key in kohya_lora_var_list and value}
|
545 |
+
|
546 |
+
network_args = ''
|
547 |
+
if LoRA_type == 'Kohya LoCon':
|
548 |
+
network_args += f' "conv_dim={conv_dim}" "conv_alpha={conv_alpha}"'
|
549 |
+
|
550 |
+
for key, value in kohya_lora_vars.items():
|
551 |
+
if value:
|
552 |
+
network_args += f' {key}="{value}"'
|
553 |
+
|
554 |
+
if network_args:
|
555 |
+
run_cmd += f' --network_args{network_args}'
|
556 |
+
|
557 |
+
if LoRA_type in ['Kohya DyLoRA']:
|
558 |
+
kohya_lora_var_list = ['conv_dim', 'conv_alpha', 'down_lr_weight', 'mid_lr_weight', 'up_lr_weight', 'block_lr_zero_threshold', 'block_dims', 'block_alphas', 'conv_dims', 'conv_alphas', 'unit']
|
559 |
+
|
560 |
+
run_cmd += f' --network_module=networks.dylora'
|
561 |
+
kohya_lora_vars = {key: value for key, value in vars().items() if key in kohya_lora_var_list and value}
|
562 |
+
|
563 |
+
network_args = ''
|
564 |
+
|
565 |
+
for key, value in kohya_lora_vars.items():
|
566 |
+
if value:
|
567 |
+
network_args += f' {key}="{value}"'
|
568 |
+
|
569 |
+
if network_args:
|
570 |
+
run_cmd += f' --network_args{network_args}'
|
571 |
+
|
572 |
+
if not (float(text_encoder_lr) == 0) or not (float(unet_lr) == 0):
|
573 |
+
if not (float(text_encoder_lr) == 0) and not (float(unet_lr) == 0):
|
574 |
+
run_cmd += f' --text_encoder_lr={text_encoder_lr}'
|
575 |
+
run_cmd += f' --unet_lr={unet_lr}'
|
576 |
+
elif not (float(text_encoder_lr) == 0):
|
577 |
+
run_cmd += f' --text_encoder_lr={text_encoder_lr}'
|
578 |
+
run_cmd += f' --network_train_text_encoder_only'
|
579 |
+
else:
|
580 |
+
run_cmd += f' --unet_lr={unet_lr}'
|
581 |
+
run_cmd += f' --network_train_unet_only'
|
582 |
+
else:
|
583 |
+
if float(text_encoder_lr) == 0:
|
584 |
+
msgbox('Please input learning rate values.')
|
585 |
+
return
|
586 |
+
|
587 |
+
run_cmd += f' --network_dim={network_dim}'
|
588 |
+
|
589 |
+
if not lora_network_weights == '':
|
590 |
+
run_cmd += f' --network_weights="{lora_network_weights}"'
|
591 |
+
if int(gradient_accumulation_steps) > 1:
|
592 |
+
run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}'
|
593 |
+
if not output_name == '':
|
594 |
+
run_cmd += f' --output_name="{output_name}"'
|
595 |
+
if not lr_scheduler_num_cycles == '':
|
596 |
+
run_cmd += f' --lr_scheduler_num_cycles="{lr_scheduler_num_cycles}"'
|
597 |
+
else:
|
598 |
+
run_cmd += f' --lr_scheduler_num_cycles="{epoch}"'
|
599 |
+
if not lr_scheduler_power == '':
|
600 |
+
run_cmd += f' --lr_scheduler_power="{lr_scheduler_power}"'
|
601 |
+
|
602 |
+
run_cmd += run_cmd_training(
|
603 |
+
learning_rate=learning_rate,
|
604 |
+
lr_scheduler=lr_scheduler,
|
605 |
+
lr_warmup_steps=lr_warmup_steps,
|
606 |
+
train_batch_size=train_batch_size,
|
607 |
+
max_train_steps=max_train_steps,
|
608 |
+
save_every_n_epochs=save_every_n_epochs,
|
609 |
+
mixed_precision=mixed_precision,
|
610 |
+
save_precision=save_precision,
|
611 |
+
seed=seed,
|
612 |
+
caption_extension=caption_extension,
|
613 |
+
cache_latents=cache_latents,
|
614 |
+
optimizer=optimizer,
|
615 |
+
optimizer_args=optimizer_args,
|
616 |
+
)
|
617 |
+
|
618 |
+
run_cmd += run_cmd_advanced_training(
|
619 |
+
max_train_epochs=max_train_epochs,
|
620 |
+
max_data_loader_n_workers=max_data_loader_n_workers,
|
621 |
+
max_token_length=max_token_length,
|
622 |
+
resume=resume,
|
623 |
+
save_state=save_state,
|
624 |
+
mem_eff_attn=mem_eff_attn,
|
625 |
+
clip_skip=clip_skip,
|
626 |
+
flip_aug=flip_aug,
|
627 |
+
color_aug=color_aug,
|
628 |
+
shuffle_caption=shuffle_caption,
|
629 |
+
gradient_checkpointing=gradient_checkpointing,
|
630 |
+
full_fp16=full_fp16,
|
631 |
+
xformers=xformers,
|
632 |
+
# use_8bit_adam=use_8bit_adam,
|
633 |
+
keep_tokens=keep_tokens,
|
634 |
+
persistent_data_loader_workers=persistent_data_loader_workers,
|
635 |
+
bucket_no_upscale=bucket_no_upscale,
|
636 |
+
random_crop=random_crop,
|
637 |
+
bucket_reso_steps=bucket_reso_steps,
|
638 |
+
caption_dropout_every_n_epochs=caption_dropout_every_n_epochs,
|
639 |
+
caption_dropout_rate=caption_dropout_rate,
|
640 |
+
noise_offset=noise_offset,
|
641 |
+
additional_parameters=additional_parameters,
|
642 |
+
vae_batch_size=vae_batch_size,
|
643 |
+
min_snr_gamma=min_snr_gamma,
|
644 |
+
)
|
645 |
+
|
646 |
+
run_cmd += run_cmd_sample(
|
647 |
+
sample_every_n_steps,
|
648 |
+
sample_every_n_epochs,
|
649 |
+
sample_sampler,
|
650 |
+
sample_prompts,
|
651 |
+
output_dir,
|
652 |
+
)
|
653 |
+
|
654 |
+
# if not down_lr_weight == '':
|
655 |
+
# run_cmd += f' --down_lr_weight="{down_lr_weight}"'
|
656 |
+
# if not mid_lr_weight == '':
|
657 |
+
# run_cmd += f' --mid_lr_weight="{mid_lr_weight}"'
|
658 |
+
# if not up_lr_weight == '':
|
659 |
+
# run_cmd += f' --up_lr_weight="{up_lr_weight}"'
|
660 |
+
# if not block_lr_zero_threshold == '':
|
661 |
+
# run_cmd += f' --block_lr_zero_threshold="{block_lr_zero_threshold}"'
|
662 |
+
# if not block_dims == '':
|
663 |
+
# run_cmd += f' --block_dims="{block_dims}"'
|
664 |
+
# if not block_alphas == '':
|
665 |
+
# run_cmd += f' --block_alphas="{block_alphas}"'
|
666 |
+
# if not conv_dims == '':
|
667 |
+
# run_cmd += f' --conv_dims="{conv_dims}"'
|
668 |
+
# if not conv_alphas == '':
|
669 |
+
# run_cmd += f' --conv_alphas="{conv_alphas}"'
|
670 |
+
|
671 |
+
|
672 |
+
|
673 |
+
|
674 |
+
if print_only_bool:
|
675 |
+
print(
|
676 |
+
'\033[93m\nHere is the trainer command as a reference. It will not be executed:\033[0m\n'
|
677 |
+
)
|
678 |
+
print('\033[96m' + run_cmd + '\033[0m\n')
|
679 |
+
else:
|
680 |
+
print(run_cmd)
|
681 |
+
# Run the command
|
682 |
+
if os.name == 'posix':
|
683 |
+
os.system(run_cmd)
|
684 |
+
else:
|
685 |
+
subprocess.run(run_cmd)
|
686 |
+
|
687 |
+
# check if output_dir/last is a folder... therefore it is a diffuser model
|
688 |
+
last_dir = pathlib.Path(f'{output_dir}/{output_name}')
|
689 |
+
|
690 |
+
if not last_dir.is_dir():
|
691 |
+
# Copy inference model for v2 if required
|
692 |
+
save_inference_file(
|
693 |
+
output_dir, v2, v_parameterization, output_name
|
694 |
+
)
|
695 |
+
|
696 |
+
|
697 |
+
def lora_tab(
|
698 |
+
train_data_dir_input=gr.Textbox(),
|
699 |
+
reg_data_dir_input=gr.Textbox(),
|
700 |
+
output_dir_input=gr.Textbox(),
|
701 |
+
logging_dir_input=gr.Textbox(),
|
702 |
+
):
|
703 |
+
dummy_db_true = gr.Label(value=True, visible=False)
|
704 |
+
dummy_db_false = gr.Label(value=False, visible=False)
|
705 |
+
gr.Markdown(
|
706 |
+
'Train a custom model using kohya train network LoRA python code...'
|
707 |
+
)
|
708 |
+
(
|
709 |
+
button_open_config,
|
710 |
+
button_save_config,
|
711 |
+
button_save_as_config,
|
712 |
+
config_file_name,
|
713 |
+
button_load_config,
|
714 |
+
) = gradio_config()
|
715 |
+
|
716 |
+
(
|
717 |
+
pretrained_model_name_or_path,
|
718 |
+
v2,
|
719 |
+
v_parameterization,
|
720 |
+
save_model_as,
|
721 |
+
model_list,
|
722 |
+
) = gradio_source_model(
|
723 |
+
save_model_as_choices=[
|
724 |
+
'ckpt',
|
725 |
+
'safetensors',
|
726 |
+
]
|
727 |
+
)
|
728 |
+
|
729 |
+
with gr.Tab('Folders'):
|
730 |
+
with gr.Row():
|
731 |
+
train_data_dir = gr.Textbox(
|
732 |
+
label='Image folder',
|
733 |
+
placeholder='Folder where the training folders containing the images are located',
|
734 |
+
)
|
735 |
+
train_data_dir_folder = gr.Button('📂', elem_id='open_folder_small')
|
736 |
+
train_data_dir_folder.click(
|
737 |
+
get_folder_path,
|
738 |
+
outputs=train_data_dir,
|
739 |
+
show_progress=False,
|
740 |
+
)
|
741 |
+
reg_data_dir = gr.Textbox(
|
742 |
+
label='Regularisation folder',
|
743 |
+
placeholder='(Optional) Folder where where the regularization folders containing the images are located',
|
744 |
+
)
|
745 |
+
reg_data_dir_folder = gr.Button('📂', elem_id='open_folder_small')
|
746 |
+
reg_data_dir_folder.click(
|
747 |
+
get_folder_path,
|
748 |
+
outputs=reg_data_dir,
|
749 |
+
show_progress=False,
|
750 |
+
)
|
751 |
+
with gr.Row():
|
752 |
+
output_dir = gr.Textbox(
|
753 |
+
label='Output folder',
|
754 |
+
placeholder='Folder to output trained model',
|
755 |
+
)
|
756 |
+
output_dir_folder = gr.Button('📂', elem_id='open_folder_small')
|
757 |
+
output_dir_folder.click(
|
758 |
+
get_folder_path,
|
759 |
+
outputs=output_dir,
|
760 |
+
show_progress=False,
|
761 |
+
)
|
762 |
+
logging_dir = gr.Textbox(
|
763 |
+
label='Logging folder',
|
764 |
+
placeholder='Optional: enable logging and output TensorBoard log to this folder',
|
765 |
+
)
|
766 |
+
logging_dir_folder = gr.Button('📂', elem_id='open_folder_small')
|
767 |
+
logging_dir_folder.click(
|
768 |
+
get_folder_path,
|
769 |
+
outputs=logging_dir,
|
770 |
+
show_progress=False,
|
771 |
+
)
|
772 |
+
with gr.Row():
|
773 |
+
output_name = gr.Textbox(
|
774 |
+
label='Model output name',
|
775 |
+
placeholder='(Name of the model to output)',
|
776 |
+
value='last',
|
777 |
+
interactive=True,
|
778 |
+
)
|
779 |
+
training_comment = gr.Textbox(
|
780 |
+
label='Training comment',
|
781 |
+
placeholder='(Optional) Add training comment to be included in metadata',
|
782 |
+
interactive=True,
|
783 |
+
)
|
784 |
+
train_data_dir.change(
|
785 |
+
remove_doublequote,
|
786 |
+
inputs=[train_data_dir],
|
787 |
+
outputs=[train_data_dir],
|
788 |
+
)
|
789 |
+
reg_data_dir.change(
|
790 |
+
remove_doublequote,
|
791 |
+
inputs=[reg_data_dir],
|
792 |
+
outputs=[reg_data_dir],
|
793 |
+
)
|
794 |
+
output_dir.change(
|
795 |
+
remove_doublequote,
|
796 |
+
inputs=[output_dir],
|
797 |
+
outputs=[output_dir],
|
798 |
+
)
|
799 |
+
logging_dir.change(
|
800 |
+
remove_doublequote,
|
801 |
+
inputs=[logging_dir],
|
802 |
+
outputs=[logging_dir],
|
803 |
+
)
|
804 |
+
with gr.Tab('Training parameters'):
|
805 |
+
with gr.Row():
|
806 |
+
LoRA_type = gr.Dropdown(
|
807 |
+
label='LoRA type',
|
808 |
+
choices=[
|
809 |
+
'Kohya DyLoRA',
|
810 |
+
'Kohya LoCon',
|
811 |
+
# 'LoCon',
|
812 |
+
'LyCORIS/LoCon',
|
813 |
+
'LyCORIS/LoHa',
|
814 |
+
'Standard',
|
815 |
+
],
|
816 |
+
value='Standard',
|
817 |
+
)
|
818 |
+
lora_network_weights = gr.Textbox(
|
819 |
+
label='LoRA network weights',
|
820 |
+
placeholder='{Optional) Path to existing LoRA network weights to resume training',
|
821 |
+
)
|
822 |
+
lora_network_weights_file = gr.Button(
|
823 |
+
document_symbol, elem_id='open_folder_small'
|
824 |
+
)
|
825 |
+
lora_network_weights_file.click(
|
826 |
+
get_any_file_path,
|
827 |
+
inputs=[lora_network_weights],
|
828 |
+
outputs=lora_network_weights,
|
829 |
+
show_progress=False,
|
830 |
+
)
|
831 |
+
(
|
832 |
+
learning_rate,
|
833 |
+
lr_scheduler,
|
834 |
+
lr_warmup,
|
835 |
+
train_batch_size,
|
836 |
+
epoch,
|
837 |
+
save_every_n_epochs,
|
838 |
+
mixed_precision,
|
839 |
+
save_precision,
|
840 |
+
num_cpu_threads_per_process,
|
841 |
+
seed,
|
842 |
+
caption_extension,
|
843 |
+
cache_latents,
|
844 |
+
optimizer,
|
845 |
+
optimizer_args,
|
846 |
+
) = gradio_training(
|
847 |
+
learning_rate_value='0.0001',
|
848 |
+
lr_scheduler_value='cosine',
|
849 |
+
lr_warmup_value='10',
|
850 |
+
)
|
851 |
+
|
852 |
+
with gr.Row():
|
853 |
+
text_encoder_lr = gr.Textbox(
|
854 |
+
label='Text Encoder learning rate',
|
855 |
+
value='5e-5',
|
856 |
+
placeholder='Optional',
|
857 |
+
)
|
858 |
+
unet_lr = gr.Textbox(
|
859 |
+
label='Unet learning rate',
|
860 |
+
value='0.0001',
|
861 |
+
placeholder='Optional',
|
862 |
+
)
|
863 |
+
network_dim = gr.Slider(
|
864 |
+
minimum=1,
|
865 |
+
maximum=1024,
|
866 |
+
label='Network Rank (Dimension)',
|
867 |
+
value=8,
|
868 |
+
step=1,
|
869 |
+
interactive=True,
|
870 |
+
)
|
871 |
+
network_alpha = gr.Slider(
|
872 |
+
minimum=0.1,
|
873 |
+
maximum=1024,
|
874 |
+
label='Network Alpha',
|
875 |
+
value=1,
|
876 |
+
step=0.1,
|
877 |
+
interactive=True,
|
878 |
+
info='alpha for LoRA weight scaling',
|
879 |
+
)
|
880 |
+
with gr.Row(visible=False) as LoCon_row:
|
881 |
+
|
882 |
+
# locon= gr.Checkbox(label='Train a LoCon instead of a general LoRA (does not support v2 base models) (may not be able to some utilities now)', value=False)
|
883 |
+
conv_dim = gr.Slider(
|
884 |
+
minimum=1,
|
885 |
+
maximum=512,
|
886 |
+
value=1,
|
887 |
+
step=1,
|
888 |
+
label='Convolution Rank (Dimension)',
|
889 |
+
)
|
890 |
+
conv_alpha = gr.Slider(
|
891 |
+
minimum=0.1,
|
892 |
+
maximum=512,
|
893 |
+
value=1,
|
894 |
+
step=0.1,
|
895 |
+
label='Convolution Alpha',
|
896 |
+
)
|
897 |
+
with gr.Row(visible=False) as kohya_dylora:
|
898 |
+
unit = gr.Slider(
|
899 |
+
minimum=1,
|
900 |
+
maximum=64,
|
901 |
+
label='DyLoRA Unit',
|
902 |
+
value=1,
|
903 |
+
step=1,
|
904 |
+
interactive=True,
|
905 |
+
)
|
906 |
+
|
907 |
+
# Show of hide LoCon conv settings depending on LoRA type selection
|
908 |
+
def update_LoRA_settings(LoRA_type):
|
909 |
+
# Print a message when LoRA type is changed
|
910 |
+
print('LoRA type changed...')
|
911 |
+
|
912 |
+
# Determine if LoCon_row should be visible based on LoRA_type
|
913 |
+
LoCon_row = LoRA_type in {'LoCon', 'Kohya DyLoRA', 'Kohya LoCon', 'LyCORIS/LoHa', 'LyCORIS/LoCon'}
|
914 |
+
|
915 |
+
# Determine if LoRA_type_change should be visible based on LoRA_type
|
916 |
+
LoRA_type_change = LoRA_type in {'Standard', 'Kohya DyLoRA', 'Kohya LoCon'}
|
917 |
+
|
918 |
+
# Determine if kohya_dylora_visible should be visible based on LoRA_type
|
919 |
+
kohya_dylora_visible = LoRA_type == 'Kohya DyLoRA'
|
920 |
+
|
921 |
+
# Return the updated visibility settings for the groups
|
922 |
+
return (
|
923 |
+
gr.Group.update(visible=LoCon_row),
|
924 |
+
gr.Group.update(visible=LoRA_type_change),
|
925 |
+
gr.Group.update(visible=kohya_dylora_visible),
|
926 |
+
)
|
927 |
+
|
928 |
+
|
929 |
+
with gr.Row():
|
930 |
+
max_resolution = gr.Textbox(
|
931 |
+
label='Max resolution',
|
932 |
+
value='512,512',
|
933 |
+
placeholder='512,512',
|
934 |
+
info='The maximum resolution of dataset images. W,H',
|
935 |
+
)
|
936 |
+
stop_text_encoder_training = gr.Slider(
|
937 |
+
minimum=0,
|
938 |
+
maximum=100,
|
939 |
+
value=0,
|
940 |
+
step=1,
|
941 |
+
label='Stop text encoder training',
|
942 |
+
info='After what % of steps should the text encoder stop being trained. 0 = train for all steps.',
|
943 |
+
)
|
944 |
+
enable_bucket = gr.Checkbox(label='Enable buckets', value=True,
|
945 |
+
info='Allow non similar resolution dataset images to be trained on.',)
|
946 |
+
|
947 |
+
with gr.Accordion('Advanced Configuration', open=False):
|
948 |
+
with gr.Row(visible=True) as kohya_advanced_lora:
|
949 |
+
with gr.Tab(label='Weights'):
|
950 |
+
with gr.Row(visible=True):
|
951 |
+
down_lr_weight = gr.Textbox(
|
952 |
+
label='Down LR weights',
|
953 |
+
placeholder='(Optional) eg: 0,0,0,0,0,0,1,1,1,1,1,1',
|
954 |
+
info='Specify the learning rate weight of the down blocks of U-Net.'
|
955 |
+
)
|
956 |
+
mid_lr_weight = gr.Textbox(
|
957 |
+
label='Mid LR weights',
|
958 |
+
placeholder='(Optional) eg: 0.5',
|
959 |
+
info='Specify the learning rate weight of the mid block of U-Net.'
|
960 |
+
)
|
961 |
+
up_lr_weight = gr.Textbox(
|
962 |
+
label='Up LR weights',
|
963 |
+
placeholder='(Optional) eg: 0,0,0,0,0,0,1,1,1,1,1,1',
|
964 |
+
info='Specify the learning rate weight of the up blocks of U-Net. The same as down_lr_weight.'
|
965 |
+
)
|
966 |
+
block_lr_zero_threshold = gr.Textbox(
|
967 |
+
label='Blocks LR zero threshold',
|
968 |
+
placeholder='(Optional) eg: 0.1',
|
969 |
+
info='If the weight is not more than this value, the LoRA module is not created. The default is 0.'
|
970 |
+
)
|
971 |
+
with gr.Tab(label='Blocks'):
|
972 |
+
with gr.Row(visible=True):
|
973 |
+
block_dims = gr.Textbox(
|
974 |
+
label='Block dims',
|
975 |
+
placeholder='(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2',
|
976 |
+
info='Specify the dim (rank) of each block. Specify 25 numbers.'
|
977 |
+
)
|
978 |
+
block_alphas = gr.Textbox(
|
979 |
+
label='Block alphas',
|
980 |
+
placeholder='(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2',
|
981 |
+
info='Specify the alpha of each block. Specify 25 numbers as with block_dims. If omitted, the value of network_alpha is used.'
|
982 |
+
)
|
983 |
+
with gr.Tab(label='Conv'):
|
984 |
+
with gr.Row(visible=True):
|
985 |
+
conv_dims = gr.Textbox(
|
986 |
+
label='Conv dims',
|
987 |
+
placeholder='(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2',
|
988 |
+
info='Expand LoRA to Conv2d 3x3 and specify the dim (rank) of each block. Specify 25 numbers.'
|
989 |
+
)
|
990 |
+
conv_alphas = gr.Textbox(
|
991 |
+
label='Conv alphas',
|
992 |
+
placeholder='(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2',
|
993 |
+
info='Specify the alpha of each block when expanding LoRA to Conv2d 3x3. Specify 25 numbers. If omitted, the value of conv_alpha is used.'
|
994 |
+
)
|
995 |
+
with gr.Row():
|
996 |
+
no_token_padding = gr.Checkbox(
|
997 |
+
label='No token padding', value=False
|
998 |
+
)
|
999 |
+
gradient_accumulation_steps = gr.Number(
|
1000 |
+
label='Gradient accumulate steps', value='1'
|
1001 |
+
)
|
1002 |
+
weighted_captions = gr.Checkbox(
|
1003 |
+
label='Weighted captions', value=False, info='Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder.',
|
1004 |
+
)
|
1005 |
+
with gr.Row():
|
1006 |
+
prior_loss_weight = gr.Number(
|
1007 |
+
label='Prior loss weight', value=1.0
|
1008 |
+
)
|
1009 |
+
lr_scheduler_num_cycles = gr.Textbox(
|
1010 |
+
label='LR number of cycles',
|
1011 |
+
placeholder='(Optional) For Cosine with restart and polynomial only',
|
1012 |
+
)
|
1013 |
+
|
1014 |
+
lr_scheduler_power = gr.Textbox(
|
1015 |
+
label='LR power',
|
1016 |
+
placeholder='(Optional) For Cosine with restart and polynomial only',
|
1017 |
+
)
|
1018 |
+
(
|
1019 |
+
# use_8bit_adam,
|
1020 |
+
xformers,
|
1021 |
+
full_fp16,
|
1022 |
+
gradient_checkpointing,
|
1023 |
+
shuffle_caption,
|
1024 |
+
color_aug,
|
1025 |
+
flip_aug,
|
1026 |
+
clip_skip,
|
1027 |
+
mem_eff_attn,
|
1028 |
+
save_state,
|
1029 |
+
resume,
|
1030 |
+
max_token_length,
|
1031 |
+
max_train_epochs,
|
1032 |
+
max_data_loader_n_workers,
|
1033 |
+
keep_tokens,
|
1034 |
+
persistent_data_loader_workers,
|
1035 |
+
bucket_no_upscale,
|
1036 |
+
random_crop,
|
1037 |
+
bucket_reso_steps,
|
1038 |
+
caption_dropout_every_n_epochs,
|
1039 |
+
caption_dropout_rate,
|
1040 |
+
noise_offset,
|
1041 |
+
additional_parameters,
|
1042 |
+
vae_batch_size,
|
1043 |
+
min_snr_gamma,
|
1044 |
+
) = gradio_advanced_training()
|
1045 |
+
color_aug.change(
|
1046 |
+
color_aug_changed,
|
1047 |
+
inputs=[color_aug],
|
1048 |
+
outputs=[cache_latents],
|
1049 |
+
)
|
1050 |
+
|
1051 |
+
(
|
1052 |
+
sample_every_n_steps,
|
1053 |
+
sample_every_n_epochs,
|
1054 |
+
sample_sampler,
|
1055 |
+
sample_prompts,
|
1056 |
+
) = sample_gradio_config()
|
1057 |
+
|
1058 |
+
LoRA_type.change(
|
1059 |
+
update_LoRA_settings, inputs=[LoRA_type], outputs=[LoCon_row, kohya_advanced_lora, kohya_dylora]
|
1060 |
+
)
|
1061 |
+
|
1062 |
+
with gr.Tab('Tools'):
|
1063 |
+
gr.Markdown(
|
1064 |
+
'This section provide Dreambooth tools to help setup your dataset...'
|
1065 |
+
)
|
1066 |
+
gradio_dreambooth_folder_creation_tab(
|
1067 |
+
train_data_dir_input=train_data_dir,
|
1068 |
+
reg_data_dir_input=reg_data_dir,
|
1069 |
+
output_dir_input=output_dir,
|
1070 |
+
logging_dir_input=logging_dir,
|
1071 |
+
)
|
1072 |
+
gradio_dataset_balancing_tab()
|
1073 |
+
gradio_merge_lora_tab()
|
1074 |
+
gradio_svd_merge_lora_tab()
|
1075 |
+
gradio_resize_lora_tab()
|
1076 |
+
gradio_verify_lora_tab()
|
1077 |
+
|
1078 |
+
button_run = gr.Button('Train model', variant='primary')
|
1079 |
+
|
1080 |
+
button_print = gr.Button('Print training command')
|
1081 |
+
|
1082 |
+
# Setup gradio tensorboard buttons
|
1083 |
+
button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard()
|
1084 |
+
|
1085 |
+
button_start_tensorboard.click(
|
1086 |
+
start_tensorboard,
|
1087 |
+
inputs=logging_dir,
|
1088 |
+
show_progress=False,
|
1089 |
+
)
|
1090 |
+
|
1091 |
+
button_stop_tensorboard.click(
|
1092 |
+
stop_tensorboard,
|
1093 |
+
show_progress=False,
|
1094 |
+
)
|
1095 |
+
|
1096 |
+
settings_list = [
|
1097 |
+
pretrained_model_name_or_path,
|
1098 |
+
v2,
|
1099 |
+
v_parameterization,
|
1100 |
+
logging_dir,
|
1101 |
+
train_data_dir,
|
1102 |
+
reg_data_dir,
|
1103 |
+
output_dir,
|
1104 |
+
max_resolution,
|
1105 |
+
learning_rate,
|
1106 |
+
lr_scheduler,
|
1107 |
+
lr_warmup,
|
1108 |
+
train_batch_size,
|
1109 |
+
epoch,
|
1110 |
+
save_every_n_epochs,
|
1111 |
+
mixed_precision,
|
1112 |
+
save_precision,
|
1113 |
+
seed,
|
1114 |
+
num_cpu_threads_per_process,
|
1115 |
+
cache_latents,
|
1116 |
+
caption_extension,
|
1117 |
+
enable_bucket,
|
1118 |
+
gradient_checkpointing,
|
1119 |
+
full_fp16,
|
1120 |
+
no_token_padding,
|
1121 |
+
stop_text_encoder_training,
|
1122 |
+
# use_8bit_adam,
|
1123 |
+
xformers,
|
1124 |
+
save_model_as,
|
1125 |
+
shuffle_caption,
|
1126 |
+
save_state,
|
1127 |
+
resume,
|
1128 |
+
prior_loss_weight,
|
1129 |
+
text_encoder_lr,
|
1130 |
+
unet_lr,
|
1131 |
+
network_dim,
|
1132 |
+
lora_network_weights,
|
1133 |
+
color_aug,
|
1134 |
+
flip_aug,
|
1135 |
+
clip_skip,
|
1136 |
+
gradient_accumulation_steps,
|
1137 |
+
mem_eff_attn,
|
1138 |
+
output_name,
|
1139 |
+
model_list,
|
1140 |
+
max_token_length,
|
1141 |
+
max_train_epochs,
|
1142 |
+
max_data_loader_n_workers,
|
1143 |
+
network_alpha,
|
1144 |
+
training_comment,
|
1145 |
+
keep_tokens,
|
1146 |
+
lr_scheduler_num_cycles,
|
1147 |
+
lr_scheduler_power,
|
1148 |
+
persistent_data_loader_workers,
|
1149 |
+
bucket_no_upscale,
|
1150 |
+
random_crop,
|
1151 |
+
bucket_reso_steps,
|
1152 |
+
caption_dropout_every_n_epochs,
|
1153 |
+
caption_dropout_rate,
|
1154 |
+
optimizer,
|
1155 |
+
optimizer_args,
|
1156 |
+
noise_offset,
|
1157 |
+
LoRA_type,
|
1158 |
+
conv_dim,
|
1159 |
+
conv_alpha,
|
1160 |
+
sample_every_n_steps,
|
1161 |
+
sample_every_n_epochs,
|
1162 |
+
sample_sampler,
|
1163 |
+
sample_prompts,
|
1164 |
+
additional_parameters,
|
1165 |
+
vae_batch_size,
|
1166 |
+
min_snr_gamma,
|
1167 |
+
down_lr_weight,mid_lr_weight,up_lr_weight,block_lr_zero_threshold,block_dims,block_alphas,conv_dims,conv_alphas,
|
1168 |
+
weighted_captions, unit,
|
1169 |
+
]
|
1170 |
+
|
1171 |
+
button_open_config.click(
|
1172 |
+
open_configuration,
|
1173 |
+
inputs=[dummy_db_true, config_file_name] + settings_list,
|
1174 |
+
outputs=[config_file_name] + settings_list + [LoCon_row],
|
1175 |
+
show_progress=False,
|
1176 |
+
)
|
1177 |
+
|
1178 |
+
button_load_config.click(
|
1179 |
+
open_configuration,
|
1180 |
+
inputs=[dummy_db_false, config_file_name] + settings_list,
|
1181 |
+
outputs=[config_file_name] + settings_list + [LoCon_row],
|
1182 |
+
show_progress=False,
|
1183 |
+
)
|
1184 |
+
|
1185 |
+
button_save_config.click(
|
1186 |
+
save_configuration,
|
1187 |
+
inputs=[dummy_db_false, config_file_name] + settings_list,
|
1188 |
+
outputs=[config_file_name],
|
1189 |
+
show_progress=False,
|
1190 |
+
)
|
1191 |
+
|
1192 |
+
button_save_as_config.click(
|
1193 |
+
save_configuration,
|
1194 |
+
inputs=[dummy_db_true, config_file_name] + settings_list,
|
1195 |
+
outputs=[config_file_name],
|
1196 |
+
show_progress=False,
|
1197 |
+
)
|
1198 |
+
|
1199 |
+
button_run.click(
|
1200 |
+
train_model,
|
1201 |
+
inputs=[dummy_db_false] + settings_list,
|
1202 |
+
show_progress=False,
|
1203 |
+
)
|
1204 |
+
|
1205 |
+
button_print.click(
|
1206 |
+
train_model,
|
1207 |
+
inputs=[dummy_db_true] + settings_list,
|
1208 |
+
show_progress=False,
|
1209 |
+
)
|
1210 |
+
|
1211 |
+
return (
|
1212 |
+
train_data_dir,
|
1213 |
+
reg_data_dir,
|
1214 |
+
output_dir,
|
1215 |
+
logging_dir,
|
1216 |
+
)
|
1217 |
+
|
1218 |
+
|
1219 |
+
def UI(**kwargs):
|
1220 |
+
css = ''
|
1221 |
+
|
1222 |
+
if os.path.exists('./style.css'):
|
1223 |
+
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
|
1224 |
+
print('Load CSS...')
|
1225 |
+
css += file.read() + '\n'
|
1226 |
+
|
1227 |
+
interface = gr.Blocks(css=css)
|
1228 |
+
|
1229 |
+
with interface:
|
1230 |
+
with gr.Tab('LoRA'):
|
1231 |
+
(
|
1232 |
+
train_data_dir_input,
|
1233 |
+
reg_data_dir_input,
|
1234 |
+
output_dir_input,
|
1235 |
+
logging_dir_input,
|
1236 |
+
) = lora_tab()
|
1237 |
+
with gr.Tab('Utilities'):
|
1238 |
+
utilities_tab(
|
1239 |
+
train_data_dir_input=train_data_dir_input,
|
1240 |
+
reg_data_dir_input=reg_data_dir_input,
|
1241 |
+
output_dir_input=output_dir_input,
|
1242 |
+
logging_dir_input=logging_dir_input,
|
1243 |
+
enable_copy_info_button=True,
|
1244 |
+
)
|
1245 |
+
|
1246 |
+
# Show the interface
|
1247 |
+
launch_kwargs = {}
|
1248 |
+
if not kwargs.get('username', None) == '':
|
1249 |
+
launch_kwargs['auth'] = (
|
1250 |
+
kwargs.get('username', None),
|
1251 |
+
kwargs.get('password', None),
|
1252 |
+
)
|
1253 |
+
if kwargs.get('server_port', 0) > 0:
|
1254 |
+
launch_kwargs['server_port'] = kwargs.get('server_port', 0)
|
1255 |
+
if kwargs.get('inbrowser', False):
|
1256 |
+
launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False)
|
1257 |
+
if kwargs.get('listen', True):
|
1258 |
+
launch_kwargs['server_name'] = '0.0.0.0'
|
1259 |
+
print(launch_kwargs)
|
1260 |
+
interface.launch(**launch_kwargs)
|
1261 |
+
|
1262 |
+
|
1263 |
+
if __name__ == '__main__':
|
1264 |
+
# torch.cuda.set_per_process_memory_fraction(0.48)
|
1265 |
+
parser = argparse.ArgumentParser()
|
1266 |
+
parser.add_argument(
|
1267 |
+
'--username', type=str, default='', help='Username for authentication'
|
1268 |
+
)
|
1269 |
+
parser.add_argument(
|
1270 |
+
'--password', type=str, default='', help='Password for authentication'
|
1271 |
+
)
|
1272 |
+
parser.add_argument(
|
1273 |
+
'--server_port',
|
1274 |
+
type=int,
|
1275 |
+
default=0,
|
1276 |
+
help='Port to run the server listener on',
|
1277 |
+
)
|
1278 |
+
parser.add_argument(
|
1279 |
+
'--inbrowser', action='store_true', help='Open in browser'
|
1280 |
+
)
|
1281 |
+
parser.add_argument(
|
1282 |
+
'--listen',
|
1283 |
+
action='store_true',
|
1284 |
+
help='Launch gradio with server name 0.0.0.0, allowing LAN access',
|
1285 |
+
)
|
1286 |
+
|
1287 |
+
args = parser.parse_args()
|
1288 |
+
|
1289 |
+
UI(
|
1290 |
+
username=args.username,
|
1291 |
+
password=args.password,
|
1292 |
+
inbrowser=args.inbrowser,
|
1293 |
+
server_port=args.server_port,
|
1294 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.15.0
|
2 |
+
albumentations==1.3.0
|
3 |
+
altair==4.2.2
|
4 |
+
bitsandbytes==0.35.0
|
5 |
+
dadaptation==1.5
|
6 |
+
diffusers[torch]==0.10.2
|
7 |
+
easygui==0.98.3
|
8 |
+
einops==0.6.0
|
9 |
+
ftfy==6.1.1
|
10 |
+
gradio==3.27.0; sys_platform != 'darwin'
|
11 |
+
# gradio==3.19.1; sys_platform != 'darwin'
|
12 |
+
gradio==3.23.0; sys_platform == 'darwin'
|
13 |
+
lion-pytorch==0.0.6
|
14 |
+
opencv-python==4.7.0.68
|
15 |
+
pytorch-lightning==1.9.0
|
16 |
+
safetensors==0.2.6
|
17 |
+
tensorboard==2.10.1 ; sys_platform != 'darwin'
|
18 |
+
tensorboard==2.12.1 ; sys_platform == 'darwin'
|
19 |
+
tk==0.1.0
|
20 |
+
toml==0.10.2
|
21 |
+
transformers==4.26.0
|
22 |
+
voluptuous==0.13.1
|
23 |
+
# for BLIP captioning
|
24 |
+
fairscale==0.4.13
|
25 |
+
requests==2.28.2
|
26 |
+
timm==0.6.12
|
27 |
+
# tensorflow<2.11
|
28 |
+
huggingface-hub==0.13.3; sys_platform != 'darwin'
|
29 |
+
huggingface-hub==0.13.0; sys_platform == 'darwin'
|
30 |
+
tensorflow==2.10.1; sys_platform != 'darwin'
|
31 |
+
# For locon support
|
32 |
+
lycoris_lora==0.1.4
|
33 |
+
# for kohya_ss library
|
34 |
+
.
|
setup.bat
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@echo off
|
2 |
+
|
3 |
+
set PYTHON_VER=3.10.9
|
4 |
+
|
5 |
+
REM Check if Python version meets the recommended version
|
6 |
+
python --version 2>nul | findstr /b /c:"Python %PYTHON_VER%" >nul
|
7 |
+
if errorlevel 1 (
|
8 |
+
echo Warning: Python version %PYTHON_VER% is recommended.
|
9 |
+
)
|
10 |
+
|
11 |
+
IF NOT EXIST venv (
|
12 |
+
python -m venv venv
|
13 |
+
) ELSE (
|
14 |
+
echo venv folder already exists, skipping creation...
|
15 |
+
)
|
16 |
+
call .\venv\Scripts\activate.bat
|
17 |
+
|
18 |
+
echo Do you want to uninstall previous versions of torch and associated files before installing? Usefull if you are upgrading from torch 1.12.1 to torch 2.0.0 or if you are downgrading from torch 2.0.0 to torch 1.12.1.
|
19 |
+
echo [1] - Yes
|
20 |
+
echo [2] - No (recommanded for most)
|
21 |
+
set /p uninstall_choice="Enter your choice (1 or 2): "
|
22 |
+
|
23 |
+
if %uninstall_choice%==1 (
|
24 |
+
pip uninstall -y xformers
|
25 |
+
pip uninstall -y torch torchvision
|
26 |
+
)
|
27 |
+
|
28 |
+
echo Please choose the version of torch you want to install:
|
29 |
+
echo [1] - v1 (torch 1.12.1) (Recommended)
|
30 |
+
echo [2] - v2 (torch 2.0.0) (Experimental)
|
31 |
+
set /p choice="Enter your choice (1 or 2): "
|
32 |
+
|
33 |
+
if %choice%==1 (
|
34 |
+
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
|
35 |
+
pip install --use-pep517 --upgrade -r requirements.txt
|
36 |
+
pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
|
37 |
+
) else (
|
38 |
+
pip install torch==2.0.0+cu118 torchvision==0.15.1+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
|
39 |
+
pip install --use-pep517 --upgrade -r requirements.txt
|
40 |
+
pip install --upgrade xformers==0.0.17
|
41 |
+
rem pip install -U -I --no-deps https://files.pythonhosted.org/packages/d6/f7/02662286419a2652c899e2b3d1913c47723fc164b4ac06a85f769c291013/xformers-0.0.17rc482-cp310-cp310-win_amd64.whl
|
42 |
+
)
|
43 |
+
|
44 |
+
copy /y .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
|
45 |
+
copy /y .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
|
46 |
+
copy /y .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
|
47 |
+
|
48 |
+
accelerate config
|
setup.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup, find_packages
|
2 |
+
import subprocess
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
|
6 |
+
# Call the create_user_files.py script
|
7 |
+
script_path = os.path.join("tools", "create_user_files.py")
|
8 |
+
subprocess.run([sys.executable, script_path])
|
9 |
+
|
10 |
+
setup(name="library", version="1.0.3", packages=find_packages())
|
setup.sh
ADDED
@@ -0,0 +1,648 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# This file will be the host environment setup file for all operating systems other than base Windows.
|
4 |
+
|
5 |
+
# Set the required package versions here.
|
6 |
+
# They will be appended to the requirements.txt file in the installation directory.
|
7 |
+
TENSORFLOW_VERSION="2.12.0"
|
8 |
+
TENSORFLOW_MACOS_VERSION="2.12.0"
|
9 |
+
TENSORFLOW_METAL_VERSION="0.8.0"
|
10 |
+
|
11 |
+
display_help() {
|
12 |
+
cat <<EOF
|
13 |
+
Kohya_SS Installation Script for POSIX operating systems.
|
14 |
+
|
15 |
+
Usage:
|
16 |
+
# Specifies custom branch, install directory, and git repo
|
17 |
+
setup.sh -b dev -d /workspace/kohya_ss -g https://mycustom.repo.tld/custom_fork.git
|
18 |
+
|
19 |
+
# Same as example 1, but uses long options
|
20 |
+
setup.sh --branch=dev --dir=/workspace/kohya_ss --git-repo=https://mycustom.repo.tld/custom_fork.git
|
21 |
+
|
22 |
+
# Maximum verbosity, fully automated installation in a runpod environment skipping the runpod env checks
|
23 |
+
setup.sh -vvv --skip-space-check --runpod
|
24 |
+
|
25 |
+
Options:
|
26 |
+
-b BRANCH, --branch=BRANCH Select which branch of kohya to check out on new installs.
|
27 |
+
-d DIR, --dir=DIR The full path you want kohya_ss installed to.
|
28 |
+
-g REPO, --git_repo=REPO You can optionally provide a git repo to check out for runpod installation. Useful for custom forks.
|
29 |
+
-h, --help Show this screen.
|
30 |
+
-i, --interactive Interactively configure accelerate instead of using default config file.
|
31 |
+
-n, --no-git-update Do not update kohya_ss repo. No git pull or clone operations.
|
32 |
+
-p, --public Expose public URL in runpod mode. Won't have an effect in other modes.
|
33 |
+
-r, --runpod Forces a runpod installation. Useful if detection fails for any reason.
|
34 |
+
-s, --skip-space-check Skip the 10Gb minimum storage space check.
|
35 |
+
-u, --no-gui Skips launching the GUI.
|
36 |
+
-v, --verbose Increase verbosity levels up to 3.
|
37 |
+
EOF
|
38 |
+
}
|
39 |
+
|
40 |
+
# Checks to see if variable is set and non-empty.
|
41 |
+
# This is defined first, so we can use the function for some default variable values
|
42 |
+
env_var_exists() {
|
43 |
+
if [[ -n "${!1}" ]]; then
|
44 |
+
return 0
|
45 |
+
else
|
46 |
+
return 1
|
47 |
+
fi
|
48 |
+
}
|
49 |
+
|
50 |
+
# Need RUNPOD to have a default value before first access
|
51 |
+
RUNPOD=false
|
52 |
+
if env_var_exists RUNPOD_POD_ID || env_var_exists RUNPOD_API_KEY; then
|
53 |
+
RUNPOD=true
|
54 |
+
fi
|
55 |
+
|
56 |
+
# This gets the directory the script is run from so pathing can work relative to the script where needed.
|
57 |
+
SCRIPT_DIR="$(cd -- $(dirname -- "$0") && pwd)"
|
58 |
+
|
59 |
+
# Variables defined before the getopts loop, so we have sane default values.
|
60 |
+
# Default installation locations based on OS and environment
|
61 |
+
if [[ "$OSTYPE" == "linux-gnu"* ]]; then
|
62 |
+
if [ "$RUNPOD" = true ]; then
|
63 |
+
DIR="/workspace/kohya_ss"
|
64 |
+
elif [ -d "$SCRIPT_DIR/.git" ]; then
|
65 |
+
DIR="$SCRIPT_DIR"
|
66 |
+
elif [ -w "/opt" ]; then
|
67 |
+
DIR="/opt/kohya_ss"
|
68 |
+
elif env_var_exists HOME; then
|
69 |
+
DIR="$HOME/kohya_ss"
|
70 |
+
else
|
71 |
+
# The last fallback is simply PWD
|
72 |
+
DIR="$(PWD)"
|
73 |
+
fi
|
74 |
+
else
|
75 |
+
if [ -d "$SCRIPT_DIR/.git" ]; then
|
76 |
+
DIR="$SCRIPT_DIR"
|
77 |
+
elif env_var_exists HOME; then
|
78 |
+
DIR="$HOME/kohya_ss"
|
79 |
+
else
|
80 |
+
# The last fallback is simply PWD
|
81 |
+
DIR="$(PWD)"
|
82 |
+
fi
|
83 |
+
fi
|
84 |
+
|
85 |
+
VERBOSITY=2 #Start counting at 2 so that any increase to this will result in a minimum of file descriptor 3. You should leave this alone.
|
86 |
+
MAXVERBOSITY=6 #The highest verbosity we use / allow to be displayed. Feel free to adjust.
|
87 |
+
|
88 |
+
BRANCH="master"
|
89 |
+
GIT_REPO="https://github.com/bmaltais/kohya_ss.git"
|
90 |
+
INTERACTIVE=false
|
91 |
+
PUBLIC=false
|
92 |
+
SKIP_SPACE_CHECK=false
|
93 |
+
SKIP_GIT_UPDATE=false
|
94 |
+
SKIP_GUI=false
|
95 |
+
|
96 |
+
while getopts ":vb:d:g:inprus-:" opt; do
|
97 |
+
# support long options: https://stackoverflow.com/a/28466267/519360
|
98 |
+
if [ "$opt" = "-" ]; then # long option: reformulate OPT and OPTARG
|
99 |
+
opt="${OPTARG%%=*}" # extract long option name
|
100 |
+
OPTARG="${OPTARG#$opt}" # extract long option argument (may be empty)
|
101 |
+
OPTARG="${OPTARG#=}" # if long option argument, remove assigning `=`
|
102 |
+
fi
|
103 |
+
case $opt in
|
104 |
+
b | branch) BRANCH="$OPTARG" ;;
|
105 |
+
d | dir) DIR="$OPTARG" ;;
|
106 |
+
g | git-repo) GIT_REPO="$OPTARG" ;;
|
107 |
+
i | interactive) INTERACTIVE=true ;;
|
108 |
+
n | no-git-update) SKIP_GIT_UPDATE=true ;;
|
109 |
+
p | public) PUBLIC=true ;;
|
110 |
+
r | runpod) RUNPOD=true ;;
|
111 |
+
s | skip-space-check) SKIP_SPACE_CHECK=true ;;
|
112 |
+
u | no-gui) SKIP_GUI=true ;;
|
113 |
+
v) ((VERBOSITY = VERBOSITY + 1)) ;;
|
114 |
+
h) display_help && exit 0 ;;
|
115 |
+
*) display_help && exit 0 ;;
|
116 |
+
esac
|
117 |
+
done
|
118 |
+
shift $((OPTIND - 1))
|
119 |
+
|
120 |
+
# Just in case someone puts in a relative path into $DIR,
|
121 |
+
# we're going to get the absolute path of that.
|
122 |
+
if [[ "$DIR" != /* ]] && [[ "$DIR" != ~* ]]; then
|
123 |
+
DIR="$(
|
124 |
+
cd "$(dirname "$DIR")" || exit 1
|
125 |
+
pwd
|
126 |
+
)/$(basename "$DIR")"
|
127 |
+
fi
|
128 |
+
|
129 |
+
for v in $( #Start counting from 3 since 1 and 2 are standards (stdout/stderr).
|
130 |
+
seq 3 $VERBOSITY
|
131 |
+
); do
|
132 |
+
(("$v" <= "$MAXVERBOSITY")) && eval exec "$v>&2" #Don't change anything higher than the maximum verbosity allowed.
|
133 |
+
done
|
134 |
+
|
135 |
+
for v in $( #From the verbosity level one higher than requested, through the maximum;
|
136 |
+
seq $((VERBOSITY + 1)) $MAXVERBOSITY
|
137 |
+
); do
|
138 |
+
(("$v" > "2")) && eval exec "$v>/dev/null" #Redirect these to bitbucket, provided that they don't match stdout and stderr.
|
139 |
+
done
|
140 |
+
|
141 |
+
# Example of how to use the verbosity levels.
|
142 |
+
# printf "%s\n" "This message is seen at verbosity level 1 and above." >&3
|
143 |
+
# printf "%s\n" "This message is seen at verbosity level 2 and above." >&4
|
144 |
+
# printf "%s\n" "This message is seen at verbosity level 3 and above." >&5
|
145 |
+
|
146 |
+
# Debug variable dump at max verbosity
|
147 |
+
echo "BRANCH: $BRANCH
|
148 |
+
DIR: $DIR
|
149 |
+
GIT_REPO: $GIT_REPO
|
150 |
+
INTERACTIVE: $INTERACTIVE
|
151 |
+
PUBLIC: $PUBLIC
|
152 |
+
RUNPOD: $RUNPOD
|
153 |
+
SKIP_SPACE_CHECK: $SKIP_SPACE_CHECK
|
154 |
+
VERBOSITY: $VERBOSITY
|
155 |
+
Script directory is ${SCRIPT_DIR}." >&5
|
156 |
+
|
157 |
+
# This must be set after the getopts loop to account for $DIR changes.
|
158 |
+
PARENT_DIR="$(dirname "${DIR}")"
|
159 |
+
VENV_DIR="$DIR/venv"
|
160 |
+
|
161 |
+
if [ -w "$PARENT_DIR" ] && [ ! -d "$DIR" ]; then
|
162 |
+
echo "Creating install folder ${DIR}."
|
163 |
+
mkdir "$DIR"
|
164 |
+
fi
|
165 |
+
|
166 |
+
if [ ! -w "$DIR" ]; then
|
167 |
+
echo "We cannot write to ${DIR}."
|
168 |
+
echo "Please ensure the install directory is accurate and you have the correct permissions."
|
169 |
+
exit 1
|
170 |
+
fi
|
171 |
+
|
172 |
+
# Shared functions
|
173 |
+
# This checks for free space on the installation drive and returns that in Gb.
|
174 |
+
size_available() {
|
175 |
+
local folder
|
176 |
+
if [ -d "$DIR" ]; then
|
177 |
+
folder="$DIR"
|
178 |
+
elif [ -d "$PARENT_DIR" ]; then
|
179 |
+
folder="$PARENT_DIR"
|
180 |
+
elif [ -d "$(echo "$DIR" | cut -d "/" -f2)" ]; then
|
181 |
+
folder="$(echo "$DIR" | cut -d "/" -f2)"
|
182 |
+
else
|
183 |
+
echo "We are assuming a root drive install for space-checking purposes."
|
184 |
+
folder='/'
|
185 |
+
fi
|
186 |
+
|
187 |
+
local FREESPACEINKB
|
188 |
+
FREESPACEINKB="$(df -Pk "$folder" | sed 1d | grep -v used | awk '{ print $4 "\t" }')"
|
189 |
+
echo "Detected available space in Kb: $FREESPACEINKB" >&5
|
190 |
+
local FREESPACEINGB
|
191 |
+
FREESPACEINGB=$((FREESPACEINKB / 1024 / 1024))
|
192 |
+
echo "$FREESPACEINGB"
|
193 |
+
}
|
194 |
+
|
195 |
+
# The expected usage is create_symlinks symlink target_file
|
196 |
+
create_symlinks() {
|
197 |
+
echo "Checking symlinks now."
|
198 |
+
# Next line checks for valid symlink
|
199 |
+
if [ -L "$1" ]; then
|
200 |
+
# Check if the linked file exists and points to the expected file
|
201 |
+
if [ -e "$1" ] && [ "$(readlink "$1")" == "$2" ]; then
|
202 |
+
echo "$(basename "$1") symlink looks fine. Skipping."
|
203 |
+
else
|
204 |
+
if [ -f "$2" ]; then
|
205 |
+
echo "Broken symlink detected. Recreating $(basename "$1")."
|
206 |
+
rm "$1" &&
|
207 |
+
ln -s "$2" "$1"
|
208 |
+
else
|
209 |
+
echo "$2 does not exist. Nothing to link."
|
210 |
+
fi
|
211 |
+
fi
|
212 |
+
else
|
213 |
+
echo "Linking $(basename "$1")."
|
214 |
+
ln -s "$2" "$1"
|
215 |
+
fi
|
216 |
+
}
|
217 |
+
|
218 |
+
install_python_dependencies() {
|
219 |
+
# Switch to local virtual env
|
220 |
+
echo "Switching to virtual Python environment."
|
221 |
+
if ! inDocker; then
|
222 |
+
if command -v python3 >/dev/null; then
|
223 |
+
python3 -m venv "$DIR/venv"
|
224 |
+
elif command -v python3.10 >/dev/null; then
|
225 |
+
python3.10 -m venv "$DIR/venv"
|
226 |
+
else
|
227 |
+
echo "Valid python3 or python3.10 binary not found."
|
228 |
+
echo "Cannot proceed with the python steps."
|
229 |
+
return 1
|
230 |
+
fi
|
231 |
+
|
232 |
+
# Activate the virtual environment
|
233 |
+
source "$DIR/venv/bin/activate"
|
234 |
+
fi
|
235 |
+
|
236 |
+
# Updating pip if there is one
|
237 |
+
echo "Checking for pip updates before Python operations."
|
238 |
+
pip install --upgrade pip >&3
|
239 |
+
|
240 |
+
echo "Installing python dependencies. This could take a few minutes as it downloads files."
|
241 |
+
echo "If this operation ever runs too long, you can rerun this script in verbose mode to check."
|
242 |
+
case "$OSTYPE" in
|
243 |
+
"linux-gnu"*) pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 \
|
244 |
+
--extra-index-url https://download.pytorch.org/whl/cu116 >&3 &&
|
245 |
+
pip install -U -I --no-deps \
|
246 |
+
https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/linux/xformers-0.0.14.dev0-cp310-cp310-linux_x86_64.whl >&3 ;;
|
247 |
+
"darwin"*) pip install torch==2.0.0 torchvision==0.15.1 \
|
248 |
+
-f https://download.pytorch.org/whl/cpu/torch_stable.html >&3 ;;
|
249 |
+
"cygwin")
|
250 |
+
:
|
251 |
+
;;
|
252 |
+
"msys")
|
253 |
+
:
|
254 |
+
;;
|
255 |
+
esac
|
256 |
+
|
257 |
+
if [ "$RUNPOD" = true ]; then
|
258 |
+
echo "Installing tenssort."
|
259 |
+
pip install tensorrt >&3
|
260 |
+
fi
|
261 |
+
|
262 |
+
# DEBUG ONLY (Update this version number to whatever PyCharm recommends)
|
263 |
+
# pip install pydevd-pycharm~=223.8836.43
|
264 |
+
|
265 |
+
#This will copy our requirements.txt file out and make the khoya_ss lib a dynamic location then cleanup.
|
266 |
+
local TEMP_REQUIREMENTS_FILE="$DIR/requirements_tmp_for_setup.txt"
|
267 |
+
echo "Copying $DIR/requirements.txt to $TEMP_REQUIREMENTS_FILE" >&3
|
268 |
+
echo "Replacing the . for lib to our DIR variable in $TEMP_REQUIREMENTS_FILE." >&3
|
269 |
+
awk -v dir="$DIR" '/#.*kohya_ss.*library/{print; getline; sub(/^\.$/, dir)}1' "$DIR/requirements.txt" >"$TEMP_REQUIREMENTS_FILE"
|
270 |
+
|
271 |
+
# This will check if macOS is running then determine if M1+ or Intel CPU.
|
272 |
+
# It will append the appropriate packages to the requirements.txt file.
|
273 |
+
# Other OSs won't be affected and the version variables are at the top of this file.
|
274 |
+
if [[ "$(uname)" == "Darwin" ]]; then
|
275 |
+
# Check if the processor is Apple Silicon (arm64)
|
276 |
+
if [[ "$(uname -m)" == "arm64" ]]; then
|
277 |
+
echo "tensorflow-macos==$TENSORFLOW_MACOS_VERSION" >>"$TEMP_REQUIREMENTS_FILE"
|
278 |
+
echo "tensorflow-metal==$TENSORFLOW_METAL_VERSION" >>"$TEMP_REQUIREMENTS_FILE"
|
279 |
+
# Check if the processor is Intel (x86_64)
|
280 |
+
elif [[ "$(uname -m)" == "x86_64" ]]; then
|
281 |
+
echo "tensorflow==$TENSORFLOW_VERSION" >>"$TEMP_REQUIREMENTS_FILE"
|
282 |
+
fi
|
283 |
+
fi
|
284 |
+
|
285 |
+
if [ $VERBOSITY == 2 ]; then
|
286 |
+
python -m pip install --quiet --use-pep517 --upgrade -r "$TEMP_REQUIREMENTS_FILE" >&3
|
287 |
+
else
|
288 |
+
python -m pip install --use-pep517 --upgrade -r "$TEMP_REQUIREMENTS_FILE" >&3
|
289 |
+
fi
|
290 |
+
|
291 |
+
echo "Removing the temp requirements file."
|
292 |
+
if [ -f "$TEMP_REQUIREMENTS_FILE" ]; then
|
293 |
+
rm -f "$TEMP_REQUIREMENTS_FILE"
|
294 |
+
fi
|
295 |
+
|
296 |
+
if [ -n "$VIRTUAL_ENV" ] && ! inDocker; then
|
297 |
+
if command -v deactivate >/dev/null; then
|
298 |
+
echo "Exiting Python virtual environment."
|
299 |
+
deactivate
|
300 |
+
else
|
301 |
+
echo "deactivate command not found. Could still be in the Python virtual environment."
|
302 |
+
fi
|
303 |
+
fi
|
304 |
+
}
|
305 |
+
|
306 |
+
# Attempt to non-interactively install a default accelerate config file unless specified otherwise.
|
307 |
+
# Documentation for order of precedence locations for configuration file for automated installation:
|
308 |
+
# https://huggingface.co/docs/accelerate/basic_tutorials/launch#custom-configurations
|
309 |
+
configure_accelerate() {
|
310 |
+
echo "Source accelerate config location: $DIR/config_files/accelerate/default_config.yaml" >&3
|
311 |
+
if [ "$INTERACTIVE" = true ]; then
|
312 |
+
accelerate config
|
313 |
+
else
|
314 |
+
if env_var_exists HF_HOME; then
|
315 |
+
if [ ! -f "$HF_HOME/accelerate/default_config.yaml" ]; then
|
316 |
+
mkdir -p "$HF_HOME/accelerate/" &&
|
317 |
+
echo "Target accelerate config location: $HF_HOME/accelerate/default_config.yaml" >&3
|
318 |
+
cp "$DIR/config_files/accelerate/default_config.yaml" "$HF_HOME/accelerate/default_config.yaml" &&
|
319 |
+
echo "Copied accelerate config file to: $HF_HOME/accelerate/default_config.yaml"
|
320 |
+
fi
|
321 |
+
elif env_var_exists XDG_CACHE_HOME; then
|
322 |
+
if [ ! -f "$XDG_CACHE_HOME/huggingface/accelerate" ]; then
|
323 |
+
mkdir -p "$XDG_CACHE_HOME/huggingface/accelerate" &&
|
324 |
+
echo "Target accelerate config location: $XDG_CACHE_HOME/accelerate/default_config.yaml" >&3
|
325 |
+
cp "$DIR/config_files/accelerate/default_config.yaml" "$XDG_CACHE_HOME/huggingface/accelerate/default_config.yaml" &&
|
326 |
+
echo "Copied accelerate config file to: $XDG_CACHE_HOME/huggingface/accelerate/default_config.yaml"
|
327 |
+
fi
|
328 |
+
elif env_var_exists HOME; then
|
329 |
+
if [ ! -f "$HOME/.cache/huggingface/accelerate" ]; then
|
330 |
+
mkdir -p "$HOME/.cache/huggingface/accelerate" &&
|
331 |
+
echo "Target accelerate config location: $HOME/accelerate/default_config.yaml" >&3
|
332 |
+
cp "$DIR/config_files/accelerate/default_config.yaml" "$HOME/.cache/huggingface/accelerate/default_config.yaml" &&
|
333 |
+
echo "Copying accelerate config file to: $HOME/.cache/huggingface/accelerate/default_config.yaml"
|
334 |
+
fi
|
335 |
+
else
|
336 |
+
echo "Could not place the accelerate configuration file. Please configure manually."
|
337 |
+
sleep 2
|
338 |
+
accelerate config
|
339 |
+
fi
|
340 |
+
fi
|
341 |
+
}
|
342 |
+
|
343 |
+
# Offer a warning and opportunity to cancel the installation if < 10Gb of Free Space detected
|
344 |
+
check_storage_space() {
|
345 |
+
if [ "$SKIP_SPACE_CHECK" = false ]; then
|
346 |
+
if [ "$(size_available)" -lt 10 ]; then
|
347 |
+
echo "You have less than 10Gb of free space. This installation may fail."
|
348 |
+
MSGTIMEOUT=10 # In seconds
|
349 |
+
MESSAGE="Continuing in..."
|
350 |
+
echo "Press control-c to cancel the installation."
|
351 |
+
for ((i = MSGTIMEOUT; i >= 0; i--)); do
|
352 |
+
printf "\r${MESSAGE} %ss. " "${i}"
|
353 |
+
sleep 1
|
354 |
+
done
|
355 |
+
fi
|
356 |
+
fi
|
357 |
+
}
|
358 |
+
|
359 |
+
isContainerOrPod() {
|
360 |
+
local cgroup=/proc/1/cgroup
|
361 |
+
test -f $cgroup && (grep -qE ':cpuset:/(docker|kubepods)' $cgroup || grep -q ':/docker/' $cgroup)
|
362 |
+
}
|
363 |
+
|
364 |
+
isDockerBuildkit() {
|
365 |
+
local cgroup=/proc/1/cgroup
|
366 |
+
test -f $cgroup && grep -q ':cpuset:/docker/buildkit' $cgroup
|
367 |
+
}
|
368 |
+
|
369 |
+
isDockerContainer() {
|
370 |
+
[ -e /.dockerenv ]
|
371 |
+
}
|
372 |
+
|
373 |
+
inDocker() {
|
374 |
+
if isContainerOrPod || isDockerBuildkit || isDockerContainer; then
|
375 |
+
return 0
|
376 |
+
else
|
377 |
+
return 1
|
378 |
+
fi
|
379 |
+
}
|
380 |
+
|
381 |
+
# These are the git operations that will run to update or clone the repo
|
382 |
+
update_kohya_ss() {
|
383 |
+
if [ "$SKIP_GIT_UPDATE" = false ]; then
|
384 |
+
if command -v git >/dev/null; then
|
385 |
+
# First, we make sure there are no changes that need to be made in git, so no work is lost.
|
386 |
+
if [ "$(git -C "$DIR" status --porcelain=v1 2>/dev/null | wc -l)" -gt 0 ] &&
|
387 |
+
echo "These files need to be committed or discarded: " >&4 &&
|
388 |
+
git -C "$DIR" status >&4; then
|
389 |
+
echo "There are changes that need to be committed or discarded in the repo in $DIR."
|
390 |
+
echo "Commit those changes or run this script with -n to skip git operations entirely."
|
391 |
+
exit 1
|
392 |
+
fi
|
393 |
+
|
394 |
+
echo "Attempting to clone $GIT_REPO."
|
395 |
+
if [ ! -d "$DIR/.git" ]; then
|
396 |
+
echo "Cloning and switching to $GIT_REPO:$BRANCH" >&4
|
397 |
+
git -C "$PARENT_DIR" clone -b "$BRANCH" "$GIT_REPO" "$(basename "$DIR")" >&3
|
398 |
+
git -C "$DIR" switch "$BRANCH" >&4
|
399 |
+
else
|
400 |
+
echo "git repo detected. Attempting to update repository instead."
|
401 |
+
echo "Updating: $GIT_REPO"
|
402 |
+
git -C "$DIR" pull "$GIT_REPO" "$BRANCH" >&3
|
403 |
+
if ! git -C "$DIR" switch "$BRANCH" >&4; then
|
404 |
+
echo "Branch $BRANCH did not exist. Creating it." >&4
|
405 |
+
git -C "$DIR" switch -c "$BRANCH" >&4
|
406 |
+
fi
|
407 |
+
fi
|
408 |
+
else
|
409 |
+
echo "You need to install git."
|
410 |
+
echo "Rerun this after installing git or run this script with -n to skip the git operations."
|
411 |
+
fi
|
412 |
+
else
|
413 |
+
echo "Skipping git operations."
|
414 |
+
fi
|
415 |
+
}
|
416 |
+
|
417 |
+
# Start OS-specific detection and work
|
418 |
+
if [[ "$OSTYPE" == "linux-gnu"* ]]; then
|
419 |
+
# Check if root or sudo
|
420 |
+
root=false
|
421 |
+
if [ "$EUID" = 0 ]; then
|
422 |
+
root=true
|
423 |
+
elif command -v id >/dev/null && [ "$(id -u)" = 0 ]; then
|
424 |
+
root=true
|
425 |
+
elif [ "$UID" = 0 ]; then
|
426 |
+
root=true
|
427 |
+
fi
|
428 |
+
|
429 |
+
get_distro_name() {
|
430 |
+
local line
|
431 |
+
if [ -f /etc/os-release ]; then
|
432 |
+
# We search for the line starting with ID=
|
433 |
+
# Then we remove the ID= prefix to get the name itself
|
434 |
+
line="$(grep -Ei '^ID=' /etc/os-release)"
|
435 |
+
echo "Raw detected os-release distro line: $line" >&5
|
436 |
+
line=${line##*=}
|
437 |
+
echo "$line"
|
438 |
+
return 0
|
439 |
+
elif command -v python >/dev/null; then
|
440 |
+
line="$(python -mplatform)"
|
441 |
+
echo "$line"
|
442 |
+
return 0
|
443 |
+
elif command -v python3 >/dev/null; then
|
444 |
+
line="$(python3 -mplatform)"
|
445 |
+
echo "$line"
|
446 |
+
return 0
|
447 |
+
else
|
448 |
+
line="None"
|
449 |
+
echo "$line"
|
450 |
+
return 1
|
451 |
+
fi
|
452 |
+
}
|
453 |
+
|
454 |
+
# We search for the line starting with ID_LIKE=
|
455 |
+
# Then we remove the ID_LIKE= prefix to get the name itself
|
456 |
+
# This is the "type" of distro. For example, Ubuntu returns "debian".
|
457 |
+
get_distro_family() {
|
458 |
+
local line
|
459 |
+
if [ -f /etc/os-release ]; then
|
460 |
+
if grep -Eiq '^ID_LIKE=' /etc/os-release >/dev/null; then
|
461 |
+
line="$(grep -Ei '^ID_LIKE=' /etc/os-release)"
|
462 |
+
echo "Raw detected os-release distro family line: $line" >&5
|
463 |
+
line=${line##*=}
|
464 |
+
echo "$line"
|
465 |
+
return 0
|
466 |
+
else
|
467 |
+
line="None"
|
468 |
+
echo "$line"
|
469 |
+
return 1
|
470 |
+
fi
|
471 |
+
else
|
472 |
+
line="None"
|
473 |
+
echo "$line"
|
474 |
+
return 1
|
475 |
+
fi
|
476 |
+
}
|
477 |
+
|
478 |
+
check_storage_space
|
479 |
+
update_kohya_ss
|
480 |
+
|
481 |
+
distro=get_distro_name
|
482 |
+
family=get_distro_family
|
483 |
+
echo "Raw detected distro string: $distro" >&4
|
484 |
+
echo "Raw detected distro family string: $family" >&4
|
485 |
+
|
486 |
+
echo "Installing Python TK if not found on the system."
|
487 |
+
if "$distro" | grep -qi "Ubuntu" || "$family" | grep -qi "Ubuntu"; then
|
488 |
+
echo "Ubuntu detected."
|
489 |
+
if [ $(dpkg-query -W -f='${Status}' python3-tk 2>/dev/null | grep -c "ok installed") = 0 ]; then
|
490 |
+
if [ "$root" = true ]; then
|
491 |
+
apt update -y >&3 && apt install -y python3-tk >&3
|
492 |
+
else
|
493 |
+
echo "This script needs to be run as root or via sudo to install packages."
|
494 |
+
exit 1
|
495 |
+
fi
|
496 |
+
else
|
497 |
+
echo "Python TK found! Skipping install!"
|
498 |
+
fi
|
499 |
+
elif "$distro" | grep -Eqi "Fedora|CentOS|Redhat"; then
|
500 |
+
echo "Redhat or Redhat base detected."
|
501 |
+
if ! rpm -qa | grep -qi python3-tkinter; then
|
502 |
+
if [ "$root" = true ]; then
|
503 |
+
dnf install python3-tkinter -y >&3
|
504 |
+
else
|
505 |
+
echo "This script needs to be run as root or via sudo to install packages."
|
506 |
+
exit 1
|
507 |
+
fi
|
508 |
+
fi
|
509 |
+
elif "$distro" | grep -Eqi "arch" || "$family" | grep -qi "arch"; then
|
510 |
+
echo "Arch Linux or Arch base detected."
|
511 |
+
if ! pacman -Qi tk >/dev/null; then
|
512 |
+
if [ "$root" = true ]; then
|
513 |
+
pacman --noconfirm -S tk >&3
|
514 |
+
else
|
515 |
+
echo "This script needs to be run as root or via sudo to install packages."
|
516 |
+
exit 1
|
517 |
+
fi
|
518 |
+
fi
|
519 |
+
elif "$distro" | grep -Eqi "opensuse" || "$family" | grep -qi "opensuse"; then
|
520 |
+
echo "OpenSUSE detected."
|
521 |
+
if ! rpm -qa | grep -qi python-tk; then
|
522 |
+
if [ "$root" = true ]; then
|
523 |
+
zypper install -y python-tk >&3
|
524 |
+
else
|
525 |
+
echo "This script needs to be run as root or via sudo to install packages."
|
526 |
+
exit 1
|
527 |
+
fi
|
528 |
+
fi
|
529 |
+
elif [ "$distro" = "None" ] || [ "$family" = "None" ]; then
|
530 |
+
if [ "$distro" = "None" ]; then
|
531 |
+
echo "We could not detect your distribution of Linux. Please file a bug report on github with the contents of your /etc/os-release file."
|
532 |
+
fi
|
533 |
+
|
534 |
+
if [ "$family" = "None" ]; then
|
535 |
+
echo "We could not detect the family of your Linux distribution. Please file a bug report on github with the contents of your /etc/os-release file."
|
536 |
+
fi
|
537 |
+
fi
|
538 |
+
|
539 |
+
install_python_dependencies
|
540 |
+
|
541 |
+
# We need just a little bit more setup for non-interactive environments
|
542 |
+
if [ "$RUNPOD" = true ]; then
|
543 |
+
if inDocker; then
|
544 |
+
# We get the site-packages from python itself, then cut the string, so no other code changes required.
|
545 |
+
VENV_DIR=$(python -c "import site; print(site.getsitepackages()[0])")
|
546 |
+
VENV_DIR="${VENV_DIR%/lib/python3.10/site-packages}"
|
547 |
+
fi
|
548 |
+
|
549 |
+
# Symlink paths
|
550 |
+
libnvinfer_plugin_symlink="$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer_plugin.so.7"
|
551 |
+
libnvinfer_symlink="$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer.so.7"
|
552 |
+
libcudart_symlink="$VENV_DIR/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/libcudart.so.11.0"
|
553 |
+
|
554 |
+
#Target file paths
|
555 |
+
libnvinfer_plugin_target="$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer_plugin.so.8"
|
556 |
+
libnvinfer_target="$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer.so.8"
|
557 |
+
libcudart_target="$VENV_DIR/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12"
|
558 |
+
|
559 |
+
echo "Checking symlinks now."
|
560 |
+
create_symlinks "$libnvinfer_plugin_symlink" "$libnvinfer_plugin_target"
|
561 |
+
create_symlinks "$libnvinfer_symlink" "$libnvinfer_target"
|
562 |
+
create_symlinks "$libcudart_symlink" "$libcudart_target"
|
563 |
+
|
564 |
+
if [ -d "${VENV_DIR}/lib/python3.10/site-packages/tensorrt/" ]; then
|
565 |
+
export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${VENV_DIR}/lib/python3.10/site-packages/tensorrt/"
|
566 |
+
else
|
567 |
+
echo "${VENV_DIR}/lib/python3.10/site-packages/tensorrt/ not found; not linking library."
|
568 |
+
fi
|
569 |
+
|
570 |
+
if [ -d "${VENV_DIR}/lib/python3.10/site-packages/tensorrt/" ]; then
|
571 |
+
export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${VENV_DIR}/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/"
|
572 |
+
else
|
573 |
+
echo "${VENV_DIR}/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/ not found; not linking library."
|
574 |
+
fi
|
575 |
+
|
576 |
+
configure_accelerate
|
577 |
+
|
578 |
+
# This is a non-interactive environment, so just directly call gui.sh after all setup steps are complete.
|
579 |
+
if [ "$SKIP_GUI" = false ]; then
|
580 |
+
if command -v bash >/dev/null; then
|
581 |
+
if [ "$PUBLIC" = false ]; then
|
582 |
+
bash "$DIR"/gui.sh
|
583 |
+
exit 0
|
584 |
+
else
|
585 |
+
bash "$DIR"/gui.sh --share
|
586 |
+
exit 0
|
587 |
+
fi
|
588 |
+
else
|
589 |
+
# This shouldn't happen, but we're going to try to help.
|
590 |
+
if [ "$PUBLIC" = false ]; then
|
591 |
+
sh "$DIR"/gui.sh
|
592 |
+
exit 0
|
593 |
+
else
|
594 |
+
sh "$DIR"/gui.sh --share
|
595 |
+
exit 0
|
596 |
+
fi
|
597 |
+
fi
|
598 |
+
fi
|
599 |
+
fi
|
600 |
+
|
601 |
+
echo -e "Setup finished! Run \e[0;92m./gui.sh\e[0m to start."
|
602 |
+
echo "Please note if you'd like to expose your public server you need to run ./gui.sh --share"
|
603 |
+
elif [[ "$OSTYPE" == "darwin"* ]]; then
|
604 |
+
# The initial setup script to prep the environment on macOS
|
605 |
+
# xformers has been omitted as that is for Nvidia GPUs only
|
606 |
+
|
607 |
+
if ! command -v brew >/dev/null; then
|
608 |
+
echo "Please install homebrew first. This is a requirement for the remaining setup."
|
609 |
+
echo "You can find that here: https://brew.sh"
|
610 |
+
#shellcheck disable=SC2016
|
611 |
+
echo 'The "brew" command should be in $PATH to be detected.'
|
612 |
+
exit 1
|
613 |
+
fi
|
614 |
+
|
615 |
+
check_storage_space
|
616 |
+
|
617 |
+
# Install base python packages
|
618 |
+
echo "Installing Python 3.10 if not found."
|
619 |
+
if ! brew ls --versions python@3.10 >/dev/null; then
|
620 |
+
echo "Installing Python 3.10."
|
621 |
+
brew install python@3.10 >&3
|
622 |
+
else
|
623 |
+
echo "Python 3.10 found!"
|
624 |
+
fi
|
625 |
+
echo "Installing Python-TK 3.10 if not found."
|
626 |
+
if ! brew ls --versions python-tk@3.10 >/dev/null; then
|
627 |
+
echo "Installing Python TK 3.10."
|
628 |
+
brew install python-tk@3.10 >&3
|
629 |
+
else
|
630 |
+
echo "Python Tkinter 3.10 found!"
|
631 |
+
fi
|
632 |
+
|
633 |
+
update_kohya_ss
|
634 |
+
|
635 |
+
if ! install_python_dependencies; then
|
636 |
+
echo "You may need to install Python. The command for this is brew install python@3.10."
|
637 |
+
fi
|
638 |
+
|
639 |
+
configure_accelerate
|
640 |
+
echo -e "Setup finished! Run ./gui.sh to start."
|
641 |
+
elif [[ "$OSTYPE" == "cygwin" ]]; then
|
642 |
+
# Cygwin is a standalone suite of Linux utilities on Windows
|
643 |
+
echo "This hasn't been validated on cygwin yet."
|
644 |
+
elif [[ "$OSTYPE" == "msys" ]]; then
|
645 |
+
# MinGW has the msys environment which is a standalone suite of Linux utilities on Windows
|
646 |
+
# "git bash" on Windows may also be detected as msys.
|
647 |
+
echo "This hasn't been validated in msys (mingw) on Windows yet."
|
648 |
+
fi
|
style.css
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#open_folder_small{
|
2 |
+
height: auto;
|
3 |
+
min-width: auto;
|
4 |
+
flex-grow: 0;
|
5 |
+
padding-left: 0.25em;
|
6 |
+
padding-right: 0.25em;
|
7 |
+
}
|
8 |
+
|
9 |
+
#open_folder{
|
10 |
+
height: auto;
|
11 |
+
flex-grow: 0;
|
12 |
+
padding-left: 0.25em;
|
13 |
+
padding-right: 0.25em;
|
14 |
+
}
|
15 |
+
|
16 |
+
#number_input{
|
17 |
+
min-width: min-content;
|
18 |
+
flex-grow: 0.3;
|
19 |
+
padding-left: 0.75em;
|
20 |
+
padding-right: 0.75em;
|
21 |
+
}
|
textual_inversion_gui.py
ADDED
@@ -0,0 +1,1014 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# v1: initial release
|
2 |
+
# v2: add open and save folder icons
|
3 |
+
# v3: Add new Utilities tab for Dreambooth folder preparation
|
4 |
+
# v3.1: Adding captionning of images to utilities
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
import json
|
8 |
+
import math
|
9 |
+
import os
|
10 |
+
import subprocess
|
11 |
+
import pathlib
|
12 |
+
import argparse
|
13 |
+
from library.common_gui import (
|
14 |
+
get_folder_path,
|
15 |
+
remove_doublequote,
|
16 |
+
get_file_path,
|
17 |
+
get_any_file_path,
|
18 |
+
get_saveasfile_path,
|
19 |
+
color_aug_changed,
|
20 |
+
save_inference_file,
|
21 |
+
gradio_advanced_training,
|
22 |
+
run_cmd_advanced_training,
|
23 |
+
run_cmd_training,
|
24 |
+
gradio_training,
|
25 |
+
gradio_config,
|
26 |
+
gradio_source_model,
|
27 |
+
# set_legacy_8bitadam,
|
28 |
+
update_my_data,
|
29 |
+
check_if_model_exist,
|
30 |
+
)
|
31 |
+
from library.tensorboard_gui import (
|
32 |
+
gradio_tensorboard,
|
33 |
+
start_tensorboard,
|
34 |
+
stop_tensorboard,
|
35 |
+
)
|
36 |
+
from library.dreambooth_folder_creation_gui import (
|
37 |
+
gradio_dreambooth_folder_creation_tab,
|
38 |
+
)
|
39 |
+
from library.utilities import utilities_tab
|
40 |
+
from library.sampler_gui import sample_gradio_config, run_cmd_sample
|
41 |
+
from easygui import msgbox
|
42 |
+
|
43 |
+
folder_symbol = '\U0001f4c2' # 📂
|
44 |
+
refresh_symbol = '\U0001f504' # 🔄
|
45 |
+
save_style_symbol = '\U0001f4be' # 💾
|
46 |
+
document_symbol = '\U0001F4C4' # 📄
|
47 |
+
|
48 |
+
|
49 |
+
def save_configuration(
|
50 |
+
save_as,
|
51 |
+
file_path,
|
52 |
+
pretrained_model_name_or_path,
|
53 |
+
v2,
|
54 |
+
v_parameterization,
|
55 |
+
logging_dir,
|
56 |
+
train_data_dir,
|
57 |
+
reg_data_dir,
|
58 |
+
output_dir,
|
59 |
+
max_resolution,
|
60 |
+
learning_rate,
|
61 |
+
lr_scheduler,
|
62 |
+
lr_warmup,
|
63 |
+
train_batch_size,
|
64 |
+
epoch,
|
65 |
+
save_every_n_epochs,
|
66 |
+
mixed_precision,
|
67 |
+
save_precision,
|
68 |
+
seed,
|
69 |
+
num_cpu_threads_per_process,
|
70 |
+
cache_latents,
|
71 |
+
caption_extension,
|
72 |
+
enable_bucket,
|
73 |
+
gradient_checkpointing,
|
74 |
+
full_fp16,
|
75 |
+
no_token_padding,
|
76 |
+
stop_text_encoder_training,
|
77 |
+
# use_8bit_adam,
|
78 |
+
xformers,
|
79 |
+
save_model_as,
|
80 |
+
shuffle_caption,
|
81 |
+
save_state,
|
82 |
+
resume,
|
83 |
+
prior_loss_weight,
|
84 |
+
color_aug,
|
85 |
+
flip_aug,
|
86 |
+
clip_skip,
|
87 |
+
vae,
|
88 |
+
output_name,
|
89 |
+
max_token_length,
|
90 |
+
max_train_epochs,
|
91 |
+
max_data_loader_n_workers,
|
92 |
+
mem_eff_attn,
|
93 |
+
gradient_accumulation_steps,
|
94 |
+
model_list,
|
95 |
+
token_string,
|
96 |
+
init_word,
|
97 |
+
num_vectors_per_token,
|
98 |
+
max_train_steps,
|
99 |
+
weights,
|
100 |
+
template,
|
101 |
+
keep_tokens,
|
102 |
+
persistent_data_loader_workers,
|
103 |
+
bucket_no_upscale,
|
104 |
+
random_crop,
|
105 |
+
bucket_reso_steps,
|
106 |
+
caption_dropout_every_n_epochs,
|
107 |
+
caption_dropout_rate,
|
108 |
+
optimizer,
|
109 |
+
optimizer_args,
|
110 |
+
noise_offset,
|
111 |
+
sample_every_n_steps,
|
112 |
+
sample_every_n_epochs,
|
113 |
+
sample_sampler,
|
114 |
+
sample_prompts,
|
115 |
+
additional_parameters,
|
116 |
+
vae_batch_size,
|
117 |
+
min_snr_gamma,
|
118 |
+
):
|
119 |
+
# Get list of function parameters and values
|
120 |
+
parameters = list(locals().items())
|
121 |
+
|
122 |
+
original_file_path = file_path
|
123 |
+
|
124 |
+
save_as_bool = True if save_as.get('label') == 'True' else False
|
125 |
+
|
126 |
+
if save_as_bool:
|
127 |
+
print('Save as...')
|
128 |
+
file_path = get_saveasfile_path(file_path)
|
129 |
+
else:
|
130 |
+
print('Save...')
|
131 |
+
if file_path == None or file_path == '':
|
132 |
+
file_path = get_saveasfile_path(file_path)
|
133 |
+
|
134 |
+
# print(file_path)
|
135 |
+
|
136 |
+
if file_path == None or file_path == '':
|
137 |
+
return original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
138 |
+
|
139 |
+
# Return the values of the variables as a dictionary
|
140 |
+
variables = {
|
141 |
+
name: value
|
142 |
+
for name, value in parameters # locals().items()
|
143 |
+
if name
|
144 |
+
not in [
|
145 |
+
'file_path',
|
146 |
+
'save_as',
|
147 |
+
]
|
148 |
+
}
|
149 |
+
|
150 |
+
# Extract the destination directory from the file path
|
151 |
+
destination_directory = os.path.dirname(file_path)
|
152 |
+
|
153 |
+
# Create the destination directory if it doesn't exist
|
154 |
+
if not os.path.exists(destination_directory):
|
155 |
+
os.makedirs(destination_directory)
|
156 |
+
|
157 |
+
# Save the data to the selected file
|
158 |
+
with open(file_path, 'w') as file:
|
159 |
+
json.dump(variables, file, indent=2)
|
160 |
+
|
161 |
+
return file_path
|
162 |
+
|
163 |
+
|
164 |
+
def open_configuration(
|
165 |
+
ask_for_file,
|
166 |
+
file_path,
|
167 |
+
pretrained_model_name_or_path,
|
168 |
+
v2,
|
169 |
+
v_parameterization,
|
170 |
+
logging_dir,
|
171 |
+
train_data_dir,
|
172 |
+
reg_data_dir,
|
173 |
+
output_dir,
|
174 |
+
max_resolution,
|
175 |
+
learning_rate,
|
176 |
+
lr_scheduler,
|
177 |
+
lr_warmup,
|
178 |
+
train_batch_size,
|
179 |
+
epoch,
|
180 |
+
save_every_n_epochs,
|
181 |
+
mixed_precision,
|
182 |
+
save_precision,
|
183 |
+
seed,
|
184 |
+
num_cpu_threads_per_process,
|
185 |
+
cache_latents,
|
186 |
+
caption_extension,
|
187 |
+
enable_bucket,
|
188 |
+
gradient_checkpointing,
|
189 |
+
full_fp16,
|
190 |
+
no_token_padding,
|
191 |
+
stop_text_encoder_training,
|
192 |
+
# use_8bit_adam,
|
193 |
+
xformers,
|
194 |
+
save_model_as,
|
195 |
+
shuffle_caption,
|
196 |
+
save_state,
|
197 |
+
resume,
|
198 |
+
prior_loss_weight,
|
199 |
+
color_aug,
|
200 |
+
flip_aug,
|
201 |
+
clip_skip,
|
202 |
+
vae,
|
203 |
+
output_name,
|
204 |
+
max_token_length,
|
205 |
+
max_train_epochs,
|
206 |
+
max_data_loader_n_workers,
|
207 |
+
mem_eff_attn,
|
208 |
+
gradient_accumulation_steps,
|
209 |
+
model_list,
|
210 |
+
token_string,
|
211 |
+
init_word,
|
212 |
+
num_vectors_per_token,
|
213 |
+
max_train_steps,
|
214 |
+
weights,
|
215 |
+
template,
|
216 |
+
keep_tokens,
|
217 |
+
persistent_data_loader_workers,
|
218 |
+
bucket_no_upscale,
|
219 |
+
random_crop,
|
220 |
+
bucket_reso_steps,
|
221 |
+
caption_dropout_every_n_epochs,
|
222 |
+
caption_dropout_rate,
|
223 |
+
optimizer,
|
224 |
+
optimizer_args,
|
225 |
+
noise_offset,
|
226 |
+
sample_every_n_steps,
|
227 |
+
sample_every_n_epochs,
|
228 |
+
sample_sampler,
|
229 |
+
sample_prompts,
|
230 |
+
additional_parameters,
|
231 |
+
vae_batch_size,
|
232 |
+
min_snr_gamma,
|
233 |
+
):
|
234 |
+
# Get list of function parameters and values
|
235 |
+
parameters = list(locals().items())
|
236 |
+
|
237 |
+
ask_for_file = True if ask_for_file.get('label') == 'True' else False
|
238 |
+
|
239 |
+
original_file_path = file_path
|
240 |
+
|
241 |
+
if ask_for_file:
|
242 |
+
file_path = get_file_path(file_path)
|
243 |
+
|
244 |
+
if not file_path == '' and not file_path == None:
|
245 |
+
# load variables from JSON file
|
246 |
+
with open(file_path, 'r') as f:
|
247 |
+
my_data = json.load(f)
|
248 |
+
print('Loading config...')
|
249 |
+
# Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True
|
250 |
+
my_data = update_my_data(my_data)
|
251 |
+
else:
|
252 |
+
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
253 |
+
my_data = {}
|
254 |
+
|
255 |
+
values = [file_path]
|
256 |
+
for key, value in parameters:
|
257 |
+
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
|
258 |
+
if not key in ['ask_for_file', 'file_path']:
|
259 |
+
values.append(my_data.get(key, value))
|
260 |
+
return tuple(values)
|
261 |
+
|
262 |
+
|
263 |
+
def train_model(
|
264 |
+
pretrained_model_name_or_path,
|
265 |
+
v2,
|
266 |
+
v_parameterization,
|
267 |
+
logging_dir,
|
268 |
+
train_data_dir,
|
269 |
+
reg_data_dir,
|
270 |
+
output_dir,
|
271 |
+
max_resolution,
|
272 |
+
learning_rate,
|
273 |
+
lr_scheduler,
|
274 |
+
lr_warmup,
|
275 |
+
train_batch_size,
|
276 |
+
epoch,
|
277 |
+
save_every_n_epochs,
|
278 |
+
mixed_precision,
|
279 |
+
save_precision,
|
280 |
+
seed,
|
281 |
+
num_cpu_threads_per_process,
|
282 |
+
cache_latents,
|
283 |
+
caption_extension,
|
284 |
+
enable_bucket,
|
285 |
+
gradient_checkpointing,
|
286 |
+
full_fp16,
|
287 |
+
no_token_padding,
|
288 |
+
stop_text_encoder_training_pct,
|
289 |
+
# use_8bit_adam,
|
290 |
+
xformers,
|
291 |
+
save_model_as,
|
292 |
+
shuffle_caption,
|
293 |
+
save_state,
|
294 |
+
resume,
|
295 |
+
prior_loss_weight,
|
296 |
+
color_aug,
|
297 |
+
flip_aug,
|
298 |
+
clip_skip,
|
299 |
+
vae,
|
300 |
+
output_name,
|
301 |
+
max_token_length,
|
302 |
+
max_train_epochs,
|
303 |
+
max_data_loader_n_workers,
|
304 |
+
mem_eff_attn,
|
305 |
+
gradient_accumulation_steps,
|
306 |
+
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
307 |
+
token_string,
|
308 |
+
init_word,
|
309 |
+
num_vectors_per_token,
|
310 |
+
max_train_steps,
|
311 |
+
weights,
|
312 |
+
template,
|
313 |
+
keep_tokens,
|
314 |
+
persistent_data_loader_workers,
|
315 |
+
bucket_no_upscale,
|
316 |
+
random_crop,
|
317 |
+
bucket_reso_steps,
|
318 |
+
caption_dropout_every_n_epochs,
|
319 |
+
caption_dropout_rate,
|
320 |
+
optimizer,
|
321 |
+
optimizer_args,
|
322 |
+
noise_offset,
|
323 |
+
sample_every_n_steps,
|
324 |
+
sample_every_n_epochs,
|
325 |
+
sample_sampler,
|
326 |
+
sample_prompts,
|
327 |
+
additional_parameters,
|
328 |
+
vae_batch_size,
|
329 |
+
min_snr_gamma,
|
330 |
+
):
|
331 |
+
if pretrained_model_name_or_path == '':
|
332 |
+
msgbox('Source model information is missing')
|
333 |
+
return
|
334 |
+
|
335 |
+
if train_data_dir == '':
|
336 |
+
msgbox('Image folder path is missing')
|
337 |
+
return
|
338 |
+
|
339 |
+
if not os.path.exists(train_data_dir):
|
340 |
+
msgbox('Image folder does not exist')
|
341 |
+
return
|
342 |
+
|
343 |
+
if reg_data_dir != '':
|
344 |
+
if not os.path.exists(reg_data_dir):
|
345 |
+
msgbox('Regularisation folder does not exist')
|
346 |
+
return
|
347 |
+
|
348 |
+
if output_dir == '':
|
349 |
+
msgbox('Output folder path is missing')
|
350 |
+
return
|
351 |
+
|
352 |
+
if token_string == '':
|
353 |
+
msgbox('Token string is missing')
|
354 |
+
return
|
355 |
+
|
356 |
+
if init_word == '':
|
357 |
+
msgbox('Init word is missing')
|
358 |
+
return
|
359 |
+
|
360 |
+
if not os.path.exists(output_dir):
|
361 |
+
os.makedirs(output_dir)
|
362 |
+
|
363 |
+
if check_if_model_exist(output_name, output_dir, save_model_as):
|
364 |
+
return
|
365 |
+
|
366 |
+
if optimizer == 'Adafactor' and lr_warmup != '0':
|
367 |
+
msgbox("Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.", title="Warning")
|
368 |
+
lr_warmup = '0'
|
369 |
+
|
370 |
+
# Get a list of all subfolders in train_data_dir
|
371 |
+
subfolders = [
|
372 |
+
f
|
373 |
+
for f in os.listdir(train_data_dir)
|
374 |
+
if os.path.isdir(os.path.join(train_data_dir, f))
|
375 |
+
]
|
376 |
+
|
377 |
+
total_steps = 0
|
378 |
+
|
379 |
+
# Loop through each subfolder and extract the number of repeats
|
380 |
+
for folder in subfolders:
|
381 |
+
# Extract the number of repeats from the folder name
|
382 |
+
repeats = int(folder.split('_')[0])
|
383 |
+
|
384 |
+
# Count the number of images in the folder
|
385 |
+
num_images = len(
|
386 |
+
[
|
387 |
+
f
|
388 |
+
for f, lower_f in (
|
389 |
+
(file, file.lower())
|
390 |
+
for file in os.listdir(
|
391 |
+
os.path.join(train_data_dir, folder)
|
392 |
+
)
|
393 |
+
)
|
394 |
+
if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp'))
|
395 |
+
]
|
396 |
+
)
|
397 |
+
|
398 |
+
# Calculate the total number of steps for this folder
|
399 |
+
steps = repeats * num_images
|
400 |
+
total_steps += steps
|
401 |
+
|
402 |
+
# Print the result
|
403 |
+
print(f'Folder {folder}: {steps} steps')
|
404 |
+
|
405 |
+
# Print the result
|
406 |
+
# print(f"{total_steps} total steps")
|
407 |
+
|
408 |
+
if reg_data_dir == '':
|
409 |
+
reg_factor = 1
|
410 |
+
else:
|
411 |
+
print(
|
412 |
+
'Regularisation images are used... Will double the number of steps required...'
|
413 |
+
)
|
414 |
+
reg_factor = 2
|
415 |
+
|
416 |
+
# calculate max_train_steps
|
417 |
+
if max_train_steps == '':
|
418 |
+
max_train_steps = int(
|
419 |
+
math.ceil(
|
420 |
+
float(total_steps)
|
421 |
+
/ int(train_batch_size)
|
422 |
+
* int(epoch)
|
423 |
+
* int(reg_factor)
|
424 |
+
)
|
425 |
+
)
|
426 |
+
else:
|
427 |
+
max_train_steps = int(max_train_steps)
|
428 |
+
|
429 |
+
print(f'max_train_steps = {max_train_steps}')
|
430 |
+
|
431 |
+
# calculate stop encoder training
|
432 |
+
if stop_text_encoder_training_pct == None:
|
433 |
+
stop_text_encoder_training = 0
|
434 |
+
else:
|
435 |
+
stop_text_encoder_training = math.ceil(
|
436 |
+
float(max_train_steps) / 100 * int(stop_text_encoder_training_pct)
|
437 |
+
)
|
438 |
+
print(f'stop_text_encoder_training = {stop_text_encoder_training}')
|
439 |
+
|
440 |
+
lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
|
441 |
+
print(f'lr_warmup_steps = {lr_warmup_steps}')
|
442 |
+
|
443 |
+
run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_textual_inversion.py"'
|
444 |
+
if v2:
|
445 |
+
run_cmd += ' --v2'
|
446 |
+
if v_parameterization:
|
447 |
+
run_cmd += ' --v_parameterization'
|
448 |
+
if enable_bucket:
|
449 |
+
run_cmd += ' --enable_bucket'
|
450 |
+
if no_token_padding:
|
451 |
+
run_cmd += ' --no_token_padding'
|
452 |
+
run_cmd += (
|
453 |
+
f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"'
|
454 |
+
)
|
455 |
+
run_cmd += f' --train_data_dir="{train_data_dir}"'
|
456 |
+
if len(reg_data_dir):
|
457 |
+
run_cmd += f' --reg_data_dir="{reg_data_dir}"'
|
458 |
+
run_cmd += f' --resolution={max_resolution}'
|
459 |
+
run_cmd += f' --output_dir="{output_dir}"'
|
460 |
+
run_cmd += f' --logging_dir="{logging_dir}"'
|
461 |
+
if not stop_text_encoder_training == 0:
|
462 |
+
run_cmd += (
|
463 |
+
f' --stop_text_encoder_training={stop_text_encoder_training}'
|
464 |
+
)
|
465 |
+
if not save_model_as == 'same as source model':
|
466 |
+
run_cmd += f' --save_model_as={save_model_as}'
|
467 |
+
# if not resume == '':
|
468 |
+
# run_cmd += f' --resume={resume}'
|
469 |
+
if not float(prior_loss_weight) == 1.0:
|
470 |
+
run_cmd += f' --prior_loss_weight={prior_loss_weight}'
|
471 |
+
if not vae == '':
|
472 |
+
run_cmd += f' --vae="{vae}"'
|
473 |
+
if not output_name == '':
|
474 |
+
run_cmd += f' --output_name="{output_name}"'
|
475 |
+
if int(max_token_length) > 75:
|
476 |
+
run_cmd += f' --max_token_length={max_token_length}'
|
477 |
+
if not max_train_epochs == '':
|
478 |
+
run_cmd += f' --max_train_epochs="{max_train_epochs}"'
|
479 |
+
if not max_data_loader_n_workers == '':
|
480 |
+
run_cmd += (
|
481 |
+
f' --max_data_loader_n_workers="{max_data_loader_n_workers}"'
|
482 |
+
)
|
483 |
+
if int(gradient_accumulation_steps) > 1:
|
484 |
+
run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}'
|
485 |
+
|
486 |
+
run_cmd += run_cmd_training(
|
487 |
+
learning_rate=learning_rate,
|
488 |
+
lr_scheduler=lr_scheduler,
|
489 |
+
lr_warmup_steps=lr_warmup_steps,
|
490 |
+
train_batch_size=train_batch_size,
|
491 |
+
max_train_steps=max_train_steps,
|
492 |
+
save_every_n_epochs=save_every_n_epochs,
|
493 |
+
mixed_precision=mixed_precision,
|
494 |
+
save_precision=save_precision,
|
495 |
+
seed=seed,
|
496 |
+
caption_extension=caption_extension,
|
497 |
+
cache_latents=cache_latents,
|
498 |
+
optimizer=optimizer,
|
499 |
+
optimizer_args=optimizer_args,
|
500 |
+
)
|
501 |
+
|
502 |
+
run_cmd += run_cmd_advanced_training(
|
503 |
+
max_train_epochs=max_train_epochs,
|
504 |
+
max_data_loader_n_workers=max_data_loader_n_workers,
|
505 |
+
max_token_length=max_token_length,
|
506 |
+
resume=resume,
|
507 |
+
save_state=save_state,
|
508 |
+
mem_eff_attn=mem_eff_attn,
|
509 |
+
clip_skip=clip_skip,
|
510 |
+
flip_aug=flip_aug,
|
511 |
+
color_aug=color_aug,
|
512 |
+
shuffle_caption=shuffle_caption,
|
513 |
+
gradient_checkpointing=gradient_checkpointing,
|
514 |
+
full_fp16=full_fp16,
|
515 |
+
xformers=xformers,
|
516 |
+
# use_8bit_adam=use_8bit_adam,
|
517 |
+
keep_tokens=keep_tokens,
|
518 |
+
persistent_data_loader_workers=persistent_data_loader_workers,
|
519 |
+
bucket_no_upscale=bucket_no_upscale,
|
520 |
+
random_crop=random_crop,
|
521 |
+
bucket_reso_steps=bucket_reso_steps,
|
522 |
+
caption_dropout_every_n_epochs=caption_dropout_every_n_epochs,
|
523 |
+
caption_dropout_rate=caption_dropout_rate,
|
524 |
+
noise_offset=noise_offset,
|
525 |
+
additional_parameters=additional_parameters,
|
526 |
+
vae_batch_size=vae_batch_size,
|
527 |
+
min_snr_gamma=min_snr_gamma,
|
528 |
+
)
|
529 |
+
run_cmd += f' --token_string="{token_string}"'
|
530 |
+
run_cmd += f' --init_word="{init_word}"'
|
531 |
+
run_cmd += f' --num_vectors_per_token={num_vectors_per_token}'
|
532 |
+
if not weights == '':
|
533 |
+
run_cmd += f' --weights="{weights}"'
|
534 |
+
if template == 'object template':
|
535 |
+
run_cmd += f' --use_object_template'
|
536 |
+
elif template == 'style template':
|
537 |
+
run_cmd += f' --use_style_template'
|
538 |
+
|
539 |
+
run_cmd += run_cmd_sample(
|
540 |
+
sample_every_n_steps,
|
541 |
+
sample_every_n_epochs,
|
542 |
+
sample_sampler,
|
543 |
+
sample_prompts,
|
544 |
+
output_dir,
|
545 |
+
)
|
546 |
+
|
547 |
+
print(run_cmd)
|
548 |
+
|
549 |
+
# Run the command
|
550 |
+
if os.name == 'posix':
|
551 |
+
os.system(run_cmd)
|
552 |
+
else:
|
553 |
+
subprocess.run(run_cmd)
|
554 |
+
|
555 |
+
# check if output_dir/last is a folder... therefore it is a diffuser model
|
556 |
+
last_dir = pathlib.Path(f'{output_dir}/{output_name}')
|
557 |
+
|
558 |
+
if not last_dir.is_dir():
|
559 |
+
# Copy inference model for v2 if required
|
560 |
+
save_inference_file(output_dir, v2, v_parameterization, output_name)
|
561 |
+
|
562 |
+
|
563 |
+
def ti_tab(
|
564 |
+
train_data_dir=gr.Textbox(),
|
565 |
+
reg_data_dir=gr.Textbox(),
|
566 |
+
output_dir=gr.Textbox(),
|
567 |
+
logging_dir=gr.Textbox(),
|
568 |
+
):
|
569 |
+
dummy_db_true = gr.Label(value=True, visible=False)
|
570 |
+
dummy_db_false = gr.Label(value=False, visible=False)
|
571 |
+
gr.Markdown('Train a TI using kohya textual inversion python code...')
|
572 |
+
(
|
573 |
+
button_open_config,
|
574 |
+
button_save_config,
|
575 |
+
button_save_as_config,
|
576 |
+
config_file_name,
|
577 |
+
button_load_config,
|
578 |
+
) = gradio_config()
|
579 |
+
|
580 |
+
(
|
581 |
+
pretrained_model_name_or_path,
|
582 |
+
v2,
|
583 |
+
v_parameterization,
|
584 |
+
save_model_as,
|
585 |
+
model_list,
|
586 |
+
) = gradio_source_model(
|
587 |
+
save_model_as_choices=[
|
588 |
+
'ckpt',
|
589 |
+
'safetensors',
|
590 |
+
]
|
591 |
+
)
|
592 |
+
|
593 |
+
with gr.Tab('Folders'):
|
594 |
+
with gr.Row():
|
595 |
+
train_data_dir = gr.Textbox(
|
596 |
+
label='Image folder',
|
597 |
+
placeholder='Folder where the training folders containing the images are located',
|
598 |
+
)
|
599 |
+
train_data_dir_input_folder = gr.Button(
|
600 |
+
'📂', elem_id='open_folder_small'
|
601 |
+
)
|
602 |
+
train_data_dir_input_folder.click(
|
603 |
+
get_folder_path,
|
604 |
+
outputs=train_data_dir,
|
605 |
+
show_progress=False,
|
606 |
+
)
|
607 |
+
reg_data_dir = gr.Textbox(
|
608 |
+
label='Regularisation folder',
|
609 |
+
placeholder='(Optional) Folder where where the regularization folders containing the images are located',
|
610 |
+
)
|
611 |
+
reg_data_dir_input_folder = gr.Button(
|
612 |
+
'📂', elem_id='open_folder_small'
|
613 |
+
)
|
614 |
+
reg_data_dir_input_folder.click(
|
615 |
+
get_folder_path,
|
616 |
+
outputs=reg_data_dir,
|
617 |
+
show_progress=False,
|
618 |
+
)
|
619 |
+
with gr.Row():
|
620 |
+
output_dir = gr.Textbox(
|
621 |
+
label='Model output folder',
|
622 |
+
placeholder='Folder to output trained model',
|
623 |
+
)
|
624 |
+
output_dir_input_folder = gr.Button(
|
625 |
+
'📂', elem_id='open_folder_small'
|
626 |
+
)
|
627 |
+
output_dir_input_folder.click(
|
628 |
+
get_folder_path,
|
629 |
+
outputs=output_dir,
|
630 |
+
show_progress=False,
|
631 |
+
)
|
632 |
+
logging_dir = gr.Textbox(
|
633 |
+
label='Logging folder',
|
634 |
+
placeholder='Optional: enable logging and output TensorBoard log to this folder',
|
635 |
+
)
|
636 |
+
logging_dir_input_folder = gr.Button(
|
637 |
+
'📂', elem_id='open_folder_small'
|
638 |
+
)
|
639 |
+
logging_dir_input_folder.click(
|
640 |
+
get_folder_path,
|
641 |
+
outputs=logging_dir,
|
642 |
+
show_progress=False,
|
643 |
+
)
|
644 |
+
with gr.Row():
|
645 |
+
output_name = gr.Textbox(
|
646 |
+
label='Model output name',
|
647 |
+
placeholder='Name of the model to output',
|
648 |
+
value='last',
|
649 |
+
interactive=True,
|
650 |
+
)
|
651 |
+
train_data_dir.change(
|
652 |
+
remove_doublequote,
|
653 |
+
inputs=[train_data_dir],
|
654 |
+
outputs=[train_data_dir],
|
655 |
+
)
|
656 |
+
reg_data_dir.change(
|
657 |
+
remove_doublequote,
|
658 |
+
inputs=[reg_data_dir],
|
659 |
+
outputs=[reg_data_dir],
|
660 |
+
)
|
661 |
+
output_dir.change(
|
662 |
+
remove_doublequote,
|
663 |
+
inputs=[output_dir],
|
664 |
+
outputs=[output_dir],
|
665 |
+
)
|
666 |
+
logging_dir.change(
|
667 |
+
remove_doublequote,
|
668 |
+
inputs=[logging_dir],
|
669 |
+
outputs=[logging_dir],
|
670 |
+
)
|
671 |
+
with gr.Tab('Training parameters'):
|
672 |
+
with gr.Row():
|
673 |
+
weights = gr.Textbox(
|
674 |
+
label='Resume TI training',
|
675 |
+
placeholder='(Optional) Path to existing TI embeding file to keep training',
|
676 |
+
)
|
677 |
+
weights_file_input = gr.Button('📂', elem_id='open_folder_small')
|
678 |
+
weights_file_input.click(
|
679 |
+
get_file_path,
|
680 |
+
outputs=weights,
|
681 |
+
show_progress=False,
|
682 |
+
)
|
683 |
+
with gr.Row():
|
684 |
+
token_string = gr.Textbox(
|
685 |
+
label='Token string',
|
686 |
+
placeholder='eg: cat',
|
687 |
+
)
|
688 |
+
init_word = gr.Textbox(
|
689 |
+
label='Init word',
|
690 |
+
value='*',
|
691 |
+
)
|
692 |
+
num_vectors_per_token = gr.Slider(
|
693 |
+
minimum=1,
|
694 |
+
maximum=75,
|
695 |
+
value=1,
|
696 |
+
step=1,
|
697 |
+
label='Vectors',
|
698 |
+
)
|
699 |
+
max_train_steps = gr.Textbox(
|
700 |
+
label='Max train steps',
|
701 |
+
placeholder='(Optional) Maximum number of steps',
|
702 |
+
)
|
703 |
+
template = gr.Dropdown(
|
704 |
+
label='Template',
|
705 |
+
choices=[
|
706 |
+
'caption',
|
707 |
+
'object template',
|
708 |
+
'style template',
|
709 |
+
],
|
710 |
+
value='caption',
|
711 |
+
)
|
712 |
+
(
|
713 |
+
learning_rate,
|
714 |
+
lr_scheduler,
|
715 |
+
lr_warmup,
|
716 |
+
train_batch_size,
|
717 |
+
epoch,
|
718 |
+
save_every_n_epochs,
|
719 |
+
mixed_precision,
|
720 |
+
save_precision,
|
721 |
+
num_cpu_threads_per_process,
|
722 |
+
seed,
|
723 |
+
caption_extension,
|
724 |
+
cache_latents,
|
725 |
+
optimizer,
|
726 |
+
optimizer_args,
|
727 |
+
) = gradio_training(
|
728 |
+
learning_rate_value='1e-5',
|
729 |
+
lr_scheduler_value='cosine',
|
730 |
+
lr_warmup_value='10',
|
731 |
+
)
|
732 |
+
with gr.Row():
|
733 |
+
max_resolution = gr.Textbox(
|
734 |
+
label='Max resolution',
|
735 |
+
value='512,512',
|
736 |
+
placeholder='512,512',
|
737 |
+
)
|
738 |
+
stop_text_encoder_training = gr.Slider(
|
739 |
+
minimum=0,
|
740 |
+
maximum=100,
|
741 |
+
value=0,
|
742 |
+
step=1,
|
743 |
+
label='Stop text encoder training',
|
744 |
+
)
|
745 |
+
enable_bucket = gr.Checkbox(label='Enable buckets', value=True)
|
746 |
+
with gr.Accordion('Advanced Configuration', open=False):
|
747 |
+
with gr.Row():
|
748 |
+
no_token_padding = gr.Checkbox(
|
749 |
+
label='No token padding', value=False
|
750 |
+
)
|
751 |
+
gradient_accumulation_steps = gr.Number(
|
752 |
+
label='Gradient accumulate steps', value='1'
|
753 |
+
)
|
754 |
+
with gr.Row():
|
755 |
+
prior_loss_weight = gr.Number(
|
756 |
+
label='Prior loss weight', value=1.0
|
757 |
+
)
|
758 |
+
vae = gr.Textbox(
|
759 |
+
label='VAE',
|
760 |
+
placeholder='(Optiona) path to checkpoint of vae to replace for training',
|
761 |
+
)
|
762 |
+
vae_button = gr.Button('📂', elem_id='open_folder_small')
|
763 |
+
vae_button.click(
|
764 |
+
get_any_file_path,
|
765 |
+
outputs=vae,
|
766 |
+
show_progress=False,
|
767 |
+
)
|
768 |
+
(
|
769 |
+
# use_8bit_adam,
|
770 |
+
xformers,
|
771 |
+
full_fp16,
|
772 |
+
gradient_checkpointing,
|
773 |
+
shuffle_caption,
|
774 |
+
color_aug,
|
775 |
+
flip_aug,
|
776 |
+
clip_skip,
|
777 |
+
mem_eff_attn,
|
778 |
+
save_state,
|
779 |
+
resume,
|
780 |
+
max_token_length,
|
781 |
+
max_train_epochs,
|
782 |
+
max_data_loader_n_workers,
|
783 |
+
keep_tokens,
|
784 |
+
persistent_data_loader_workers,
|
785 |
+
bucket_no_upscale,
|
786 |
+
random_crop,
|
787 |
+
bucket_reso_steps,
|
788 |
+
caption_dropout_every_n_epochs,
|
789 |
+
caption_dropout_rate,
|
790 |
+
noise_offset,
|
791 |
+
additional_parameters,
|
792 |
+
vae_batch_size,
|
793 |
+
min_snr_gamma,
|
794 |
+
) = gradio_advanced_training()
|
795 |
+
color_aug.change(
|
796 |
+
color_aug_changed,
|
797 |
+
inputs=[color_aug],
|
798 |
+
outputs=[cache_latents],
|
799 |
+
)
|
800 |
+
|
801 |
+
(
|
802 |
+
sample_every_n_steps,
|
803 |
+
sample_every_n_epochs,
|
804 |
+
sample_sampler,
|
805 |
+
sample_prompts,
|
806 |
+
) = sample_gradio_config()
|
807 |
+
|
808 |
+
with gr.Tab('Tools'):
|
809 |
+
gr.Markdown(
|
810 |
+
'This section provide Dreambooth tools to help setup your dataset...'
|
811 |
+
)
|
812 |
+
gradio_dreambooth_folder_creation_tab(
|
813 |
+
train_data_dir_input=train_data_dir,
|
814 |
+
reg_data_dir_input=reg_data_dir,
|
815 |
+
output_dir_input=output_dir,
|
816 |
+
logging_dir_input=logging_dir,
|
817 |
+
)
|
818 |
+
|
819 |
+
button_run = gr.Button('Train model', variant='primary')
|
820 |
+
|
821 |
+
# Setup gradio tensorboard buttons
|
822 |
+
button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard()
|
823 |
+
|
824 |
+
button_start_tensorboard.click(
|
825 |
+
start_tensorboard,
|
826 |
+
inputs=logging_dir,
|
827 |
+
show_progress=False,
|
828 |
+
)
|
829 |
+
|
830 |
+
button_stop_tensorboard.click(
|
831 |
+
stop_tensorboard,
|
832 |
+
show_progress=False,
|
833 |
+
)
|
834 |
+
|
835 |
+
settings_list = [
|
836 |
+
pretrained_model_name_or_path,
|
837 |
+
v2,
|
838 |
+
v_parameterization,
|
839 |
+
logging_dir,
|
840 |
+
train_data_dir,
|
841 |
+
reg_data_dir,
|
842 |
+
output_dir,
|
843 |
+
max_resolution,
|
844 |
+
learning_rate,
|
845 |
+
lr_scheduler,
|
846 |
+
lr_warmup,
|
847 |
+
train_batch_size,
|
848 |
+
epoch,
|
849 |
+
save_every_n_epochs,
|
850 |
+
mixed_precision,
|
851 |
+
save_precision,
|
852 |
+
seed,
|
853 |
+
num_cpu_threads_per_process,
|
854 |
+
cache_latents,
|
855 |
+
caption_extension,
|
856 |
+
enable_bucket,
|
857 |
+
gradient_checkpointing,
|
858 |
+
full_fp16,
|
859 |
+
no_token_padding,
|
860 |
+
stop_text_encoder_training,
|
861 |
+
# use_8bit_adam,
|
862 |
+
xformers,
|
863 |
+
save_model_as,
|
864 |
+
shuffle_caption,
|
865 |
+
save_state,
|
866 |
+
resume,
|
867 |
+
prior_loss_weight,
|
868 |
+
color_aug,
|
869 |
+
flip_aug,
|
870 |
+
clip_skip,
|
871 |
+
vae,
|
872 |
+
output_name,
|
873 |
+
max_token_length,
|
874 |
+
max_train_epochs,
|
875 |
+
max_data_loader_n_workers,
|
876 |
+
mem_eff_attn,
|
877 |
+
gradient_accumulation_steps,
|
878 |
+
model_list,
|
879 |
+
token_string,
|
880 |
+
init_word,
|
881 |
+
num_vectors_per_token,
|
882 |
+
max_train_steps,
|
883 |
+
weights,
|
884 |
+
template,
|
885 |
+
keep_tokens,
|
886 |
+
persistent_data_loader_workers,
|
887 |
+
bucket_no_upscale,
|
888 |
+
random_crop,
|
889 |
+
bucket_reso_steps,
|
890 |
+
caption_dropout_every_n_epochs,
|
891 |
+
caption_dropout_rate,
|
892 |
+
optimizer,
|
893 |
+
optimizer_args,
|
894 |
+
noise_offset,
|
895 |
+
sample_every_n_steps,
|
896 |
+
sample_every_n_epochs,
|
897 |
+
sample_sampler,
|
898 |
+
sample_prompts,
|
899 |
+
additional_parameters,
|
900 |
+
vae_batch_size,
|
901 |
+
min_snr_gamma,
|
902 |
+
]
|
903 |
+
|
904 |
+
button_open_config.click(
|
905 |
+
open_configuration,
|
906 |
+
inputs=[dummy_db_true, config_file_name] + settings_list,
|
907 |
+
outputs=[config_file_name] + settings_list,
|
908 |
+
show_progress=False,
|
909 |
+
)
|
910 |
+
|
911 |
+
button_load_config.click(
|
912 |
+
open_configuration,
|
913 |
+
inputs=[dummy_db_false, config_file_name] + settings_list,
|
914 |
+
outputs=[config_file_name] + settings_list,
|
915 |
+
show_progress=False,
|
916 |
+
)
|
917 |
+
|
918 |
+
button_save_config.click(
|
919 |
+
save_configuration,
|
920 |
+
inputs=[dummy_db_false, config_file_name] + settings_list,
|
921 |
+
outputs=[config_file_name],
|
922 |
+
show_progress=False,
|
923 |
+
)
|
924 |
+
|
925 |
+
button_save_as_config.click(
|
926 |
+
save_configuration,
|
927 |
+
inputs=[dummy_db_true, config_file_name] + settings_list,
|
928 |
+
outputs=[config_file_name],
|
929 |
+
show_progress=False,
|
930 |
+
)
|
931 |
+
|
932 |
+
button_run.click(
|
933 |
+
train_model,
|
934 |
+
inputs=settings_list,
|
935 |
+
show_progress=False,
|
936 |
+
)
|
937 |
+
|
938 |
+
return (
|
939 |
+
train_data_dir,
|
940 |
+
reg_data_dir,
|
941 |
+
output_dir,
|
942 |
+
logging_dir,
|
943 |
+
)
|
944 |
+
|
945 |
+
|
946 |
+
def UI(**kwargs):
|
947 |
+
css = ''
|
948 |
+
|
949 |
+
if os.path.exists('./style.css'):
|
950 |
+
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
|
951 |
+
print('Load CSS...')
|
952 |
+
css += file.read() + '\n'
|
953 |
+
|
954 |
+
interface = gr.Blocks(css=css)
|
955 |
+
|
956 |
+
with interface:
|
957 |
+
with gr.Tab('Dreambooth TI'):
|
958 |
+
(
|
959 |
+
train_data_dir_input,
|
960 |
+
reg_data_dir_input,
|
961 |
+
output_dir_input,
|
962 |
+
logging_dir_input,
|
963 |
+
) = ti_tab()
|
964 |
+
with gr.Tab('Utilities'):
|
965 |
+
utilities_tab(
|
966 |
+
train_data_dir_input=train_data_dir_input,
|
967 |
+
reg_data_dir_input=reg_data_dir_input,
|
968 |
+
output_dir_input=output_dir_input,
|
969 |
+
logging_dir_input=logging_dir_input,
|
970 |
+
enable_copy_info_button=True,
|
971 |
+
)
|
972 |
+
|
973 |
+
# Show the interface
|
974 |
+
launch_kwargs = {}
|
975 |
+
if not kwargs.get('username', None) == '':
|
976 |
+
launch_kwargs['auth'] = (
|
977 |
+
kwargs.get('username', None),
|
978 |
+
kwargs.get('password', None),
|
979 |
+
)
|
980 |
+
if kwargs.get('server_port', 0) > 0:
|
981 |
+
launch_kwargs['server_port'] = kwargs.get('server_port', 0)
|
982 |
+
if kwargs.get('inbrowser', False):
|
983 |
+
launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False)
|
984 |
+
print(launch_kwargs)
|
985 |
+
interface.launch(**launch_kwargs)
|
986 |
+
|
987 |
+
|
988 |
+
if __name__ == '__main__':
|
989 |
+
# torch.cuda.set_per_process_memory_fraction(0.48)
|
990 |
+
parser = argparse.ArgumentParser()
|
991 |
+
parser.add_argument(
|
992 |
+
'--username', type=str, default='', help='Username for authentication'
|
993 |
+
)
|
994 |
+
parser.add_argument(
|
995 |
+
'--password', type=str, default='', help='Password for authentication'
|
996 |
+
)
|
997 |
+
parser.add_argument(
|
998 |
+
'--server_port',
|
999 |
+
type=int,
|
1000 |
+
default=0,
|
1001 |
+
help='Port to run the server listener on',
|
1002 |
+
)
|
1003 |
+
parser.add_argument(
|
1004 |
+
'--inbrowser', action='store_true', help='Open in browser'
|
1005 |
+
)
|
1006 |
+
|
1007 |
+
args = parser.parse_args()
|
1008 |
+
|
1009 |
+
UI(
|
1010 |
+
username=args.username,
|
1011 |
+
password=args.password,
|
1012 |
+
inbrowser=args.inbrowser,
|
1013 |
+
server_port=args.server_port,
|
1014 |
+
)
|
train_README-ja.md
ADDED
@@ -0,0 +1,945 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__ドキュメント更新中のため記述に誤りがあるかもしれません。__
|
2 |
+
|
3 |
+
# 学習について、共通編
|
4 |
+
|
5 |
+
当リポジトリではモデルのfine tuning、DreamBooth、およびLoRAとTextual Inversion([XTI:P+](https://github.com/kohya-ss/sd-scripts/pull/327)を含む)の学習をサポートします。この文書ではそれらに共通する、学習データの準備方法やオプション等について説明します。
|
6 |
+
|
7 |
+
# 概要
|
8 |
+
|
9 |
+
あらかじめこのリポジトリのREADMEを参照し、環境整備を行ってください。
|
10 |
+
|
11 |
+
|
12 |
+
以下について説明します。
|
13 |
+
|
14 |
+
1. 学習データの準備について(設定ファイルを用いる新形式)
|
15 |
+
1. 学習で使われる用語のごく簡単な解説
|
16 |
+
1. 以前の指定形式(設定ファイルを用いずコマンドラインから指定)
|
17 |
+
1. 学習途中のサンプル画像生成
|
18 |
+
1. 各スクリプトで共通の、よく使われるオプション
|
19 |
+
1. fine tuning 方式のメタデータ準備:キャプションニングなど
|
20 |
+
|
21 |
+
1.だけ実行すればとりあえず学習は可能です(学習については各スクリプトのドキュメントを参照)。2.以降は必要に応じて参照してください。
|
22 |
+
|
23 |
+
|
24 |
+
# 学習データの準備について
|
25 |
+
|
26 |
+
任意のフォルダ(複数でも可)に学習データの画像ファイルを用意しておきます。`.png`, `.jpg`, `.jpeg`, `.webp`, `.bmp` をサポートします。リサイズなどの前処理は基本的に必要ありません。
|
27 |
+
|
28 |
+
ただし学習解像度(後述)よりも極端に小さい画像は使わないか、あらかじめ超解像AIなどで拡大しておくことをお勧めします。また極端に大きな画像(3000x3000ピクセル程度?)よりも大きな画像はエラーになる場合があるようですので事前に縮小してください。
|
29 |
+
|
30 |
+
学習時には、モデルに学ばせる画像データを整理し、スクリプトに対して指定する必要があります。学習データの数、学習対象、キャプション(画像の説明)が用意できるか否かなどにより、いくつかの方法で学習データを指定できます。以下の方式があります(それぞれの名前は一般的なものではなく、当リポジトリ独自の定義です)。正則化画像については後述します。
|
31 |
+
|
32 |
+
1. DreamBooth、class+identifier方式(正則化画像使用可)
|
33 |
+
|
34 |
+
特定の単語 (identifier) に学習対象を紐づけるように学習します。キャプションを用意する必要はありません。たとえば特定のキャラを学ばせる場合に使うとキャプションを用意する必要がない分、手軽ですが、髪型や服装、背景など学習データの全要素が identifier に紐づけられて学習されるため、生成時のプロンプトで服が変えられない、といった事態も起こりえます。
|
35 |
+
|
36 |
+
1. DreamBooth、キャプション方式(正則化画像使用可)
|
37 |
+
|
38 |
+
画像ごとにキャプションが記録されたテキストファイルを用意して学習します。たとえば特定のキャラを学ばせると、画像の詳細をキャプションに記述することで(白い服を着たキャラA、赤い服を着たキャラA、など)キャラとそれ以外の要素が分離され、より厳密にモデルがキャラだけを学ぶことが期待できます。
|
39 |
+
|
40 |
+
1. fine tuning方式(正則化画像使用不可)
|
41 |
+
|
42 |
+
あらかじめキャプションをメタデータファイルにまとめます。タグとキャプションを分けて管理したり、学習を高速化するためlatentsを事前キャッシュしたりなどの機能をサポートします(いずれも別文書で説明しています)。(fine tuning方式という名前ですが fine tuning 以外でも使えます。)
|
43 |
+
|
44 |
+
学習したいものと使用できる指定方法の組み合わせは以下の通りです。
|
45 |
+
|
46 |
+
| 学習対象または方法 | スクリプト | DB / class+identifier | DB / キャプション | fine tuning |
|
47 |
+
| ----- | ----- | ----- | ----- | ----- |
|
48 |
+
| モデルをfine tuning | `fine_tune.py`| x | x | o |
|
49 |
+
| モデルをDreamBooth | `train_db.py`| o | o | x |
|
50 |
+
| LoRA | `train_network.py`| o | o | o |
|
51 |
+
| Textual Invesion | `train_textual_inversion.py`| o | o | o |
|
52 |
+
|
53 |
+
## どれを選ぶか
|
54 |
+
|
55 |
+
LoRA、Textual Inversionについては、手軽にキャプションファイルを用意せずに学習したい場合はDreamBooth class+identifier、用意できるならDreamBooth キャプション方式がよいでしょう。学習データの枚数が多く、かつ正則化画像を使用しない場合はfine tuning方式も検討してください。
|
56 |
+
|
57 |
+
DreamBoothについても同様ですが、fine tuning方式は使えません。fine tuningの場合はfine tuning方式のみです。
|
58 |
+
|
59 |
+
# 各方式の指定方法について
|
60 |
+
|
61 |
+
ここではそれぞれの指定方法で典型的なパターンについてだけ説明し���す。より詳細な指定方法については [データセット設定](./config_README-ja.md) をご覧ください。
|
62 |
+
|
63 |
+
# DreamBooth、class+identifier方式(正則化画像使用可)
|
64 |
+
|
65 |
+
この方式では、各画像は `class identifier` というキャプションで学習されたのと同じことになります(`shs dog` など)。
|
66 |
+
|
67 |
+
## step 1. identifierとclassを決める
|
68 |
+
|
69 |
+
学ばせたい対象を結びつける単語identifierと、対象の属するclassを決めます。
|
70 |
+
|
71 |
+
(instanceなどいろいろな呼び方がありますが、とりあえず元の論文に合わせます。)
|
72 |
+
|
73 |
+
以下ごく簡単に説明します(詳しくは調べてください)。
|
74 |
+
|
75 |
+
classは学習対象の一般的な種別です。たとえば特定の犬種を学ばせる場合には、classはdogになります。アニメキャラならモデルによりboyやgirl、1boyや1girlになるでしょう。
|
76 |
+
|
77 |
+
identifierは学習対象を識別して学習するためのものです。任意の単語で構いませんが、元論文によると「tokinizerで1トークンになる3文字以下でレアな単語」が良いとのことです。
|
78 |
+
|
79 |
+
identifierとclassを使い、たとえば「shs dog」などでモデルを学習することで、学習させたい対象をclassから識別して学習できます。
|
80 |
+
|
81 |
+
画像生成時には「shs dog」とすれば学ばせた犬種の画像が生成されます。
|
82 |
+
|
83 |
+
(identifierとして私が最近使っているものを参考までに挙げると、``shs sts scs cpc coc cic msm usu ici lvl cic dii muk ori hru rik koo yos wny`` などです。本当は Danbooru Tag に含まれないやつがより望ましいです。)
|
84 |
+
|
85 |
+
## step 2. 正則化画像を使うか否かを決め、使う場合には正則化画像を生成する
|
86 |
+
|
87 |
+
正則化画像とは、前述のclass全体が、学習対象に引っ張られることを防ぐための画像です(language drift)。正則化画像を使わないと、たとえば `shs 1girl` で特定のキャラクタを学ばせると、単なる `1girl` というプロンプトで生成してもそのキャラに似てきます。これは `1girl` が学習時のキャプションに含まれているためです。
|
88 |
+
|
89 |
+
学習対象の画像と正則化画像を同時に学ばせることで、class は class のままで留まり、identifier をプロンプトにつけた時だけ学習対象が生成されるようになります。
|
90 |
+
|
91 |
+
LoRAやDreamBoothで特定のキャラだけ出てくればよい場合は、正則化画像を用いなくても良いといえます。
|
92 |
+
|
93 |
+
Textual Inversionでは用いなくてよいでしょう(学ばせる token string がキャプションに含まれない場合はなにも学習されないため)。
|
94 |
+
|
95 |
+
正則化画像としては、学習対象のモデルで、class 名だけで生成した画像を用いるのが一般的です(たとえば `1girl`)。ただし生成画像の品質が悪い場合には、プロンプトを工夫したり、ネットから別途ダウンロードした画像を用いることもできます。
|
96 |
+
|
97 |
+
(正則化画像も学習されるため、その品質はモデルに影響します。)
|
98 |
+
|
99 |
+
一般的には数百枚程度、用意するのが望ましいようです(枚数が少ないと class 画像が一般化されずそれらの特徴を学んでしまいます)。
|
100 |
+
|
101 |
+
生成画像を使う場合、通常、生成画像のサイズは学習解像度(より正確にはbucketの解像度、後述)にあわせてください。
|
102 |
+
|
103 |
+
## step 2. 設定ファイルの記述
|
104 |
+
|
105 |
+
テキストファイルを作成し、拡張子を `.toml` にします。たとえば以下のように記述します。
|
106 |
+
|
107 |
+
(`#` で始まっている部分はコメントですので、このままコピペしてそのままでもよいですし、削除しても問題ありません。)
|
108 |
+
|
109 |
+
```toml
|
110 |
+
[general]
|
111 |
+
enable_bucket = true # Aspect Ratio Bucketingを使うか否か
|
112 |
+
|
113 |
+
[[datasets]]
|
114 |
+
resolution = 512 # 学習解像度
|
115 |
+
batch_size = 4 # バッチサイズ
|
116 |
+
|
117 |
+
[[datasets.subsets]]
|
118 |
+
image_dir = 'C:\hoge' # 学習用画像を入れたフォルダを指定
|
119 |
+
class_tokens = 'hoge girl' # identifier class を指定
|
120 |
+
num_repeats = 10 # 学習用画像の繰り返し回数
|
121 |
+
|
122 |
+
# 以下は正則化画像を用いる場合のみ記述する。用いない場合は削除する
|
123 |
+
[[datasets.subsets]]
|
124 |
+
is_reg = true
|
125 |
+
image_dir = 'C:\reg' # 正則化画像を入れたフォルダを指定
|
126 |
+
class_tokens = 'girl' # class を指定
|
127 |
+
num_repeats = 1 # 正則化画像の繰り返し回数、基本的には1でよい
|
128 |
+
```
|
129 |
+
|
130 |
+
基本的には以下の場所のみ書き換えれば学習できます。
|
131 |
+
|
132 |
+
1. 学習解像度
|
133 |
+
|
134 |
+
数値1つを指定すると正方形(`512`なら512x512)、鍵カッコカンマ区切りで2つ指定すると横×縦(`[512,768]`なら512x768)になります。SD1.x系ではもともとの学習解像度は512です。`[512,768]` 等の大きめの解像度を指定すると縦長、横長画像生成時の破綻を小さくできるかもしれません。SD2.x 768系では `768` です。
|
135 |
+
|
136 |
+
1. バッチサイズ
|
137 |
+
|
138 |
+
同時に何件のデータを学習するかを指定します。GPUのVRAMサイズ、学習解像度によって変わってきます。詳しくは後述します。またfine tuning/DreamBooth/LoRA等でも変わってきますので各スクリプトの説明もご覧ください。
|
139 |
+
|
140 |
+
1. フォルダ指定
|
141 |
+
|
142 |
+
学習用画像、正則化画像(使用する場合のみ)のフォルダを指定します。画像データが含まれているフォルダそのものを指定します。
|
143 |
+
|
144 |
+
1. identifier と class の指定
|
145 |
+
|
146 |
+
前述のサンプルの通りです。
|
147 |
+
|
148 |
+
1. 繰り返し回数
|
149 |
+
|
150 |
+
後述します。
|
151 |
+
|
152 |
+
### 繰り返し回数について
|
153 |
+
|
154 |
+
繰り返し回数は、正則化画像の枚数と学習用画像の枚数を調整するために用いられます。正則化画像の枚数は学習用画像よりも多いため、学習用画像を繰り返して枚数を合わせ、1対1の比率で学習できるようにします。
|
155 |
+
|
156 |
+
繰り返し回数は「 __学習用画像の繰り返し回数×学習用画像の枚数≧正則化画像の繰り返し回数×正則化画像の枚数__ 」となるように指定してください。
|
157 |
+
|
158 |
+
(1 epoch(データが一周すると1 epoch)のデータ数が「学習用画像の繰り返し回数×学習用画像の枚数」となります。正則化画像の枚数がそれより多いと、余った部分の正則化画像は使用されません。)
|
159 |
+
|
160 |
+
## step 3. 学習
|
161 |
+
|
162 |
+
それぞれのドキュメントを参考に学習を行ってください。
|
163 |
+
|
164 |
+
# DreamBooth、キャプション方式(正則化画像使用可)
|
165 |
+
|
166 |
+
この方式では各画像はキャプションで学習されます。
|
167 |
+
|
168 |
+
## step 1. キャプションファイルを準備する
|
169 |
+
|
170 |
+
学習用画像のフォルダに、画像と同じファイル名で、拡張子 `.caption`(設定で変えられます)のファイルを置いてください。それぞれのファイルは1行のみとしてください。エンコーディングは `UTF-8` です。
|
171 |
+
|
172 |
+
## step 2. 正則化画像を使うか否かを決め、使う場合には正則化画像を生成する
|
173 |
+
|
174 |
+
class+identifier形式と同様です。なお正則化画像にもキャプションを付けることができますが、通常は不要でしょう。
|
175 |
+
|
176 |
+
## step 2. 設定ファイルの記述
|
177 |
+
|
178 |
+
テキストファイルを作成し、拡張子を `.toml` にします。たとえば以下のように記述します。
|
179 |
+
|
180 |
+
```toml
|
181 |
+
[general]
|
182 |
+
enable_bucket = true # Aspect Ratio Bucketingを使うか否か
|
183 |
+
|
184 |
+
[[datasets]]
|
185 |
+
resolution = 512 # 学習解像度
|
186 |
+
batch_size = 4 # バッチサイズ
|
187 |
+
|
188 |
+
[[datasets.subsets]]
|
189 |
+
image_dir = 'C:\hoge' # 学習用画像を入れたフォルダを指定
|
190 |
+
caption_extension = '.caption' # キャプションファイルの拡張子 .txt を使う場合には書き換える
|
191 |
+
num_repeats = 10 # 学習用画像の繰り返し回数
|
192 |
+
|
193 |
+
# 以下は正則化画像を用いる場合のみ記述する。用いない場合は削除する
|
194 |
+
[[datasets.subsets]]
|
195 |
+
is_reg = true
|
196 |
+
image_dir = 'C:\reg' # 正則化画像を入れたフォルダを指定
|
197 |
+
class_tokens = 'girl' # class を指定
|
198 |
+
num_repeats = 1 # 正則化画像の繰り返し回数、基本的には1でよい
|
199 |
+
```
|
200 |
+
|
201 |
+
基本的には以下を場所のみ書き換えれば学習できます。特に記述がない部分は class+identifier 方式と同じです。
|
202 |
+
|
203 |
+
1. 学習解像度
|
204 |
+
1. バッチサイズ
|
205 |
+
1. フォルダ指定
|
206 |
+
1. キャプションファイルの拡張子
|
207 |
+
|
208 |
+
任意の拡張子を指定できます。
|
209 |
+
1. 繰り返し回数
|
210 |
+
|
211 |
+
## step 3. 学習
|
212 |
+
|
213 |
+
それぞれのドキュメントを参考に学習を行ってください。
|
214 |
+
|
215 |
+
# fine tuning 方式
|
216 |
+
|
217 |
+
## step 1. メタデータを準備する
|
218 |
+
|
219 |
+
キャプションやタグをまとめた管理用ファイルをメタデータと呼びます。json形式で拡張子は `.json`
|
220 |
+
です。作成方法は長くなりますのでこの文書の末尾に書きました。
|
221 |
+
|
222 |
+
## step 2. 設定ファイルの記述
|
223 |
+
|
224 |
+
テキストファイルを作成し、拡張子を `.toml` にします。たとえば以下のように記述します。
|
225 |
+
|
226 |
+
```toml
|
227 |
+
[general]
|
228 |
+
shuffle_caption = true
|
229 |
+
keep_tokens = 1
|
230 |
+
|
231 |
+
[[datasets]]
|
232 |
+
resolution = 512 # 学習解像度
|
233 |
+
batch_size = 4 # バッチサイズ
|
234 |
+
|
235 |
+
[[datasets.subsets]]
|
236 |
+
image_dir = 'C:\piyo' # 学習用画像を入れたフォルダを指定
|
237 |
+
metadata_file = 'C:\piyo\piyo_md.json' # メタデータファイル名
|
238 |
+
```
|
239 |
+
|
240 |
+
基本的には以下を場所のみ書き換えれば学習できます。特に記述がない部分は DreamBooth, class+identifier 方式と同じです。
|
241 |
+
|
242 |
+
1. 学習解像度
|
243 |
+
1. バッチサイズ
|
244 |
+
1. フォルダ指定
|
245 |
+
1. メタデータファイル名
|
246 |
+
|
247 |
+
後述の方法で作成したメタデータファイルを指定します。
|
248 |
+
|
249 |
+
|
250 |
+
## step 3. 学習
|
251 |
+
|
252 |
+
それぞれのドキュメントを参考に学習を行ってください。
|
253 |
+
|
254 |
+
# 学習で使われる用語のごく簡単な解説
|
255 |
+
|
256 |
+
細かいことは省略していますし私も完全には理解していないため、詳しくは各自お調べください。
|
257 |
+
|
258 |
+
## fine tuning(ファインチューニング)
|
259 |
+
|
260 |
+
モデルを学習して微調整することを指します。使われ方によって意味が異なってきますが、狭義のfine tuningはStable Diffusionの場合、モデルを画像とキャプションで学習することです。DreamBoothは狭義のfine tuningのひとつの特殊なやり方と言えます。広義のfine tuningは、LoRAやTextual Inversion、Hypernetworksなどを含み、モデルを学習することすべてを含みます。
|
261 |
+
|
262 |
+
## ステップ
|
263 |
+
|
264 |
+
ざっくりいうと学習データで1回計算すると1ステップです。「学習データのキャプションを今のモデルに流してみて、出てくる画像を学習データの画像と比較し、学習データに近づくようにモデルをわずかに変更する」のが1ステップです。
|
265 |
+
|
266 |
+
## バッチサイズ
|
267 |
+
|
268 |
+
バッチサイズは1ステップで何件のデータをまとめて計算するかを指定する値です。まとめて計算するため速度は相対的に向上します。また一般的には精度も高くなるといわれています。
|
269 |
+
|
270 |
+
`バッチサイズ×ステップ数` が学習に使われるデータの件数になります。そのため、バッチサイズを増やした分だけステップ数を減らすとよいでしょう。
|
271 |
+
|
272 |
+
(ただし、たとえば「バッチサイズ1で1600ステップ」と「バッチサイズ4で400ステップ」は同じ結果にはなりません。同じ学習率の場合、一般的には後者のほうが学習不足になります。学習率を多少大きくするか(たとえば `2e-6` など)、ステップ数をたとえば500ステップにするなどして工夫してください。)
|
273 |
+
|
274 |
+
バッチサイズを大きくするとその分だけGPUメモリを消費します。メモリが足りなくなるとエラーになりますし、エラーにならないギリギリでは学習速度が低下します。タスクマネージャーや `nvidia-smi` コマンドで使用メモリ量を確認しながら調整するとよいでしょう。
|
275 |
+
|
276 |
+
なお、バッチは「一塊のデータ」位の意味です。
|
277 |
+
|
278 |
+
## 学習率
|
279 |
+
|
280 |
+
ざっくりいうと1ステップごとにどのくらい変化させるかを表します。大きな値を指定するとそれだけ速く学習が進みますが、変化しすぎてモデルが壊れたり、最適な状態にまで至れない場合があります。小さい値を指定すると学習速度は遅くなり、また最適な状態にやはり至れない場合があります。
|
281 |
+
|
282 |
+
fine tuning、DreamBoooth、LoRAそれぞれで大きく異なり、また学習データや学習させたいモデル、バッチサイズやステップ数によっても変わってきます。一般的な値から初めて学習状態を見ながら増減してください。
|
283 |
+
|
284 |
+
デフォルトでは学習全体を通して学習率は固定です。スケジューラの指定で学習率をどう変化させるか決められますので、それらによっても結果は変わってきます。
|
285 |
+
|
286 |
+
## エポック(epoch)
|
287 |
+
|
288 |
+
学習データが一通り学習されると(データが一周すると)1 epochです。繰り返し回数を指定した場合は、その繰り返し後のデータが一周すると1 epochです。
|
289 |
+
|
290 |
+
1 epochのステップ数は、基本的には `データ件数÷バッチサイズ` ですが、Aspect Ratio Bucketing を使うと微妙に増えます(異なるbucketのデータは同じバッチにできないため、ステップ数が増えます)。
|
291 |
+
|
292 |
+
## Aspect Ratio Bucketing
|
293 |
+
|
294 |
+
Stable Diffusion のv1は512\*512で学習されていますが、それに加えて256\*1024や384\*640といった解像度でも学習します。これによりトリミングされる部分が減り、より正しくキャプションと画像の関係が学習されることが期待されます。
|
295 |
+
|
296 |
+
また任意の解像度で学習するため、事前に画像データの縦横比を統一しておく必要がなくなります。
|
297 |
+
|
298 |
+
設定で有効、向こうが切り替えられますが、ここまでの設定ファイルの記述例では有効になっています(`true` が設定されています)。
|
299 |
+
|
300 |
+
学習解像度はパラメータとして与えられた解像度の面積(=メモリ使用量)���超えない範囲で、64ピクセル単位(デフォルト、変更可)で縦横に調整、作成されます。
|
301 |
+
|
302 |
+
機械学習では入力サイズをすべて統一するのが一般的ですが、特に制約があるわけではなく、実際は同一のバッチ内で統一されていれば大丈夫です。NovelAIの言うbucketingは、あらかじめ教師データを、アスペクト比に応じた学習解像度ごとに分類しておくことを指しているようです。そしてバッチを各bucket内の画像で作成することで、バッチの画像サイズを統一します。
|
303 |
+
|
304 |
+
# 以前の指定形式(設定ファイルを用いずコマンドラインから指定)
|
305 |
+
|
306 |
+
`.toml` ファイルを指定せずコマンドラインオプションで指定する方法です。DreamBooth class+identifier方式、DreamBooth キャプション方式、fine tuning方式があります。
|
307 |
+
|
308 |
+
## DreamBooth、class+identifier方式
|
309 |
+
|
310 |
+
フォルダ名で繰り返し回数を指定します。また `train_data_dir` オプションと `reg_data_dir` オプションを用います。
|
311 |
+
|
312 |
+
### step 1. 学習用画像の準備
|
313 |
+
|
314 |
+
学習用画像を格納するフォルダを作成します。 __さらにその中に__ 、以下の名前でディレクトリを作成します。
|
315 |
+
|
316 |
+
```
|
317 |
+
<繰り返し回数>_<identifier> <class>
|
318 |
+
```
|
319 |
+
|
320 |
+
間の``_``を忘れないでください。
|
321 |
+
|
322 |
+
たとえば「sls frog」というプロンプトで、データを20回繰り返す場合、「20_sls frog」となります。以下のようになります。
|
323 |
+
|
324 |
+
![image](https://user-images.githubusercontent.com/52813779/210770636-1c851377-5936-4c15-90b7-8ac8ad6c2074.png)
|
325 |
+
|
326 |
+
### 複数class、複数対象(identifier)の学習
|
327 |
+
|
328 |
+
方法は単純で、学習用画像のフォルダ内に ``繰り返し回数_<identifier> <class>`` のフォルダを複数、正則化画像フォルダにも同様に ``繰り返し回数_<class>`` のフォルダを複数、用意してください。
|
329 |
+
|
330 |
+
たとえば「sls frog」と「cpc rabbit」を同時に学習する場合、以下のようになります。
|
331 |
+
|
332 |
+
![image](https://user-images.githubusercontent.com/52813779/210777933-a22229db-b219-4cd8-83ca-e87320fc4192.png)
|
333 |
+
|
334 |
+
classがひとつで対象が複数の場合、正則化画像フォルダはひとつで構いません。たとえば1girlにキャラAとキャラBがいる場合は次のようにします。
|
335 |
+
|
336 |
+
- train_girls
|
337 |
+
- 10_sls 1girl
|
338 |
+
- 10_cpc 1girl
|
339 |
+
- reg_girls
|
340 |
+
- 1_1girl
|
341 |
+
|
342 |
+
### step 2. 正則化画像の準備
|
343 |
+
|
344 |
+
正則化画像を使う場合の手順です。
|
345 |
+
|
346 |
+
正則化画像を格納するフォルダを作成します。 __さらにその中に__ ``<繰り返し回数>_<class>`` という名前でディレクトリを作成します。
|
347 |
+
|
348 |
+
たとえば「frog」というプロンプトで、データを繰り返さない(1回だけ)場合、以下のようになります。
|
349 |
+
|
350 |
+
![image](https://user-images.githubusercontent.com/52813779/210770897-329758e5-3675-49f1-b345-c135f1725832.png)
|
351 |
+
|
352 |
+
|
353 |
+
### step 3. 学習の実行
|
354 |
+
|
355 |
+
各学習スクリプトを実行します。 `--train_data_dir` オプションで前述の学習用データのフォルダを(__画像を含むフォルダではなく、その親フォルダ__)、`--reg_data_dir` オプションで正則化画像のフォルダ(__画像を含むフォルダではなく、その親フォルダ__)を指定してください。
|
356 |
+
|
357 |
+
## DreamBooth、キャプション方式
|
358 |
+
|
359 |
+
学習用画像、正則化画像のフォルダに、画像と同じファイル名で、拡張子.caption(オプションで変えられます)のファイルを置くと、そのファイルからキャプションを読み込みプロンプトとして学習します。
|
360 |
+
|
361 |
+
※それらの画像の学習に、フォルダ名(identifier class)は使用されなくなります。
|
362 |
+
|
363 |
+
キャプションファイルの拡張子はデフォルトで.captionです。学習スクリプトの `--caption_extension` オプションで変更できます。`--shuffle_caption` オプションで学習時のキャプションについて、カンマ区切りの各部分をシャッフルしながら学習します。
|
364 |
+
|
365 |
+
## fine tuning 方式
|
366 |
+
|
367 |
+
メタデータを作るところまでは設定ファイルを使う場合と同様です。`in_json` オプションでメタデータファイルを指定します。
|
368 |
+
|
369 |
+
# 学習途中でのサンプル出力
|
370 |
+
|
371 |
+
学習中のモデルで試しに画像生成することで学習の進み方を確認できます。学習スクリプトに以下のオプションを指定します。
|
372 |
+
|
373 |
+
- `--sample_every_n_steps` / `--sample_every_n_epochs`
|
374 |
+
|
375 |
+
サンプル出力するステップ数またはエポック数を指定します。この数ごとにサンプル出力します。両方指定するとエポック数が優先されます。
|
376 |
+
|
377 |
+
- `--sample_prompts`
|
378 |
+
|
379 |
+
サンプル出力用プロンプトのファイルを指定します。
|
380 |
+
|
381 |
+
- `--sample_sampler`
|
382 |
+
|
383 |
+
サンプル出力に使うサンプラーを指定します���
|
384 |
+
`'ddim', 'pndm', 'heun', 'dpmsolver', 'dpmsolver++', 'dpmsingle', 'k_lms', 'k_euler', 'k_euler_a', 'k_dpm_2', 'k_dpm_2_a'`が選べます。
|
385 |
+
|
386 |
+
サンプル出力を行うにはあらかじめプロンプトを記述したテキストファイルを用意しておく必要があります。1行につき1プロンプトで記述します。
|
387 |
+
|
388 |
+
たとえば以下のようになります。
|
389 |
+
|
390 |
+
```txt
|
391 |
+
# prompt 1
|
392 |
+
masterpiece, best quality, 1girl, in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
|
393 |
+
|
394 |
+
# prompt 2
|
395 |
+
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
|
396 |
+
```
|
397 |
+
|
398 |
+
先頭が `#` の行はコメントになります。`--n` のように 「`--` + 英小文字」で生成画像へのオプションを指定できます。以下が使えます。
|
399 |
+
|
400 |
+
- `--n` 次のオプションまでをネガティブプロンプトとします。
|
401 |
+
- `--w` 生成画像の横幅を指定します。
|
402 |
+
- `--h` 生成画像の高さを指定します。
|
403 |
+
- `--d` 生成画像のseedを指定します。
|
404 |
+
- `--l` 生成画像のCFG scaleを指定します。
|
405 |
+
- `--s` 生成時のステップ数を指定します。
|
406 |
+
|
407 |
+
|
408 |
+
# 各スクリプトで共通の、よく使われるオプション
|
409 |
+
|
410 |
+
スクリプトの更新後、ドキュメントの更新が追い付いていない場合があります。その場合は `--help` オプションで使用できるオプションを確認してください。
|
411 |
+
|
412 |
+
## 学習に使うモデル指定
|
413 |
+
|
414 |
+
- `--v2` / `--v_parameterization`
|
415 |
+
|
416 |
+
学習対象モデルとしてHugging Faceのstable-diffusion-2-base、またはそこからのfine tuningモデルを使う場合(推論時に `v2-inference.yaml` を使うように指示されているモデルの場合)は `--v2` オプションを、stable-diffusion-2や768-v-ema.ckpt、およびそれらのfine tuningモデルを使う場合(推論時に `v2-inference-v.yaml` を使うモデルの場合)は `--v2` と `--v_parameterization` の両方のオプションを指定してください。
|
417 |
+
|
418 |
+
Stable Diffusion 2.0では大きく以下の点が変わっています。
|
419 |
+
|
420 |
+
1. 使用するTokenizer
|
421 |
+
2. 使用するText Encoderおよび使用する出力層(2.0は最後から二番目の層を使う)
|
422 |
+
3. Text Encoderの出力次元数(768->1024)
|
423 |
+
4. U-Netの構造(CrossAttentionのhead数など)
|
424 |
+
5. v-parameterization(サンプリング方法が変更されているらしい)
|
425 |
+
|
426 |
+
このうちbaseでは1~4が、baseのつかない方(768-v)では1~5が採用されています。1~4を有効にするのがv2オプション、5を有効にするのがv_parameterizationオプションです。
|
427 |
+
|
428 |
+
- `--pretrained_model_name_or_path`
|
429 |
+
|
430 |
+
追加学習を行う元となるモデルを指定します。Stable Diffusionのcheckpointファイル(.ckptまたは.safetensors)、Diffusersのローカルディスクにあるモデルディレクトリ、DiffusersのモデルID("stabilityai/stable-diffusion-2"など)が指定できます。
|
431 |
+
|
432 |
+
## 学習に関する設定
|
433 |
+
|
434 |
+
- `--output_dir`
|
435 |
+
|
436 |
+
学習後のモデルを保存するフォルダを指定します。
|
437 |
+
|
438 |
+
- `--output_name`
|
439 |
+
|
440 |
+
モデルのファイル名を拡張子を除いて指定します。
|
441 |
+
|
442 |
+
- `--dataset_config`
|
443 |
+
|
444 |
+
データセットの設定を記述した `.toml` ファイルを指定します。
|
445 |
+
|
446 |
+
- `--max_train_steps` / `--max_train_epochs`
|
447 |
+
|
448 |
+
学習するステップ数やエポック数を指定します。両方指定するとエポック数のほうが優先されます。
|
449 |
+
|
450 |
+
- `--mixed_precision`
|
451 |
+
|
452 |
+
省メモリ化のため mixed precision (混合精度)で学習します。`--mixed_precision="fp16"` のように指定します。mixed precision なし(デフォルト)と比べて精度が低くなる可能性がありますが、学習に必要なGPUメモリ量が大きく減ります。
|
453 |
+
|
454 |
+
(RTX30 シリーズ以降では `bf16` も指定できます。環境整備時にaccelerateに行った設定と合わせてください)。
|
455 |
+
|
456 |
+
- `--gradient_checkpointing`
|
457 |
+
|
458 |
+
学習時の重みの計算をまとめて行うのではなく少しずつ行うことで、学習に必要なGPUメモリ量を減らします。オンオフは精度には影響しませんが、オンにするとバッチサイズを大きくできるため、そちらでの影響はあります。
|
459 |
+
|
460 |
+
また一般的にはオンにすると速度は低下しますが、バッチサイズを大きくできるので、トータルでの学習時間はむしろ速くなるかもしれません。
|
461 |
+
|
462 |
+
- `--xformers` / `--mem_eff_attn`
|
463 |
+
|
464 |
+
xformersオプションを指定するとxformersのCrossAttentionを用います。xformersをインストールしていない場合やエラーとなる場合(環境にもよ��ますが `mixed_precision="no"` の場合など)、代わりに `mem_eff_attn` オプションを指定すると省メモリ版CrossAttentionを使用します(xformersよりも速度は遅くなります)。
|
465 |
+
|
466 |
+
- `--save_precision`
|
467 |
+
|
468 |
+
保存時のデータ精度を指定します。save_precisionオプションにfloat、fp16、bf16のいずれかを指定すると、その形式でモデルを保存します(DreamBooth、fine tuningでDiffusers形式でモデルを保存する場合は無効です)。モデルのサイズを削減したい場合などにお使いください。
|
469 |
+
|
470 |
+
- `--save_every_n_epochs` / `--save_state` / `--resume`
|
471 |
+
save_every_n_epochsオプションに数値を指定すると、そのエポックごとに学習途中のモデルを保存します。
|
472 |
+
|
473 |
+
save_stateオプションを同時に指定すると、optimizer等の状態も含めた学習状態を合わせて保存します(保存したモデルからも学習再開できますが、それに比べると精度の向上、学習時間の短縮が期待できます)。保存先はフォルダになります。
|
474 |
+
|
475 |
+
学習状態は保存先フォルダに `<output_name>-??????-state`(??????はエポック数)という名前のフォルダで出力されます。長時間にわたる学習時にご利用ください。
|
476 |
+
|
477 |
+
保存された学習状態から学習を再開するにはresumeオプションを使います。学習状態のフォルダ(`output_dir` ではなくその中のstateのフォルダ)を指定してください。
|
478 |
+
|
479 |
+
なおAcceleratorの仕様により、エポック数、global stepは保存されておらず、resumeしたときにも1からになりますがご容赦ください。
|
480 |
+
|
481 |
+
- `--save_model_as` (DreamBooth, fine tuning のみ)
|
482 |
+
|
483 |
+
モデルの保存形式を`ckpt, safetensors, diffusers, diffusers_safetensors` から選べます。
|
484 |
+
|
485 |
+
`--save_model_as=safetensors` のように指定します。Stable Diffusion形式(ckptまたはsafetensors)を読み込み、Diffusers形式で保存する場合、不足する情報はHugging Faceからv1.5またはv2.1の情報を落としてきて補完します。
|
486 |
+
|
487 |
+
- `--clip_skip`
|
488 |
+
|
489 |
+
`2` を指定すると、Text Encoder (CLIP) の後ろから二番目の層の出力を用います。1またはオプション省略時は最後の層を用います。
|
490 |
+
|
491 |
+
※SD2.0はデフォルトで後ろから二番目の層を使うため、SD2.0の学習では指定しないでください。
|
492 |
+
|
493 |
+
学習対象のモデルがもともと二番目の層を使うように学習されている場合は、2を指定するとよいでしょう。
|
494 |
+
|
495 |
+
そうではなく最後の層を使用していた場合はモデル全体がそれを前提に学習されています。そのため改めて二番目の層を使用して学習すると、望ましい学習結果を得るにはある程度の枚数の教師データ、長めの学習が必要になるかもしれません。
|
496 |
+
|
497 |
+
- `--max_token_length`
|
498 |
+
|
499 |
+
デフォルトは75です。`150` または `225` を指定することでトークン長を拡張して学習できます。長いキャプションで学習する場合に指定してください。
|
500 |
+
|
501 |
+
ただし学習時のトークン拡張の仕様は Automatic1111 氏のWeb UIとは微妙に異なるため(分割の仕様など)、必要なければ75で学習することをお勧めします。
|
502 |
+
|
503 |
+
clip_skipと同様に、モデルの学習状態と異なる長さで学習するには、ある程度の教師データ枚数、長めの学習時間が必要になると思われます。
|
504 |
+
|
505 |
+
- `--persistent_data_loader_workers`
|
506 |
+
|
507 |
+
Windows環境で指定するとエポック間の待ち時間が大幅に短縮されます。
|
508 |
+
|
509 |
+
- `--max_data_loader_n_workers`
|
510 |
+
|
511 |
+
データ読み込みのプロセス数を指定します。プロセス数が多いとデータ読み込みが速くなりGPUを効率的に利用できますが、メインメモリを消費します。デフォルトは「`8` または `CPU同時実行スレッド数-1` の小さいほう」なので、メインメモリに余裕がない場合や、GPU使用率が90%程度以上なら、それらの数値を見ながら `2` または `1` 程度まで下げてください。
|
512 |
+
|
513 |
+
- `--logging_dir` / `--log_prefix`
|
514 |
+
|
515 |
+
学習ログの保存に関するオプションです。logging_dirオプションにログ保存先フォルダを指定してください。TensorBoard形式のログが保存されます。
|
516 |
+
|
517 |
+
たとえば--logging_dir=logsと指定すると、作業フォルダにlogsフォルダが作成され、その中の日時フォルダにログが保存されます。
|
518 |
+
また--log_prefixオプションを指定すると、日時の前に指定した文字列が追加されます。「--logging_dir=logs --log_prefix=db_style1_」などとして識別用にお使いください。
|
519 |
+
|
520 |
+
TensorBoardでログを確認するには、別のコマンドプロンプトを開き、作業フォルダで以下のように入力します。
|
521 |
+
|
522 |
+
```
|
523 |
+
tensorboard --logdir=logs
|
524 |
+
```
|
525 |
+
|
526 |
+
(tensorboardは環境整備時にあわせてインストールされると思いますが、もし入っていないなら `pip install tensorboard` で入れてください。)
|
527 |
+
|
528 |
+
その後ブラウザを開き、http://localhost:6006/ へアクセスすると表示されます。
|
529 |
+
|
530 |
+
- `--noise_offset`
|
531 |
+
|
532 |
+
こちらの記事の実装になります: https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
533 |
+
|
534 |
+
全体的に暗い、明るい画像の生成結果が良くなる可能性があるようです。LoRA学習でも有効なようです。`0.1` 程度の値を指定するとよいようです。
|
535 |
+
|
536 |
+
- `--debug_dataset`
|
537 |
+
|
538 |
+
このオプションを付けることで学習を行う前に事前にどのような画像データ、キャプションで学習されるかを確認できます。Escキーを押すと終了してコマンドラインに戻ります。`S`キーで次のステップ(バッチ)、`E`キーで次のエポックに進みます。
|
539 |
+
|
540 |
+
※Linux環境(Colabを含む)では画像は表示されません。
|
541 |
+
|
542 |
+
- `--vae`
|
543 |
+
|
544 |
+
vaeオプションにStable Diffusionのcheckpoint、VAEのcheckpointファイル、DiffusesのモデルまたはVAE(ともにローカルまたはHugging FaceのモデルIDが指定できます)のいずれかを指定すると、そのVAEを使って学習します(latentsのキャッシュ時または学習中のlatents取得時)。
|
545 |
+
|
546 |
+
DreamBoothおよびfine tuningでは、保存されるモデルはこのVAEを組み込んだものになります。
|
547 |
+
|
548 |
+
- `--cache_latents`
|
549 |
+
|
550 |
+
使用VRAMを減らすためVAEの出力をメインメモリにキャッシュします。`flip_aug` 以外のaugmentationは使えなくなります。また全体の学習速度が若干速くなります。
|
551 |
+
|
552 |
+
- `--min_snr_gamma`
|
553 |
+
|
554 |
+
Min-SNR Weighting strategyを指定します。詳細は[こちら](https://github.com/kohya-ss/sd-scripts/pull/308)を参照してください。論文では`5`が推奨されています。
|
555 |
+
|
556 |
+
## オプティマイザ関係
|
557 |
+
|
558 |
+
- `--optimizer_type`
|
559 |
+
--オプティマイザの種類を指定します。以下が指定できます。
|
560 |
+
- AdamW : [torch.optim.AdamW](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html)
|
561 |
+
- 過去のバージョンのオプション未指定時と同じ
|
562 |
+
- AdamW8bit : 引数は同上
|
563 |
+
- 過去のバージョンの--use_8bit_adam指定時と同じ
|
564 |
+
- Lion : https://github.com/lucidrains/lion-pytorch
|
565 |
+
- 過去のバージョンの--use_lion_optimizer指定時と同じ
|
566 |
+
- SGDNesterov : [torch.optim.SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html), nesterov=True
|
567 |
+
- SGDNesterov8bit : 引数は同上
|
568 |
+
- DAdaptation : https://github.com/facebookresearch/dadaptation
|
569 |
+
- AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules)
|
570 |
+
- 任意のオプティマイザ
|
571 |
+
|
572 |
+
- `--learning_rate`
|
573 |
+
|
574 |
+
学習率を指定します。適切な学習率は学習スクリプトにより異なりますので、それぞれの説明を参照してください。
|
575 |
+
|
576 |
+
- `--lr_scheduler` / `--lr_warmup_steps` / `--lr_scheduler_num_cycles` / `--lr_scheduler_power`
|
577 |
+
|
578 |
+
学習率のスケジューラ関連の指定です。
|
579 |
+
|
580 |
+
lr_schedulerオプションで学習率のスケジューラをlinear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup, 任意のスケジューラから選べます。デフォルトはconstantです。
|
581 |
+
|
582 |
+
lr_warmup_stepsでスケジューラのウォームアップ(だんだん学習率を変えていく)ステップ数を指定できます。
|
583 |
+
|
584 |
+
lr_scheduler_num_cycles は cosine with restartsスケジューラでのリスタート回数、lr_scheduler_power は polynomialスケジューラでのpolynomial power です。
|
585 |
+
|
586 |
+
詳細については各自お調べください。
|
587 |
+
|
588 |
+
任意のスケジューラを使う場合、任意のオプティマイザと同様に、`--scheduler_args`でオプション引数を指定してください。
|
589 |
+
|
590 |
+
### オプティマイザの指定について
|
591 |
+
|
592 |
+
オプティマイザのオプション引数は--optimizer_argsオプションで指定してください。key=valueの形式で、複数の値が指定できます。また、valueはカンマ区切りで複数の値が指定できます。たとえばAdamWオプティマイザに引数を指定する場合は、``--optimizer_args weight_decay=0.01 betas=.9,.999``のようになります。
|
593 |
+
|
594 |
+
オプション引数を指定する場合は、それぞれのオプティマイザの仕様をご確認ください。
|
595 |
+
|
596 |
+
一部のオプティマイザでは必須の引数があり、省略すると自動的に追加されます(SGDNesterovのmomentumなど)。コンソールの出力を確認してください。
|
597 |
+
|
598 |
+
D-Adaptationオプティマイザは学習率を自動調整します。学習率のオプションに指定した値は学習率そのものではなくD-Adaptationが決定した学習率の適用率になりますので、通常は1.0を指定してください。Text EncoderにU-Netの半分の学習率を指定したい場合は、``--text_encoder_lr=0.5 --unet_lr=1.0``と指定します。
|
599 |
+
|
600 |
+
AdaFactorオプティマイザはrelative_step=Trueを指定すると学習率を自動調整できます(省略時はデフォルトで追加されます)。自動調整する場合は学習率のスケジューラにはadafactor_schedulerが強制的に使用されます。またscale_parameterとwarmup_initを指定するとよいようです。
|
601 |
+
|
602 |
+
自動調整する場合のオプション指定はたとえば ``--optimizer_args "relative_step=True" "scale_parameter=True" "warmup_init=True"`` のようになります。
|
603 |
+
|
604 |
+
学習率を自動調整しない場合はオプション引数 ``relative_step=False`` を追加してください。その場合、学習率のスケジューラにはconstant_with_warmupが、また勾配のclip normをしないことが推奨されているようです。そのため引数は ``--optimizer_type=adafactor --optimizer_args "relative_step=False" --lr_scheduler="constant_with_warmup" --max_grad_norm=0.0`` のようになります。
|
605 |
+
|
606 |
+
### 任意のオプティマイザを使う
|
607 |
+
|
608 |
+
``torch.optim`` のオプティマイザを使う場合にはクラス名のみを(``--optimizer_type=RMSprop``など)、他のモジュールのオプティマイザを使う時は「モジュール名.クラス名」を指定してください(``--optimizer_type=bitsandbytes.optim.lamb.LAMB``など)。
|
609 |
+
|
610 |
+
(内部でimportlibしているだけで動作は未確認です。必要ならパッケージをインストールしてください。)
|
611 |
+
|
612 |
+
|
613 |
+
<!--
|
614 |
+
## 任意サイズの画像での学習 --resolution
|
615 |
+
正方形以外で学習できます。resolutionに「448,640」のように「幅,高さ」で指定してください。幅と高さは64で割り切れる必要があります。学習用画像、正則化画像のサイズを合わせてください。
|
616 |
+
|
617 |
+
個人的には縦長の画像を生成することが多いため「448,640」などで学習することもあります。
|
618 |
+
|
619 |
+
## Aspect Ratio Bucketing --enable_bucket / --min_bucket_reso / --max_bucket_reso
|
620 |
+
enable_bucketオプションを指定すると有効になります。Stable Diffusionは512x512で学習されていますが、それに加えて256x768や384x640といった解像度でも学習します。
|
621 |
+
|
622 |
+
このオプションを指定した場合は、学習用画像、正則化画像を特定の解像度に統一する必要はありません。いくつかの解像度(アスペクト比)から最適なものを選び、その解像度で学習します。
|
623 |
+
解像度は64ピクセル単位のため、元画像とアスペクト比が完全に一致しない場合がありますが、その場合は、はみ出した部分がわずかにトリミングされます。
|
624 |
+
|
625 |
+
解像度の最小サイズをmin_bucket_resoオプションで、最大サイズをmax_bucket_resoで指定できます。デフォルトはそれぞれ256、1024です。
|
626 |
+
たとえば最小サイズに384を指定すると、256x1024や320x768などの解像度は使わなくなります。
|
627 |
+
解像度を768x768のように大きくした場合、最大サイズに1280などを指定しても良いかもしれません。
|
628 |
+
|
629 |
+
なおAspect Ratio Bucketingを有効にするときには、正則化画像についても、学習用画像と似た傾向の様々な解像度を用意した方がいいかもしれません。
|
630 |
+
|
631 |
+
(ひとつのバッチ内の画像が学習用画像、正則化画像に偏らなくなるため。そこまで大きな影響はないと思いますが……。)
|
632 |
+
|
633 |
+
## augmentation --color_aug / --flip_aug
|
634 |
+
augmentationは学習時に動的にデータを変化させることで、モデルの性能を上げる手法です。color_augで色合いを微妙に変えつつ、flip_augで左右反転をしつつ、学習します。
|
635 |
+
|
636 |
+
動的にデータを変化させるため、cache_latentsオプションと同時に指定できません。
|
637 |
+
|
638 |
+
|
639 |
+
## 勾配をfp16とした学習(実験的機能) --full_fp16
|
640 |
+
full_fp16オプションを指定すると勾配を通常のfloat32からfloat16(fp16)に変更して学習します(mixed precisionではなく完全なfp16学習になるようです)。
|
641 |
+
これによりSD1.xの512x512サイズでは8GB未満、SD2.xの512x512サイズで12GB未満のVRAM使用量で学習できるようです。
|
642 |
+
|
643 |
+
あらかじめaccelerate configでfp16を指定し、オプションで ``mixed_precision="fp16"`` としてください(bf16では動作しません)。
|
644 |
+
|
645 |
+
メモリ使用量を最小化するためには、xformers、use_8bit_adam、cache_latents、gradient_checkpointingの各オプションを指定し、train_batch_sizeを1としてください。
|
646 |
+
|
647 |
+
(余裕があるようならtrain_batch_sizeを段階的に増やすと若干精度が上がるはずです。)
|
648 |
+
|
649 |
+
PyTorchのソースにパッチを当てて無理やり実現しています(PyTorch 1.12.1と1.13.0で確認)。精度はかなり落ちますし、途中で学習失敗する確率も高くなります。
|
650 |
+
学習率やステップ数の設定もシビアなようです。それらを認識したうえで自己責任でお使いください。
|
651 |
+
|
652 |
+
-->
|
653 |
+
|
654 |
+
# メタデータファイルの作成
|
655 |
+
|
656 |
+
## 教師データの用意
|
657 |
+
|
658 |
+
前述のように学習させたい画像データを用意し、任意のフォルダに入れてください。
|
659 |
+
|
660 |
+
たとえば以下のように画像を格納します。
|
661 |
+
|
662 |
+
![教師データフォルダのスクショ](https://user-images.githubusercontent.com/52813779/208907739-8e89d5fa-6ca8-4b60-8927-f484d2a9ae04.png)
|
663 |
+
|
664 |
+
## 自動キャプショニング
|
665 |
+
|
666 |
+
キャプションを使わずタグだけで学習する場合はスキップしてください。
|
667 |
+
|
668 |
+
また手動でキャプションを用意する場合、キャプションは教師データ画像と同じディレクトリに、同じファイル名、拡張子.caption等で用意してください。各ファイルは1行のみのテキストファイルとします。
|
669 |
+
|
670 |
+
### BLIPによるキャプショニング
|
671 |
+
|
672 |
+
最新版ではBLIPのダウンロード、重みのダウンロード、仮想環境の追加は不要になりました。そのままで動作します。
|
673 |
+
|
674 |
+
finetuneフォルダ内のmake_captions.pyを実行します。
|
675 |
+
|
676 |
+
```
|
677 |
+
python finetune\make_captions.py --batch_size <バッチサイズ> <教師データフォルダ>
|
678 |
+
```
|
679 |
+
|
680 |
+
バッチサイズ8、教師データを親フォルダのtrain_dataに置いた場合、以下のようになります。
|
681 |
+
|
682 |
+
```
|
683 |
+
python finetune\make_captions.py --batch_size 8 ..\train_data
|
684 |
+
```
|
685 |
+
|
686 |
+
キャプションファイルが教師データ画像と同じディレクトリに、同じファイル名、拡張子.captionで作成されます。
|
687 |
+
|
688 |
+
batch_sizeはGPUのVRAM容量に応じて増減してください。大きいほうが速くなります(VRAM 12GBでももう少し増やせると思います)。
|
689 |
+
max_lengthオプションでキャプションの最大長を指定できます。デフォルトは75です。モデルをトークン長225で学習する場合には長くしても良いかもしれません。
|
690 |
+
caption_extensionオプションでキャプションの拡張子を変更できます。デフォルトは.captionです(.txtにすると後述のDeepDanbooruと競合します)。
|
691 |
+
|
692 |
+
複数の教師データフォルダがある場合には、それぞれのフォルダに対して実行してください。
|
693 |
+
|
694 |
+
なお、推論にランダム性があるため、実行するたびに結果が変わります。固定する場合には--seedオプションで `--seed 42` のように乱数seedを指定してください。
|
695 |
+
|
696 |
+
その他のオプションは `--help` でヘルプをご参照ください(パラメータの意味についてはドキュメントがまとまっていないようで、ソースを見るしかないようです)。
|
697 |
+
|
698 |
+
デフォルトでは拡張子.captionでキャプションファイルが生成されます。
|
699 |
+
|
700 |
+
![captionが生成されたフォルダ](https://user-images.githubusercontent.com/52813779/208908845-48a9d36c-f6ee-4dae-af71-9ab462d1459e.png)
|
701 |
+
|
702 |
+
たとえば以下のようなキャプションが付きます。
|
703 |
+
|
704 |
+
![キャプションと画像](https://user-images.githubusercontent.com/52813779/208908947-af936957-5d73-4339-b6c8-945a52857373.png)
|
705 |
+
|
706 |
+
## DeepDanbooruによるタグ付け
|
707 |
+
|
708 |
+
danbooruタグのタグ付け自体を行わない場合は「キャプションとタグ情報の前処理」に進んでください。
|
709 |
+
|
710 |
+
タグ付けはDeepDanbooruまたはWD14Taggerで行います。WD14Taggerのほうが精度が良いようです。WD14Taggerでタグ付けする場合は、次の章へ進んでください。
|
711 |
+
|
712 |
+
### 環境整備
|
713 |
+
|
714 |
+
DeepDanbooru https://github.com/KichangKim/DeepDanbooru を作業フォルダにcloneしてくるか、zipをダウンロードして展開します。私はzipで展開しました。
|
715 |
+
またDeepDanbooruのReleasesのページ https://github.com/KichangKim/DeepDanbooru/releases の「DeepDanbooru Pretrained Model v3-20211112-sgd-e28」のAssetsから、deepdanbooru-v3-20211112-sgd-e28.zipをダウンロードしてきてDeepDanbooruのフォルダに展開します。
|
716 |
+
|
717 |
+
以下からダウンロードします。Assetsをクリックして開き、そこからダウンロードします。
|
718 |
+
|
719 |
+
![DeepDanbooruダウンロードページ](https://user-images.githubusercontent.com/52813779/208909417-10e597df-7085-41ee-bd06-3e856a1339df.png)
|
720 |
+
|
721 |
+
以下のようなこういうディレクトリ構造にしてください
|
722 |
+
|
723 |
+
![DeepDanbooruのディレクトリ構造](https://user-images.githubusercontent.com/52813779/208909486-38935d8b-8dc6-43f1-84d3-fef99bc471aa.png)
|
724 |
+
|
725 |
+
Diffusersの環境に必要なライブラリをインストールします。DeepDanbooruのフォルダに移動してインストールします(実質的にはtensorflow-ioが追加されるだけだと思います)。
|
726 |
+
|
727 |
+
```
|
728 |
+
pip install -r requirements.txt
|
729 |
+
```
|
730 |
+
|
731 |
+
続いてDeepDanbooru自体をインストールします。
|
732 |
+
|
733 |
+
```
|
734 |
+
pip install .
|
735 |
+
```
|
736 |
+
|
737 |
+
以上でタグ付けの環境整備は完了です。
|
738 |
+
|
739 |
+
### タグ付けの実施
|
740 |
+
DeepDanbooruのフォルダに移動し、deepdanbooruを実行してタグ付けを行います。
|
741 |
+
|
742 |
+
```
|
743 |
+
deepdanbooru evaluate <教師データフォルダ> --project-path deepdanbooru-v3-20211112-sgd-e28 --allow-folder --save-txt
|
744 |
+
```
|
745 |
+
|
746 |
+
教師データを親フォルダのtrain_dataに置いた場合、以下のようになります。
|
747 |
+
|
748 |
+
```
|
749 |
+
deepdanbooru evaluate ../train_data --project-path deepdanbooru-v3-20211112-sgd-e28 --allow-folder --save-txt
|
750 |
+
```
|
751 |
+
|
752 |
+
タグファイルが教師データ画像と同じディレクトリに、同じファイル名、拡張子.txtで作成されます。1件ずつ処理されるためわりと遅いです。
|
753 |
+
|
754 |
+
複数の教師データフォルダがある場合には、それぞれのフォルダに対して実行してください。
|
755 |
+
|
756 |
+
以下のように生成されます。
|
757 |
+
|
758 |
+
![DeepDanbooruの生成ファイル](https://user-images.githubusercontent.com/52813779/208909855-d21b9c98-f2d3-4283-8238-5b0e5aad6691.png)
|
759 |
+
|
760 |
+
こんな感じにタグが付きます(すごい情報量……)。
|
761 |
+
|
762 |
+
![DeepDanbooruタグと画像](https://user-images.githubusercontent.com/52813779/208909908-a7920174-266e-48d5-aaef-940aba709519.png)
|
763 |
+
|
764 |
+
## WD14Taggerによるタグ付け
|
765 |
+
|
766 |
+
DeepDanbooruの代わりにWD14Taggerを用いる手順です。
|
767 |
+
|
768 |
+
Automatic1111氏のWebUIで使用しているtaggerを利用します。こちらのgithubページ(https://github.com/toriato/stable-diffusion-webui-wd14-tagger#mrsmilingwolfs-model-aka-waifu-diffusion-14-tagger )の情報を参考にさせていただきました。
|
769 |
+
|
770 |
+
最初の環境整備で必要なモジュールはインストール済みです。また重みはHugging Faceから自動的にダウンロードしてきます。
|
771 |
+
|
772 |
+
### タグ付けの実施
|
773 |
+
|
774 |
+
スクリプトを実行してタグ付けを行います。
|
775 |
+
```
|
776 |
+
python tag_images_by_wd14_tagger.py --batch_size <バッチサイズ> <教師データフォルダ>
|
777 |
+
```
|
778 |
+
|
779 |
+
教師データを親フォルダのtrain_dataに置いた場合、以下のようになります。
|
780 |
+
```
|
781 |
+
python tag_images_by_wd14_tagger.py --batch_size 4 ..\train_data
|
782 |
+
```
|
783 |
+
|
784 |
+
初回起動時にはモデルファイルがwd14_tagger_modelフォルダに自動的にダウンロードされます(フォルダはオプションで変えられます)。以下のようになります。
|
785 |
+
|
786 |
+
![ダウンロードされたファイル](https://user-images.githubusercontent.com/52813779/208910447-f7eb0582-90d6-49d3-a666-2b508c7d1842.png)
|
787 |
+
|
788 |
+
タグファイルが教師データ画像と同じディレクトリに、同じファイル名、拡張子.txtで作成されます。
|
789 |
+
|
790 |
+
![生成されたタグファイル](https://user-images.githubusercontent.com/52813779/208910534-ea514373-1185-4b7d-9ae3-61eb50bc294e.png)
|
791 |
+
|
792 |
+
![タグと画像](https://user-images.githubusercontent.com/52813779/208910599-29070c15-7639-474f-b3e4-06bd5a3df29e.png)
|
793 |
+
|
794 |
+
threshオプションで、判定されたタグのconfidence(確信度)がいくつ以上でタグをつけるかが指定できます。デフォルトはWD14Taggerのサンプルと同じ0.35です。値を下げるとより多くのタグが付与されますが、精度は下がります。
|
795 |
+
|
796 |
+
batch_sizeはGPUのVRAM容量に応じて増減してください。大きいほうが速くなります(VRAM 12GBでももう少し増やせると思います)。caption_extensionオプションでタグファイルの拡張子を変更できます。デフォルトは.txtです。
|
797 |
+
|
798 |
+
model_dirオプションでモデルの保存先フォルダを指定できます。
|
799 |
+
|
800 |
+
またforce_downloadオプションを指定すると保存先フォルダがあってもモデルを再ダウンロードします。
|
801 |
+
|
802 |
+
複数の教師データフォルダがある場合には、それぞれのフォルダに対して実行してください。
|
803 |
+
|
804 |
+
## キャプションとタグ情報の前処理
|
805 |
+
|
806 |
+
スクリプトから処理しやすいようにキャプションとタグをメタデータとしてひとつのファイルにまとめます。
|
807 |
+
|
808 |
+
### キャプションの前処理
|
809 |
+
|
810 |
+
キャプションをメタデータに入れるには、作業フォルダ内で以下を実行してください(キャプションを学習に使わない場合は実行不要です)(実際は1行で記述します、以下同様)。`--full_path` オプションを指定してメタデータに画像ファイルの場所をフルパスで格納します。このオプションを省略すると相対パスで記録されますが、フォルダ指定が `.toml` ファイル内で別途必要になります。
|
811 |
+
|
812 |
+
```
|
813 |
+
python merge_captions_to_metadata.py --full_path <教師データフォルダ>
|
814 |
+
--in_json <読み込むメタデータファイル名> <メタデータファイル名>
|
815 |
+
```
|
816 |
+
|
817 |
+
メタデータファイル名は任意の名前です。
|
818 |
+
教師データがtrain_data、読み込むメタデータファイルなし、メタデータファイルがmeta_cap.jsonの場合、以下のようになります。
|
819 |
+
|
820 |
+
```
|
821 |
+
python merge_captions_to_metadata.py --full_path train_data meta_cap.json
|
822 |
+
```
|
823 |
+
|
824 |
+
caption_extensionオプションでキャプションの拡張子を指定できます。
|
825 |
+
|
826 |
+
複数の教師データフォルダがある場合には、full_path引数を指定しつつ、それぞれのフォルダに対して実行してください。
|
827 |
+
|
828 |
+
```
|
829 |
+
python merge_captions_to_metadata.py --full_path
|
830 |
+
train_data1 meta_cap1.json
|
831 |
+
python merge_captions_to_metadata.py --full_path --in_json meta_cap1.json
|
832 |
+
train_data2 meta_cap2.json
|
833 |
+
```
|
834 |
+
|
835 |
+
in_jsonを省略すると書き込み先メタデータファイルがあるとそこから読み込み、そこに上書きします。
|
836 |
+
|
837 |
+
__※in_jsonオプションと書き込み先を都度書き換えて、別のメタデータファイルへ書き出すようにすると安全です。__
|
838 |
+
|
839 |
+
### タグの前処理
|
840 |
+
|
841 |
+
同様にタグもメタデータにまとめます(タグを学習に使わない場合は実行不要です)。
|
842 |
+
```
|
843 |
+
python merge_dd_tags_to_metadata.py --full_path <教師データフォルダ>
|
844 |
+
--in_json <読み込むメタデータファイル名> <書き込むメタデータファイル名>
|
845 |
+
```
|
846 |
+
|
847 |
+
先と同じディレクトリ構成で、meta_cap.jsonを読み、meta_cap_dd.jsonに書きだす場合、以下となります。
|
848 |
+
```
|
849 |
+
python merge_dd_tags_to_metadata.py --full_path train_data --in_json meta_cap.json meta_cap_dd.json
|
850 |
+
```
|
851 |
+
|
852 |
+
複数の教師データフォルダがある場合には、full_path引数を指定しつつ、それぞれのフォルダに対して実行してください。
|
853 |
+
|
854 |
+
```
|
855 |
+
python merge_dd_tags_to_metadata.py --full_path --in_json meta_cap2.json
|
856 |
+
train_data1 meta_cap_dd1.json
|
857 |
+
python merge_dd_tags_to_metadata.py --full_path --in_json meta_cap_dd1.json
|
858 |
+
train_data2 meta_cap_dd2.json
|
859 |
+
```
|
860 |
+
|
861 |
+
in_jsonを省略すると書き込み先メタデータファイルがあるとそこから読み込み、そこに上書きします。
|
862 |
+
|
863 |
+
__※in_jsonオプションと書き込み先を都度書き換えて、別のメタデータファイルへ書き出すようにすると安全です。__
|
864 |
+
|
865 |
+
### キャプションとタグのクリーニング
|
866 |
+
|
867 |
+
ここまででメタデータファイルにキャプションとDeepDanbooruのタグがまとめられています。ただ自動キャプショニングにしたキャプションは表記ゆれなどがあり微妙(※)ですし、タグにはアンダースコアが含まれていたりratingが付いていたりしますので(DeepDanbooruの場合)、エディタの置換機能などを用いてキャプションとタグのクリーニングをしたほうがいいでしょう。
|
868 |
+
|
869 |
+
※たとえばアニメ絵の少女を学習する場合、キャプションにはgirl/girls/woman/womenなどのばらつきがあります。また「anime girl」なども単に「girl」としたほうが適切かもしれません。
|
870 |
+
|
871 |
+
クリーニング用のスクリプトが用意してありますので、スクリプトの内容を状況に応じて編集してお使いください。
|
872 |
+
|
873 |
+
(教師データフォルダの指定は不要になりました。メタデータ内の全データをクリーニングします。)
|
874 |
+
|
875 |
+
```
|
876 |
+
python clean_captions_and_tags.py <読み込むメタデータファイル名> <書き込むメタデータファイル名>
|
877 |
+
```
|
878 |
+
|
879 |
+
--in_jsonは付きませんのでご注意ください。たとえば次のようになります。
|
880 |
+
|
881 |
+
```
|
882 |
+
python clean_captions_and_tags.py meta_cap_dd.json meta_clean.json
|
883 |
+
```
|
884 |
+
|
885 |
+
以上でキャプションとタグの前処理は完了です。
|
886 |
+
|
887 |
+
## latentsの事前取得
|
888 |
+
|
889 |
+
※ このステップは必須ではありません。省略しても学習時にlatentsを取得しながら学習できます。
|
890 |
+
また学習時に `random_crop` や `color_aug` などを行う場合にはlatentsの事前取得はできません(画像を毎回変えながら学習するため)。事前取得をしない場合、ここまでのメタデータで学習できます。
|
891 |
+
|
892 |
+
あらかじめ画像の潜在表現を取得しディスクに保存しておきます。それにより、学習を高速に進めることができます。あわせてbucketing(教師データをアスペクト比に応じて分類する)を行います。
|
893 |
+
|
894 |
+
作業フォルダで以下のように入力してください。
|
895 |
+
```
|
896 |
+
python prepare_buckets_latents.py --full_path <教師データフォルダ>
|
897 |
+
<読み込むメタデータファイル名> <書き込むメタデータファイル名>
|
898 |
+
<fine tuningするモデル名またはcheckpoint>
|
899 |
+
--batch_size <バッチサイズ>
|
900 |
+
--max_resolution <解像度 幅,高さ>
|
901 |
+
--mixed_precision <精度>
|
902 |
+
```
|
903 |
+
|
904 |
+
モデルがmodel.ckpt、バッチサイズ4、学習解像度は512\*512、精度no(float32)で、meta_clean.jsonからメタデータを読み込み、meta_lat.jsonに書き込む場合、以下のようになります。
|
905 |
+
|
906 |
+
```
|
907 |
+
python prepare_buckets_latents.py --full_path
|
908 |
+
train_data meta_clean.json meta_lat.json model.ckpt
|
909 |
+
--batch_size 4 --max_resolution 512,512 --mixed_precision no
|
910 |
+
```
|
911 |
+
|
912 |
+
教師データフォルダにnumpyのnpz形式でlatentsが保存されます。
|
913 |
+
|
914 |
+
解像度の最小サイズを--min_bucket_resoオプションで、最大サイズを--max_bucket_resoで指定できます。デフォルトはそれぞれ256、1024です。たとえば最小サイズに384を指定すると、256\*1024や320\*768などの解像度は使わなくなります。
|
915 |
+
解像度を768\*768のように大きくした場合、最大サイズに1280などを指定すると良いでしょう。
|
916 |
+
|
917 |
+
--flip_augオプションを指定すると左右反転のaugmentation(データ拡張)を行います。疑似的にデータ量を二倍に増やすことができますが、データが左右対称でない場合に指定すると(例えばキャラクタの外見、髪型など)学習がうまく行かなくなります。
|
918 |
+
|
919 |
+
|
920 |
+
(反転した画像についてもlatentsを取得し、\*\_flip.npzファイルを保存する単純な実装です。fline_tune.pyには特にオプション指定は必要ありません。\_flip付きのファイルがある場合、flip付き・なしのファイルを、ランダムに読み込みます。)
|
921 |
+
|
922 |
+
バッチサイズはVRAM 12GBでももう少し増やせるかもしれません。
|
923 |
+
解像度は64で割り切れる数字で、"幅,高さ"で指定します。解像度はfine tuning時のメモリサイズに直結します。VRAM 12GBでは512,512が限界と思われます(※)。16GBなら512,704や512,768まで上げられるかもしれません。なお256,256等にしてもVRAM 8GBでは厳しいようです(パラメータやoptimizerなどは解像度に関係せず一定のメモリが必要なため)。
|
924 |
+
|
925 |
+
※batch size 1の学習で12GB VRAM、640,640で動いたとの報告もありました。
|
926 |
+
|
927 |
+
以下のようにbucketingの結果が表示されます。
|
928 |
+
|
929 |
+
![bucketingの結果](https://user-images.githubusercontent.com/52813779/208911419-71c00fbb-2ce6-49d5-89b5-b78d7715e441.png)
|
930 |
+
|
931 |
+
複数の教師データフォルダがある場合には、full_path引数を指定しつつ、それぞれのフォルダに対して実行してください。
|
932 |
+
```
|
933 |
+
python prepare_buckets_latents.py --full_path
|
934 |
+
train_data1 meta_clean.json meta_lat1.json model.ckpt
|
935 |
+
--batch_size 4 --max_resolution 512,512 --mixed_precision no
|
936 |
+
|
937 |
+
python prepare_buckets_latents.py --full_path
|
938 |
+
train_data2 meta_lat1.json meta_lat2.json model.ckpt
|
939 |
+
--batch_size 4 --max_resolution 512,512 --mixed_precision no
|
940 |
+
|
941 |
+
```
|
942 |
+
読み込み元と書き込み先を同じにすることも可能ですが別々の方が安全です。
|
943 |
+
|
944 |
+
__※引数を都度書き換えて、別のメタデータファイルに書き込むと安全です。__
|
945 |
+
|
train_db.py
ADDED
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DreamBooth training
|
2 |
+
# XXX dropped option: fine_tune
|
3 |
+
|
4 |
+
import gc
|
5 |
+
import time
|
6 |
+
import argparse
|
7 |
+
import itertools
|
8 |
+
import math
|
9 |
+
import os
|
10 |
+
import toml
|
11 |
+
from multiprocessing import Value
|
12 |
+
|
13 |
+
from tqdm import tqdm
|
14 |
+
import torch
|
15 |
+
from accelerate.utils import set_seed
|
16 |
+
import diffusers
|
17 |
+
from diffusers import DDPMScheduler
|
18 |
+
|
19 |
+
import library.train_util as train_util
|
20 |
+
import library.config_util as config_util
|
21 |
+
from library.config_util import (
|
22 |
+
ConfigSanitizer,
|
23 |
+
BlueprintGenerator,
|
24 |
+
)
|
25 |
+
import library.custom_train_functions as custom_train_functions
|
26 |
+
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings
|
27 |
+
|
28 |
+
def train(args):
|
29 |
+
train_util.verify_training_args(args)
|
30 |
+
train_util.prepare_dataset_args(args, False)
|
31 |
+
|
32 |
+
cache_latents = args.cache_latents
|
33 |
+
|
34 |
+
if args.seed is not None:
|
35 |
+
set_seed(args.seed) # 乱数系列を初期化する
|
36 |
+
|
37 |
+
tokenizer = train_util.load_tokenizer(args)
|
38 |
+
|
39 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True))
|
40 |
+
if args.dataset_config is not None:
|
41 |
+
print(f"Load dataset config from {args.dataset_config}")
|
42 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
43 |
+
ignored = ["train_data_dir", "reg_data_dir"]
|
44 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
45 |
+
print(
|
46 |
+
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
47 |
+
", ".join(ignored)
|
48 |
+
)
|
49 |
+
)
|
50 |
+
else:
|
51 |
+
user_config = {
|
52 |
+
"datasets": [
|
53 |
+
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
|
54 |
+
]
|
55 |
+
}
|
56 |
+
|
57 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
58 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
59 |
+
|
60 |
+
current_epoch = Value("i", 0)
|
61 |
+
current_step = Value("i", 0)
|
62 |
+
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
63 |
+
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
64 |
+
|
65 |
+
if args.no_token_padding:
|
66 |
+
train_dataset_group.disable_token_padding()
|
67 |
+
|
68 |
+
if args.debug_dataset:
|
69 |
+
train_util.debug_dataset(train_dataset_group)
|
70 |
+
return
|
71 |
+
|
72 |
+
if cache_latents:
|
73 |
+
assert (
|
74 |
+
train_dataset_group.is_latent_cacheable()
|
75 |
+
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
76 |
+
|
77 |
+
# acceleratorを準備する
|
78 |
+
print("prepare accelerator")
|
79 |
+
|
80 |
+
if args.gradient_accumulation_steps > 1:
|
81 |
+
print(
|
82 |
+
f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong"
|
83 |
+
)
|
84 |
+
print(
|
85 |
+
f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です"
|
86 |
+
)
|
87 |
+
|
88 |
+
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
89 |
+
|
90 |
+
# mixed precisionに対応した型を用意しておき適宜castする
|
91 |
+
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
92 |
+
|
93 |
+
# モデルを読み込む
|
94 |
+
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype)
|
95 |
+
|
96 |
+
# verify load/save model formats
|
97 |
+
if load_stable_diffusion_format:
|
98 |
+
src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
|
99 |
+
src_diffusers_model_path = None
|
100 |
+
else:
|
101 |
+
src_stable_diffusion_ckpt = None
|
102 |
+
src_diffusers_model_path = args.pretrained_model_name_or_path
|
103 |
+
|
104 |
+
if args.save_model_as is None:
|
105 |
+
save_stable_diffusion_format = load_stable_diffusion_format
|
106 |
+
use_safetensors = args.use_safetensors
|
107 |
+
else:
|
108 |
+
save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
|
109 |
+
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
|
110 |
+
|
111 |
+
# モデルに xformers とか memory efficient attention を組み込む
|
112 |
+
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
113 |
+
|
114 |
+
# 学習を準備する
|
115 |
+
if cache_latents:
|
116 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
117 |
+
vae.requires_grad_(False)
|
118 |
+
vae.eval()
|
119 |
+
with torch.no_grad():
|
120 |
+
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
121 |
+
vae.to("cpu")
|
122 |
+
if torch.cuda.is_available():
|
123 |
+
torch.cuda.empty_cache()
|
124 |
+
gc.collect()
|
125 |
+
|
126 |
+
accelerator.wait_for_everyone()
|
127 |
+
|
128 |
+
# 学習を準備する:モデルを適切な状態にする
|
129 |
+
train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0
|
130 |
+
unet.requires_grad_(True) # 念のため追加
|
131 |
+
text_encoder.requires_grad_(train_text_encoder)
|
132 |
+
if not train_text_encoder:
|
133 |
+
print("Text Encoder is not trained.")
|
134 |
+
|
135 |
+
if args.gradient_checkpointing:
|
136 |
+
unet.enable_gradient_checkpointing()
|
137 |
+
text_encoder.gradient_checkpointing_enable()
|
138 |
+
|
139 |
+
if not cache_latents:
|
140 |
+
vae.requires_grad_(False)
|
141 |
+
vae.eval()
|
142 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
143 |
+
|
144 |
+
# 学習に必要なクラスを準備する
|
145 |
+
print("prepare optimizer, data loader etc.")
|
146 |
+
if train_text_encoder:
|
147 |
+
trainable_params = itertools.chain(unet.parameters(), text_encoder.parameters())
|
148 |
+
else:
|
149 |
+
trainable_params = unet.parameters()
|
150 |
+
|
151 |
+
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
152 |
+
|
153 |
+
# dataloaderを準備する
|
154 |
+
# DataLoaderのプロセス数:0はメインプロセスになる
|
155 |
+
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
156 |
+
train_dataloader = torch.utils.data.DataLoader(
|
157 |
+
train_dataset_group,
|
158 |
+
batch_size=1,
|
159 |
+
shuffle=True,
|
160 |
+
collate_fn=collater,
|
161 |
+
num_workers=n_workers,
|
162 |
+
persistent_workers=args.persistent_data_loader_workers,
|
163 |
+
)
|
164 |
+
|
165 |
+
# 学習ステップ数を計算する
|
166 |
+
if args.max_train_epochs is not None:
|
167 |
+
args.max_train_steps = args.max_train_epochs * math.ceil(
|
168 |
+
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
169 |
+
)
|
170 |
+
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
171 |
+
|
172 |
+
# データセット側にも学習ステップを送信
|
173 |
+
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
174 |
+
|
175 |
+
if args.stop_text_encoder_training is None:
|
176 |
+
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
177 |
+
|
178 |
+
# lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
|
179 |
+
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
180 |
+
|
181 |
+
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
182 |
+
if args.full_fp16:
|
183 |
+
assert (
|
184 |
+
args.mixed_precision == "fp16"
|
185 |
+
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
186 |
+
print("enable full fp16 training.")
|
187 |
+
unet.to(weight_dtype)
|
188 |
+
text_encoder.to(weight_dtype)
|
189 |
+
|
190 |
+
# acceleratorがなんかよろしくやってくれるらしい
|
191 |
+
if train_text_encoder:
|
192 |
+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
193 |
+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
194 |
+
)
|
195 |
+
else:
|
196 |
+
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
197 |
+
|
198 |
+
if not train_text_encoder:
|
199 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
|
200 |
+
|
201 |
+
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
202 |
+
if args.full_fp16:
|
203 |
+
train_util.patch_accelerator_for_fp16_training(accelerator)
|
204 |
+
|
205 |
+
# resumeする
|
206 |
+
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
207 |
+
|
208 |
+
# epoch数を計算する
|
209 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
210 |
+
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
211 |
+
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
212 |
+
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
213 |
+
|
214 |
+
# 学習する
|
215 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
216 |
+
print("running training / 学習開始")
|
217 |
+
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
218 |
+
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
219 |
+
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
220 |
+
print(f" num epochs / epoch数: {num_train_epochs}")
|
221 |
+
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
222 |
+
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
223 |
+
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
224 |
+
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
225 |
+
|
226 |
+
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
227 |
+
global_step = 0
|
228 |
+
|
229 |
+
noise_scheduler = DDPMScheduler(
|
230 |
+
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
231 |
+
)
|
232 |
+
|
233 |
+
if accelerator.is_main_process:
|
234 |
+
accelerator.init_trackers("dreambooth")
|
235 |
+
|
236 |
+
loss_list = []
|
237 |
+
loss_total = 0.0
|
238 |
+
for epoch in range(num_train_epochs):
|
239 |
+
print(f"epoch {epoch+1}/{num_train_epochs}")
|
240 |
+
current_epoch.value = epoch + 1
|
241 |
+
|
242 |
+
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
243 |
+
unet.train()
|
244 |
+
# train==True is required to enable gradient_checkpointing
|
245 |
+
if args.gradient_checkpointing or global_step < args.stop_text_encoder_training:
|
246 |
+
text_encoder.train()
|
247 |
+
|
248 |
+
for step, batch in enumerate(train_dataloader):
|
249 |
+
current_step.value = global_step
|
250 |
+
# 指定したステップ数でText Encoderの学習を止める
|
251 |
+
if global_step == args.stop_text_encoder_training:
|
252 |
+
print(f"stop text encoder training at step {global_step}")
|
253 |
+
if not args.gradient_checkpointing:
|
254 |
+
text_encoder.train(False)
|
255 |
+
text_encoder.requires_grad_(False)
|
256 |
+
|
257 |
+
with accelerator.accumulate(unet):
|
258 |
+
with torch.no_grad():
|
259 |
+
# latentに変換
|
260 |
+
if cache_latents:
|
261 |
+
latents = batch["latents"].to(accelerator.device)
|
262 |
+
else:
|
263 |
+
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
264 |
+
latents = latents * 0.18215
|
265 |
+
b_size = latents.shape[0]
|
266 |
+
|
267 |
+
# Sample noise that we'll add to the latents
|
268 |
+
noise = torch.randn_like(latents, device=latents.device)
|
269 |
+
if args.noise_offset:
|
270 |
+
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
271 |
+
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
272 |
+
|
273 |
+
# Get the text embedding for conditioning
|
274 |
+
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
|
275 |
+
if args.weighted_captions:
|
276 |
+
encoder_hidden_states = get_weighted_text_embeddings(tokenizer,
|
277 |
+
text_encoder,
|
278 |
+
batch["captions"],
|
279 |
+
accelerator.device,
|
280 |
+
args.max_token_length // 75 if args.max_token_length else 1,
|
281 |
+
clip_skip=args.clip_skip,
|
282 |
+
)
|
283 |
+
else:
|
284 |
+
input_ids = batch["input_ids"].to(accelerator.device)
|
285 |
+
encoder_hidden_states = train_util.get_hidden_states(
|
286 |
+
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
|
287 |
+
)
|
288 |
+
|
289 |
+
# Sample a random timestep for each image
|
290 |
+
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
291 |
+
timesteps = timesteps.long()
|
292 |
+
|
293 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
294 |
+
# (this is the forward diffusion process)
|
295 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
296 |
+
|
297 |
+
# Predict the noise residual
|
298 |
+
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
299 |
+
|
300 |
+
if args.v_parameterization:
|
301 |
+
# v-parameterization training
|
302 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
303 |
+
else:
|
304 |
+
target = noise
|
305 |
+
|
306 |
+
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
307 |
+
loss = loss.mean([1, 2, 3])
|
308 |
+
|
309 |
+
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
310 |
+
loss = loss * loss_weights
|
311 |
+
|
312 |
+
if args.min_snr_gamma:
|
313 |
+
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
314 |
+
|
315 |
+
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
316 |
+
|
317 |
+
accelerator.backward(loss)
|
318 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
319 |
+
if train_text_encoder:
|
320 |
+
params_to_clip = itertools.chain(unet.parameters(), text_encoder.parameters())
|
321 |
+
else:
|
322 |
+
params_to_clip = unet.parameters()
|
323 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
324 |
+
|
325 |
+
optimizer.step()
|
326 |
+
lr_scheduler.step()
|
327 |
+
optimizer.zero_grad(set_to_none=True)
|
328 |
+
|
329 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
330 |
+
if accelerator.sync_gradients:
|
331 |
+
progress_bar.update(1)
|
332 |
+
global_step += 1
|
333 |
+
|
334 |
+
train_util.sample_images(
|
335 |
+
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
|
336 |
+
)
|
337 |
+
|
338 |
+
current_loss = loss.detach().item()
|
339 |
+
if args.logging_dir is not None:
|
340 |
+
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
341 |
+
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
342 |
+
logs["lr/d*lr"] = (
|
343 |
+
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
|
344 |
+
)
|
345 |
+
accelerator.log(logs, step=global_step)
|
346 |
+
|
347 |
+
if epoch == 0:
|
348 |
+
loss_list.append(current_loss)
|
349 |
+
else:
|
350 |
+
loss_total -= loss_list[step]
|
351 |
+
loss_list[step] = current_loss
|
352 |
+
loss_total += current_loss
|
353 |
+
avr_loss = loss_total / len(loss_list)
|
354 |
+
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
355 |
+
progress_bar.set_postfix(**logs)
|
356 |
+
|
357 |
+
if global_step >= args.max_train_steps:
|
358 |
+
break
|
359 |
+
|
360 |
+
if args.logging_dir is not None:
|
361 |
+
logs = {"loss/epoch": loss_total / len(loss_list)}
|
362 |
+
accelerator.log(logs, step=epoch + 1)
|
363 |
+
|
364 |
+
accelerator.wait_for_everyone()
|
365 |
+
|
366 |
+
if args.save_every_n_epochs is not None:
|
367 |
+
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
368 |
+
train_util.save_sd_model_on_epoch_end(
|
369 |
+
args,
|
370 |
+
accelerator,
|
371 |
+
src_path,
|
372 |
+
save_stable_diffusion_format,
|
373 |
+
use_safetensors,
|
374 |
+
save_dtype,
|
375 |
+
epoch,
|
376 |
+
num_train_epochs,
|
377 |
+
global_step,
|
378 |
+
unwrap_model(text_encoder),
|
379 |
+
unwrap_model(unet),
|
380 |
+
vae,
|
381 |
+
)
|
382 |
+
|
383 |
+
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
384 |
+
|
385 |
+
is_main_process = accelerator.is_main_process
|
386 |
+
if is_main_process:
|
387 |
+
unet = unwrap_model(unet)
|
388 |
+
text_encoder = unwrap_model(text_encoder)
|
389 |
+
|
390 |
+
accelerator.end_training()
|
391 |
+
|
392 |
+
if args.save_state:
|
393 |
+
train_util.save_state_on_train_end(args, accelerator)
|
394 |
+
|
395 |
+
del accelerator # この後メモリを使うのでこれは消す
|
396 |
+
|
397 |
+
if is_main_process:
|
398 |
+
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
399 |
+
train_util.save_sd_model_on_train_end(
|
400 |
+
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
|
401 |
+
)
|
402 |
+
print("model saved.")
|
403 |
+
|
404 |
+
|
405 |
+
def setup_parser() -> argparse.ArgumentParser:
|
406 |
+
parser = argparse.ArgumentParser()
|
407 |
+
|
408 |
+
train_util.add_sd_models_arguments(parser)
|
409 |
+
train_util.add_dataset_arguments(parser, True, False, True)
|
410 |
+
train_util.add_training_arguments(parser, True)
|
411 |
+
train_util.add_sd_saving_arguments(parser)
|
412 |
+
train_util.add_optimizer_arguments(parser)
|
413 |
+
config_util.add_config_arguments(parser)
|
414 |
+
custom_train_functions.add_custom_train_arguments(parser)
|
415 |
+
|
416 |
+
parser.add_argument(
|
417 |
+
"--no_token_padding",
|
418 |
+
action="store_true",
|
419 |
+
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)",
|
420 |
+
)
|
421 |
+
parser.add_argument(
|
422 |
+
"--stop_text_encoder_training",
|
423 |
+
type=int,
|
424 |
+
default=None,
|
425 |
+
help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない",
|
426 |
+
)
|
427 |
+
|
428 |
+
return parser
|
429 |
+
|
430 |
+
|
431 |
+
if __name__ == "__main__":
|
432 |
+
parser = setup_parser()
|
433 |
+
|
434 |
+
args = parser.parse_args()
|
435 |
+
args = train_util.read_config_from_file(args, parser)
|
436 |
+
|
437 |
+
train(args)
|
train_db_README-ja.md
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DreamBoothのガイドです。
|
2 |
+
|
3 |
+
[学習についての共通ドキュメント](./train_README-ja.md) もあわせてご覧ください。
|
4 |
+
|
5 |
+
# 概要
|
6 |
+
|
7 |
+
DreamBoothとは、画像生成モデルに特定の主題を追加学習し、それを特定の識別子で生成する技術です。[論文はこちら](https://arxiv.org/abs/2208.12242)。
|
8 |
+
|
9 |
+
具体的には、Stable Diffusionのモデルにキャラや画風などを学ばせ、それを `shs` のような特定の単語で呼び出せる(生成画像に出現させる)ことができます。
|
10 |
+
|
11 |
+
スクリプトは[DiffusersのDreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth)を元にしていますが、以下のような機能追加を行っています(いくつかの機能は元のスクリプト側もその後対応しています)。
|
12 |
+
|
13 |
+
スクリプトの主な機能は以下の通りです。
|
14 |
+
|
15 |
+
- 8bit Adam optimizerおよびlatentのキャッシュによる省メモリ化([Shivam Shrirao氏版](https://github.com/ShivamShrirao/diffusers/tree/main/examples/dreambooth)と同様)。
|
16 |
+
- xformersによる省メモリ化。
|
17 |
+
- 512x512だけではなく任意サイズでの学習。
|
18 |
+
- augmentationによる品質の向上。
|
19 |
+
- DreamBoothだけではなくText Encoder+U-Netのfine tuningに対応。
|
20 |
+
- Stable Diffusion形式でのモデルの読み書き。
|
21 |
+
- Aspect Ratio Bucketing。
|
22 |
+
- Stable Diffusion v2.0対応。
|
23 |
+
|
24 |
+
# 学習の手順
|
25 |
+
|
26 |
+
あらかじめこのリポジトリのREADMEを参照し、環境整備を行ってください。
|
27 |
+
|
28 |
+
## データの準備
|
29 |
+
|
30 |
+
[学習データの準備について](./train_README-ja.md) を参照してください。
|
31 |
+
|
32 |
+
## 学習の実行
|
33 |
+
|
34 |
+
スクリプトを実行します。最大限、メモリを節約したコマンドは以下のようになります(実際には1行で入力します)。それぞれの行を必要に応じて書き換えてください。12GB程度のVRAMで動作するようです。
|
35 |
+
|
36 |
+
```
|
37 |
+
accelerate launch --num_cpu_threads_per_process 1 train_db.py
|
38 |
+
--pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ>
|
39 |
+
--dataset_config=<データ準備で作成した.tomlファイル>
|
40 |
+
--output_dir=<学習したモデルの出力先フォルダ>
|
41 |
+
--output_name=<学習したモデル出力時のファイル名>
|
42 |
+
--save_model_as=safetensors
|
43 |
+
--prior_loss_weight=1.0
|
44 |
+
--max_train_steps=1600
|
45 |
+
--learning_rate=1e-6
|
46 |
+
--optimizer_type="AdamW8bit"
|
47 |
+
--xformers
|
48 |
+
--mixed_precision="fp16"
|
49 |
+
--cache_latents
|
50 |
+
--gradient_checkpointing
|
51 |
+
```
|
52 |
+
|
53 |
+
`num_cpu_threads_per_process` には通常は1を指定するとよいようです。
|
54 |
+
|
55 |
+
`pretrained_model_name_or_path` に追加学習を行う元となるモデルを指定します。Stable Diffusionのcheckpointファイル(.ckptまたは.safetensors)、Diffusersのローカルディスクにあるモデルディレクトリ、DiffusersのモデルID("stabilityai/stable-diffusion-2"など)が指定できます。
|
56 |
+
|
57 |
+
`output_dir` に学習後のモデルを保存するフォルダを指定します。`output_name` にモデルのファイル名を拡張子を除いて指定します。`save_model_as` でsafetensors形式での保存を指定しています。
|
58 |
+
|
59 |
+
`dataset_config` に `.toml` ファイルを指定します。ファイル内でのバッチサイズ指定は、当初はメモリ消費を抑えるために `1` としてください。
|
60 |
+
|
61 |
+
`prior_loss_weight` は正則化画像のlossの重みです。通常は1.0を指定します。
|
62 |
+
|
63 |
+
学習させるステップ数 `max_train_steps` を1600とします。学習率 `learning_rate` はここでは1e-6を指定しています。
|
64 |
+
|
65 |
+
省メモリ化のため `mixed_precision="fp16"` を指定します(RTX30 シリーズ以降では `bf16` も指定できます。環境整備時にaccelerateに行った設定と合わせてください)。また `gradient_checkpointing` を指定します。
|
66 |
+
|
67 |
+
オプティマイザ(モデルを学習データにあうように最適化=学習させるクラス)にメモリ消費の少ない 8bit AdamW を使うため、 `optimizer_type="AdamW8bit"` を指定します。
|
68 |
+
|
69 |
+
`xformers` オプションを指定し、xformersのCrossAttentionを用います。xformersをインストールしていない場合やエラーとなる場合(環境にもよりますが `mixed_precision="no"` の場合など)、代わりに `mem_eff_attn` オプションを指定すると省メモリ版CrossAttentionを使用します(速度は遅くなります)。
|
70 |
+
|
71 |
+
省メモリ化のため `cache_latents` オプションを指定してVAEの出力をキャッシュします。
|
72 |
+
|
73 |
+
ある程度メモリがある場合は、`.toml` ファイルを編集してバッチサイズをたとえば `4` くらいに増やしてください(高速化と精度向上の可能性があります)。また `cache_latents` を外すことで augmentation が可能になります。
|
74 |
+
|
75 |
+
### よく使われるオプションについて
|
76 |
+
|
77 |
+
以下の場合には [学習の共通ドキュメント](./train_README-ja.md) の「よく使われるオプション」を参照してください。
|
78 |
+
|
79 |
+
- Stable Diffusion 2.xまたはそこからの派生モデルを学習する
|
80 |
+
- clip skipを2以上を前提としたモデルを学習する
|
81 |
+
- 75トークンを超えたキャプションで学習する
|
82 |
+
|
83 |
+
### DreamBoothでのステップ数について
|
84 |
+
|
85 |
+
当スクリプトでは省メモリ化のため、ステップ当たりの学習回数が元のスクリプトの半分になっています(対象の画像と正則化画像を同一のバッチではなく別のバッチに分割して学習するため)。
|
86 |
+
|
87 |
+
元のDiffusers版やXavierXiao氏のStable Diffusion版とほぼ同じ学習を行うには、ステップ数を倍にしてください。
|
88 |
+
|
89 |
+
(学習画像と正則化画像をまとめてから shuffle するため厳密にはデータの順番が変わってしまいますが、学習には大きな影響はないと思います。)
|
90 |
+
|
91 |
+
### DreamBoothでのバッチサイズについて
|
92 |
+
|
93 |
+
モデル全体を学習するためLoRA等の学習に比べるとメモリ消費量は多くなります(fine tuningと同じ)。
|
94 |
+
|
95 |
+
### 学習率について
|
96 |
+
|
97 |
+
Diffusers版では5e-6ですがStable Diffusion版は1e-6ですので、上のサンプルでは1e-6を指定しています。
|
98 |
+
|
99 |
+
### 以前の形式のデータセット指定をした場合のコマンドライン
|
100 |
+
|
101 |
+
解像度やバッチサイズをオプションで指定します。コマンドラインの例は以下の通りです。
|
102 |
+
|
103 |
+
```
|
104 |
+
accelerate launch --num_cpu_threads_per_process 1 train_db.py
|
105 |
+
--pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ>
|
106 |
+
--train_data_dir=<学習用データのディレクトリ>
|
107 |
+
--reg_data_dir=<正則化画像のディレクトリ>
|
108 |
+
--output_dir=<学習したモデルの出力先ディレクトリ>
|
109 |
+
--output_name=<学習したモデル出力時のファイル名>
|
110 |
+
--prior_loss_weight=1.0
|
111 |
+
--resolution=512
|
112 |
+
--train_batch_size=1
|
113 |
+
--learning_rate=1e-6
|
114 |
+
--max_train_steps=1600
|
115 |
+
--use_8bit_adam
|
116 |
+
--xformers
|
117 |
+
--mixed_precision="bf16"
|
118 |
+
--cache_latents
|
119 |
+
--gradient_checkpointing
|
120 |
+
```
|
121 |
+
|
122 |
+
## 学習したモデルで画像生成する
|
123 |
+
|
124 |
+
学習が終わると指定したフォルダに指定した名前でsafetensorsファイルが出力されます。
|
125 |
+
|
126 |
+
v1.4/1.5およびその他の派生モデルの場合、このモデルでAutomatic1111氏のWebUIなどで推論できます。models\Stable-diffusionフォルダに置いてください。
|
127 |
+
|
128 |
+
v2.xモデルでWebUIで画像生成する場合、モデルの仕様が記述された.yamlファイルが別途必要になります。v2.x baseの場合はv2-inference.yamlを、768/vの場合はv2-inference-v.yamlを、同じフォルダに置き、拡張子の前の部分をモデルと同じ名前にしてください。
|
129 |
+
|
130 |
+
![image](https://user-images.githubusercontent.com/52813779/210776915-061d79c3-6582-42c2-8884-8b91d2f07313.png)
|
131 |
+
|
132 |
+
各yamlファイルは[Stability AIのSD2.0のリポジトリ](https://github.com/Stability-AI/stablediffusion/tree/main/configs/stable-diffusion)にあります。
|
133 |
+
|
134 |
+
# DreamBooth特有のその他の主なオプション
|
135 |
+
|
136 |
+
すべてのオプションについては別文書を参照してください。
|
137 |
+
|
138 |
+
## Text Encoderの学習を途中から行わない --stop_text_encoder_training
|
139 |
+
|
140 |
+
stop_text_encoder_trainingオプションに数値を指定すると、そのステップ数以降はText Encoderの学習を行わずU-Netだけ学習します。場合によっては精度の向上が期待できるかもしれません。
|
141 |
+
|
142 |
+
(恐らくText Encoderだけ先に過学習することがあり、それを防げるのではないかと推測していますが、詳細な影響は不明です。)
|
143 |
+
|
144 |
+
## Tokenizerのパディングをしない --no_token_padding
|
145 |
+
no_token_paddingオプションを指定するとTokenizerの出力をpaddingしません(Diffusers版の旧DreamBoothと同じ動きになります)。
|
146 |
+
|
147 |
+
|
148 |
+
<!--
|
149 |
+
bucketing(後述)を利用しかつaugmentation(後述)を使う場合の例は以下のようになります。
|
150 |
+
|
151 |
+
```
|
152 |
+
accelerate launch --num_cpu_threads_per_process 8 train_db.py
|
153 |
+
--pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ>
|
154 |
+
--train_data_dir=<学習用データのディレクトリ>
|
155 |
+
--reg_data_dir=<正則化画像のディレクトリ>
|
156 |
+
--output_dir=<学習したモデルの出力先ディレクトリ>
|
157 |
+
--resolution=768,512
|
158 |
+
--train_batch_size=20 --learning_rate=5e-6 --max_train_steps=800
|
159 |
+
--use_8bit_adam --xformers --mixed_precision="bf16"
|
160 |
+
--save_every_n_epochs=1 --save_state --save_precision="bf16"
|
161 |
+
--logging_dir=logs
|
162 |
+
--enable_bucket --min_bucket_reso=384 --max_bucket_reso=1280
|
163 |
+
--color_aug --flip_aug --gradient_checkpointing --seed 42
|
164 |
+
```
|
165 |
+
|
166 |
+
|
167 |
+
-->
|
train_db_README.md
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
A guide to DreamBooth. The same procedure is used for training additional networks such as LoRA.
|
2 |
+
|
3 |
+
# overview
|
4 |
+
|
5 |
+
The main functions of the script are as follows.
|
6 |
+
|
7 |
+
- Memory saving by 8bit Adam optimizer and latent cache (similar to ShivamShirao's version).
|
8 |
+
- Saved memory by xformers.
|
9 |
+
- Study in any size, not just 512x512.
|
10 |
+
- Quality improvement with augmentation.
|
11 |
+
- Supports fine tuning of Text Encoder+U-Net as well as DreamBooth.
|
12 |
+
- Read and write models in StableDiffusion format.
|
13 |
+
- Aspect Ratio Bucketing.
|
14 |
+
- Supports Stable Diffusion v2.0.
|
15 |
+
|
16 |
+
# learning procedure
|
17 |
+
|
18 |
+
## step 1. Environment improvement
|
19 |
+
|
20 |
+
See the README in this repository.
|
21 |
+
|
22 |
+
|
23 |
+
## step 2. Determine identifier and class
|
24 |
+
|
25 |
+
Decide the word identifier that connects the target you want to learn and the class to which the target belongs.
|
26 |
+
|
27 |
+
(There are various names such as instance, but for the time being I will stick to the original paper.)
|
28 |
+
|
29 |
+
Here's a very brief explanation (look it up for more details).
|
30 |
+
|
31 |
+
class is the general type to learn. For example, if you want to learn a specific breed of dog, the class will be dog. Anime characters will be boy, girl, 1boy or 1girl depending on the model.
|
32 |
+
|
33 |
+
The identifier is for identifying and learning the learning target. Any word is fine, but according to the original paper, ``a rare word with 3 letters or less that becomes one token with tokinizer'' is good.
|
34 |
+
|
35 |
+
By using the identifier and class to train the model, for example, "shs dog", you can learn by identifying the object you want to learn from the class.
|
36 |
+
|
37 |
+
When generating an image, if you say "shs dog", an image of the learned dog breed will be generated.
|
38 |
+
|
39 |
+
(For reference, the identifier I use these days is ``shs sts scs cpc coc cic msm usu ici lvl cic dii muk ori hru rik koo yos wny``.)
|
40 |
+
|
41 |
+
## step 3. Prepare images for training
|
42 |
+
Create a folder to store training images. __In addition, create a directory with the following name:
|
43 |
+
|
44 |
+
```
|
45 |
+
<repeat count>_<identifier> <class>
|
46 |
+
```
|
47 |
+
|
48 |
+
Don't forget the ``_`` between them.
|
49 |
+
|
50 |
+
The number of repetitions is specified to match the number of regularized images (described later).
|
51 |
+
|
52 |
+
For example, at the prompt "sls frog", to repeat the data 20 times, it would be "20_sls frog". It will be as follows.
|
53 |
+
|
54 |
+
![image](https://user-images.githubusercontent.com/52813779/210770636-1c851377-5936-4c15-90b7-8ac8ad6c2074.png)
|
55 |
+
|
56 |
+
## step 4. Preparing regularized images
|
57 |
+
This is the procedure when using a regularized image. It is also possible to learn without using the regularization image (the whole target class is affected because it is impossible to distinguish without using the regularization image).
|
58 |
+
|
59 |
+
Create a folder to store the regularized images. __In addition, __ create a directory named ``<repeat count>_<class>``.
|
60 |
+
|
61 |
+
For example, with the prompt "frog" and without repeating the data (just once):
|
62 |
+
|
63 |
+
![image](https://user-images.githubusercontent.com/52813779/210770897-329758e5-3675-49f1-b345-c135f1725832.png)
|
64 |
+
|
65 |
+
Specify the number of iterations so that " __ number of iterations of training images x number of training images ≥ number of iterations of regularization images x number of regularization images __".
|
66 |
+
|
67 |
+
(The number of data in one epoch is "number of repetitions of training images x number of training images". If the number of regularization images is more than that, the remaining regularization images will not be used.)
|
68 |
+
|
69 |
+
## step 5. Run training
|
70 |
+
Run the script. The maximally memory-saving command looks like this (actually typed on one line):
|
71 |
+
|
72 |
+
*The command for learning additional networks such as LoRA is ``train_network.py`` instead of ``train_db.py``. You will also need additional network_\* options, so please refer to LoRA's guide.
|
73 |
+
|
74 |
+
```
|
75 |
+
accelerate launch --num_cpu_threads_per_process 8 train_db.py
|
76 |
+
--pretrained_model_name_or_path=<directory of .ckpt or .safetensord or Diffusers model>
|
77 |
+
--train_data_dir=<training data directory>
|
78 |
+
--reg_data_dir=<regularized image directory>
|
79 |
+
--output_dir=<output destination directory for trained model>
|
80 |
+
--prior_loss_weight=1.0
|
81 |
+
--resolution=512
|
82 |
+
--train_batch_size=1
|
83 |
+
--learning_rate=1e-6
|
84 |
+
--max_train_steps=1600
|
85 |
+
--use_8bit_adam
|
86 |
+
--xformers
|
87 |
+
--mixed_precision="bf16"
|
88 |
+
--cache_latents
|
89 |
+
--gradient_checkpointing
|
90 |
+
```
|
91 |
+
|
92 |
+
It seems to be good to specify the number of CPU cores for num_cpu_threads_per_process.
|
93 |
+
|
94 |
+
Specify the model to perform additional training in pretrained_model_name_or_path. You can specify a Stable Diffusion checkpoint file (.ckpt or .safetensors), a model directory on the Diffusers local disk, or a Diffusers model ID (such as "stabilityai/stable-diffusion-2"). The saved model after training will be saved in the same format as the original model by default (can be changed with the save_model_as option).
|
95 |
+
|
96 |
+
prior_loss_weight is the loss weight of the regularized image. Normally, specify 1.0.
|
97 |
+
|
98 |
+
resolution will be the size of the image (resolution, width and height). If bucketing (described later) is not used, use this size for training images and regularization images.
|
99 |
+
|
100 |
+
train_batch_size is the training batch size. Set max_train_steps to 1600. The learning rate learning_rate is 5e-6 in the diffusers version and 1e-6 in the StableDiffusion version, so 1e-6 is specified here.
|
101 |
+
|
102 |
+
Specify mixed_precision="bf16" (or "fp16") and gradient_checkpointing for memory saving.
|
103 |
+
|
104 |
+
Specify the xformers option and use xformers' CrossAttention. If you don't have xformers installed, if you get an error (without mixed_precision, it was an error in my environment), specify the mem_eff_attn option instead to use the memory-saving version of CrossAttention (speed will be slower) .
|
105 |
+
|
106 |
+
Cache VAE output with cache_latents option to save memory.
|
107 |
+
|
108 |
+
If you have a certain amount of memory, specify it as follows, for example.
|
109 |
+
|
110 |
+
```
|
111 |
+
accelerate launch --num_cpu_threads_per_process 8 train_db.py
|
112 |
+
--pretrained_model_name_or_path=<directory of .ckpt or .safetensord or Diffusers model>
|
113 |
+
--train_data_dir=<training data directory>
|
114 |
+
--reg_data_dir=<regularized image directory>
|
115 |
+
--output_dir=<output destination directory for trained model>
|
116 |
+
--prior_loss_weight=1.0
|
117 |
+
--resolution=512
|
118 |
+
--train_batch_size=4
|
119 |
+
--learning_rate=1e-6
|
120 |
+
--max_train_steps=400
|
121 |
+
--use_8bit_adam
|
122 |
+
--xformers
|
123 |
+
--mixed_precision="bf16"
|
124 |
+
--cache_latents
|
125 |
+
```
|
126 |
+
|
127 |
+
Remove gradient_checkpointing to speed up (memory usage will increase). Increase the batch size to improve speed and accuracy.
|
128 |
+
|
129 |
+
An example of using bucketing (see below) and using augmentation (see below) looks like this:
|
130 |
+
|
131 |
+
```
|
132 |
+
accelerate launch --num_cpu_threads_per_process 8 train_db.py
|
133 |
+
--pretrained_model_name_or_path=<directory of .ckpt or .safetensord or Diffusers model>
|
134 |
+
--train_data_dir=<training data directory>
|
135 |
+
--reg_data_dir=<regularized image directory>
|
136 |
+
--output_dir=<output destination directory for trained model>
|
137 |
+
--resolution=768,512
|
138 |
+
--train_batch_size=20 --learning_rate=5e-6 --max_train_steps=800
|
139 |
+
--use_8bit_adam --xformers --mixed_precision="bf16"
|
140 |
+
--save_every_n_epochs=1 --save_state --save_precision="bf16"
|
141 |
+
--logging_dir=logs
|
142 |
+
--enable_bucket --min_bucket_reso=384 --max_bucket_reso=1280
|
143 |
+
--color_aug --flip_aug --gradient_checkpointing --seed 42
|
144 |
+
```
|
145 |
+
|
146 |
+
### About the number of steps
|
147 |
+
To save memory, the number of training steps per step is half that of train_drebooth.py (because the target image and the regularization image are divided into different batches instead of the same batch).
|
148 |
+
Double the number of steps to get almost the same training as the original Diffusers version and XavierXiao's StableDiffusion version.
|
149 |
+
|
150 |
+
(Strictly speaking, the order of the data changes due to shuffle=True, but I don't think it has a big impact on learning.)
|
151 |
+
|
152 |
+
## Generate an image with the trained model
|
153 |
+
|
154 |
+
Name last.ckpt in the specified folder when learning is completed will output the checkpoint (if you learned the DiffUsers version model, it will be the last folder).
|
155 |
+
|
156 |
+
For v1.4/1.5 and other derived models, this model can be inferred by Automatic1111's WebUI, etc. Place it in the models\Stable-diffusion folder.
|
157 |
+
|
158 |
+
When generating images with WebUI with the v2.x model, a separate .yaml file that describes the model specifications is required. Place v2-inference.yaml for v2.x base and v2-inference-v.yaml for 768/v in the same folder and make the part before the extension the same name as the model.
|
159 |
+
|
160 |
+
![image](https://user-images.githubusercontent.com/52813779/210776915-061d79c3-6582-42c2-8884-8b91d2f07313.png)
|
161 |
+
|
162 |
+
Each yaml file can be found at [https://github.com/Stability-AI/stablediffusion/tree/main/configs/stable-diffusion] (Stability AI SD2.0 repository).
|
163 |
+
|
164 |
+
# Other study options
|
165 |
+
|
166 |
+
## Supports Stable Diffusion 2.0 --v2 / --v_parameterization
|
167 |
+
Specify the v2 option when using Hugging Face's stable-diffusion-2-base, and specify both the v2 and v_parameterization options when using stable-diffusion-2 or 768-v-ema.ckpt.
|
168 |
+
|
169 |
+
In addition, learning SD 2.0 seems to be difficult with VRAM 12GB because the Text Encoder is getting bigger.
|
170 |
+
|
171 |
+
The following points have changed significantly in Stable Diffusion 2.0.
|
172 |
+
|
173 |
+
1. Tokenizer to use
|
174 |
+
2. Which Text Encoder to use and which output layer to use (2.0 uses the penultimate layer)
|
175 |
+
3. Output dimensionality of Text Encoder (768->1024)
|
176 |
+
4. Structure of U-Net (number of heads of CrossAttention, etc.)
|
177 |
+
5. v-parameterization (the sampling method seems to have changed)
|
178 |
+
|
179 |
+
Among these, 1 to 4 are adopted for base, and 1 to 5 are adopted for the one without base (768-v). Enabling 1-4 is the v2 option, and enabling 5 is the v_parameterization option.
|
180 |
+
|
181 |
+
## check training data --debug_dataset
|
182 |
+
By adding this option, you can check what kind of image data and captions will be learned in advance before learning. Press Esc to exit and return to the command line.
|
183 |
+
|
184 |
+
*Please note that it seems to hang when executed in an environment where there is no screen such as Colab.
|
185 |
+
|
186 |
+
## Stop training Text Encoder --stop_text_encoder_training
|
187 |
+
If you specify a numerical value for the stop_text_encoder_training option, after that number of steps, only the U-Net will be trained without training the Text Encoder. In some cases, the accuracy may be improved.
|
188 |
+
|
189 |
+
(Probably only the Text Encoder may overfit first, and I guess that it can be prevented, but the detailed impact is unknown.)
|
190 |
+
|
191 |
+
## Load and learn VAE separately --vae
|
192 |
+
If you specify either a Stable Diffusion checkpoint, a VAE checkpoint file, a Diffuses model, or a VAE (both of which can specify a local or Hugging Face model ID) in the vae option, that VAE is used for learning (latents when caching or getting latents during learning).
|
193 |
+
The saved model will incorporate this VAE.
|
194 |
+
|
195 |
+
## save during learning --save_every_n_epochs / --save_state / --resume
|
196 |
+
Specifying a number for the save_every_n_epochs option saves the model during training every epoch.
|
197 |
+
|
198 |
+
If you specify the save_state option at the same time, the learning state including the state of the optimizer etc. will be saved together (compared to restarting learning from the checkpoint, you can expect to improve accuracy and shorten the learning time). The learning state is output in a folder named "epoch-??????-state" (?????? is the number of epochs) in the destination folder. Please use it when studying for a long time.
|
199 |
+
|
200 |
+
Use the resume option to resume training from a saved training state. Please specify the learning state folder.
|
201 |
+
|
202 |
+
Please note that due to the specifications of Accelerator (?), the number of epochs and global step are not saved, and it will start from 1 even when you resume.
|
203 |
+
|
204 |
+
## No tokenizer padding --no_token_padding
|
205 |
+
The no_token_padding option does not pad the output of the Tokenizer (same behavior as Diffusers version of old DreamBooth).
|
206 |
+
|
207 |
+
## Training with arbitrary size images --resolution
|
208 |
+
You can study outside the square. Please specify "width, height" like "448,640" in resolution. Width and height must be divisible by 64. Match the size of the training image and the regularization image.
|
209 |
+
|
210 |
+
Personally, I often generate vertically long images, so I sometimes learn with "448, 640".
|
211 |
+
|
212 |
+
## Aspect Ratio Bucketing --enable_bucket / --min_bucket_reso / --max_bucket_reso
|
213 |
+
It is enabled by specifying the enable_bucket option. Stable Diffusion is trained at 512x512, but also at resolutions such as 256x768 and 384x640.
|
214 |
+
|
215 |
+
If you specify this option, you do not need to unify the training images and regularization images to a specific resolution. Choose from several resolutions (aspect ratios) and learn at that resolution.
|
216 |
+
Since the resolution is 64 pixels, the aspect ratio may not be exactly the same as the original image.
|
217 |
+
|
218 |
+
You can specify the minimum size of the resolution with the min_bucket_reso option and the maximum size with the max_bucket_reso. The defaults are 256 and 1024 respectively.
|
219 |
+
For example, specifying a minimum size of 384 will not use resolutions such as 256x1024 or 320x768.
|
220 |
+
If you increase the resolution to 768x768, you may want to specify 1280 as the maximum size.
|
221 |
+
|
222 |
+
When Aspect Ratio Bucketing is enabled, it may be better to prepare regularization images with various resolutions that are similar to the training images.
|
223 |
+
|
224 |
+
(Because the images in one batch are not biased toward training images and regularization images.
|
225 |
+
|
226 |
+
## augmentation --color_aug / --flip_aug
|
227 |
+
Augmentation is a method of improving model performance by dynamically changing data during learning. Learn while subtly changing the hue with color_aug and flipping left and right with flip_aug.
|
228 |
+
|
229 |
+
Since the data changes dynamically, it cannot be specified together with the cache_latents option.
|
230 |
+
|
231 |
+
## Specify data precision when saving --save_precision
|
232 |
+
Specifying float, fp16, or bf16 as the save_precision option will save the checkpoint in that format (only when saving in Stable Diffusion format). Please use it when you want to reduce the size of checkpoint.
|
233 |
+
|
234 |
+
## save in any format --save_model_as
|
235 |
+
Specify the save format of the model. Specify one of ckpt, safetensors, diffusers, diffusers_safetensors.
|
236 |
+
|
237 |
+
When reading Stable Diffusion format (ckpt or safetensors) and saving in Diffusers format, missing information is supplemented by dropping v1.5 or v2.1 information from Hugging Face.
|
238 |
+
|
239 |
+
## Save learning log --logging_dir / --log_prefix
|
240 |
+
Specify the log save destination folder in the logging_dir option. Logs in TensorBoard format are saved.
|
241 |
+
|
242 |
+
For example, if you specify --logging_dir=logs, a logs folder will be created in your working folder, and logs will be saved in the date/time folder.
|
243 |
+
Also, if you specify the --log_prefix option, the specified string will be added before the date and time. Use "--logging_dir=logs --log_prefix=db_style1_" for identification.
|
244 |
+
|
245 |
+
To check the log with TensorBoard, open another command prompt and enter the following in the working folder (I think tensorboard is installed when Diffusers is installed, but if it is not installed, pip install Please put it in tensorboard).
|
246 |
+
|
247 |
+
```
|
248 |
+
tensorboard --logdir=logs
|
249 |
+
```
|
250 |
+
|
251 |
+
Then open your browser and go to http://localhost:6006/ to see it.
|
252 |
+
|
253 |
+
## scheduler related specification of learning rate --lr_scheduler / --lr_warmup_steps
|
254 |
+
You can choose the learning rate scheduler from linear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup with the lr_scheduler option. Default is constant. With lr_warmup_steps, you can specify the number of steps to warm up the scheduler (gradually changing the learning rate). Please do your own research for details.
|
255 |
+
|
256 |
+
## Training with fp16 gradient (experimental feature) --full_fp16
|
257 |
+
The full_fp16 option will change the gradient from normal float32 to float16 (fp16) and learn (it seems to be full fp16 learning instead of mixed precision).
|
258 |
+
As a result, it seems that the SD1.x 512x512 size can be learned with a VRAM usage of less than 8GB, and the SD2.x 512x512 size can be learned with a VRAM usage of less than 12GB.
|
259 |
+
|
260 |
+
Specify fp16 in the accelerate config beforehand and optionally set ``mixed_precision="fp16"`` (bf16 does not work).
|
261 |
+
|
262 |
+
To minimize memory usage, use xformers, use_8bit_adam, cache_latents, gradient_checkpointing options and set train_batch_size to 1.
|
263 |
+
|
264 |
+
(If you can afford it, increasing the train_batch_size step by step should improve the accuracy a little.)
|
265 |
+
|
266 |
+
It is realized by patching the PyTorch source (confirmed with PyTorch 1.12.1 and 1.13.0). Accuracy will drop considerably, and the probability of learning failure on the way will also increase.
|
267 |
+
The setting of the learning rate and the number of steps seems to be severe. Please be aware of them and use them at your own risk.
|
268 |
+
|
269 |
+
# Other learning methods
|
270 |
+
|
271 |
+
## Learning multiple classes, multiple identifiers
|
272 |
+
The method is simple, multiple folders with ``Repetition count_<identifier> <class>`` in the training image folder, and a folder with ``Repetition count_<class>`` in the regularization image folder. Please prepare multiple
|
273 |
+
|
274 |
+
For example, learning "sls frog" and "cpc rabbit" at the same time would look like this:
|
275 |
+
|
276 |
+
![image](https://user-images.githubusercontent.com/52813779/210777933-a22229db-b219-4cd8-83ca-e87320fc4192.png)
|
277 |
+
|
278 |
+
If you have one class and multiple targets, you can have only one regularized image folder. For example, if 1girl has character A and character B, do as follows.
|
279 |
+
|
280 |
+
- train_girls
|
281 |
+
- 10_sls 1girl
|
282 |
+
- 10_cpc 1girl
|
283 |
+
- reg_girls
|
284 |
+
-1_1girl
|
285 |
+
|
286 |
+
If the number of data varies, it seems that good results can be obtained by adjusting the number of repetitions to unify the number of sheets for each class and identifier.
|
287 |
+
|
288 |
+
## Use captions in DreamBooth
|
289 |
+
If you put a file with the same file name as the image and the extension .caption (you can change it in the option) in the training image and regularization image folders, the caption will be read from that file and learned as a prompt.
|
290 |
+
|
291 |
+
* The folder name (identifier class) will no longer be used for training those images.
|
292 |
+
|
293 |
+
Adding captions to each image (you can use BLIP, etc.) may help clarify the attributes you want to learn.
|
294 |
+
|
295 |
+
Caption files have a .caption extension by default. You can change it with --caption_extension. With the --shuffle_caption option, study captions during learning while shuffling each part separated by commas.
|
train_network.py
ADDED
@@ -0,0 +1,773 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
2 |
+
import importlib
|
3 |
+
import argparse
|
4 |
+
import gc
|
5 |
+
import math
|
6 |
+
import os
|
7 |
+
import random
|
8 |
+
import time
|
9 |
+
import json
|
10 |
+
import toml
|
11 |
+
from multiprocessing import Value
|
12 |
+
|
13 |
+
from tqdm import tqdm
|
14 |
+
import torch
|
15 |
+
from accelerate.utils import set_seed
|
16 |
+
from diffusers import DDPMScheduler
|
17 |
+
|
18 |
+
import library.train_util as train_util
|
19 |
+
from library.train_util import (
|
20 |
+
DreamBoothDataset,
|
21 |
+
)
|
22 |
+
import library.config_util as config_util
|
23 |
+
from library.config_util import (
|
24 |
+
ConfigSanitizer,
|
25 |
+
BlueprintGenerator,
|
26 |
+
)
|
27 |
+
import library.huggingface_util as huggingface_util
|
28 |
+
import library.custom_train_functions as custom_train_functions
|
29 |
+
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings
|
30 |
+
|
31 |
+
|
32 |
+
# TODO 他のスクリプトと共通化する
|
33 |
+
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
34 |
+
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
35 |
+
|
36 |
+
lrs = lr_scheduler.get_last_lr()
|
37 |
+
|
38 |
+
if args.network_train_text_encoder_only or len(lrs) <= 2: # not block lr (or single block)
|
39 |
+
if args.network_train_unet_only:
|
40 |
+
logs["lr/unet"] = float(lrs[0])
|
41 |
+
elif args.network_train_text_encoder_only:
|
42 |
+
logs["lr/textencoder"] = float(lrs[0])
|
43 |
+
else:
|
44 |
+
logs["lr/textencoder"] = float(lrs[0])
|
45 |
+
logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder
|
46 |
+
|
47 |
+
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
|
48 |
+
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
|
49 |
+
else:
|
50 |
+
idx = 0
|
51 |
+
if not args.network_train_unet_only:
|
52 |
+
logs["lr/textencoder"] = float(lrs[0])
|
53 |
+
idx = 1
|
54 |
+
|
55 |
+
for i in range(idx, len(lrs)):
|
56 |
+
logs[f"lr/group{i}"] = float(lrs[i])
|
57 |
+
if args.optimizer_type.lower() == "DAdaptation".lower():
|
58 |
+
logs[f"lr/d*lr/group{i}"] = (
|
59 |
+
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
|
60 |
+
)
|
61 |
+
|
62 |
+
return logs
|
63 |
+
|
64 |
+
|
65 |
+
def train(args):
|
66 |
+
session_id = random.randint(0, 2**32)
|
67 |
+
training_started_at = time.time()
|
68 |
+
train_util.verify_training_args(args)
|
69 |
+
train_util.prepare_dataset_args(args, True)
|
70 |
+
|
71 |
+
cache_latents = args.cache_latents
|
72 |
+
use_dreambooth_method = args.in_json is None
|
73 |
+
use_user_config = args.dataset_config is not None
|
74 |
+
|
75 |
+
if args.seed is None:
|
76 |
+
args.seed = random.randint(0, 2**32)
|
77 |
+
set_seed(args.seed)
|
78 |
+
|
79 |
+
tokenizer = train_util.load_tokenizer(args)
|
80 |
+
|
81 |
+
# データセットを準備する
|
82 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True))
|
83 |
+
if use_user_config:
|
84 |
+
print(f"Load dataset config from {args.dataset_config}")
|
85 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
86 |
+
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
87 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
88 |
+
print(
|
89 |
+
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
90 |
+
", ".join(ignored)
|
91 |
+
)
|
92 |
+
)
|
93 |
+
else:
|
94 |
+
if use_dreambooth_method:
|
95 |
+
print("Use DreamBooth method.")
|
96 |
+
user_config = {
|
97 |
+
"datasets": [
|
98 |
+
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
|
99 |
+
]
|
100 |
+
}
|
101 |
+
else:
|
102 |
+
print("Train with captions.")
|
103 |
+
user_config = {
|
104 |
+
"datasets": [
|
105 |
+
{
|
106 |
+
"subsets": [
|
107 |
+
{
|
108 |
+
"image_dir": args.train_data_dir,
|
109 |
+
"metadata_file": args.in_json,
|
110 |
+
}
|
111 |
+
]
|
112 |
+
}
|
113 |
+
]
|
114 |
+
}
|
115 |
+
|
116 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
117 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
118 |
+
|
119 |
+
current_epoch = Value("i", 0)
|
120 |
+
current_step = Value("i", 0)
|
121 |
+
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
122 |
+
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
123 |
+
|
124 |
+
if args.debug_dataset:
|
125 |
+
train_util.debug_dataset(train_dataset_group)
|
126 |
+
return
|
127 |
+
if len(train_dataset_group) == 0:
|
128 |
+
print(
|
129 |
+
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)"
|
130 |
+
)
|
131 |
+
return
|
132 |
+
|
133 |
+
if cache_latents:
|
134 |
+
assert (
|
135 |
+
train_dataset_group.is_latent_cacheable()
|
136 |
+
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
137 |
+
|
138 |
+
# acceleratorを準備する
|
139 |
+
print("prepare accelerator")
|
140 |
+
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
141 |
+
is_main_process = accelerator.is_main_process
|
142 |
+
|
143 |
+
# mixed precisionに対応した型を用意しておき適宜castする
|
144 |
+
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
145 |
+
|
146 |
+
# モデルを読み込む
|
147 |
+
for pi in range(accelerator.state.num_processes):
|
148 |
+
# TODO: modify other training scripts as well
|
149 |
+
if pi == accelerator.state.local_process_index:
|
150 |
+
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
151 |
+
|
152 |
+
text_encoder, vae, unet, _ = train_util.load_target_model(
|
153 |
+
args, weight_dtype, accelerator.device if args.lowram else "cpu"
|
154 |
+
)
|
155 |
+
|
156 |
+
# work on low-ram device
|
157 |
+
if args.lowram:
|
158 |
+
text_encoder.to(accelerator.device)
|
159 |
+
unet.to(accelerator.device)
|
160 |
+
vae.to(accelerator.device)
|
161 |
+
|
162 |
+
gc.collect()
|
163 |
+
torch.cuda.empty_cache()
|
164 |
+
accelerator.wait_for_everyone()
|
165 |
+
|
166 |
+
# モデルに xformers とか memory efficient attention を組み込む
|
167 |
+
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
168 |
+
|
169 |
+
# 学習を準備する
|
170 |
+
if cache_latents:
|
171 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
172 |
+
vae.requires_grad_(False)
|
173 |
+
vae.eval()
|
174 |
+
with torch.no_grad():
|
175 |
+
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
176 |
+
vae.to("cpu")
|
177 |
+
if torch.cuda.is_available():
|
178 |
+
torch.cuda.empty_cache()
|
179 |
+
gc.collect()
|
180 |
+
|
181 |
+
accelerator.wait_for_everyone()
|
182 |
+
|
183 |
+
# prepare network
|
184 |
+
import sys
|
185 |
+
|
186 |
+
sys.path.append(os.path.dirname(__file__))
|
187 |
+
print("import network module:", args.network_module)
|
188 |
+
network_module = importlib.import_module(args.network_module)
|
189 |
+
|
190 |
+
net_kwargs = {}
|
191 |
+
if args.network_args is not None:
|
192 |
+
for net_arg in args.network_args:
|
193 |
+
key, value = net_arg.split("=")
|
194 |
+
net_kwargs[key] = value
|
195 |
+
|
196 |
+
# if a new network is added in future, add if ~ then blocks for each network (;'∀')
|
197 |
+
network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs)
|
198 |
+
if network is None:
|
199 |
+
return
|
200 |
+
|
201 |
+
if hasattr(network, "prepare_network"):
|
202 |
+
network.prepare_network(args)
|
203 |
+
|
204 |
+
train_unet = not args.network_train_text_encoder_only
|
205 |
+
train_text_encoder = not args.network_train_unet_only
|
206 |
+
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
|
207 |
+
|
208 |
+
if args.network_weights is not None:
|
209 |
+
info = network.load_weights(args.network_weights)
|
210 |
+
print(f"load network weights from {args.network_weights}: {info}")
|
211 |
+
|
212 |
+
if args.gradient_checkpointing:
|
213 |
+
unet.enable_gradient_checkpointing()
|
214 |
+
text_encoder.gradient_checkpointing_enable()
|
215 |
+
network.enable_gradient_checkpointing() # may have no effect
|
216 |
+
|
217 |
+
# 学習に必要なクラスを準備する
|
218 |
+
print("prepare optimizer, data loader etc.")
|
219 |
+
|
220 |
+
# 後方互換性を確保するよ
|
221 |
+
try:
|
222 |
+
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
|
223 |
+
except TypeError:
|
224 |
+
print(
|
225 |
+
"Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
|
226 |
+
)
|
227 |
+
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
228 |
+
|
229 |
+
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
|
230 |
+
|
231 |
+
# dataloaderを準備する
|
232 |
+
# DataLoaderのプロセス数:0はメインプロセスになる
|
233 |
+
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
234 |
+
|
235 |
+
train_dataloader = torch.utils.data.DataLoader(
|
236 |
+
train_dataset_group,
|
237 |
+
batch_size=1,
|
238 |
+
shuffle=True,
|
239 |
+
collate_fn=collater,
|
240 |
+
num_workers=n_workers,
|
241 |
+
persistent_workers=args.persistent_data_loader_workers,
|
242 |
+
)
|
243 |
+
|
244 |
+
# 学習ステップ数を計算する
|
245 |
+
if args.max_train_epochs is not None:
|
246 |
+
args.max_train_steps = args.max_train_epochs * math.ceil(
|
247 |
+
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
248 |
+
)
|
249 |
+
if is_main_process:
|
250 |
+
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
251 |
+
|
252 |
+
# データセット側にも学習ステップを送信
|
253 |
+
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
254 |
+
|
255 |
+
# lr schedulerを用意する
|
256 |
+
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
257 |
+
|
258 |
+
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
259 |
+
if args.full_fp16:
|
260 |
+
assert (
|
261 |
+
args.mixed_precision == "fp16"
|
262 |
+
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
263 |
+
print("enable full fp16 training.")
|
264 |
+
network.to(weight_dtype)
|
265 |
+
|
266 |
+
# acceleratorがなんかよろしくやってくれるらしい
|
267 |
+
if train_unet and train_text_encoder:
|
268 |
+
unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
269 |
+
unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler
|
270 |
+
)
|
271 |
+
elif train_unet:
|
272 |
+
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
273 |
+
unet, network, optimizer, train_dataloader, lr_scheduler
|
274 |
+
)
|
275 |
+
elif train_text_encoder:
|
276 |
+
text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
277 |
+
text_encoder, network, optimizer, train_dataloader, lr_scheduler
|
278 |
+
)
|
279 |
+
else:
|
280 |
+
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
|
281 |
+
|
282 |
+
unet.requires_grad_(False)
|
283 |
+
unet.to(accelerator.device, dtype=weight_dtype)
|
284 |
+
text_encoder.requires_grad_(False)
|
285 |
+
text_encoder.to(accelerator.device)
|
286 |
+
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
|
287 |
+
unet.train()
|
288 |
+
text_encoder.train()
|
289 |
+
|
290 |
+
# set top parameter requires_grad = True for gradient checkpointing works
|
291 |
+
if type(text_encoder) == DDP:
|
292 |
+
text_encoder.module.text_model.embeddings.requires_grad_(True)
|
293 |
+
else:
|
294 |
+
text_encoder.text_model.embeddings.requires_grad_(True)
|
295 |
+
else:
|
296 |
+
unet.eval()
|
297 |
+
text_encoder.eval()
|
298 |
+
|
299 |
+
# support DistributedDataParallel
|
300 |
+
if type(text_encoder) == DDP:
|
301 |
+
text_encoder = text_encoder.module
|
302 |
+
unet = unet.module
|
303 |
+
network = network.module
|
304 |
+
|
305 |
+
network.prepare_grad_etc(text_encoder, unet)
|
306 |
+
|
307 |
+
if not cache_latents:
|
308 |
+
vae.requires_grad_(False)
|
309 |
+
vae.eval()
|
310 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
311 |
+
|
312 |
+
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
313 |
+
if args.full_fp16:
|
314 |
+
train_util.patch_accelerator_for_fp16_training(accelerator)
|
315 |
+
|
316 |
+
# resumeする
|
317 |
+
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
318 |
+
|
319 |
+
# epoch数を計算する
|
320 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
321 |
+
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
322 |
+
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
323 |
+
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
324 |
+
|
325 |
+
# 学習する
|
326 |
+
# TODO: find a way to handle total batch size when there are multiple datasets
|
327 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
328 |
+
|
329 |
+
if is_main_process:
|
330 |
+
print("running training / 学習開始")
|
331 |
+
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
332 |
+
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
333 |
+
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
334 |
+
print(f" num epochs / epoch数: {num_train_epochs}")
|
335 |
+
print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
|
336 |
+
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
337 |
+
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
338 |
+
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
339 |
+
|
340 |
+
# TODO refactor metadata creation and move to util
|
341 |
+
metadata = {
|
342 |
+
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
|
343 |
+
"ss_training_started_at": training_started_at, # unix timestamp
|
344 |
+
"ss_output_name": args.output_name,
|
345 |
+
"ss_learning_rate": args.learning_rate,
|
346 |
+
"ss_text_encoder_lr": args.text_encoder_lr,
|
347 |
+
"ss_unet_lr": args.unet_lr,
|
348 |
+
"ss_num_train_images": train_dataset_group.num_train_images,
|
349 |
+
"ss_num_reg_images": train_dataset_group.num_reg_images,
|
350 |
+
"ss_num_batches_per_epoch": len(train_dataloader),
|
351 |
+
"ss_num_epochs": num_train_epochs,
|
352 |
+
"ss_gradient_checkpointing": args.gradient_checkpointing,
|
353 |
+
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
|
354 |
+
"ss_max_train_steps": args.max_train_steps,
|
355 |
+
"ss_lr_warmup_steps": args.lr_warmup_steps,
|
356 |
+
"ss_lr_scheduler": args.lr_scheduler,
|
357 |
+
"ss_network_module": args.network_module,
|
358 |
+
"ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
|
359 |
+
"ss_network_alpha": args.network_alpha, # some networks may not use this value
|
360 |
+
"ss_mixed_precision": args.mixed_precision,
|
361 |
+
"ss_full_fp16": bool(args.full_fp16),
|
362 |
+
"ss_v2": bool(args.v2),
|
363 |
+
"ss_clip_skip": args.clip_skip,
|
364 |
+
"ss_max_token_length": args.max_token_length,
|
365 |
+
"ss_cache_latents": bool(args.cache_latents),
|
366 |
+
"ss_seed": args.seed,
|
367 |
+
"ss_lowram": args.lowram,
|
368 |
+
"ss_noise_offset": args.noise_offset,
|
369 |
+
"ss_training_comment": args.training_comment, # will not be updated after training
|
370 |
+
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
|
371 |
+
"ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
|
372 |
+
"ss_max_grad_norm": args.max_grad_norm,
|
373 |
+
"ss_caption_dropout_rate": args.caption_dropout_rate,
|
374 |
+
"ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs,
|
375 |
+
"ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
|
376 |
+
"ss_face_crop_aug_range": args.face_crop_aug_range,
|
377 |
+
"ss_prior_loss_weight": args.prior_loss_weight,
|
378 |
+
"ss_min_snr_gamma": args.min_snr_gamma,
|
379 |
+
}
|
380 |
+
|
381 |
+
if use_user_config:
|
382 |
+
# save metadata of multiple datasets
|
383 |
+
# NOTE: pack "ss_datasets" value as json one time
|
384 |
+
# or should also pack nested collections as json?
|
385 |
+
datasets_metadata = []
|
386 |
+
tag_frequency = {} # merge tag frequency for metadata editor
|
387 |
+
dataset_dirs_info = {} # merge subset dirs for metadata editor
|
388 |
+
|
389 |
+
for dataset in train_dataset_group.datasets:
|
390 |
+
is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset)
|
391 |
+
dataset_metadata = {
|
392 |
+
"is_dreambooth": is_dreambooth_dataset,
|
393 |
+
"batch_size_per_device": dataset.batch_size,
|
394 |
+
"num_train_images": dataset.num_train_images, # includes repeating
|
395 |
+
"num_reg_images": dataset.num_reg_images,
|
396 |
+
"resolution": (dataset.width, dataset.height),
|
397 |
+
"enable_bucket": bool(dataset.enable_bucket),
|
398 |
+
"min_bucket_reso": dataset.min_bucket_reso,
|
399 |
+
"max_bucket_reso": dataset.max_bucket_reso,
|
400 |
+
"tag_frequency": dataset.tag_frequency,
|
401 |
+
"bucket_info": dataset.bucket_info,
|
402 |
+
}
|
403 |
+
|
404 |
+
subsets_metadata = []
|
405 |
+
for subset in dataset.subsets:
|
406 |
+
subset_metadata = {
|
407 |
+
"img_count": subset.img_count,
|
408 |
+
"num_repeats": subset.num_repeats,
|
409 |
+
"color_aug": bool(subset.color_aug),
|
410 |
+
"flip_aug": bool(subset.flip_aug),
|
411 |
+
"random_crop": bool(subset.random_crop),
|
412 |
+
"shuffle_caption": bool(subset.shuffle_caption),
|
413 |
+
"keep_tokens": subset.keep_tokens,
|
414 |
+
}
|
415 |
+
|
416 |
+
image_dir_or_metadata_file = None
|
417 |
+
if subset.image_dir:
|
418 |
+
image_dir = os.path.basename(subset.image_dir)
|
419 |
+
subset_metadata["image_dir"] = image_dir
|
420 |
+
image_dir_or_metadata_file = image_dir
|
421 |
+
|
422 |
+
if is_dreambooth_dataset:
|
423 |
+
subset_metadata["class_tokens"] = subset.class_tokens
|
424 |
+
subset_metadata["is_reg"] = subset.is_reg
|
425 |
+
if subset.is_reg:
|
426 |
+
image_dir_or_metadata_file = None # not merging reg dataset
|
427 |
+
else:
|
428 |
+
metadata_file = os.path.basename(subset.metadata_file)
|
429 |
+
subset_metadata["metadata_file"] = metadata_file
|
430 |
+
image_dir_or_metadata_file = metadata_file # may overwrite
|
431 |
+
|
432 |
+
subsets_metadata.append(subset_metadata)
|
433 |
+
|
434 |
+
# merge dataset dir: not reg subset only
|
435 |
+
# TODO update additional-network extension to show detailed dataset config from metadata
|
436 |
+
if image_dir_or_metadata_file is not None:
|
437 |
+
# datasets may have a certain dir multiple times
|
438 |
+
v = image_dir_or_metadata_file
|
439 |
+
i = 2
|
440 |
+
while v in dataset_dirs_info:
|
441 |
+
v = image_dir_or_metadata_file + f" ({i})"
|
442 |
+
i += 1
|
443 |
+
image_dir_or_metadata_file = v
|
444 |
+
|
445 |
+
dataset_dirs_info[image_dir_or_metadata_file] = {"n_repeats": subset.num_repeats, "img_count": subset.img_count}
|
446 |
+
|
447 |
+
dataset_metadata["subsets"] = subsets_metadata
|
448 |
+
datasets_metadata.append(dataset_metadata)
|
449 |
+
|
450 |
+
# merge tag frequency:
|
451 |
+
for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items():
|
452 |
+
# あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える
|
453 |
+
# もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない
|
454 |
+
# なので、ここで複数datasetの回数を合算してもあまり意味はない
|
455 |
+
if ds_dir_name in tag_frequency:
|
456 |
+
continue
|
457 |
+
tag_frequency[ds_dir_name] = ds_freq_for_dir
|
458 |
+
|
459 |
+
metadata["ss_datasets"] = json.dumps(datasets_metadata)
|
460 |
+
metadata["ss_tag_frequency"] = json.dumps(tag_frequency)
|
461 |
+
metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info)
|
462 |
+
else:
|
463 |
+
# conserving backward compatibility when using train_dataset_dir and reg_dataset_dir
|
464 |
+
assert (
|
465 |
+
len(train_dataset_group.datasets) == 1
|
466 |
+
), f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。"
|
467 |
+
|
468 |
+
dataset = train_dataset_group.datasets[0]
|
469 |
+
|
470 |
+
dataset_dirs_info = {}
|
471 |
+
reg_dataset_dirs_info = {}
|
472 |
+
if use_dreambooth_method:
|
473 |
+
for subset in dataset.subsets:
|
474 |
+
info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info
|
475 |
+
info[os.path.basename(subset.image_dir)] = {"n_repeats": subset.num_repeats, "img_count": subset.img_count}
|
476 |
+
else:
|
477 |
+
for subset in dataset.subsets:
|
478 |
+
dataset_dirs_info[os.path.basename(subset.metadata_file)] = {
|
479 |
+
"n_repeats": subset.num_repeats,
|
480 |
+
"img_count": subset.img_count,
|
481 |
+
}
|
482 |
+
|
483 |
+
metadata.update(
|
484 |
+
{
|
485 |
+
"ss_batch_size_per_device": args.train_batch_size,
|
486 |
+
"ss_total_batch_size": total_batch_size,
|
487 |
+
"ss_resolution": args.resolution,
|
488 |
+
"ss_color_aug": bool(args.color_aug),
|
489 |
+
"ss_flip_aug": bool(args.flip_aug),
|
490 |
+
"ss_random_crop": bool(args.random_crop),
|
491 |
+
"ss_shuffle_caption": bool(args.shuffle_caption),
|
492 |
+
"ss_enable_bucket": bool(dataset.enable_bucket),
|
493 |
+
"ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
|
494 |
+
"ss_min_bucket_reso": dataset.min_bucket_reso,
|
495 |
+
"ss_max_bucket_reso": dataset.max_bucket_reso,
|
496 |
+
"ss_keep_tokens": args.keep_tokens,
|
497 |
+
"ss_dataset_dirs": json.dumps(dataset_dirs_info),
|
498 |
+
"ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
|
499 |
+
"ss_tag_frequency": json.dumps(dataset.tag_frequency),
|
500 |
+
"ss_bucket_info": json.dumps(dataset.bucket_info),
|
501 |
+
}
|
502 |
+
)
|
503 |
+
|
504 |
+
# add extra args
|
505 |
+
if args.network_args:
|
506 |
+
metadata["ss_network_args"] = json.dumps(net_kwargs)
|
507 |
+
|
508 |
+
# model name and hash
|
509 |
+
if args.pretrained_model_name_or_path is not None:
|
510 |
+
sd_model_name = args.pretrained_model_name_or_path
|
511 |
+
if os.path.exists(sd_model_name):
|
512 |
+
metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name)
|
513 |
+
metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name)
|
514 |
+
sd_model_name = os.path.basename(sd_model_name)
|
515 |
+
metadata["ss_sd_model_name"] = sd_model_name
|
516 |
+
|
517 |
+
if args.vae is not None:
|
518 |
+
vae_name = args.vae
|
519 |
+
if os.path.exists(vae_name):
|
520 |
+
metadata["ss_vae_hash"] = train_util.model_hash(vae_name)
|
521 |
+
metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name)
|
522 |
+
vae_name = os.path.basename(vae_name)
|
523 |
+
metadata["ss_vae_name"] = vae_name
|
524 |
+
|
525 |
+
metadata = {k: str(v) for k, v in metadata.items()}
|
526 |
+
|
527 |
+
# make minimum metadata for filtering
|
528 |
+
minimum_keys = ["ss_network_module", "ss_network_dim", "ss_network_alpha", "ss_network_args"]
|
529 |
+
minimum_metadata = {}
|
530 |
+
for key in minimum_keys:
|
531 |
+
if key in metadata:
|
532 |
+
minimum_metadata[key] = metadata[key]
|
533 |
+
|
534 |
+
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
535 |
+
global_step = 0
|
536 |
+
|
537 |
+
noise_scheduler = DDPMScheduler(
|
538 |
+
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
539 |
+
)
|
540 |
+
if accelerator.is_main_process:
|
541 |
+
accelerator.init_trackers("network_train")
|
542 |
+
|
543 |
+
loss_list = []
|
544 |
+
loss_total = 0.0
|
545 |
+
del train_dataset_group
|
546 |
+
|
547 |
+
# if hasattr(network, "on_step_start"):
|
548 |
+
# on_step_start = network.on_step_start
|
549 |
+
# else:
|
550 |
+
# on_step_start = lambda *args, **kwargs: None
|
551 |
+
|
552 |
+
for epoch in range(num_train_epochs):
|
553 |
+
if is_main_process:
|
554 |
+
print(f"epoch {epoch+1}/{num_train_epochs}")
|
555 |
+
current_epoch.value = epoch + 1
|
556 |
+
|
557 |
+
metadata["ss_epoch"] = str(epoch + 1)
|
558 |
+
|
559 |
+
network.on_epoch_start(text_encoder, unet)
|
560 |
+
|
561 |
+
for step, batch in enumerate(train_dataloader):
|
562 |
+
current_step.value = global_step
|
563 |
+
with accelerator.accumulate(network):
|
564 |
+
# on_step_start(text_encoder, unet)
|
565 |
+
|
566 |
+
with torch.no_grad():
|
567 |
+
if "latents" in batch and batch["latents"] is not None:
|
568 |
+
latents = batch["latents"].to(accelerator.device)
|
569 |
+
else:
|
570 |
+
# latentに変換
|
571 |
+
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
572 |
+
latents = latents * 0.18215
|
573 |
+
b_size = latents.shape[0]
|
574 |
+
|
575 |
+
with torch.set_grad_enabled(train_text_encoder):
|
576 |
+
# Get the text embedding for conditioning
|
577 |
+
if args.weighted_captions:
|
578 |
+
encoder_hidden_states = get_weighted_text_embeddings(
|
579 |
+
tokenizer,
|
580 |
+
text_encoder,
|
581 |
+
batch["captions"],
|
582 |
+
accelerator.device,
|
583 |
+
args.max_token_length // 75 if args.max_token_length else 1,
|
584 |
+
clip_skip=args.clip_skip,
|
585 |
+
)
|
586 |
+
else:
|
587 |
+
input_ids = batch["input_ids"].to(accelerator.device)
|
588 |
+
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype)
|
589 |
+
# Sample noise that we'll add to the latents
|
590 |
+
noise = torch.randn_like(latents, device=latents.device)
|
591 |
+
if args.noise_offset:
|
592 |
+
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
593 |
+
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
594 |
+
|
595 |
+
# Sample a random timestep for each image
|
596 |
+
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
597 |
+
timesteps = timesteps.long()
|
598 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
599 |
+
# (this is the forward diffusion process)
|
600 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
601 |
+
|
602 |
+
# Predict the noise residual
|
603 |
+
with accelerator.autocast():
|
604 |
+
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
605 |
+
|
606 |
+
if args.v_parameterization:
|
607 |
+
# v-parameterization training
|
608 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
609 |
+
else:
|
610 |
+
target = noise
|
611 |
+
|
612 |
+
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
613 |
+
loss = loss.mean([1, 2, 3])
|
614 |
+
|
615 |
+
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
616 |
+
loss = loss * loss_weights
|
617 |
+
|
618 |
+
if args.min_snr_gamma:
|
619 |
+
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
620 |
+
|
621 |
+
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
622 |
+
|
623 |
+
accelerator.backward(loss)
|
624 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
625 |
+
params_to_clip = network.get_trainable_params()
|
626 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
627 |
+
|
628 |
+
optimizer.step()
|
629 |
+
lr_scheduler.step()
|
630 |
+
optimizer.zero_grad(set_to_none=True)
|
631 |
+
|
632 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
633 |
+
if accelerator.sync_gradients:
|
634 |
+
progress_bar.update(1)
|
635 |
+
global_step += 1
|
636 |
+
|
637 |
+
train_util.sample_images(
|
638 |
+
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
|
639 |
+
)
|
640 |
+
|
641 |
+
current_loss = loss.detach().item()
|
642 |
+
if epoch == 0:
|
643 |
+
loss_list.append(current_loss)
|
644 |
+
else:
|
645 |
+
loss_total -= loss_list[step]
|
646 |
+
loss_list[step] = current_loss
|
647 |
+
loss_total += current_loss
|
648 |
+
avr_loss = loss_total / len(loss_list)
|
649 |
+
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
650 |
+
progress_bar.set_postfix(**logs)
|
651 |
+
|
652 |
+
if args.logging_dir is not None:
|
653 |
+
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
|
654 |
+
accelerator.log(logs, step=global_step)
|
655 |
+
|
656 |
+
if global_step >= args.max_train_steps:
|
657 |
+
break
|
658 |
+
|
659 |
+
if args.logging_dir is not None:
|
660 |
+
logs = {"loss/epoch": loss_total / len(loss_list)}
|
661 |
+
accelerator.log(logs, step=epoch + 1)
|
662 |
+
|
663 |
+
accelerator.wait_for_everyone()
|
664 |
+
|
665 |
+
if args.save_every_n_epochs is not None:
|
666 |
+
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
667 |
+
|
668 |
+
def save_func():
|
669 |
+
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + "." + args.save_model_as
|
670 |
+
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
671 |
+
metadata["ss_training_finished_at"] = str(time.time())
|
672 |
+
print(f"saving checkpoint: {ckpt_file}")
|
673 |
+
unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
674 |
+
if args.huggingface_repo_id is not None:
|
675 |
+
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
|
676 |
+
|
677 |
+
def remove_old_func(old_epoch_no):
|
678 |
+
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
|
679 |
+
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
680 |
+
if os.path.exists(old_ckpt_file):
|
681 |
+
print(f"removing old checkpoint: {old_ckpt_file}")
|
682 |
+
os.remove(old_ckpt_file)
|
683 |
+
|
684 |
+
if is_main_process:
|
685 |
+
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
|
686 |
+
if saving and args.save_state:
|
687 |
+
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
688 |
+
|
689 |
+
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
690 |
+
|
691 |
+
# end of epoch
|
692 |
+
|
693 |
+
metadata["ss_epoch"] = str(num_train_epochs)
|
694 |
+
metadata["ss_training_finished_at"] = str(time.time())
|
695 |
+
|
696 |
+
if is_main_process:
|
697 |
+
network = unwrap_model(network)
|
698 |
+
|
699 |
+
accelerator.end_training()
|
700 |
+
|
701 |
+
if args.save_state:
|
702 |
+
train_util.save_state_on_train_end(args, accelerator)
|
703 |
+
|
704 |
+
del accelerator # この後メモリを使うのでこれは消す
|
705 |
+
|
706 |
+
if is_main_process:
|
707 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
708 |
+
|
709 |
+
model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
710 |
+
ckpt_name = model_name + "." + args.save_model_as
|
711 |
+
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
712 |
+
|
713 |
+
print(f"save trained model to {ckpt_file}")
|
714 |
+
network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
715 |
+
if args.huggingface_repo_id is not None:
|
716 |
+
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
|
717 |
+
print("model saved.")
|
718 |
+
|
719 |
+
|
720 |
+
def setup_parser() -> argparse.ArgumentParser:
|
721 |
+
parser = argparse.ArgumentParser()
|
722 |
+
|
723 |
+
train_util.add_sd_models_arguments(parser)
|
724 |
+
train_util.add_dataset_arguments(parser, True, True, True)
|
725 |
+
train_util.add_training_arguments(parser, True)
|
726 |
+
train_util.add_optimizer_arguments(parser)
|
727 |
+
config_util.add_config_arguments(parser)
|
728 |
+
custom_train_functions.add_custom_train_arguments(parser)
|
729 |
+
|
730 |
+
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
731 |
+
parser.add_argument(
|
732 |
+
"--save_model_as",
|
733 |
+
type=str,
|
734 |
+
default="safetensors",
|
735 |
+
choices=[None, "ckpt", "pt", "safetensors"],
|
736 |
+
help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)",
|
737 |
+
)
|
738 |
+
|
739 |
+
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
740 |
+
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
741 |
+
|
742 |
+
parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み")
|
743 |
+
parser.add_argument("--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール")
|
744 |
+
parser.add_argument(
|
745 |
+
"--network_dim", type=int, default=None, help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)"
|
746 |
+
)
|
747 |
+
parser.add_argument(
|
748 |
+
"--network_alpha",
|
749 |
+
type=float,
|
750 |
+
default=1,
|
751 |
+
help="alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)",
|
752 |
+
)
|
753 |
+
parser.add_argument(
|
754 |
+
"--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数"
|
755 |
+
)
|
756 |
+
parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する")
|
757 |
+
parser.add_argument(
|
758 |
+
"--network_train_text_encoder_only", action="store_true", help="only training Text Encoder part / Text Encoder関連部分のみ学習する"
|
759 |
+
)
|
760 |
+
parser.add_argument(
|
761 |
+
"--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列"
|
762 |
+
)
|
763 |
+
|
764 |
+
return parser
|
765 |
+
|
766 |
+
|
767 |
+
if __name__ == "__main__":
|
768 |
+
parser = setup_parser()
|
769 |
+
|
770 |
+
args = parser.parse_args()
|
771 |
+
args = train_util.read_config_from_file(args, parser)
|
772 |
+
|
773 |
+
train(args)
|
train_network_README-ja.md
ADDED
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LoRAの学習について
|
2 |
+
|
3 |
+
[LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685)(arxiv)、[LoRA](https://github.com/microsoft/LoRA)(github)をStable Diffusionに適用したものです。
|
4 |
+
|
5 |
+
[cloneofsimo氏のリポジトリ](https://github.com/cloneofsimo/lora)を大いに参考にさせていただきました。ありがとうございます。
|
6 |
+
|
7 |
+
通常のLoRAは Linear およぴカーネルサイズ 1x1 の Conv2d にのみ適用されますが、カーネルサイズ 3x3 のConv2dに適用を拡大することもできます。
|
8 |
+
|
9 |
+
Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora) が最初にリリースし、KohakuBlueleaf氏が [LoCon](https://github.com/KohakuBlueleaf/LoCon) でその有効性を明らかにしたものです。KohakuBlueleaf氏に深く感謝します。
|
10 |
+
|
11 |
+
8GB VRAMでもぎりぎり動作するようです。
|
12 |
+
|
13 |
+
[学習についての共通ドキュメント](./train_README-ja.md) もあわせてご覧ください。
|
14 |
+
|
15 |
+
# 学習できるLoRAの種類
|
16 |
+
|
17 |
+
以下の二種類をサポートします。以下は当リポジトリ内の独自の名称です。
|
18 |
+
|
19 |
+
1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers、リエラと読みます)
|
20 |
+
|
21 |
+
Linear およびカーネルサイズ 1x1 の Conv2d に適用されるLoRA
|
22 |
+
|
23 |
+
2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers、セリアと読みます)
|
24 |
+
|
25 |
+
1.に加え、カーネルサイズ 3x3 の Conv2d に適用されるLoRA
|
26 |
+
|
27 |
+
LoRA-LierLaに比べ、LoRA-C3Liarは適用される層が増える分、高い精度が期待できるかもしれません。
|
28 |
+
|
29 |
+
また学習時は __DyLoRA__ を使用することもできます(後述します)。
|
30 |
+
|
31 |
+
## 学習したモデルに関する注意
|
32 |
+
|
33 |
+
LoRA-LierLa は、AUTOMATIC1111氏のWeb UIのLoRA機能で使用することができます。
|
34 |
+
|
35 |
+
LoRA-C3Liarを使いWeb UIで生成するには、こちらの[WebUI用extension](https://github.com/kohya-ss/sd-webui-additional-networks)を使ってください。
|
36 |
+
|
37 |
+
いずれも学習したLoRAのモデルを、Stable Diffusionのモデルにこのリポジトリ内のスクリプトであらかじめマージすることもできます。
|
38 |
+
|
39 |
+
cloneofsimo氏のリポジトリ、およびd8ahazard氏の[Dreambooth Extension for Stable-Diffusion-WebUI](https://github.com/d8ahazard/sd_dreambooth_extension)とは、現時点では互換性がありません。いくつかの機能拡張を行っているためです(後述)。
|
40 |
+
|
41 |
+
# 学習の手順
|
42 |
+
|
43 |
+
あらかじめこのリポジトリのREADMEを参照し、環境整備を行ってください。
|
44 |
+
|
45 |
+
## データの準備
|
46 |
+
|
47 |
+
[学習データの準備について](./train_README-ja.md) を参照してください。
|
48 |
+
|
49 |
+
|
50 |
+
## 学習の実行
|
51 |
+
|
52 |
+
`train_network.py`を用います。
|
53 |
+
|
54 |
+
`train_network.py`では `--network_module` オプションに、学習対象のモジュール名を指定します。LoRAに対応するのは`network.lora`となりますので、それを指定してください。
|
55 |
+
|
56 |
+
なお学習率は通常のDreamBoothやfine tuningよりも高めの、`1e-4`~`1e-3`程度を指定するとよいようです。
|
57 |
+
|
58 |
+
以下はコマンドラインの例です。
|
59 |
+
|
60 |
+
```
|
61 |
+
accelerate launch --num_cpu_threads_per_process 1 train_network.py
|
62 |
+
--pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ>
|
63 |
+
--dataset_config=<データ準備で作成した.tomlファイル>
|
64 |
+
--output_dir=<学習したモデルの出力先フォルダ>
|
65 |
+
--output_name=<学習したモデル出力時のファイル名>
|
66 |
+
--save_model_as=safetensors
|
67 |
+
--prior_loss_weight=1.0
|
68 |
+
--max_train_steps=400
|
69 |
+
--learning_rate=1e-4
|
70 |
+
--optimizer_type="AdamW8bit"
|
71 |
+
--xformers
|
72 |
+
--mixed_precision="fp16"
|
73 |
+
--cache_latents
|
74 |
+
--gradient_checkpointing
|
75 |
+
--save_every_n_epochs=1
|
76 |
+
--network_module=networks.lora
|
77 |
+
```
|
78 |
+
|
79 |
+
このコマンドラインでは LoRA-LierLa が学習されます。
|
80 |
+
|
81 |
+
`--output_dir` オプションで指定したフォルダに、LoRAのモデルが保存されます。他のオプション、オプティマイザ等については [学習の共通ドキュメント](./train_README-ja.md) の「よく使われるオプション」も参照してください。
|
82 |
+
|
83 |
+
その他、以下のオプションが指定できます。
|
84 |
+
|
85 |
+
* `--network_dim`
|
86 |
+
* LoRAのRANKを指定します(``--networkdim=4``など)。省略時は4になります。数が多いほど表現力は増しますが、学習に必要なメモリ、時間は増えます。また闇雲に増やしても良くないようです。
|
87 |
+
* `--network_alpha`
|
88 |
+
* アンダーフローを防ぎ安定して学習するための ``alpha`` 値を指定します。デフォルトは1です。``network_dim``と同じ値を指定すると以前のバージョンと同じ動作になります。
|
89 |
+
* `--persistent_data_loader_workers`
|
90 |
+
* Windows環境で指定するとエポック間��待ち時間が大幅に短縮されます。
|
91 |
+
* `--max_data_loader_n_workers`
|
92 |
+
* データ読み込みのプロセス数を指定します。プロセス数が多いとデータ読み込みが速くなりGPUを効率的に利用できますが、メインメモリを消費します。デフォルトは「`8` または `CPU同時実行スレッド数-1` の小さいほう」なので、メインメモリに余裕がない場合や、GPU使用率が90%程度以上なら、それらの数値を見ながら `2` または `1` 程度まで下げてください。
|
93 |
+
* `--network_weights`
|
94 |
+
* 学習前に学習済みのLoRAの重みを読み込み、そこから追加で学習します。
|
95 |
+
* `--network_train_unet_only`
|
96 |
+
* U-Netに関連するLoRAモジュールのみ有効とします。fine tuning的な学習で指定するとよいかもしれません。
|
97 |
+
* `--network_train_text_encoder_only`
|
98 |
+
* Text Encoderに関連するLoRAモジュールのみ有効とします。Textual Inversion的な効果が期待できるかもしれません。
|
99 |
+
* `--unet_lr`
|
100 |
+
* U-Netに関連するLoRAモジュールに、通常の学習率(--learning_rateオプションで指定)とは異なる学習率を使う時に指定します。
|
101 |
+
* `--text_encoder_lr`
|
102 |
+
* Text Encoderに関連するLoRAモジュールに、通常の学習率(--learning_rateオプションで指定)とは異なる学習率を使う時に指定します。Text Encoderのほうを若干低めの学習率(5e-5など)にしたほうが良い、という話もあるようです。
|
103 |
+
* `--network_args`
|
104 |
+
* 複数の引数を指定できます。後述します。
|
105 |
+
|
106 |
+
`--network_train_unet_only` と `--network_train_text_encoder_only` の両方とも未指定時(デフォルト)はText EncoderとU-Netの両方のLoRAモジュールを有効にします。
|
107 |
+
|
108 |
+
# その他の学習方法
|
109 |
+
|
110 |
+
## LoRA-C3Lier を学習する
|
111 |
+
|
112 |
+
`--network_args` に以下のように指定してください。`conv_dim` で Conv2d (3x3) の rank を、`conv_alpha` で alpha を指定してください。
|
113 |
+
|
114 |
+
```
|
115 |
+
--network_args "conv_dim=4" "conv_alpha=1"
|
116 |
+
```
|
117 |
+
|
118 |
+
以下のように alpha 省略時は1になります。
|
119 |
+
|
120 |
+
```
|
121 |
+
--network_args "conv_dim=4"
|
122 |
+
```
|
123 |
+
|
124 |
+
## DyLoRA
|
125 |
+
|
126 |
+
DyLoRAはこちらの論文で提案されたものです。[DyLoRA: Parameter Efficient Tuning of Pre-trained Models using Dynamic Search-Free Low-Rank Adaptation](https://arxiv.org/abs/2210.07558) 公式実装は[こちら](https://github.com/huawei-noah/KD-NLP/tree/main/DyLoRA)です。
|
127 |
+
|
128 |
+
論文によると、LoRAのrankは必ずしも高いほうが良いわけではなく、対象のモデル、データセット、タスクなどにより適切なrankを探す必要があるようです。DyLoRAを使うと、指定したdim(rank)以下のさまざまなrankで同時にLoRAを学習します。これにより最適なrankをそれぞれ学習して探す手間を省くことができます。
|
129 |
+
|
130 |
+
当リポジトリの実装は公式実装をベースに独自の拡張を加えています(そのため不具合などあるかもしれません)。
|
131 |
+
|
132 |
+
### 当リポジトリのDyLoRAの特徴
|
133 |
+
|
134 |
+
学習後のDyLoRAのモデルファイルはLoRAと互換性があります。また、モデルファイルから指定したdim(rank)以下の複数のdimのLoRAを抽出できます。
|
135 |
+
|
136 |
+
DyLoRA-LierLa、DyLoRA-C3Lierのどちらも学習できます。
|
137 |
+
|
138 |
+
### DyLoRAで学習する
|
139 |
+
|
140 |
+
`--network_module=networks.dylora` のように、DyLoRAに対応する`network.dylora`を指定してください。
|
141 |
+
|
142 |
+
また `--network_args` に、たとえば`--network_args "unit=4"`のように`unit`を指定します。`unit`はrankを分割する単位です。たとえば`--network_dim=16 --network_args "unit=4"` のように指定します。`unit`は`network_dim`を割り切れる値(`network_dim`は`unit`の倍数)としてください。
|
143 |
+
|
144 |
+
`unit`を指定しない場合は、`unit=1`として扱われます。
|
145 |
+
|
146 |
+
記述例は以下です。
|
147 |
+
|
148 |
+
```
|
149 |
+
--network_module=networks.dylora --network_dim=16 --network_args "unit=4"
|
150 |
+
|
151 |
+
--network_module=networks.dylora --network_dim=32 --network_alpha=16 --network_args "unit=4"
|
152 |
+
```
|
153 |
+
|
154 |
+
DyLoRA-C3Lierの場合は、`--network_args` に`"conv_dim=4"`のように`conv_dim`を指定します。通常のLoRAと異なり、`conv_dim`は`network_dim`と同じ値である必要があります。記述例は以下です。
|
155 |
+
|
156 |
+
```
|
157 |
+
--network_module=networks.dylora --network_dim=16 --network_args "conv_dim=16" "unit=4"
|
158 |
+
|
159 |
+
--network_module=networks.dylora --network_dim=32 --network_alpha=16 --network_args "conv_dim=32" "conv_alpha=16" "unit=8"
|
160 |
+
```
|
161 |
+
|
162 |
+
たとえばdim=16、unit=4(後述)で学習すると、4、8、12、16の4つのrankのLoRAを学習、抽出できます。抽出した各モデルで画像を生成し、比較することで、最適なrankのLoRAを選択できます。
|
163 |
+
|
164 |
+
その他のオプションは通常のLoRAと同じです。
|
165 |
+
|
166 |
+
※ `unit`は当リポジトリの独自拡張で、DyLoRAでは同dim(rank)の通常LoRAに比べると学習時間が長くな��ことが予想されるため、分割単位を大きくしたものです。
|
167 |
+
|
168 |
+
### DyLoRAのモデルからLoRAモデルを抽出する
|
169 |
+
|
170 |
+
`networks`フォルダ内の `extract_lora_from_dylora.py`を使用します。指定した`unit`単位で、DyLoRAのモデルからLoRAのモデルを抽出します。
|
171 |
+
|
172 |
+
コマンドラインはたとえば以下のようになります。
|
173 |
+
|
174 |
+
```powershell
|
175 |
+
python networks\extract_lora_from_dylora.py --model "foldername/dylora-model.safetensors" --save_to "foldername/dylora-model-split.safetensors" --unit 4
|
176 |
+
```
|
177 |
+
|
178 |
+
`--model` にはDyLoRAのモデルファイルを指定します。`--save_to` には抽出したモデルを保存するファイル名を指定します(rankの数値がファイル名に付加されます)。`--unit` にはDyLoRAの学習時の`unit`を指定します。
|
179 |
+
|
180 |
+
## 階層別学習率
|
181 |
+
|
182 |
+
詳細は[PR #355](https://github.com/kohya-ss/sd-scripts/pull/355) をご覧ください。
|
183 |
+
|
184 |
+
フルモデルの25個のブロックの重みを指定できます。最初のブロックに該当するLoRAは存在しませんが、階層別LoRA適用等との互換性のために25個としています。またconv2d3x3に拡張しない場合も一部のブロックにはLoRAが存在しませんが、記述を統一するため常に25個の値を指定してください。
|
185 |
+
|
186 |
+
`--network_args` で以下の引数を指定してください。
|
187 |
+
|
188 |
+
- `down_lr_weight` : U-Netのdown blocksの学習率の重みを指定します。以下が指定可能です。
|
189 |
+
- ブロックごとの重み : `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"` のように12個の数値を指定します。
|
190 |
+
- プリセットからの指定 : `"down_lr_weight=sine"` のように指定します(サインカーブで重みを指定します)。sine, cosine, linear, reverse_linear, zeros が指定可能です。また `"down_lr_weight=cosine+.25"` のように `+数値` を追加すると、指定した数値を加算します(0.25~1.25になります)。
|
191 |
+
- `mid_lr_weight` : U-Netのmid blockの学習率の重みを指定します。`"down_lr_weight=0.5"` のように数値を一つだけ指定します。
|
192 |
+
- `up_lr_weight` : U-Netのup blocksの学習率の重みを指定します。down_lr_weightと同様です。
|
193 |
+
- 指定を省略した部分は1.0として扱われます。また重みを0にするとそのブロックのLoRAモジュールは作成されません。
|
194 |
+
- `block_lr_zero_threshold` : 重みがこの値以下の場合、LoRAモジュールを作成しません。デフォルトは0です。
|
195 |
+
|
196 |
+
### 階層別学習率コマンドライン指定例:
|
197 |
+
|
198 |
+
```powershell
|
199 |
+
--network_args "down_lr_weight=0.5,0.5,0.5,0.5,1.0,1.0,1.0,1.0,1.5,1.5,1.5,1.5" "mid_lr_weight=2.0" "up_lr_weight=1.5,1.5,1.5,1.5,1.0,1.0,1.0,1.0,0.5,0.5,0.5,0.5"
|
200 |
+
|
201 |
+
--network_args "block_lr_zero_threshold=0.1" "down_lr_weight=sine+.5" "mid_lr_weight=1.5" "up_lr_weight=cosine+.5"
|
202 |
+
```
|
203 |
+
|
204 |
+
### 階層別学習率tomlファイル指定例:
|
205 |
+
|
206 |
+
```toml
|
207 |
+
network_args = [ "down_lr_weight=0.5,0.5,0.5,0.5,1.0,1.0,1.0,1.0,1.5,1.5,1.5,1.5", "mid_lr_weight=2.0", "up_lr_weight=1.5,1.5,1.5,1.5,1.0,1.0,1.0,1.0,0.5,0.5,0.5,0.5",]
|
208 |
+
|
209 |
+
network_args = [ "block_lr_zero_threshold=0.1", "down_lr_weight=sine+.5", "mid_lr_weight=1.5", "up_lr_weight=cosine+.5", ]
|
210 |
+
```
|
211 |
+
|
212 |
+
## 階層別dim (rank)
|
213 |
+
|
214 |
+
フルモデルの25個のブロックのdim (rank)を指定できます。階層別学習率と同様に一部のブロックにはLoRAが存在しない場合がありますが、常に25個の値を指定してください。
|
215 |
+
|
216 |
+
`--network_args` で以下の引数を指定してください。
|
217 |
+
|
218 |
+
- `block_dims` : 各ブロックのdim (rank)を指定します。`"block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"` のように25個の数値を指定します。
|
219 |
+
- `block_alphas` : 各ブロックのalphaを指定します。block_dimsと同様に25個の数値を指定します。省略時はnetwork_alphaの値が使用されます。
|
220 |
+
- `conv_block_dims` : LoRAをConv2d 3x3に拡張し、各ブロックのdim (rank)を指定します。
|
221 |
+
- `conv_block_alphas` : LoRAをConv2d 3x3に拡張したときの各ブロックのalphaを指定します。省略時はconv_alphaの値が使用されます。
|
222 |
+
|
223 |
+
### 階層別dim (rank)コマンドライン指定例:
|
224 |
+
|
225 |
+
```powershell
|
226 |
+
--network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2"
|
227 |
+
|
228 |
+
--network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2" "conv_block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"
|
229 |
+
|
230 |
+
--network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2" "block_alphas=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"
|
231 |
+
```
|
232 |
+
|
233 |
+
### 階層別dim (rank)tomlファイル指定例:
|
234 |
+
|
235 |
+
```toml
|
236 |
+
network_args = [ "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2",]
|
237 |
+
|
238 |
+
network_args = [ "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2", "block_alphas=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2",]
|
239 |
+
```
|
240 |
+
|
241 |
+
# その他のスクリプト
|
242 |
+
|
243 |
+
マージ等LoRAに関連するスクリプト群です。
|
244 |
+
|
245 |
+
## マージスクリプトについて
|
246 |
+
|
247 |
+
merge_lora.pyでStable DiffusionのモデルにLoRAの学習結果をマージしたり、複数のLoRAモデルをマージしたりできます。
|
248 |
+
|
249 |
+
### Stable DiffusionのモデルにLoRAのモデルをマージする
|
250 |
+
|
251 |
+
マージ後のモデルは通常のStable Diffusionのckptと同様に扱えます。たとえば以下のようなコマンドラインになります。
|
252 |
+
|
253 |
+
```
|
254 |
+
python networks\merge_lora.py --sd_model ..\model\model.ckpt
|
255 |
+
--save_to ..\lora_train1\model-char1-merged.safetensors
|
256 |
+
--models ..\lora_train1\last.safetensors --ratios 0.8
|
257 |
+
```
|
258 |
+
|
259 |
+
Stable Diffusion v2.xのモデルで学習し、それにマージする場合は、--v2オプションを指定してください。
|
260 |
+
|
261 |
+
--sd_modelオプションにマージの元となるStable Diffusionのモデルファイルを指定します(.ckptまたは.safetensorsのみ対応で、Diffusersは今のところ対応していません)。
|
262 |
+
|
263 |
+
--save_toオプションにマージ後のモデルの保存先を指定します(.ckptまたは.safetensors、拡張子で自動判定)。
|
264 |
+
|
265 |
+
--modelsに学習したLoRAのモデルファイルを指定します。複数指定も可能で、その時は順にマージします。
|
266 |
+
|
267 |
+
--ratiosにそれぞれのモデルの適用率(どのくらい重みを元モデルに反映するか)を0~1.0の数値で指定します。例えば過学習に近いような場合は、適用率を下げるとマシになるかもしれません。モデルの数と同じだけ指定してください。
|
268 |
+
|
269 |
+
複数指定時は以下のようになります。
|
270 |
+
|
271 |
+
```
|
272 |
+
python networks\merge_lora.py --sd_model ..\model\model.ckpt
|
273 |
+
--save_to ..\lora_train1\model-char1-merged.safetensors
|
274 |
+
--models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors --ratios 0.8 0.5
|
275 |
+
```
|
276 |
+
|
277 |
+
### 複数のLoRAのモデルをマージする
|
278 |
+
|
279 |
+
複数のLoRAモデルをひとつずつSDモデルに適用する場合と、複数のLoRAモデルをマージしてからSDモデルにマージする場合とは、計算順序の関連で微妙に異なる結果になります。
|
280 |
+
|
281 |
+
たとえば以下のようなコマンドラインになります。
|
282 |
+
|
283 |
+
```
|
284 |
+
python networks\merge_lora.py
|
285 |
+
--save_to ..\lora_train1\model-char1-style1-merged.safetensors
|
286 |
+
--models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors --ratios 0.6 0.4
|
287 |
+
```
|
288 |
+
|
289 |
+
--sd_modelオプションは指定不要です。
|
290 |
+
|
291 |
+
--save_toオプションにマージ後のLoRAモデルの保存先を指定します(.ckptまたは.safetensors、拡張子で自動判定)。
|
292 |
+
|
293 |
+
--modelsに学習したLoRAのモデルファイルを指定します。三つ以上も指定可能です。
|
294 |
+
|
295 |
+
--ratiosにそれぞれのモデルの比率(どのくらい重みを元モデルに反映するか)を0~1.0の数値で指定します。二つのモデルを一対一でマージす場合は、「0.5 0.5」になります。「1.0 1.0」では合計の重みが大きくなりすぎて、恐らく結果はあまり望ましくないものになると思われます。
|
296 |
+
|
297 |
+
v1で学習したLoRAとv2で学習したLoRA、rank(次元数)や``alpha``の異なるLoRAはマージできません。U-NetだけのLoRAとU-Net+Text EncoderのLoRAはマージできるはずですが、結果は未知数です。
|
298 |
+
|
299 |
+
|
300 |
+
### その他のオプション
|
301 |
+
|
302 |
+
* precision
|
303 |
+
* マージ計算時の精度をfloat、fp16、bf16から指定できます。省略時は精度を確保するためfloatになります。メモリ使用量を減らしたい場合はfp16/bf16を指定してください。
|
304 |
+
* save_precision
|
305 |
+
* モデル保存時の精度をfloat、fp16、bf16から指定できます。省略時はprecisionと同じ精度になります。
|
306 |
+
|
307 |
+
|
308 |
+
## 複数のrankが異なるLoRAのモデルをマージする
|
309 |
+
|
310 |
+
複数のLoRAをひとつのLoRAで近似します(完全な再現はできません)。`svd_merge_lora.py`を用います。たとえば以下のようなコマンドラインになります。
|
311 |
+
|
312 |
+
```
|
313 |
+
python networks\svd_merge_lora.py
|
314 |
+
--save_to ..\lora_train1\model-char1-style1-merged.safetensors
|
315 |
+
--models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors
|
316 |
+
--ratios 0.6 0.4 --new_rank 32 --device cuda
|
317 |
+
```
|
318 |
+
|
319 |
+
`merge_lora.py` と主なオプションは同一です。以下のオプションが追加されています。
|
320 |
+
|
321 |
+
- `--new_rank`
|
322 |
+
- 作成するLoRAのrankを指定します。
|
323 |
+
- `--new_conv_rank`
|
324 |
+
- 作成する Conv2d 3x3 LoRA の rank を指定します。省略時は `new_rank` と同じになります。
|
325 |
+
- `--device`
|
326 |
+
- `--device cuda`としてcudaを指定すると計算をGPU上で行います。処理が速くなります。
|
327 |
+
|
328 |
+
## 当リポジトリ内の画像生成スクリプトで生成する
|
329 |
+
|
330 |
+
gen_img_diffusers.pyに、--network_module、--network_weightsの各オプションを追加してください。意味は学習時と同様です。
|
331 |
+
|
332 |
+
--network_mulオプションで0~1.0の数値を指定す���と、LoRAの適用率を変えられます。
|
333 |
+
|
334 |
+
## Diffusersのpipelineで生成する
|
335 |
+
|
336 |
+
以下の例を参考にしてください。必要なファイルはnetworks/lora.pyのみです。Diffusersのバージョンは0.10.2以外では動作しない可能性があります。
|
337 |
+
|
338 |
+
```python
|
339 |
+
import torch
|
340 |
+
from diffusers import StableDiffusionPipeline
|
341 |
+
from networks.lora import LoRAModule, create_network_from_weights
|
342 |
+
from safetensors.torch import load_file
|
343 |
+
|
344 |
+
# if the ckpt is CompVis based, convert it to Diffusers beforehand with tools/convert_diffusers20_original_sd.py. See --help for more details.
|
345 |
+
|
346 |
+
model_id_or_dir = r"model_id_on_hugging_face_or_dir"
|
347 |
+
device = "cuda"
|
348 |
+
|
349 |
+
# create pipe
|
350 |
+
print(f"creating pipe from {model_id_or_dir}...")
|
351 |
+
pipe = StableDiffusionPipeline.from_pretrained(model_id_or_dir, revision="fp16", torch_dtype=torch.float16)
|
352 |
+
pipe = pipe.to(device)
|
353 |
+
vae = pipe.vae
|
354 |
+
text_encoder = pipe.text_encoder
|
355 |
+
unet = pipe.unet
|
356 |
+
|
357 |
+
# load lora networks
|
358 |
+
print(f"loading lora networks...")
|
359 |
+
|
360 |
+
lora_path1 = r"lora1.safetensors"
|
361 |
+
sd = load_file(lora_path1) # If the file is .ckpt, use torch.load instead.
|
362 |
+
network1, sd = create_network_from_weights(0.5, None, vae, text_encoder,unet, sd)
|
363 |
+
network1.apply_to(text_encoder, unet)
|
364 |
+
network1.load_state_dict(sd)
|
365 |
+
network1.to(device, dtype=torch.float16)
|
366 |
+
|
367 |
+
# # You can merge weights instead of apply_to+load_state_dict. network.set_multiplier does not work
|
368 |
+
# network.merge_to(text_encoder, unet, sd)
|
369 |
+
|
370 |
+
lora_path2 = r"lora2.safetensors"
|
371 |
+
sd = load_file(lora_path2)
|
372 |
+
network2, sd = create_network_from_weights(0.7, None, vae, text_encoder,unet, sd)
|
373 |
+
network2.apply_to(text_encoder, unet)
|
374 |
+
network2.load_state_dict(sd)
|
375 |
+
network2.to(device, dtype=torch.float16)
|
376 |
+
|
377 |
+
lora_path3 = r"lora3.safetensors"
|
378 |
+
sd = load_file(lora_path3)
|
379 |
+
network3, sd = create_network_from_weights(0.5, None, vae, text_encoder,unet, sd)
|
380 |
+
network3.apply_to(text_encoder, unet)
|
381 |
+
network3.load_state_dict(sd)
|
382 |
+
network3.to(device, dtype=torch.float16)
|
383 |
+
|
384 |
+
# prompts
|
385 |
+
prompt = "masterpiece, best quality, 1girl, in white shirt, looking at viewer"
|
386 |
+
negative_prompt = "bad quality, worst quality, bad anatomy, bad hands"
|
387 |
+
|
388 |
+
# exec pipe
|
389 |
+
print("generating image...")
|
390 |
+
with torch.autocast("cuda"):
|
391 |
+
image = pipe(prompt, guidance_scale=7.5, negative_prompt=negative_prompt).images[0]
|
392 |
+
|
393 |
+
# if not merged, you can use set_multiplier
|
394 |
+
# network1.set_multiplier(0.8)
|
395 |
+
# and generate image again...
|
396 |
+
|
397 |
+
# save image
|
398 |
+
image.save(r"by_diffusers..png")
|
399 |
+
```
|
400 |
+
|
401 |
+
## 二つのモデルの差分からLoRAモデルを作成する
|
402 |
+
|
403 |
+
[こちらのディスカッション](https://github.com/cloneofsimo/lora/discussions/56)を参考に実装したものです。数式はそのまま使わせていただきました(よく理解していませんが近似には特異値分解を用いるようです)。
|
404 |
+
|
405 |
+
二つのモデル(たとえばfine tuningの元モデルとfine tuning後のモデル)の差分を、LoRAで近似します。
|
406 |
+
|
407 |
+
### スクリプトの実行方法
|
408 |
+
|
409 |
+
以下のように指定してください。
|
410 |
+
```
|
411 |
+
python networks\extract_lora_from_models.py --model_org base-model.ckpt
|
412 |
+
--model_tuned fine-tuned-model.ckpt
|
413 |
+
--save_to lora-weights.safetensors --dim 4
|
414 |
+
```
|
415 |
+
|
416 |
+
--model_orgオプションに元のStable Diffusionモデルを指定します。作成したLoRAモデルを適用する場合は、このモデルを指定して適用することになります。.ckptまたは.safetensorsが指定できます。
|
417 |
+
|
418 |
+
--model_tunedオプションに差分を抽出する対象のStable Diffusionモデルを指定します。たとえばfine tuningやDreamBooth後のモデルを指定します。.ckptまたは.safetensorsが指定できます。
|
419 |
+
|
420 |
+
--save_toにLoRAモデルの保存先を指定します。--dimにLoRAの次元数を指定します。
|
421 |
+
|
422 |
+
生成されたLoRAモデルは、学習したLoRAモデルと同様に使用できます。
|
423 |
+
|
424 |
+
Text Encoderが二つのモデルで同じ場合にはLoRAはU-NetのみのLoRAとなります。
|
425 |
+
|
426 |
+
### その他のオプション
|
427 |
+
|
428 |
+
- `--v2`
|
429 |
+
- v2.xのStable Diffusionモデルを使う場合に指定してください。
|
430 |
+
- `--device`
|
431 |
+
- ``--device cuda``としてcudaを指定すると計算をGPU上で行います。処理が速くなります(CPUでもそこまで遅くないため、せいぜい倍~数倍程度のようです)。
|
432 |
+
- `--save_precision`
|
433 |
+
- LoRAの保存形式を"float", "fp16", "bf16"から指定します。省略時はfloatになります。
|
434 |
+
- `--conv_dim`
|
435 |
+
- 指定するとLoRAの適用範囲を Conv2d 3x3 へ拡大します。Conv2d 3x3 の rank を指定します。
|
436 |
+
|
437 |
+
## 画像リサイズスクリプト
|
438 |
+
|
439 |
+
(のちほどドキュメントを整理しますがとりあえずここに説明を書いておきます。)
|
440 |
+
|
441 |
+
Aspect Ratio Bucketingの機能拡張で、小さな画像については拡大しないでそのまま教師データとすることが可能になりました。元の教師画像を縮小した画像を、教師データに加えると精度が向上したという報告とともに前処理用のスクリプトをい���だきましたので整備して追加しました。bmaltais氏に感謝します。
|
442 |
+
|
443 |
+
### スクリプトの実行方法
|
444 |
+
|
445 |
+
以下のように指定してください。元の画像そのまま、およびリサイズ後の画像が変換先フォルダに保存されます。リサイズ後の画像には、ファイル名に ``+512x512`` のようにリサイズ先の解像度が付け加えられます(画像サイズとは異なります)。リサイズ先の解像度より小さい画像は拡大されることはありません。
|
446 |
+
|
447 |
+
```
|
448 |
+
python tools\resize_images_to_resolution.py --max_resolution 512x512,384x384,256x256 --save_as_png
|
449 |
+
--copy_associated_files 元画像フォルダ 変換先フォルダ
|
450 |
+
```
|
451 |
+
|
452 |
+
元画像フォルダ内の画像ファイルが、指定した解像度(複数指定可)と同じ面積になるようにリサイズされ、変換先フォルダに保存されます。画像以外のファイルはそのままコピーされます。
|
453 |
+
|
454 |
+
``--max_resolution`` オプションにリサイズ先のサイズを例のように指定してください。面積がそのサイズになるようにリサイズします。複数指定すると、それぞれの解像度でリサイズされます。``512x512,384x384,256x256``なら、変換先フォルダの画像は、元サイズとリサイズ後サイズ×3の計4枚になります。
|
455 |
+
|
456 |
+
``--save_as_png`` オプションを指定するとpng形式で保存します。省略するとjpeg形式(quality=100)で保存されます。
|
457 |
+
|
458 |
+
``--copy_associated_files`` オプションを指定すると、拡張子を除き画像と同じファイル名(たとえばキャプションなど)のファイルが、リサイズ後の画像のファイル名と同じ名前でコピーされます。
|
459 |
+
|
460 |
+
|
461 |
+
### その他のオプション
|
462 |
+
|
463 |
+
- divisible_by
|
464 |
+
- リサイズ後の画像のサイズ(縦、横のそれぞれ)がこの値で割り切れるように、画像中心を切り出します。
|
465 |
+
- interpolation
|
466 |
+
- 縮小時の補完方法を指定します。``area, cubic, lanczos4``から選択可能で、デフォルトは``area``です。
|
467 |
+
|
468 |
+
|
469 |
+
# 追加情報
|
470 |
+
|
471 |
+
## cloneofsimo氏のリポジトリとの違い
|
472 |
+
|
473 |
+
2022/12/25時点では、当リポジトリはLoRAの適用個所をText EncoderのMLP、U-NetのFFN、Transformerのin/out projectionに拡大し、表現力が増しています。ただその代わりメモリ使用量は増え、8GBぎりぎりになりました。
|
474 |
+
|
475 |
+
またモジュール入れ替え機構は全く異なります。
|
476 |
+
|
477 |
+
## 将来拡張について
|
478 |
+
|
479 |
+
LoRAだけでなく他の拡張にも対応可能ですので、それらも追加予定です。
|
train_network_README.md
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## About learning LoRA
|
2 |
+
|
3 |
+
[LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) (arxiv), [LoRA](https://github.com/microsoft/LoRA) (github) to Stable Applied to Diffusion.
|
4 |
+
|
5 |
+
[cloneofsimo's repository](https://github.com/cloneofsimo/lora) was a great reference. Thank you very much.
|
6 |
+
|
7 |
+
8GB VRAM seems to work just fine.
|
8 |
+
|
9 |
+
## A Note about Trained Models
|
10 |
+
|
11 |
+
Cloneofsimo's repository and d8ahazard's [Drebooth Extension for Stable-Diffusion-WebUI](https://github.com/d8ahazard/sd_drebooth_extension) are currently incompatible. Because we are doing some enhancements (see below).
|
12 |
+
|
13 |
+
When generating images with WebUI, etc., merge the learned LoRA model with the learning source Stable Diffusion model in advance with the script in this repository, or click here [Extention for WebUI] (https://github .com/kohya-ss/sd-webui-additional-networks).
|
14 |
+
|
15 |
+
## Learning method
|
16 |
+
|
17 |
+
Use train_network.py.
|
18 |
+
|
19 |
+
You can learn both the DreamBooth method (using identifiers (sks, etc.) and classes, optionally regularized images) and the fine tuning method using captions.
|
20 |
+
|
21 |
+
Both methods can be learned in much the same way as existing scripts. We will discuss the differences later.
|
22 |
+
|
23 |
+
### Using the DreamBooth Method
|
24 |
+
|
25 |
+
Please refer to [DreamBooth guide](./train_db_README-en.md) and prepare the data.
|
26 |
+
|
27 |
+
Specify train_network.py instead of train_db.py when training.
|
28 |
+
|
29 |
+
Almost all options are available (except Stable Diffusion model save related), but stop_text_encoder_training is not supported.
|
30 |
+
|
31 |
+
### When to use captions
|
32 |
+
|
33 |
+
Please refer to [fine-tuning guide](./fine_tune_README_en.md) and perform each step.
|
34 |
+
|
35 |
+
Specify train_network.py instead of fine_tune.py when training. Almost all options (except for model saving) can be used as is.
|
36 |
+
|
37 |
+
In addition, it will work even if you do not perform "Pre-obtain latents". Since the latent is acquired from the VAE when learning (or caching), the learning speed will be slower, but color_aug can be used instead.
|
38 |
+
|
39 |
+
### Options for Learning LoRA
|
40 |
+
|
41 |
+
In train_network.py, specify the name of the module to be trained in the --network_module option. LoRA is compatible with network.lora, so please specify it.
|
42 |
+
|
43 |
+
The learning rate should be set to about 1e-4, which is higher than normal DreamBooth and fine tuning.
|
44 |
+
|
45 |
+
Below is an example command line (DreamBooth technique).
|
46 |
+
|
47 |
+
```
|
48 |
+
accelerate launch --num_cpu_threads_per_process 12 train_network.py
|
49 |
+
--pretrained_model_name_or_path=..\models\model.ckpt
|
50 |
+
--train_data_dir=..\data\db\char1 --output_dir=..\lora_train1
|
51 |
+
--reg_data_dir=..\data\db\reg1 --prior_loss_weight=1.0
|
52 |
+
--resolution=448,640 --train_batch_size=1 --learning_rate=1e-4
|
53 |
+
--max_train_steps=400 --use_8bit_adam --xformers --mixed_precision=fp16
|
54 |
+
--save_every_n_epochs=1 --save_model_as=safetensors --clip_skip=2 --seed=42 --color_aug
|
55 |
+
--network_module=networks.lora
|
56 |
+
```
|
57 |
+
|
58 |
+
The LoRA model will be saved in the directory specified by the --output_dir option.
|
59 |
+
|
60 |
+
In addition, the following options can be specified.
|
61 |
+
|
62 |
+
* --network_dim
|
63 |
+
* Specify the number of dimensions of LoRA (such as ``--networkdim=4``). Default is 4. The greater the number, the greater the expressive power, but the memory and time required for learning also increase. In addition, it seems that it is not good to increase it blindly.
|
64 |
+
* --network_weights
|
65 |
+
* Load pretrained LoRA weights before training and additionally learn from them.
|
66 |
+
* --network_train_unet_only
|
67 |
+
* Valid only for LoRA modules related to U-Net. It may be better to specify it in fine-tuning study.
|
68 |
+
* --network_train_text_encoder_only
|
69 |
+
* Only LoRA modules related to Text Encoder are enabled. You may be able to expect a textual inversion effect.
|
70 |
+
* --unet_lr
|
71 |
+
* Specify when using a learning rate different from the normal learning rate (specified with the --learning_rate option) for the LoRA module related to U-Net.
|
72 |
+
* --text_encoder_lr
|
73 |
+
* Specify when using a learning rate different from the normal learning rate (specified with the --learning_rate option) for the LoRA module associated with the Text Encoder. Some people say that it is better to set the Text Encoder to a slightly lower learning rate (such as 5e-5).
|
74 |
+
|
75 |
+
When neither --network_train_unet_only nor --network_train_text_encoder_only is specified (default), both Text Encoder and U-Net LoRA modules are enabled.
|
76 |
+
|
77 |
+
## About the merge script
|
78 |
+
|
79 |
+
merge_lora.py allows you to merge LoRA training results into a Stable Diffusion model, or merge multiple LoRA models.
|
80 |
+
|
81 |
+
### Merge LoRA model into Stable Diffusion model
|
82 |
+
|
83 |
+
The model after merging can be handled in the same way as normal Stable Diffusion ckpt. For example, a command line like:
|
84 |
+
|
85 |
+
```
|
86 |
+
python networks\merge_lora.py --sd_model ..\model\model.ckpt
|
87 |
+
--save_to ..\lora_train1\model-char1-merged.safetensors
|
88 |
+
--models ..\lora_train1\last.safetensors --ratios 0.8
|
89 |
+
```
|
90 |
+
|
91 |
+
Specify the --v2 option if you want to train with a Stable Diffusion v2.x model and merge with it.
|
92 |
+
|
93 |
+
Specify the Stable Diffusion model file to be merged in the --sd_model option (only .ckpt or .safetensors are supported, Diffusers is not currently supported).
|
94 |
+
|
95 |
+
Specify the save destination of the model after merging in the --save_to option (.ckpt or .safetensors, automatically determined by extension).
|
96 |
+
|
97 |
+
Specify the LoRA model file learned in --models. It is possible to specify more than one, in which case they will be merged in order.
|
98 |
+
|
99 |
+
For --ratios, specify the application rate of each model (how much weight is reflected in the original model) with a numerical value from 0 to 1.0. For example, if it is close to overfitting, it may be better if the application rate is lowered. Specify as many as the number of models.
|
100 |
+
|
101 |
+
When specifying multiple, it will be as follows.
|
102 |
+
|
103 |
+
```
|
104 |
+
python networks\merge_lora.py --sd_model ..\model\model.ckpt
|
105 |
+
--save_to ..\lora_train1\model-char1-merged.safetensors
|
106 |
+
--models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors --ratios 0.8 0.5
|
107 |
+
```
|
108 |
+
|
109 |
+
### Merge multiple LoRA models
|
110 |
+
|
111 |
+
Applying multiple LoRA models one by one to the SD model and merging multiple LoRA models and then merging them into the SD model yield slightly different results in relation to the calculation order.
|
112 |
+
|
113 |
+
For example, a command line like:
|
114 |
+
|
115 |
+
```
|
116 |
+
python networks\merge_lora.py
|
117 |
+
--save_to ..\lora_train1\model-char1-style1-merged.safetensors
|
118 |
+
--models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors --ratios 0.6 0.4
|
119 |
+
```
|
120 |
+
|
121 |
+
The --sd_model option does not need to be specified.
|
122 |
+
|
123 |
+
Specify the save destination of the merged LoRA model in the --save_to option (.ckpt or .safetensors, automatically determined by extension).
|
124 |
+
|
125 |
+
Specify the LoRA model file learned in --models. Three or more can be specified.
|
126 |
+
|
127 |
+
For --ratios, specify the ratio of each model (how much weight is reflected in the original model) with a numerical value from 0 to 1.0. If you merge two models one-to-one, it will be "0.5 0.5". "1.0 1.0" would give too much weight to the sum, and the result would probably be less desirable.
|
128 |
+
|
129 |
+
LoRA trained with v1 and LoRA trained with v2, and LoRA with different number of dimensions cannot be merged. U-Net only LoRA and U-Net+Text Encoder LoRA should be able to merge, but the result is unknown.
|
130 |
+
|
131 |
+
|
132 |
+
### Other Options
|
133 |
+
|
134 |
+
* precision
|
135 |
+
* The precision for merge calculation can be specified from float, fp16, and bf16. If omitted, it will be float to ensure accuracy. Specify fp16/bf16 if you want to reduce memory usage.
|
136 |
+
* save_precision
|
137 |
+
* You can specify the precision when saving the model from float, fp16, bf16. If omitted, the precision is the same as precision.
|
138 |
+
|
139 |
+
## Generate with the image generation script in this repository
|
140 |
+
|
141 |
+
Add options --network_module, --network_weights, --network_dim (optional) to gen_img_diffusers.py. The meaning is the same as when learning.
|
142 |
+
|
143 |
+
You can change the LoRA application rate by specifying a value between 0 and 1.0 with the --network_mul option.
|
144 |
+
|
145 |
+
## Create a LoRA model from the difference between two models
|
146 |
+
|
147 |
+
It was implemented with reference to [this discussion](https://github.com/cloneofsimo/lora/discussions/56). I used the formula as it is (I don't understand it well, but it seems that singular value decomposition is used for approximation).
|
148 |
+
|
149 |
+
LoRA approximates the difference between two models (for example, the original model after fine tuning and the model after fine tuning).
|
150 |
+
|
151 |
+
### How to run scripts
|
152 |
+
|
153 |
+
Please specify as follows.
|
154 |
+
```
|
155 |
+
python networks\extract_lora_from_models.py --model_org base-model.ckpt
|
156 |
+
--model_tuned fine-tuned-model.ckpt
|
157 |
+
--save_to lora-weights.safetensors --dim 4
|
158 |
+
```
|
159 |
+
|
160 |
+
Specify the original Stable Diffusion model for the --model_org option. When applying the created LoRA model, this model will be specified and applied. .ckpt or .safetensors can be specified.
|
161 |
+
|
162 |
+
Specify the Stable Diffusion model to extract the difference in the --model_tuned option. For example, specify a model after fine tuning or DreamBooth. .ckpt or .safetensors can be specified.
|
163 |
+
|
164 |
+
Specify the save destination of the LoRA model in --save_to. Specify the number of dimensions of LoRA in --dim.
|
165 |
+
|
166 |
+
A generated LoRA model can be used in the same way as a trained LoRA model.
|
167 |
+
|
168 |
+
If the Text Encoder is the same for both models, LoRA will be U-Net only LoRA.
|
169 |
+
|
170 |
+
### Other Options
|
171 |
+
|
172 |
+
--v2
|
173 |
+
- Please specify when using the v2.x Stable Diffusion model.
|
174 |
+
--device
|
175 |
+
- If cuda is specified as ``--device cuda``, the calculation will be performed on the GPU. Processing will be faster (because even the CPU is not that slow, it seems to be at most twice or several times faster).
|
176 |
+
--save_precision
|
177 |
+
- Specify the LoRA save format from "float", "fp16", "bf16". Default is float.
|
178 |
+
|
179 |
+
## Additional Information
|
180 |
+
|
181 |
+
### Differences from cloneofsimo's repository
|
182 |
+
|
183 |
+
As of 12/25, this repository has expanded LoRA application points to Text Encoder's MLP, U-Net's FFN, and Transformer's in/out projection, increasing its expressiveness. However, the amount of memory used increased instead, and it became the last minute of 8GB.
|
184 |
+
|
185 |
+
Also, the module replacement mechanism is completely different.
|
186 |
+
|
187 |
+
### About Future Expansion
|
188 |
+
|
189 |
+
It is possible to support not only LoRA but also other expansions, so we plan to add them as well.
|
train_textual_inversion.py
ADDED
@@ -0,0 +1,598 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import argparse
|
3 |
+
import gc
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
import toml
|
7 |
+
from multiprocessing import Value
|
8 |
+
|
9 |
+
from tqdm import tqdm
|
10 |
+
import torch
|
11 |
+
from accelerate.utils import set_seed
|
12 |
+
import diffusers
|
13 |
+
from diffusers import DDPMScheduler
|
14 |
+
|
15 |
+
import library.train_util as train_util
|
16 |
+
import library.huggingface_util as huggingface_util
|
17 |
+
import library.config_util as config_util
|
18 |
+
from library.config_util import (
|
19 |
+
ConfigSanitizer,
|
20 |
+
BlueprintGenerator,
|
21 |
+
)
|
22 |
+
import library.custom_train_functions as custom_train_functions
|
23 |
+
from library.custom_train_functions import apply_snr_weight
|
24 |
+
|
25 |
+
imagenet_templates_small = [
|
26 |
+
"a photo of a {}",
|
27 |
+
"a rendering of a {}",
|
28 |
+
"a cropped photo of the {}",
|
29 |
+
"the photo of a {}",
|
30 |
+
"a photo of a clean {}",
|
31 |
+
"a photo of a dirty {}",
|
32 |
+
"a dark photo of the {}",
|
33 |
+
"a photo of my {}",
|
34 |
+
"a photo of the cool {}",
|
35 |
+
"a close-up photo of a {}",
|
36 |
+
"a bright photo of the {}",
|
37 |
+
"a cropped photo of a {}",
|
38 |
+
"a photo of the {}",
|
39 |
+
"a good photo of the {}",
|
40 |
+
"a photo of one {}",
|
41 |
+
"a close-up photo of the {}",
|
42 |
+
"a rendition of the {}",
|
43 |
+
"a photo of the clean {}",
|
44 |
+
"a rendition of a {}",
|
45 |
+
"a photo of a nice {}",
|
46 |
+
"a good photo of a {}",
|
47 |
+
"a photo of the nice {}",
|
48 |
+
"a photo of the small {}",
|
49 |
+
"a photo of the weird {}",
|
50 |
+
"a photo of the large {}",
|
51 |
+
"a photo of a cool {}",
|
52 |
+
"a photo of a small {}",
|
53 |
+
]
|
54 |
+
|
55 |
+
imagenet_style_templates_small = [
|
56 |
+
"a painting in the style of {}",
|
57 |
+
"a rendering in the style of {}",
|
58 |
+
"a cropped painting in the style of {}",
|
59 |
+
"the painting in the style of {}",
|
60 |
+
"a clean painting in the style of {}",
|
61 |
+
"a dirty painting in the style of {}",
|
62 |
+
"a dark painting in the style of {}",
|
63 |
+
"a picture in the style of {}",
|
64 |
+
"a cool painting in the style of {}",
|
65 |
+
"a close-up painting in the style of {}",
|
66 |
+
"a bright painting in the style of {}",
|
67 |
+
"a cropped painting in the style of {}",
|
68 |
+
"a good painting in the style of {}",
|
69 |
+
"a close-up painting in the style of {}",
|
70 |
+
"a rendition in the style of {}",
|
71 |
+
"a nice painting in the style of {}",
|
72 |
+
"a small painting in the style of {}",
|
73 |
+
"a weird painting in the style of {}",
|
74 |
+
"a large painting in the style of {}",
|
75 |
+
]
|
76 |
+
|
77 |
+
|
78 |
+
def train(args):
|
79 |
+
if args.output_name is None:
|
80 |
+
args.output_name = args.token_string
|
81 |
+
use_template = args.use_object_template or args.use_style_template
|
82 |
+
|
83 |
+
train_util.verify_training_args(args)
|
84 |
+
train_util.prepare_dataset_args(args, True)
|
85 |
+
|
86 |
+
cache_latents = args.cache_latents
|
87 |
+
|
88 |
+
if args.seed is not None:
|
89 |
+
set_seed(args.seed)
|
90 |
+
|
91 |
+
tokenizer = train_util.load_tokenizer(args)
|
92 |
+
|
93 |
+
# acceleratorを準備する
|
94 |
+
print("prepare accelerator")
|
95 |
+
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
96 |
+
|
97 |
+
# mixed precisionに対応した型を用意しておき適宜castする
|
98 |
+
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
99 |
+
|
100 |
+
# モデルを読み込む
|
101 |
+
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
|
102 |
+
|
103 |
+
# Convert the init_word to token_id
|
104 |
+
if args.init_word is not None:
|
105 |
+
init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
|
106 |
+
if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
|
107 |
+
print(
|
108 |
+
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}"
|
109 |
+
)
|
110 |
+
else:
|
111 |
+
init_token_ids = None
|
112 |
+
|
113 |
+
# add new word to tokenizer, count is num_vectors_per_token
|
114 |
+
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
|
115 |
+
num_added_tokens = tokenizer.add_tokens(token_strings)
|
116 |
+
assert (
|
117 |
+
num_added_tokens == args.num_vectors_per_token
|
118 |
+
), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
|
119 |
+
|
120 |
+
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
|
121 |
+
print(f"tokens are added: {token_ids}")
|
122 |
+
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
|
123 |
+
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
|
124 |
+
|
125 |
+
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
126 |
+
text_encoder.resize_token_embeddings(len(tokenizer))
|
127 |
+
|
128 |
+
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
129 |
+
token_embeds = text_encoder.get_input_embeddings().weight.data
|
130 |
+
if init_token_ids is not None:
|
131 |
+
for i, token_id in enumerate(token_ids):
|
132 |
+
token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]]
|
133 |
+
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
134 |
+
|
135 |
+
# load weights
|
136 |
+
if args.weights is not None:
|
137 |
+
embeddings = load_weights(args.weights)
|
138 |
+
assert len(token_ids) == len(
|
139 |
+
embeddings
|
140 |
+
), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
|
141 |
+
# print(token_ids, embeddings.size())
|
142 |
+
for token_id, embedding in zip(token_ids, embeddings):
|
143 |
+
token_embeds[token_id] = embedding
|
144 |
+
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
145 |
+
print(f"weighs loaded")
|
146 |
+
|
147 |
+
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
|
148 |
+
|
149 |
+
# データセットを準備する
|
150 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
|
151 |
+
if args.dataset_config is not None:
|
152 |
+
print(f"Load dataset config from {args.dataset_config}")
|
153 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
154 |
+
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
155 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
156 |
+
print(
|
157 |
+
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
158 |
+
", ".join(ignored)
|
159 |
+
)
|
160 |
+
)
|
161 |
+
else:
|
162 |
+
use_dreambooth_method = args.in_json is None
|
163 |
+
if use_dreambooth_method:
|
164 |
+
print("Use DreamBooth method.")
|
165 |
+
user_config = {
|
166 |
+
"datasets": [
|
167 |
+
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
|
168 |
+
]
|
169 |
+
}
|
170 |
+
else:
|
171 |
+
print("Train with captions.")
|
172 |
+
user_config = {
|
173 |
+
"datasets": [
|
174 |
+
{
|
175 |
+
"subsets": [
|
176 |
+
{
|
177 |
+
"image_dir": args.train_data_dir,
|
178 |
+
"metadata_file": args.in_json,
|
179 |
+
}
|
180 |
+
]
|
181 |
+
}
|
182 |
+
]
|
183 |
+
}
|
184 |
+
|
185 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
186 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
187 |
+
|
188 |
+
current_epoch = Value("i", 0)
|
189 |
+
current_step = Value("i", 0)
|
190 |
+
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
191 |
+
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
192 |
+
|
193 |
+
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
194 |
+
if use_template:
|
195 |
+
print("use template for training captions. is object: {args.use_object_template}")
|
196 |
+
templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small
|
197 |
+
replace_to = " ".join(token_strings)
|
198 |
+
captions = []
|
199 |
+
for tmpl in templates:
|
200 |
+
captions.append(tmpl.format(replace_to))
|
201 |
+
train_dataset_group.add_replacement("", captions)
|
202 |
+
|
203 |
+
if args.num_vectors_per_token > 1:
|
204 |
+
prompt_replacement = (args.token_string, replace_to)
|
205 |
+
else:
|
206 |
+
prompt_replacement = None
|
207 |
+
else:
|
208 |
+
if args.num_vectors_per_token > 1:
|
209 |
+
replace_to = " ".join(token_strings)
|
210 |
+
train_dataset_group.add_replacement(args.token_string, replace_to)
|
211 |
+
prompt_replacement = (args.token_string, replace_to)
|
212 |
+
else:
|
213 |
+
prompt_replacement = None
|
214 |
+
|
215 |
+
if args.debug_dataset:
|
216 |
+
train_util.debug_dataset(train_dataset_group, show_input_ids=True)
|
217 |
+
return
|
218 |
+
if len(train_dataset_group) == 0:
|
219 |
+
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
|
220 |
+
return
|
221 |
+
|
222 |
+
if cache_latents:
|
223 |
+
assert (
|
224 |
+
train_dataset_group.is_latent_cacheable()
|
225 |
+
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
226 |
+
|
227 |
+
# モデルに xformers とか memory efficient attention を組み込む
|
228 |
+
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
229 |
+
|
230 |
+
# 学習を準備する
|
231 |
+
if cache_latents:
|
232 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
233 |
+
vae.requires_grad_(False)
|
234 |
+
vae.eval()
|
235 |
+
with torch.no_grad():
|
236 |
+
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
237 |
+
vae.to("cpu")
|
238 |
+
if torch.cuda.is_available():
|
239 |
+
torch.cuda.empty_cache()
|
240 |
+
gc.collect()
|
241 |
+
|
242 |
+
accelerator.wait_for_everyone()
|
243 |
+
|
244 |
+
if args.gradient_checkpointing:
|
245 |
+
unet.enable_gradient_checkpointing()
|
246 |
+
text_encoder.gradient_checkpointing_enable()
|
247 |
+
|
248 |
+
# 学習に必要なクラスを準備する
|
249 |
+
print("prepare optimizer, data loader etc.")
|
250 |
+
trainable_params = text_encoder.get_input_embeddings().parameters()
|
251 |
+
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
252 |
+
|
253 |
+
# dataloaderを準備する
|
254 |
+
# DataLoaderのプロセス数:0はメインプロセスになる
|
255 |
+
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
256 |
+
train_dataloader = torch.utils.data.DataLoader(
|
257 |
+
train_dataset_group,
|
258 |
+
batch_size=1,
|
259 |
+
shuffle=True,
|
260 |
+
collate_fn=collater,
|
261 |
+
num_workers=n_workers,
|
262 |
+
persistent_workers=args.persistent_data_loader_workers,
|
263 |
+
)
|
264 |
+
|
265 |
+
# 学習ステップ数を計算する
|
266 |
+
if args.max_train_epochs is not None:
|
267 |
+
args.max_train_steps = args.max_train_epochs * math.ceil(
|
268 |
+
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
269 |
+
)
|
270 |
+
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
271 |
+
|
272 |
+
# データセット側にも学習ステップを送信
|
273 |
+
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
274 |
+
|
275 |
+
# lr schedulerを用意する
|
276 |
+
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
277 |
+
|
278 |
+
# acceleratorがなんかよろしくやってくれるらしい
|
279 |
+
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
280 |
+
text_encoder, optimizer, train_dataloader, lr_scheduler
|
281 |
+
)
|
282 |
+
|
283 |
+
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
|
284 |
+
# print(len(index_no_updates), torch.sum(index_no_updates))
|
285 |
+
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
286 |
+
|
287 |
+
# Freeze all parameters except for the token embeddings in text encoder
|
288 |
+
text_encoder.requires_grad_(True)
|
289 |
+
text_encoder.text_model.encoder.requires_grad_(False)
|
290 |
+
text_encoder.text_model.final_layer_norm.requires_grad_(False)
|
291 |
+
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
|
292 |
+
# text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)
|
293 |
+
|
294 |
+
unet.requires_grad_(False)
|
295 |
+
unet.to(accelerator.device, dtype=weight_dtype)
|
296 |
+
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
|
297 |
+
unet.train()
|
298 |
+
else:
|
299 |
+
unet.eval()
|
300 |
+
|
301 |
+
if not cache_latents:
|
302 |
+
vae.requires_grad_(False)
|
303 |
+
vae.eval()
|
304 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
305 |
+
|
306 |
+
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
307 |
+
if args.full_fp16:
|
308 |
+
train_util.patch_accelerator_for_fp16_training(accelerator)
|
309 |
+
text_encoder.to(weight_dtype)
|
310 |
+
|
311 |
+
# resumeする
|
312 |
+
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
313 |
+
|
314 |
+
# epoch数を計算する
|
315 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
316 |
+
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
317 |
+
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
318 |
+
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
319 |
+
|
320 |
+
# 学習する
|
321 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
322 |
+
print("running training / 学習開始")
|
323 |
+
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
324 |
+
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
325 |
+
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
326 |
+
print(f" num epochs / epoch数: {num_train_epochs}")
|
327 |
+
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
328 |
+
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
329 |
+
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
330 |
+
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
331 |
+
|
332 |
+
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
333 |
+
global_step = 0
|
334 |
+
|
335 |
+
noise_scheduler = DDPMScheduler(
|
336 |
+
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
337 |
+
)
|
338 |
+
|
339 |
+
if accelerator.is_main_process:
|
340 |
+
accelerator.init_trackers("textual_inversion")
|
341 |
+
|
342 |
+
for epoch in range(num_train_epochs):
|
343 |
+
print(f"epoch {epoch+1}/{num_train_epochs}")
|
344 |
+
current_epoch.value = epoch + 1
|
345 |
+
|
346 |
+
text_encoder.train()
|
347 |
+
|
348 |
+
loss_total = 0
|
349 |
+
|
350 |
+
for step, batch in enumerate(train_dataloader):
|
351 |
+
current_step.value = global_step
|
352 |
+
with accelerator.accumulate(text_encoder):
|
353 |
+
with torch.no_grad():
|
354 |
+
if "latents" in batch and batch["latents"] is not None:
|
355 |
+
latents = batch["latents"].to(accelerator.device)
|
356 |
+
else:
|
357 |
+
# latentに変換
|
358 |
+
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
359 |
+
latents = latents * 0.18215
|
360 |
+
b_size = latents.shape[0]
|
361 |
+
|
362 |
+
# Get the text embedding for conditioning
|
363 |
+
input_ids = batch["input_ids"].to(accelerator.device)
|
364 |
+
# use float instead of fp16/bf16 because text encoder is float
|
365 |
+
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, torch.float)
|
366 |
+
|
367 |
+
# Sample noise that we'll add to the latents
|
368 |
+
noise = torch.randn_like(latents, device=latents.device)
|
369 |
+
if args.noise_offset:
|
370 |
+
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
371 |
+
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
372 |
+
|
373 |
+
# Sample a random timestep for each image
|
374 |
+
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
375 |
+
timesteps = timesteps.long()
|
376 |
+
|
377 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
378 |
+
# (this is the forward diffusion process)
|
379 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
380 |
+
|
381 |
+
# Predict the noise residual
|
382 |
+
with accelerator.autocast():
|
383 |
+
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
384 |
+
|
385 |
+
if args.v_parameterization:
|
386 |
+
# v-parameterization training
|
387 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
388 |
+
else:
|
389 |
+
target = noise
|
390 |
+
|
391 |
+
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
392 |
+
loss = loss.mean([1, 2, 3])
|
393 |
+
|
394 |
+
if args.min_snr_gamma:
|
395 |
+
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
396 |
+
|
397 |
+
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
398 |
+
loss = loss * loss_weights
|
399 |
+
|
400 |
+
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
401 |
+
|
402 |
+
accelerator.backward(loss)
|
403 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
404 |
+
params_to_clip = text_encoder.get_input_embeddings().parameters()
|
405 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
406 |
+
|
407 |
+
optimizer.step()
|
408 |
+
lr_scheduler.step()
|
409 |
+
optimizer.zero_grad(set_to_none=True)
|
410 |
+
|
411 |
+
# Let's make sure we don't update any embedding weights besides the newly added token
|
412 |
+
with torch.no_grad():
|
413 |
+
unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[
|
414 |
+
index_no_updates
|
415 |
+
]
|
416 |
+
|
417 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
418 |
+
if accelerator.sync_gradients:
|
419 |
+
progress_bar.update(1)
|
420 |
+
global_step += 1
|
421 |
+
|
422 |
+
train_util.sample_images(
|
423 |
+
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
424 |
+
)
|
425 |
+
|
426 |
+
current_loss = loss.detach().item()
|
427 |
+
if args.logging_dir is not None:
|
428 |
+
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
429 |
+
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
430 |
+
logs["lr/d*lr"] = (
|
431 |
+
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
|
432 |
+
)
|
433 |
+
accelerator.log(logs, step=global_step)
|
434 |
+
|
435 |
+
loss_total += current_loss
|
436 |
+
avr_loss = loss_total / (step + 1)
|
437 |
+
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
438 |
+
progress_bar.set_postfix(**logs)
|
439 |
+
|
440 |
+
if global_step >= args.max_train_steps:
|
441 |
+
break
|
442 |
+
|
443 |
+
if args.logging_dir is not None:
|
444 |
+
logs = {"loss/epoch": loss_total / len(train_dataloader)}
|
445 |
+
accelerator.log(logs, step=epoch + 1)
|
446 |
+
|
447 |
+
accelerator.wait_for_everyone()
|
448 |
+
|
449 |
+
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
450 |
+
|
451 |
+
if args.save_every_n_epochs is not None:
|
452 |
+
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
453 |
+
|
454 |
+
def save_func():
|
455 |
+
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + "." + args.save_model_as
|
456 |
+
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
457 |
+
print(f"saving checkpoint: {ckpt_file}")
|
458 |
+
save_weights(ckpt_file, updated_embs, save_dtype)
|
459 |
+
if args.huggingface_repo_id is not None:
|
460 |
+
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
|
461 |
+
|
462 |
+
def remove_old_func(old_epoch_no):
|
463 |
+
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
|
464 |
+
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
465 |
+
if os.path.exists(old_ckpt_file):
|
466 |
+
print(f"removing old checkpoint: {old_ckpt_file}")
|
467 |
+
os.remove(old_ckpt_file)
|
468 |
+
|
469 |
+
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
|
470 |
+
if saving and args.save_state:
|
471 |
+
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
472 |
+
|
473 |
+
train_util.sample_images(
|
474 |
+
accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
475 |
+
)
|
476 |
+
|
477 |
+
# end of epoch
|
478 |
+
|
479 |
+
is_main_process = accelerator.is_main_process
|
480 |
+
if is_main_process:
|
481 |
+
text_encoder = unwrap_model(text_encoder)
|
482 |
+
|
483 |
+
accelerator.end_training()
|
484 |
+
|
485 |
+
if args.save_state:
|
486 |
+
train_util.save_state_on_train_end(args, accelerator)
|
487 |
+
|
488 |
+
updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone()
|
489 |
+
|
490 |
+
del accelerator # この後メモリを使うのでこれは消す
|
491 |
+
|
492 |
+
if is_main_process:
|
493 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
494 |
+
|
495 |
+
model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
496 |
+
ckpt_name = model_name + "." + args.save_model_as
|
497 |
+
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
498 |
+
|
499 |
+
print(f"save trained model to {ckpt_file}")
|
500 |
+
save_weights(ckpt_file, updated_embs, save_dtype)
|
501 |
+
if args.huggingface_repo_id is not None:
|
502 |
+
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
|
503 |
+
print("model saved.")
|
504 |
+
|
505 |
+
|
506 |
+
def save_weights(file, updated_embs, save_dtype):
|
507 |
+
state_dict = {"emb_params": updated_embs}
|
508 |
+
|
509 |
+
if save_dtype is not None:
|
510 |
+
for key in list(state_dict.keys()):
|
511 |
+
v = state_dict[key]
|
512 |
+
v = v.detach().clone().to("cpu").to(save_dtype)
|
513 |
+
state_dict[key] = v
|
514 |
+
|
515 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
516 |
+
from safetensors.torch import save_file
|
517 |
+
|
518 |
+
save_file(state_dict, file)
|
519 |
+
else:
|
520 |
+
torch.save(state_dict, file) # can be loaded in Web UI
|
521 |
+
|
522 |
+
|
523 |
+
def load_weights(file):
|
524 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
525 |
+
from safetensors.torch import load_file
|
526 |
+
|
527 |
+
data = load_file(file)
|
528 |
+
else:
|
529 |
+
# compatible to Web UI's file format
|
530 |
+
data = torch.load(file, map_location="cpu")
|
531 |
+
if type(data) != dict:
|
532 |
+
raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}")
|
533 |
+
|
534 |
+
if "string_to_param" in data: # textual inversion embeddings
|
535 |
+
data = data["string_to_param"]
|
536 |
+
if hasattr(data, "_parameters"): # support old PyTorch?
|
537 |
+
data = getattr(data, "_parameters")
|
538 |
+
|
539 |
+
emb = next(iter(data.values()))
|
540 |
+
if type(emb) != torch.Tensor:
|
541 |
+
raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {file}")
|
542 |
+
|
543 |
+
if len(emb.size()) == 1:
|
544 |
+
emb = emb.unsqueeze(0)
|
545 |
+
|
546 |
+
return emb
|
547 |
+
|
548 |
+
|
549 |
+
def setup_parser() -> argparse.ArgumentParser:
|
550 |
+
parser = argparse.ArgumentParser()
|
551 |
+
|
552 |
+
train_util.add_sd_models_arguments(parser)
|
553 |
+
train_util.add_dataset_arguments(parser, True, True, False)
|
554 |
+
train_util.add_training_arguments(parser, True)
|
555 |
+
train_util.add_optimizer_arguments(parser)
|
556 |
+
config_util.add_config_arguments(parser)
|
557 |
+
custom_train_functions.add_custom_train_arguments(parser, False)
|
558 |
+
|
559 |
+
parser.add_argument(
|
560 |
+
"--save_model_as",
|
561 |
+
type=str,
|
562 |
+
default="pt",
|
563 |
+
choices=[None, "ckpt", "pt", "safetensors"],
|
564 |
+
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)",
|
565 |
+
)
|
566 |
+
|
567 |
+
parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み")
|
568 |
+
parser.add_argument(
|
569 |
+
"--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数"
|
570 |
+
)
|
571 |
+
parser.add_argument(
|
572 |
+
"--token_string",
|
573 |
+
type=str,
|
574 |
+
default=None,
|
575 |
+
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
|
576 |
+
)
|
577 |
+
parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
|
578 |
+
parser.add_argument(
|
579 |
+
"--use_object_template",
|
580 |
+
action="store_true",
|
581 |
+
help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する",
|
582 |
+
)
|
583 |
+
parser.add_argument(
|
584 |
+
"--use_style_template",
|
585 |
+
action="store_true",
|
586 |
+
help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する",
|
587 |
+
)
|
588 |
+
|
589 |
+
return parser
|
590 |
+
|
591 |
+
|
592 |
+
if __name__ == "__main__":
|
593 |
+
parser = setup_parser()
|
594 |
+
|
595 |
+
args = parser.parse_args()
|
596 |
+
args = train_util.read_config_from_file(args, parser)
|
597 |
+
|
598 |
+
train(args)
|
train_textual_inversion_XTI.py
ADDED
@@ -0,0 +1,650 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import argparse
|
3 |
+
import gc
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
import toml
|
7 |
+
from multiprocessing import Value
|
8 |
+
|
9 |
+
from tqdm import tqdm
|
10 |
+
import torch
|
11 |
+
from accelerate.utils import set_seed
|
12 |
+
import diffusers
|
13 |
+
from diffusers import DDPMScheduler
|
14 |
+
|
15 |
+
import library.train_util as train_util
|
16 |
+
import library.huggingface_util as huggingface_util
|
17 |
+
import library.config_util as config_util
|
18 |
+
from library.config_util import (
|
19 |
+
ConfigSanitizer,
|
20 |
+
BlueprintGenerator,
|
21 |
+
)
|
22 |
+
import library.custom_train_functions as custom_train_functions
|
23 |
+
from library.custom_train_functions import apply_snr_weight
|
24 |
+
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
25 |
+
|
26 |
+
imagenet_templates_small = [
|
27 |
+
"a photo of a {}",
|
28 |
+
"a rendering of a {}",
|
29 |
+
"a cropped photo of the {}",
|
30 |
+
"the photo of a {}",
|
31 |
+
"a photo of a clean {}",
|
32 |
+
"a photo of a dirty {}",
|
33 |
+
"a dark photo of the {}",
|
34 |
+
"a photo of my {}",
|
35 |
+
"a photo of the cool {}",
|
36 |
+
"a close-up photo of a {}",
|
37 |
+
"a bright photo of the {}",
|
38 |
+
"a cropped photo of a {}",
|
39 |
+
"a photo of the {}",
|
40 |
+
"a good photo of the {}",
|
41 |
+
"a photo of one {}",
|
42 |
+
"a close-up photo of the {}",
|
43 |
+
"a rendition of the {}",
|
44 |
+
"a photo of the clean {}",
|
45 |
+
"a rendition of a {}",
|
46 |
+
"a photo of a nice {}",
|
47 |
+
"a good photo of a {}",
|
48 |
+
"a photo of the nice {}",
|
49 |
+
"a photo of the small {}",
|
50 |
+
"a photo of the weird {}",
|
51 |
+
"a photo of the large {}",
|
52 |
+
"a photo of a cool {}",
|
53 |
+
"a photo of a small {}",
|
54 |
+
]
|
55 |
+
|
56 |
+
imagenet_style_templates_small = [
|
57 |
+
"a painting in the style of {}",
|
58 |
+
"a rendering in the style of {}",
|
59 |
+
"a cropped painting in the style of {}",
|
60 |
+
"the painting in the style of {}",
|
61 |
+
"a clean painting in the style of {}",
|
62 |
+
"a dirty painting in the style of {}",
|
63 |
+
"a dark painting in the style of {}",
|
64 |
+
"a picture in the style of {}",
|
65 |
+
"a cool painting in the style of {}",
|
66 |
+
"a close-up painting in the style of {}",
|
67 |
+
"a bright painting in the style of {}",
|
68 |
+
"a cropped painting in the style of {}",
|
69 |
+
"a good painting in the style of {}",
|
70 |
+
"a close-up painting in the style of {}",
|
71 |
+
"a rendition in the style of {}",
|
72 |
+
"a nice painting in the style of {}",
|
73 |
+
"a small painting in the style of {}",
|
74 |
+
"a weird painting in the style of {}",
|
75 |
+
"a large painting in the style of {}",
|
76 |
+
]
|
77 |
+
|
78 |
+
|
79 |
+
def train(args):
|
80 |
+
if args.output_name is None:
|
81 |
+
args.output_name = args.token_string
|
82 |
+
use_template = args.use_object_template or args.use_style_template
|
83 |
+
|
84 |
+
train_util.verify_training_args(args)
|
85 |
+
train_util.prepare_dataset_args(args, True)
|
86 |
+
|
87 |
+
if args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None:
|
88 |
+
print(
|
89 |
+
"sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません"
|
90 |
+
)
|
91 |
+
|
92 |
+
cache_latents = args.cache_latents
|
93 |
+
|
94 |
+
if args.seed is not None:
|
95 |
+
set_seed(args.seed)
|
96 |
+
|
97 |
+
tokenizer = train_util.load_tokenizer(args)
|
98 |
+
|
99 |
+
# acceleratorを準備する
|
100 |
+
print("prepare accelerator")
|
101 |
+
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
102 |
+
|
103 |
+
# mixed precisionに対応した型を用意しておき適宜castする
|
104 |
+
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
105 |
+
|
106 |
+
# モデルを読み込む
|
107 |
+
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
|
108 |
+
|
109 |
+
# Convert the init_word to token_id
|
110 |
+
if args.init_word is not None:
|
111 |
+
init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
|
112 |
+
if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
|
113 |
+
print(
|
114 |
+
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}"
|
115 |
+
)
|
116 |
+
else:
|
117 |
+
init_token_ids = None
|
118 |
+
|
119 |
+
# add new word to tokenizer, count is num_vectors_per_token
|
120 |
+
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
|
121 |
+
num_added_tokens = tokenizer.add_tokens(token_strings)
|
122 |
+
assert (
|
123 |
+
num_added_tokens == args.num_vectors_per_token
|
124 |
+
), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
|
125 |
+
|
126 |
+
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
|
127 |
+
print(f"tokens are added: {token_ids}")
|
128 |
+
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
|
129 |
+
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
|
130 |
+
|
131 |
+
token_strings_XTI = []
|
132 |
+
XTI_layers = [
|
133 |
+
"IN01",
|
134 |
+
"IN02",
|
135 |
+
"IN04",
|
136 |
+
"IN05",
|
137 |
+
"IN07",
|
138 |
+
"IN08",
|
139 |
+
"MID",
|
140 |
+
"OUT03",
|
141 |
+
"OUT04",
|
142 |
+
"OUT05",
|
143 |
+
"OUT06",
|
144 |
+
"OUT07",
|
145 |
+
"OUT08",
|
146 |
+
"OUT09",
|
147 |
+
"OUT10",
|
148 |
+
"OUT11",
|
149 |
+
]
|
150 |
+
for layer_name in XTI_layers:
|
151 |
+
token_strings_XTI += [f"{t}_{layer_name}" for t in token_strings]
|
152 |
+
|
153 |
+
tokenizer.add_tokens(token_strings_XTI)
|
154 |
+
token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI)
|
155 |
+
print(f"tokens are added (XTI): {token_ids_XTI}")
|
156 |
+
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
157 |
+
text_encoder.resize_token_embeddings(len(tokenizer))
|
158 |
+
|
159 |
+
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
160 |
+
token_embeds = text_encoder.get_input_embeddings().weight.data
|
161 |
+
if init_token_ids is not None:
|
162 |
+
for i, token_id in enumerate(token_ids_XTI):
|
163 |
+
token_embeds[token_id] = token_embeds[init_token_ids[(i // 16) % len(init_token_ids)]]
|
164 |
+
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
165 |
+
|
166 |
+
# load weights
|
167 |
+
if args.weights is not None:
|
168 |
+
embeddings = load_weights(args.weights)
|
169 |
+
assert len(token_ids) == len(
|
170 |
+
embeddings
|
171 |
+
), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
|
172 |
+
# print(token_ids, embeddings.size())
|
173 |
+
for token_id, embedding in zip(token_ids_XTI, embeddings):
|
174 |
+
token_embeds[token_id] = embedding
|
175 |
+
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
176 |
+
print(f"weighs loaded")
|
177 |
+
|
178 |
+
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
|
179 |
+
|
180 |
+
# データセットを準備する
|
181 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
|
182 |
+
if args.dataset_config is not None:
|
183 |
+
print(f"Load dataset config from {args.dataset_config}")
|
184 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
185 |
+
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
186 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
187 |
+
print(
|
188 |
+
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
189 |
+
", ".join(ignored)
|
190 |
+
)
|
191 |
+
)
|
192 |
+
else:
|
193 |
+
use_dreambooth_method = args.in_json is None
|
194 |
+
if use_dreambooth_method:
|
195 |
+
print("Use DreamBooth method.")
|
196 |
+
user_config = {
|
197 |
+
"datasets": [
|
198 |
+
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
|
199 |
+
]
|
200 |
+
}
|
201 |
+
else:
|
202 |
+
print("Train with captions.")
|
203 |
+
user_config = {
|
204 |
+
"datasets": [
|
205 |
+
{
|
206 |
+
"subsets": [
|
207 |
+
{
|
208 |
+
"image_dir": args.train_data_dir,
|
209 |
+
"metadata_file": args.in_json,
|
210 |
+
}
|
211 |
+
]
|
212 |
+
}
|
213 |
+
]
|
214 |
+
}
|
215 |
+
|
216 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
217 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
218 |
+
train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings)
|
219 |
+
current_epoch = Value("i", 0)
|
220 |
+
current_step = Value("i", 0)
|
221 |
+
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
222 |
+
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
223 |
+
|
224 |
+
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
225 |
+
if use_template:
|
226 |
+
print("use template for training captions. is object: {args.use_object_template}")
|
227 |
+
templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small
|
228 |
+
replace_to = " ".join(token_strings)
|
229 |
+
captions = []
|
230 |
+
for tmpl in templates:
|
231 |
+
captions.append(tmpl.format(replace_to))
|
232 |
+
train_dataset_group.add_replacement("", captions)
|
233 |
+
|
234 |
+
if args.num_vectors_per_token > 1:
|
235 |
+
prompt_replacement = (args.token_string, replace_to)
|
236 |
+
else:
|
237 |
+
prompt_replacement = None
|
238 |
+
else:
|
239 |
+
if args.num_vectors_per_token > 1:
|
240 |
+
replace_to = " ".join(token_strings)
|
241 |
+
train_dataset_group.add_replacement(args.token_string, replace_to)
|
242 |
+
prompt_replacement = (args.token_string, replace_to)
|
243 |
+
else:
|
244 |
+
prompt_replacement = None
|
245 |
+
|
246 |
+
if args.debug_dataset:
|
247 |
+
train_util.debug_dataset(train_dataset_group, show_input_ids=True)
|
248 |
+
return
|
249 |
+
if len(train_dataset_group) == 0:
|
250 |
+
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
|
251 |
+
return
|
252 |
+
|
253 |
+
if cache_latents:
|
254 |
+
assert (
|
255 |
+
train_dataset_group.is_latent_cacheable()
|
256 |
+
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
257 |
+
|
258 |
+
# モデルに xformers とか memory efficient attention を組み込む
|
259 |
+
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
260 |
+
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
|
261 |
+
diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI
|
262 |
+
diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI
|
263 |
+
|
264 |
+
# 学習を準備する
|
265 |
+
if cache_latents:
|
266 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
267 |
+
vae.requires_grad_(False)
|
268 |
+
vae.eval()
|
269 |
+
with torch.no_grad():
|
270 |
+
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
271 |
+
vae.to("cpu")
|
272 |
+
if torch.cuda.is_available():
|
273 |
+
torch.cuda.empty_cache()
|
274 |
+
gc.collect()
|
275 |
+
|
276 |
+
accelerator.wait_for_everyone()
|
277 |
+
|
278 |
+
if args.gradient_checkpointing:
|
279 |
+
unet.enable_gradient_checkpointing()
|
280 |
+
text_encoder.gradient_checkpointing_enable()
|
281 |
+
|
282 |
+
# 学習に必要なクラスを準備する
|
283 |
+
print("prepare optimizer, data loader etc.")
|
284 |
+
trainable_params = text_encoder.get_input_embeddings().parameters()
|
285 |
+
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
286 |
+
|
287 |
+
# dataloaderを準備する
|
288 |
+
# DataLoaderのプロセス数:0はメインプロセスになる
|
289 |
+
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
290 |
+
train_dataloader = torch.utils.data.DataLoader(
|
291 |
+
train_dataset_group,
|
292 |
+
batch_size=1,
|
293 |
+
shuffle=True,
|
294 |
+
collate_fn=collater,
|
295 |
+
num_workers=n_workers,
|
296 |
+
persistent_workers=args.persistent_data_loader_workers,
|
297 |
+
)
|
298 |
+
|
299 |
+
# 学習ステップ数を計算する
|
300 |
+
if args.max_train_epochs is not None:
|
301 |
+
args.max_train_steps = args.max_train_epochs * math.ceil(
|
302 |
+
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
303 |
+
)
|
304 |
+
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
305 |
+
|
306 |
+
# データセット側にも学習ステップを送信
|
307 |
+
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
308 |
+
|
309 |
+
# lr schedulerを用意する
|
310 |
+
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
311 |
+
|
312 |
+
# acceleratorがなんかよろしくやってくれるらしい
|
313 |
+
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
314 |
+
text_encoder, optimizer, train_dataloader, lr_scheduler
|
315 |
+
)
|
316 |
+
|
317 |
+
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
|
318 |
+
# print(len(index_no_updates), torch.sum(index_no_updates))
|
319 |
+
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
320 |
+
|
321 |
+
# Freeze all parameters except for the token embeddings in text encoder
|
322 |
+
text_encoder.requires_grad_(True)
|
323 |
+
text_encoder.text_model.encoder.requires_grad_(False)
|
324 |
+
text_encoder.text_model.final_layer_norm.requires_grad_(False)
|
325 |
+
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
|
326 |
+
# text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)
|
327 |
+
|
328 |
+
unet.requires_grad_(False)
|
329 |
+
unet.to(accelerator.device, dtype=weight_dtype)
|
330 |
+
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
|
331 |
+
unet.train()
|
332 |
+
else:
|
333 |
+
unet.eval()
|
334 |
+
|
335 |
+
if not cache_latents:
|
336 |
+
vae.requires_grad_(False)
|
337 |
+
vae.eval()
|
338 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
339 |
+
|
340 |
+
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
341 |
+
if args.full_fp16:
|
342 |
+
train_util.patch_accelerator_for_fp16_training(accelerator)
|
343 |
+
text_encoder.to(weight_dtype)
|
344 |
+
|
345 |
+
# resumeする
|
346 |
+
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
347 |
+
|
348 |
+
# epoch数を計算する
|
349 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
350 |
+
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
351 |
+
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
352 |
+
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
353 |
+
|
354 |
+
# 学習する
|
355 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
356 |
+
print("running training / 学習開始")
|
357 |
+
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
358 |
+
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
359 |
+
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
360 |
+
print(f" num epochs / epoch数: {num_train_epochs}")
|
361 |
+
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
362 |
+
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
363 |
+
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
364 |
+
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
365 |
+
|
366 |
+
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
367 |
+
global_step = 0
|
368 |
+
|
369 |
+
noise_scheduler = DDPMScheduler(
|
370 |
+
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
371 |
+
)
|
372 |
+
|
373 |
+
if accelerator.is_main_process:
|
374 |
+
accelerator.init_trackers("textual_inversion")
|
375 |
+
|
376 |
+
for epoch in range(num_train_epochs):
|
377 |
+
print(f"epoch {epoch+1}/{num_train_epochs}")
|
378 |
+
current_epoch.value = epoch + 1
|
379 |
+
|
380 |
+
text_encoder.train()
|
381 |
+
|
382 |
+
loss_total = 0
|
383 |
+
|
384 |
+
for step, batch in enumerate(train_dataloader):
|
385 |
+
current_step.value = global_step
|
386 |
+
with accelerator.accumulate(text_encoder):
|
387 |
+
with torch.no_grad():
|
388 |
+
if "latents" in batch and batch["latents"] is not None:
|
389 |
+
latents = batch["latents"].to(accelerator.device)
|
390 |
+
else:
|
391 |
+
# latentに変換
|
392 |
+
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
393 |
+
latents = latents * 0.18215
|
394 |
+
b_size = latents.shape[0]
|
395 |
+
|
396 |
+
# Get the text embedding for conditioning
|
397 |
+
input_ids = batch["input_ids"].to(accelerator.device)
|
398 |
+
# weight_dtype) use float instead of fp16/bf16 because text encoder is float
|
399 |
+
encoder_hidden_states = torch.stack(
|
400 |
+
[
|
401 |
+
train_util.get_hidden_states(args, s, tokenizer, text_encoder, weight_dtype)
|
402 |
+
for s in torch.split(input_ids, 1, dim=1)
|
403 |
+
]
|
404 |
+
)
|
405 |
+
|
406 |
+
# Sample noise that we'll add to the latents
|
407 |
+
noise = torch.randn_like(latents, device=latents.device)
|
408 |
+
if args.noise_offset:
|
409 |
+
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
410 |
+
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
411 |
+
|
412 |
+
# Sample a random timestep for each image
|
413 |
+
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
414 |
+
timesteps = timesteps.long()
|
415 |
+
|
416 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
417 |
+
# (this is the forward diffusion process)
|
418 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
419 |
+
|
420 |
+
# Predict the noise residual
|
421 |
+
with accelerator.autocast():
|
422 |
+
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
|
423 |
+
|
424 |
+
if args.v_parameterization:
|
425 |
+
# v-parameterization training
|
426 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
427 |
+
else:
|
428 |
+
target = noise
|
429 |
+
|
430 |
+
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
431 |
+
loss = loss.mean([1, 2, 3])
|
432 |
+
|
433 |
+
if args.min_snr_gamma:
|
434 |
+
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
435 |
+
|
436 |
+
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
437 |
+
loss = loss * loss_weights
|
438 |
+
|
439 |
+
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
440 |
+
|
441 |
+
accelerator.backward(loss)
|
442 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
443 |
+
params_to_clip = text_encoder.get_input_embeddings().parameters()
|
444 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
445 |
+
|
446 |
+
optimizer.step()
|
447 |
+
lr_scheduler.step()
|
448 |
+
optimizer.zero_grad(set_to_none=True)
|
449 |
+
|
450 |
+
# Let's make sure we don't update any embedding weights besides the newly added token
|
451 |
+
with torch.no_grad():
|
452 |
+
unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[
|
453 |
+
index_no_updates
|
454 |
+
]
|
455 |
+
|
456 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
457 |
+
if accelerator.sync_gradients:
|
458 |
+
progress_bar.update(1)
|
459 |
+
global_step += 1
|
460 |
+
# TODO: fix sample_images
|
461 |
+
# train_util.sample_images(
|
462 |
+
# accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
463 |
+
# )
|
464 |
+
|
465 |
+
current_loss = loss.detach().item()
|
466 |
+
if args.logging_dir is not None:
|
467 |
+
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
468 |
+
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
469 |
+
logs["lr/d*lr"] = (
|
470 |
+
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
|
471 |
+
)
|
472 |
+
accelerator.log(logs, step=global_step)
|
473 |
+
|
474 |
+
loss_total += current_loss
|
475 |
+
avr_loss = loss_total / (step + 1)
|
476 |
+
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
477 |
+
progress_bar.set_postfix(**logs)
|
478 |
+
|
479 |
+
if global_step >= args.max_train_steps:
|
480 |
+
break
|
481 |
+
|
482 |
+
if args.logging_dir is not None:
|
483 |
+
logs = {"loss/epoch": loss_total / len(train_dataloader)}
|
484 |
+
accelerator.log(logs, step=epoch + 1)
|
485 |
+
|
486 |
+
accelerator.wait_for_everyone()
|
487 |
+
|
488 |
+
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
|
489 |
+
|
490 |
+
if args.save_every_n_epochs is not None:
|
491 |
+
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
492 |
+
|
493 |
+
def save_func():
|
494 |
+
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + "." + args.save_model_as
|
495 |
+
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
496 |
+
print(f"saving checkpoint: {ckpt_file}")
|
497 |
+
save_weights(ckpt_file, updated_embs, save_dtype)
|
498 |
+
if args.huggingface_repo_id is not None:
|
499 |
+
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
|
500 |
+
|
501 |
+
def remove_old_func(old_epoch_no):
|
502 |
+
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
|
503 |
+
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
504 |
+
if os.path.exists(old_ckpt_file):
|
505 |
+
print(f"removing old checkpoint: {old_ckpt_file}")
|
506 |
+
os.remove(old_ckpt_file)
|
507 |
+
|
508 |
+
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
|
509 |
+
if saving and args.save_state:
|
510 |
+
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
511 |
+
|
512 |
+
# TODO: fix sample_images
|
513 |
+
# train_util.sample_images(
|
514 |
+
# accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
515 |
+
# )
|
516 |
+
|
517 |
+
# end of epoch
|
518 |
+
|
519 |
+
is_main_process = accelerator.is_main_process
|
520 |
+
if is_main_process:
|
521 |
+
text_encoder = unwrap_model(text_encoder)
|
522 |
+
|
523 |
+
accelerator.end_training()
|
524 |
+
|
525 |
+
if args.save_state:
|
526 |
+
train_util.save_state_on_train_end(args, accelerator)
|
527 |
+
|
528 |
+
updated_embs = text_encoder.get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
|
529 |
+
|
530 |
+
del accelerator # この後メモリを使うのでこれは消す
|
531 |
+
|
532 |
+
if is_main_process:
|
533 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
534 |
+
|
535 |
+
model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
536 |
+
ckpt_name = model_name + "." + args.save_model_as
|
537 |
+
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
538 |
+
|
539 |
+
print(f"save trained model to {ckpt_file}")
|
540 |
+
save_weights(ckpt_file, updated_embs, save_dtype)
|
541 |
+
if args.huggingface_repo_id is not None:
|
542 |
+
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
|
543 |
+
print("model saved.")
|
544 |
+
|
545 |
+
|
546 |
+
def save_weights(file, updated_embs, save_dtype):
|
547 |
+
updated_embs = updated_embs.reshape(16, -1, updated_embs.shape[-1])
|
548 |
+
updated_embs = updated_embs.chunk(16)
|
549 |
+
XTI_layers = [
|
550 |
+
"IN01",
|
551 |
+
"IN02",
|
552 |
+
"IN04",
|
553 |
+
"IN05",
|
554 |
+
"IN07",
|
555 |
+
"IN08",
|
556 |
+
"MID",
|
557 |
+
"OUT03",
|
558 |
+
"OUT04",
|
559 |
+
"OUT05",
|
560 |
+
"OUT06",
|
561 |
+
"OUT07",
|
562 |
+
"OUT08",
|
563 |
+
"OUT09",
|
564 |
+
"OUT10",
|
565 |
+
"OUT11",
|
566 |
+
]
|
567 |
+
state_dict = {}
|
568 |
+
for i, layer_name in enumerate(XTI_layers):
|
569 |
+
state_dict[layer_name] = updated_embs[i].squeeze(0).detach().clone().to("cpu").to(save_dtype)
|
570 |
+
|
571 |
+
# if save_dtype is not None:
|
572 |
+
# for key in list(state_dict.keys()):
|
573 |
+
# v = state_dict[key]
|
574 |
+
# v = v.detach().clone().to("cpu").to(save_dtype)
|
575 |
+
# state_dict[key] = v
|
576 |
+
|
577 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
578 |
+
from safetensors.torch import save_file
|
579 |
+
|
580 |
+
save_file(state_dict, file)
|
581 |
+
else:
|
582 |
+
torch.save(state_dict, file) # can be loaded in Web UI
|
583 |
+
|
584 |
+
|
585 |
+
def load_weights(file):
|
586 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
587 |
+
from safetensors.torch import load_file
|
588 |
+
|
589 |
+
data = load_file(file)
|
590 |
+
else:
|
591 |
+
raise ValueError(f"NOT XTI: {file}")
|
592 |
+
|
593 |
+
if len(data.values()) != 16:
|
594 |
+
raise ValueError(f"NOT XTI: {file}")
|
595 |
+
|
596 |
+
emb = torch.concat([x for x in data.values()])
|
597 |
+
|
598 |
+
return emb
|
599 |
+
|
600 |
+
|
601 |
+
def setup_parser() -> argparse.ArgumentParser:
|
602 |
+
parser = argparse.ArgumentParser()
|
603 |
+
|
604 |
+
train_util.add_sd_models_arguments(parser)
|
605 |
+
train_util.add_dataset_arguments(parser, True, True, False)
|
606 |
+
train_util.add_training_arguments(parser, True)
|
607 |
+
train_util.add_optimizer_arguments(parser)
|
608 |
+
config_util.add_config_arguments(parser)
|
609 |
+
custom_train_functions.add_custom_train_arguments(parser, False)
|
610 |
+
|
611 |
+
parser.add_argument(
|
612 |
+
"--save_model_as",
|
613 |
+
type=str,
|
614 |
+
default="pt",
|
615 |
+
choices=[None, "ckpt", "pt", "safetensors"],
|
616 |
+
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)",
|
617 |
+
)
|
618 |
+
|
619 |
+
parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み")
|
620 |
+
parser.add_argument(
|
621 |
+
"--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数"
|
622 |
+
)
|
623 |
+
parser.add_argument(
|
624 |
+
"--token_string",
|
625 |
+
type=str,
|
626 |
+
default=None,
|
627 |
+
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
|
628 |
+
)
|
629 |
+
parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
|
630 |
+
parser.add_argument(
|
631 |
+
"--use_object_template",
|
632 |
+
action="store_true",
|
633 |
+
help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する",
|
634 |
+
)
|
635 |
+
parser.add_argument(
|
636 |
+
"--use_style_template",
|
637 |
+
action="store_true",
|
638 |
+
help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する",
|
639 |
+
)
|
640 |
+
|
641 |
+
return parser
|
642 |
+
|
643 |
+
|
644 |
+
if __name__ == "__main__":
|
645 |
+
parser = setup_parser()
|
646 |
+
|
647 |
+
args = parser.parse_args()
|
648 |
+
args = train_util.read_config_from_file(args, parser)
|
649 |
+
|
650 |
+
train(args)
|
train_ti_README-ja.md
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[Textual Inversion](https://textual-inversion.github.io/) の学習についての説明です。
|
2 |
+
|
3 |
+
[学習についての共通ドキュメント](./train_README-ja.md) もあわせてご覧ください。
|
4 |
+
|
5 |
+
実装に当たっては https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion を大いに参考にしました。
|
6 |
+
|
7 |
+
学習したモデルはWeb UIでもそのまま使えます。
|
8 |
+
|
9 |
+
# 学習の手順
|
10 |
+
|
11 |
+
あらかじめこのリポジトリのREADMEを参照し、環境整備を行ってください。
|
12 |
+
|
13 |
+
## データの準備
|
14 |
+
|
15 |
+
[学習データの準備について](./train_README-ja.md) を参照してください。
|
16 |
+
|
17 |
+
## 学習の実行
|
18 |
+
|
19 |
+
``train_textual_inversion.py`` を用います。以下はコマンドラインの例です(DreamBooth手法)。
|
20 |
+
|
21 |
+
```
|
22 |
+
accelerate launch --num_cpu_threads_per_process 1 train_textual_inversion.py
|
23 |
+
--dataset_config=<データ準備で作成した.tomlファイル>
|
24 |
+
--output_dir=<学習したモデルの出力先フォルダ>
|
25 |
+
--output_name=<学習したモデル出力時のファイル名>
|
26 |
+
--save_model_as=safetensors
|
27 |
+
--prior_loss_weight=1.0
|
28 |
+
--max_train_steps=1600
|
29 |
+
--learning_rate=1e-6
|
30 |
+
--optimizer_type="AdamW8bit"
|
31 |
+
--xformers
|
32 |
+
--mixed_precision="fp16"
|
33 |
+
--cache_latents
|
34 |
+
--gradient_checkpointing
|
35 |
+
--token_string=mychar4 --init_word=cute --num_vectors_per_token=4
|
36 |
+
```
|
37 |
+
|
38 |
+
``--token_string`` に学習時のトークン文字列を指定します。__学習時のプロンプトは、この文字列を含むようにしてください(token_stringがmychar4なら、``mychar4 1girl`` など)__。プロンプトのこの文字列の部分が、Textual Inversionの新しいtokenに置換されて学習されます。DreamBooth, class+identifier形式のデータセットとして、`token_string` をトークン文字列にするのが最も簡単で確実です。
|
39 |
+
|
40 |
+
プロンプトにトークン文字列が含まれているかどうかは、``--debug_dataset`` で置換後のtoken idが表示されますので、以下のように ``49408`` 以降のtokenが存在するかどうかで確認できます。
|
41 |
+
|
42 |
+
```
|
43 |
+
input ids: tensor([[49406, 49408, 49409, 49410, 49411, 49412, 49413, 49414, 49415, 49407,
|
44 |
+
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
|
45 |
+
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
|
46 |
+
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
|
47 |
+
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
|
48 |
+
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
|
49 |
+
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
|
50 |
+
49407, 49407, 49407, 49407, 49407, 49407, 49407]])
|
51 |
+
```
|
52 |
+
|
53 |
+
tokenizerがすでに持っている単語(一般的な単語)は使用できません。
|
54 |
+
|
55 |
+
``--init_word`` にembeddingsを初期化するときのコピー元トークンの文字列を指定します。学ばせたい概念が近いものを選ぶとよいようです。二つ以上のトークンになる文字列は指定できません。
|
56 |
+
|
57 |
+
``--num_vectors_per_token`` にいくつのトークンをこの学習で使うかを指定します。多いほうが表現力が増しますが、その分多くのトークンを消費します。たとえばnum_vectors_per_token=8の場合、指定したトークン文字列は(一般的なプロンプトの77トークン制限のうち)8トークンを消費します。
|
58 |
+
|
59 |
+
以上がTextual Inversionのための主なオプションです。以降は他の学習スクリプトと同様です。
|
60 |
+
|
61 |
+
`num_cpu_threads_per_process` には通常は1を指定するとよいようです。
|
62 |
+
|
63 |
+
`pretrained_model_name_or_path` に追加学習を行う元となるモデルを指定します。Stable Diffusionのcheckpointファイル(.ckptまたは.safetensors)、Diffusersのローカルディスクにあるモデルディレクトリ、DiffusersのモデルID("stabilityai/stable-diffusion-2"など)が指定できます。
|
64 |
+
|
65 |
+
`output_dir` に学習後のモデルを保存するフォルダを指定します。`output_name` にモデルのファイル名を拡張子を除いて指定します。`save_model_as` でsafetensors形式での保存を指定しています。
|
66 |
+
|
67 |
+
`dataset_config` に `.toml` ファイルを指定します。ファイル内でのバッチサイズ指定は、当初はメモリ消費を抑えるために `1` としてください。
|
68 |
+
|
69 |
+
学習させるステップ数 `max_train_steps` を10000とします。学習率 `learning_rate` はここでは5e-6を指定しています。
|
70 |
+
|
71 |
+
省メモリ化のため `mixed_precision="fp16"` を指定します(RTX30 シリーズ以降では `bf16` も指定できます。環境整備時にaccelerateに行った設定と合わせてください)。また `gradient_checkpointing` を指定します。
|
72 |
+
|
73 |
+
オプティマイザ(モデルを学習データにあうように最適化=学習させるクラス)にメモリ消費の少ない 8bit AdamW を使うため、 `optimizer_type="AdamW8bit"` を指定します。
|
74 |
+
|
75 |
+
`xformers` オプションを指定し、xformersのCrossAttentionを用います。xformersをインストールしていない場合やエラーとなる場合(環境にもよりますが `mixed_precision="no"` の場合など)、代わりに `mem_eff_attn` オプションを指定すると省メモリ版CrossAttentionを使用します(速度は遅くなります)。
|
76 |
+
|
77 |
+
ある程度メモリがある場合は、`.toml` ファイルを編集してバッチサイズをたとえば `8` くらいに増やしてください(高速化と精度向上の可能性があります)。
|
78 |
+
|
79 |
+
### よく使われるオプションについて
|
80 |
+
|
81 |
+
以下の場合にはオプションに関するドキュメントを参照してください。
|
82 |
+
|
83 |
+
- Stable Diffusion 2.xまたはそこからの派生モデルを学習する
|
84 |
+
- clip skipを2以上を前提としたモデルを学習する
|
85 |
+
- 75トークンを超えたキャプションで学習する
|
86 |
+
|
87 |
+
### Textual Inversionでのバッチサイズについて
|
88 |
+
|
89 |
+
モデル全体を学習するDreamBoothやfine tuningに比べてメモリ使用量が少ないため、バッチサイズは大きめにできます。
|
90 |
+
|
91 |
+
# Textual Inversionのその他の主なオプション
|
92 |
+
|
93 |
+
すべてのオプションについては別文書を参照してください。
|
94 |
+
|
95 |
+
* `--weights`
|
96 |
+
* 学習前に学習済みのembeddingsを読み込み、そこから追加で学習します。
|
97 |
+
* `--use_object_template`
|
98 |
+
* キャプションではなく既定の物体用テンプレート文字列(``a photo of a {}``など)で学習します。公式実装と同じになります。キャプションは無視されます。
|
99 |
+
* `--use_style_template`
|
100 |
+
* キャプションではなく既定のスタイル用テンプレート文字列で学習します(``a painting in the style of {}``など)。公式実装と同じになります。キャプションは無視されます。
|
101 |
+
|
102 |
+
## 当リポジトリ内の画像生成スクリプトで生成する
|
103 |
+
|
104 |
+
gen_img_diffusers.pyに、``--textual_inversion_embeddings`` オプションで学習したembeddingsファイルを指定してください(複数可)。プロンプトでembeddingsファイルのファイル名(拡張子を除く)を使うと、そのembeddingsが適用されます。
|
105 |
+
|
train_ti_README.md
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## About learning Textual Inversion
|
2 |
+
|
3 |
+
[Textual Inversion](https://textual-inversion.github.io/). I heavily referenced https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion for the implementation.
|
4 |
+
|
5 |
+
The trained model can be used as is on the Web UI.
|
6 |
+
|
7 |
+
In addition, it is probably compatible with SD2.x, but it has not been tested at this time.
|
8 |
+
|
9 |
+
## Learning method
|
10 |
+
|
11 |
+
Use ``train_textual_inversion.py``.
|
12 |
+
|
13 |
+
Data preparation is exactly the same as ``train_network.py``, so please refer to [their document](./train_network_README-en.md).
|
14 |
+
|
15 |
+
## options
|
16 |
+
|
17 |
+
Below is an example command line (DreamBooth technique).
|
18 |
+
|
19 |
+
```
|
20 |
+
accelerate launch --num_cpu_threads_per_process 1 train_textual_inversion.py
|
21 |
+
--pretrained_model_name_or_path=..\models\model.ckpt
|
22 |
+
--train_data_dir=..\data\db\char1 --output_dir=..\ti_train1
|
23 |
+
--resolution=448,640 --train_batch_size=1 --learning_rate=1e-4
|
24 |
+
--max_train_steps=400 --use_8bit_adam --xformers --mixed_precision=fp16
|
25 |
+
--save_every_n_epochs=1 --save_model_as=safetensors --clip_skip=2 --seed=42 --color_aug
|
26 |
+
--token_string=mychar4 --init_word=cute --num_vectors_per_token=4
|
27 |
+
```
|
28 |
+
|
29 |
+
``--token_string`` specifies the token string for learning. __The learning prompt should contain this string (eg ``mychar4 1girl`` if token_string is mychar4)__. This string part of the prompt is replaced with a new token for Textual Inversion and learned.
|
30 |
+
|
31 |
+
``--debug_dataset`` will display the token id after substitution, so you can check if the token string after ``49408`` exists as shown below. I can confirm.
|
32 |
+
|
33 |
+
```
|
34 |
+
input ids: tensor([[49406, 49408, 49409, 49410, 49411, 49412, 49413, 49414, 49415, 49407,
|
35 |
+
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
|
36 |
+
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
|
37 |
+
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
|
38 |
+
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
|
39 |
+
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
|
40 |
+
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
|
41 |
+
49407, 49407, 49407, 49407, 49407, 49407, 49407]])
|
42 |
+
```
|
43 |
+
|
44 |
+
Words that the tokenizer already has (common words) cannot be used.
|
45 |
+
|
46 |
+
In ``--init_word``, specify the string of the copy source token when initializing embeddings. It seems to be a good idea to choose something that has a similar concept to what you want to learn. You cannot specify a character string that becomes two or more tokens.
|
47 |
+
|
48 |
+
``--num_vectors_per_token`` specifies how many tokens to use for this training. The higher the number, the more expressive it is, but it consumes more tokens. For example, if num_vectors_per_token=8, then the specified token string will consume 8 tokens (out of the 77 token limit for a typical prompt).
|
49 |
+
|
50 |
+
|
51 |
+
In addition, the following options can be specified.
|
52 |
+
|
53 |
+
* --weights
|
54 |
+
* Load learned embeddings before learning and learn additionally from there.
|
55 |
+
* --use_object_template
|
56 |
+
* Learn with default object template strings (such as ``a photo of a {}``) instead of captions. It will be the same as the official implementation. Captions are ignored.
|
57 |
+
* --use_style_template
|
58 |
+
* Learn with default style template strings instead of captions (such as ``a painting in the style of {}``). It will be the same as the official implementation. Captions are ignored.
|
59 |
+
|
60 |
+
## Generate with the image generation script in this repository
|
61 |
+
|
62 |
+
In gen_img_diffusers.py, specify the learned embeddings file with the ``--textual_inversion_embeddings`` option. Using the filename (without the extension) of the embeddings file at the prompt will apply the embeddings.
|
upgrade.bat
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@echo off
|
2 |
+
:: Check if there are any changes that need to be committed
|
3 |
+
git status --short
|
4 |
+
if %errorlevel%==1 (
|
5 |
+
echo There are changes that need to be committed. Please stash or undo your changes before running this script.
|
6 |
+
exit
|
7 |
+
)
|
8 |
+
|
9 |
+
:: Pull the latest changes from the remote repository
|
10 |
+
git pull
|
11 |
+
|
12 |
+
:: Activate the virtual environment
|
13 |
+
call .\venv\Scripts\activate.bat
|
14 |
+
|
15 |
+
:: Upgrade the required packages
|
16 |
+
pip install --use-pep517 --upgrade -r requirements.txt
|
upgrade.ps1
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Check if there are any changes that need to be committed
|
2 |
+
if (git status --short) {
|
3 |
+
Write-Error "There are changes that need to be committed. Please stash or undo your changes before running this script."
|
4 |
+
return
|
5 |
+
}
|
6 |
+
|
7 |
+
# Pull the latest changes from the remote repository
|
8 |
+
git pull
|
9 |
+
|
10 |
+
# Activate the virtual environment
|
11 |
+
.\venv\Scripts\activate
|
12 |
+
|
13 |
+
# Upgrade the required packages
|
14 |
+
pip install --use-pep517 --upgrade -r requirements.txt
|
upgrade.sh
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Check if there are any changes that need to be committed
|
4 |
+
if git status --short | grep -q "^[^ ?][^?]*"; then
|
5 |
+
echo "There are changes that need to be committed. Please stash or undo your changes before running this script."
|
6 |
+
exit 1
|
7 |
+
fi
|
8 |
+
|
9 |
+
# Pull the latest changes from the remote repository
|
10 |
+
git pull
|
11 |
+
|
12 |
+
# Activate the virtual environment
|
13 |
+
source venv/bin/activate
|
14 |
+
|
15 |
+
# Upgrade the required packages
|
16 |
+
pip install --use-pep517 --upgrade -r requirements.txt
|
utilities.cmd
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
.\venv\Scripts\python.exe library\utilities.py
|