harkov000 smangrul commited on
Commit
a1a0fc5
0 Parent(s):

Duplicate from smangrul/peft-lora-sd-dreambooth

Browse files

Co-authored-by: Sourab Mangrulkar <smangrul@users.noreply.huggingface.co>

Files changed (10) hide show
  1. .gitattributes +34 -0
  2. README.md +14 -0
  3. app.py +371 -0
  4. colab.py +371 -0
  5. inference.py +91 -0
  6. requirements.txt +12 -0
  7. style.css +3 -0
  8. train_dreambooth.py +1005 -0
  9. trainer.py +156 -0
  10. uploader.py +17 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Peft Lora Sd Dreambooth
3
+ emoji: 🎨
4
+ colorFrom: purple
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 3.16.2
8
+ app_file: app.py
9
+ pinned: false
10
+ license: openrail
11
+ duplicated_from: smangrul/peft-lora-sd-dreambooth
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Demo showcasing parameter-efficient fine-tuning of Stable Dissfusion via Dreambooth leveraging 🤗 PEFT (https://github.com/huggingface/peft)
4
+
5
+ The code in this repo is partly adapted from the following repositories:
6
+ https://huggingface.co/spaces/hysts/LoRA-SD-training
7
+ https://huggingface.co/spaces/multimodalart/dreambooth-training
8
+ """
9
+ from __future__ import annotations
10
+
11
+ import os
12
+ import pathlib
13
+
14
+ import gradio as gr
15
+ import torch
16
+ from typing import List
17
+
18
+ from inference import InferencePipeline
19
+ from trainer import Trainer
20
+ from uploader import upload
21
+
22
+
23
+ TITLE = "# LoRA + Dreambooth Training and Inference Demo 🎨"
24
+ DESCRIPTION = "Demo showcasing parameter-efficient fine-tuning of Stable Dissfusion via Dreambooth leveraging 🤗 PEFT (https://github.com/huggingface/peft)."
25
+
26
+
27
+ ORIGINAL_SPACE_ID = "smangrul/peft-lora-sd-dreambooth"
28
+
29
+ SPACE_ID = os.getenv("SPACE_ID", ORIGINAL_SPACE_ID)
30
+ SHARED_UI_WARNING = f"""# Attention - This Space doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.
31
+ <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></center>
32
+ """
33
+ if os.getenv("SYSTEM") == "spaces" and SPACE_ID != ORIGINAL_SPACE_ID:
34
+ SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'
35
+
36
+ else:
37
+ SETTINGS = "Settings"
38
+ CUDA_NOT_AVAILABLE_WARNING = f"""# Attention - Running on CPU.
39
+ <center>
40
+ You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
41
+ "T4 small" is sufficient to run this demo.
42
+ </center>
43
+ """
44
+
45
+
46
+ def show_warning(warning_text: str) -> gr.Blocks:
47
+ with gr.Blocks() as demo:
48
+ with gr.Box():
49
+ gr.Markdown(warning_text)
50
+ return demo
51
+
52
+
53
+ def update_output_files() -> dict:
54
+ paths = sorted(pathlib.Path("results").glob("*.pt"))
55
+ config_paths = sorted(pathlib.Path("results").glob("*.json"))
56
+ paths = paths + config_paths
57
+ paths = [path.as_posix() for path in paths] # type: ignore
58
+ return gr.update(value=paths or None)
59
+
60
+
61
+ def create_training_demo(trainer: Trainer, pipe: InferencePipeline) -> gr.Blocks:
62
+ with gr.Blocks() as demo:
63
+ base_model = gr.Dropdown(
64
+ choices=[
65
+ "CompVis/stable-diffusion-v1-4",
66
+ "runwayml/stable-diffusion-v1-5",
67
+ "stabilityai/stable-diffusion-2-1-base",
68
+ ],
69
+ value="runwayml/stable-diffusion-v1-5",
70
+ label="Base Model",
71
+ visible=True,
72
+ )
73
+ resolution = gr.Dropdown(choices=["512"], value="512", label="Resolution", visible=False)
74
+
75
+ with gr.Row():
76
+ with gr.Box():
77
+ gr.Markdown("Training Data")
78
+ concept_images = gr.Files(label="Images for your concept")
79
+ concept_prompt = gr.Textbox(label="Concept Prompt", max_lines=1)
80
+ gr.Markdown(
81
+ """
82
+ - Upload images of the style you are planning on training on.
83
+ - For a concept prompt, use a unique, made up word to avoid collisions.
84
+ - Guidelines for getting good results:
85
+ - Dreambooth for an `object` or `style`:
86
+ - 5-10 images of the object from different angles
87
+ - 500-800 iterations should be good enough.
88
+ - Prior preservation is recommended.
89
+ - `class_prompt`:
90
+ - `a photo of object`
91
+ - `style`
92
+ - `concept_prompt`:
93
+ - `<concept prompt> object`
94
+ - `<concept prompt> style`
95
+ - `a photo of <concept prompt> object`
96
+ - `a photo of <concept prompt> style`
97
+ - Dreambooth for a `Person/Face`:
98
+ - 15-50 images of the person from different angles, lighting, and expressions.
99
+ Have considerable photos with close up faces.
100
+ - 800-1200 iterations should be good enough.
101
+ - good defaults for hyperparams
102
+ - Model - `runwayml/stable-diffusion-v1-5` or `stabilityai/stable-diffusion-2-1-base`
103
+ - Use/check Prior preservation.
104
+ - Number of class images to use - 200
105
+ - Prior Loss Weight - 1
106
+ - LoRA Rank for unet - 16
107
+ - LoRA Alpha for unet - 20
108
+ - lora dropout - 0
109
+ - LoRA Bias for unet - `all`
110
+ - LoRA Rank for CLIP - 16
111
+ - LoRA Alpha for CLIP - 17
112
+ - LoRA Bias for CLIP - `all`
113
+ - lora dropout for CLIP - 0
114
+ - Uncheck `FP16` and `8bit-Adam` (don't use them for faces)
115
+ - `class_prompt`: Use the gender related word of the person
116
+ - `man`
117
+ - `woman`
118
+ - `boy`
119
+ - `girl`
120
+ - `concept_prompt`: just the unique, made up word, e.g., `srm`
121
+ - Choose `all` for `lora_bias` and `text_encode_lora_bias`
122
+ - Dreambooth for a `Scene`:
123
+ - 15-50 images of the scene from different angles, lighting, and expressions.
124
+ - 800-1200 iterations should be good enough.
125
+ - Prior preservation is recommended.
126
+ - `class_prompt`:
127
+ - `scene`
128
+ - `landscape`
129
+ - `city`
130
+ - `beach`
131
+ - `mountain`
132
+ - `concept_prompt`:
133
+ - `<concept prompt> scene`
134
+ - `<concept prompt> landscape`
135
+ - Experiment with various values for lora dropouts, enabling/disabling fp16 and 8bit-Adam
136
+ """
137
+ )
138
+ with gr.Box():
139
+ gr.Markdown("Training Parameters")
140
+ num_training_steps = gr.Number(label="Number of Training Steps", value=1000, precision=0)
141
+ learning_rate = gr.Number(label="Learning Rate", value=0.0001)
142
+ gradient_checkpointing = gr.Checkbox(label="Whether to use gradient checkpointing", value=True)
143
+ train_text_encoder = gr.Checkbox(label="Train Text Encoder", value=True)
144
+ with_prior_preservation = gr.Checkbox(label="Prior Preservation", value=True)
145
+ class_prompt = gr.Textbox(
146
+ label="Class Prompt", max_lines=1, placeholder='Example: "a photo of object"'
147
+ )
148
+ num_class_images = gr.Number(label="Number of class images to use", value=50, precision=0)
149
+ prior_loss_weight = gr.Number(label="Prior Loss Weight", value=1.0, precision=1)
150
+ # use_lora = gr.Checkbox(label="Whether to use LoRA", value=True)
151
+ lora_r = gr.Number(label="LoRA Rank for unet", value=4, precision=0)
152
+ lora_alpha = gr.Number(
153
+ label="LoRA Alpha for unet. scaling factor = lora_alpha/lora_r", value=4, precision=0
154
+ )
155
+ lora_dropout = gr.Number(label="lora dropout", value=0.00)
156
+ lora_bias = gr.Dropdown(
157
+ choices=["none", "all", "lora_only"],
158
+ value="none",
159
+ label="LoRA Bias for unet. This enables bias params to be trainable based on the bias type",
160
+ visible=True,
161
+ )
162
+ lora_text_encoder_r = gr.Number(label="LoRA Rank for CLIP", value=4, precision=0)
163
+ lora_text_encoder_alpha = gr.Number(
164
+ label="LoRA Alpha for CLIP. scaling factor = lora_alpha/lora_r", value=4, precision=0
165
+ )
166
+ lora_text_encoder_dropout = gr.Number(label="lora dropout for CLIP", value=0.00)
167
+ lora_text_encoder_bias = gr.Dropdown(
168
+ choices=["none", "all", "lora_only"],
169
+ value="none",
170
+ label="LoRA Bias for CLIP. This enables bias params to be trainable based on the bias type",
171
+ visible=True,
172
+ )
173
+ gradient_accumulation = gr.Number(label="Number of Gradient Accumulation", value=1, precision=0)
174
+ fp16 = gr.Checkbox(label="FP16", value=True)
175
+ use_8bit_adam = gr.Checkbox(label="Use 8bit Adam", value=True)
176
+ gr.Markdown(
177
+ """
178
+ - It will take about 20-30 minutes to train for 1000 steps with a T4 GPU.
179
+ - You may want to try a small number of steps first, like 1, to see if everything works fine in your environment.
180
+ - Note that your trained models will be deleted when the second training is started. You can upload your trained model in the "Upload" tab.
181
+ """
182
+ )
183
+
184
+ run_button = gr.Button("Start Training")
185
+ with gr.Box():
186
+ with gr.Row():
187
+ check_status_button = gr.Button("Check Training Status")
188
+ with gr.Column():
189
+ with gr.Box():
190
+ gr.Markdown("Message")
191
+ training_status = gr.Markdown()
192
+ output_files = gr.Files(label="Trained Weight Files and Configs")
193
+
194
+ run_button.click(fn=pipe.clear)
195
+
196
+ run_button.click(
197
+ fn=trainer.run,
198
+ inputs=[
199
+ base_model,
200
+ resolution,
201
+ num_training_steps,
202
+ concept_images,
203
+ concept_prompt,
204
+ learning_rate,
205
+ gradient_accumulation,
206
+ fp16,
207
+ use_8bit_adam,
208
+ gradient_checkpointing,
209
+ train_text_encoder,
210
+ with_prior_preservation,
211
+ prior_loss_weight,
212
+ class_prompt,
213
+ num_class_images,
214
+ lora_r,
215
+ lora_alpha,
216
+ lora_bias,
217
+ lora_dropout,
218
+ lora_text_encoder_r,
219
+ lora_text_encoder_alpha,
220
+ lora_text_encoder_bias,
221
+ lora_text_encoder_dropout,
222
+ ],
223
+ outputs=[
224
+ training_status,
225
+ output_files,
226
+ ],
227
+ queue=False,
228
+ )
229
+ check_status_button.click(fn=trainer.check_if_running, inputs=None, outputs=training_status, queue=False)
230
+ check_status_button.click(fn=update_output_files, inputs=None, outputs=output_files, queue=False)
231
+ return demo
232
+
233
+
234
+ def find_weight_files() -> List[str]:
235
+ curr_dir = pathlib.Path(__file__).parent
236
+ paths = sorted(curr_dir.rglob("*.pt"))
237
+ return [path.relative_to(curr_dir).as_posix() for path in paths]
238
+
239
+
240
+ def reload_lora_weight_list() -> dict:
241
+ return gr.update(choices=find_weight_files())
242
+
243
+
244
+ def create_inference_demo(pipe: InferencePipeline) -> gr.Blocks:
245
+ with gr.Blocks() as demo:
246
+ with gr.Row():
247
+ with gr.Column():
248
+ base_model = gr.Dropdown(
249
+ choices=[
250
+ "CompVis/stable-diffusion-v1-4",
251
+ "runwayml/stable-diffusion-v1-5",
252
+ "stabilityai/stable-diffusion-2-1-base",
253
+ ],
254
+ value="runwayml/stable-diffusion-v1-5",
255
+ label="Base Model",
256
+ visible=True,
257
+ )
258
+ reload_button = gr.Button("Reload Weight List")
259
+ lora_weight_name = gr.Dropdown(
260
+ choices=find_weight_files(), value="lora/lora_disney.pt", label="LoRA Weight File"
261
+ )
262
+ prompt = gr.Textbox(label="Prompt", max_lines=1, placeholder='Example: "style of sks, baby lion"')
263
+ negative_prompt = gr.Textbox(
264
+ label="Negative Prompt", max_lines=1, placeholder='Example: "blurry, botched, low quality"'
265
+ )
266
+ seed = gr.Slider(label="Seed", minimum=0, maximum=100000, step=1, value=1)
267
+ with gr.Accordion("Other Parameters", open=False):
268
+ num_steps = gr.Slider(label="Number of Steps", minimum=0, maximum=1000, step=1, value=50)
269
+ guidance_scale = gr.Slider(label="CFG Scale", minimum=0, maximum=50, step=0.1, value=7)
270
+
271
+ run_button = gr.Button("Generate")
272
+
273
+ gr.Markdown(
274
+ """
275
+ - After training, you can press "Reload Weight List" button to load your trained model names.
276
+ - Few repos to refer for ideas:
277
+ - https://huggingface.co/smangrul/smangrul
278
+ - https://huggingface.co/smangrul/painting-in-the-style-of-smangrul
279
+ - https://huggingface.co/smangrul/erenyeager
280
+ """
281
+ )
282
+ with gr.Column():
283
+ result = gr.Image(label="Result")
284
+
285
+ reload_button.click(fn=reload_lora_weight_list, inputs=None, outputs=lora_weight_name)
286
+ prompt.submit(
287
+ fn=pipe.run,
288
+ inputs=[
289
+ base_model,
290
+ lora_weight_name,
291
+ prompt,
292
+ negative_prompt,
293
+ seed,
294
+ num_steps,
295
+ guidance_scale,
296
+ ],
297
+ outputs=result,
298
+ queue=False,
299
+ )
300
+ run_button.click(
301
+ fn=pipe.run,
302
+ inputs=[
303
+ base_model,
304
+ lora_weight_name,
305
+ prompt,
306
+ negative_prompt,
307
+ seed,
308
+ num_steps,
309
+ guidance_scale,
310
+ ],
311
+ outputs=result,
312
+ queue=False,
313
+ )
314
+ seed.change(
315
+ fn=pipe.run,
316
+ inputs=[
317
+ base_model,
318
+ lora_weight_name,
319
+ prompt,
320
+ negative_prompt,
321
+ seed,
322
+ num_steps,
323
+ guidance_scale,
324
+ ],
325
+ outputs=result,
326
+ queue=False,
327
+ )
328
+ return demo
329
+
330
+
331
+ def create_upload_demo() -> gr.Blocks:
332
+ with gr.Blocks() as demo:
333
+ model_name = gr.Textbox(label="Model Name")
334
+ hf_token = gr.Textbox(label="Hugging Face Token (with write permission)")
335
+ upload_button = gr.Button("Upload")
336
+ with gr.Box():
337
+ gr.Markdown("Message")
338
+ result = gr.Markdown()
339
+ gr.Markdown(
340
+ """
341
+ - You can upload your trained model to your private Model repo (i.e. https://huggingface.co/{your_username}/{model_name}).
342
+ - You can find your Hugging Face token [here](https://huggingface.co/settings/tokens).
343
+ """
344
+ )
345
+
346
+ upload_button.click(fn=upload, inputs=[model_name, hf_token], outputs=result)
347
+
348
+ return demo
349
+
350
+
351
+ pipe = InferencePipeline()
352
+ trainer = Trainer()
353
+
354
+ with gr.Blocks(css="style.css") as demo:
355
+ if os.getenv("IS_SHARED_UI"):
356
+ show_warning(SHARED_UI_WARNING)
357
+ if not torch.cuda.is_available():
358
+ show_warning(CUDA_NOT_AVAILABLE_WARNING)
359
+
360
+ gr.Markdown(TITLE)
361
+ gr.Markdown(DESCRIPTION)
362
+
363
+ with gr.Tabs():
364
+ with gr.TabItem("Train"):
365
+ create_training_demo(trainer, pipe)
366
+ with gr.TabItem("Test"):
367
+ create_inference_demo(pipe)
368
+ with gr.TabItem("Upload"):
369
+ create_upload_demo()
370
+
371
+ demo.queue(default_enabled=False).launch(share=False)
colab.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Demo showcasing parameter-efficient fine-tuning of Stable Dissfusion via Dreambooth leveraging 🤗 PEFT (https://github.com/huggingface/peft)
4
+
5
+ The code in this repo is partly adapted from the following repositories:
6
+ https://huggingface.co/spaces/hysts/LoRA-SD-training
7
+ https://huggingface.co/spaces/multimodalart/dreambooth-training
8
+ """
9
+ from __future__ import annotations
10
+
11
+ import os
12
+ import pathlib
13
+
14
+ import gradio as gr
15
+ import torch
16
+ from typing import List
17
+
18
+ from inference import InferencePipeline
19
+ from trainer import Trainer
20
+ from uploader import upload
21
+
22
+
23
+ TITLE = "# LoRA + Dreambooth Training and Inference Demo 🎨"
24
+ DESCRIPTION = "Demo showcasing parameter-efficient fine-tuning of Stable Dissfusion via Dreambooth leveraging 🤗 PEFT (https://github.com/huggingface/peft)."
25
+
26
+
27
+ ORIGINAL_SPACE_ID = "smangrul/peft-lora-sd-dreambooth"
28
+
29
+ SPACE_ID = os.getenv("SPACE_ID", ORIGINAL_SPACE_ID)
30
+ SHARED_UI_WARNING = f"""# Attention - This Space doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.
31
+ <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></center>
32
+ """
33
+ if os.getenv("SYSTEM") == "spaces" and SPACE_ID != ORIGINAL_SPACE_ID:
34
+ SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'
35
+
36
+ else:
37
+ SETTINGS = "Settings"
38
+ CUDA_NOT_AVAILABLE_WARNING = f"""# Attention - Running on CPU.
39
+ <center>
40
+ You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
41
+ "T4 small" is sufficient to run this demo.
42
+ </center>
43
+ """
44
+
45
+
46
+ def show_warning(warning_text: str) -> gr.Blocks:
47
+ with gr.Blocks() as demo:
48
+ with gr.Box():
49
+ gr.Markdown(warning_text)
50
+ return demo
51
+
52
+
53
+ def update_output_files() -> dict:
54
+ paths = sorted(pathlib.Path("results").glob("*.pt"))
55
+ config_paths = sorted(pathlib.Path("results").glob("*.json"))
56
+ paths = paths + config_paths
57
+ paths = [path.as_posix() for path in paths] # type: ignore
58
+ return gr.update(value=paths or None)
59
+
60
+
61
+ def create_training_demo(trainer: Trainer, pipe: InferencePipeline) -> gr.Blocks:
62
+ with gr.Blocks() as demo:
63
+ base_model = gr.Dropdown(
64
+ choices=[
65
+ "CompVis/stable-diffusion-v1-4",
66
+ "runwayml/stable-diffusion-v1-5",
67
+ "stabilityai/stable-diffusion-2-1-base",
68
+ ],
69
+ value="runwayml/stable-diffusion-v1-5",
70
+ label="Base Model",
71
+ visible=True,
72
+ )
73
+ resolution = gr.Dropdown(choices=["512"], value="512", label="Resolution", visible=False)
74
+
75
+ with gr.Row():
76
+ with gr.Box():
77
+ gr.Markdown("Training Data")
78
+ concept_images = gr.Files(label="Images for your concept")
79
+ concept_prompt = gr.Textbox(label="Concept Prompt", max_lines=1)
80
+ gr.Markdown(
81
+ """
82
+ - Upload images of the style you are planning on training on.
83
+ - For a concept prompt, use a unique, made up word to avoid collisions.
84
+ - Guidelines for getting good results:
85
+ - Dreambooth for an `object` or `style`:
86
+ - 5-10 images of the object from different angles
87
+ - 500-800 iterations should be good enough.
88
+ - Prior preservation is recommended.
89
+ - `class_prompt`:
90
+ - `a photo of object`
91
+ - `style`
92
+ - `concept_prompt`:
93
+ - `<concept prompt> object`
94
+ - `<concept prompt> style`
95
+ - `a photo of <concept prompt> object`
96
+ - `a photo of <concept prompt> style`
97
+ - Dreambooth for a `Person/Face`:
98
+ - 15-50 images of the person from different angles, lighting, and expressions.
99
+ Have considerable photos with close up faces.
100
+ - 800-1200 iterations should be good enough.
101
+ - good defaults for hyperparams
102
+ - Model - `runwayml/stable-diffusion-v1-5` or `stabilityai/stable-diffusion-2-1-base`
103
+ - Use/check Prior preservation.
104
+ - Number of class images to use - 200
105
+ - Prior Loss Weight - 1
106
+ - LoRA Rank for unet - 16
107
+ - LoRA Alpha for unet - 20
108
+ - lora dropout - 0
109
+ - LoRA Bias for unet - `all`
110
+ - LoRA Rank for CLIP - 16
111
+ - LoRA Alpha for CLIP - 17
112
+ - LoRA Bias for CLIP - `all`
113
+ - lora dropout for CLIP - 0
114
+ - Uncheck `FP16` and `8bit-Adam` (don't use them for faces)
115
+ - `class_prompt`: Use the gender related word of the person
116
+ - `man`
117
+ - `woman`
118
+ - `boy`
119
+ - `girl`
120
+ - `concept_prompt`: just the unique, made up word, e.g., `srm`
121
+ - Choose `all` for `lora_bias` and `text_encode_lora_bias`
122
+ - Dreambooth for a `Scene`:
123
+ - 15-50 images of the scene from different angles, lighting, and expressions.
124
+ - 800-1200 iterations should be good enough.
125
+ - Prior preservation is recommended.
126
+ - `class_prompt`:
127
+ - `scene`
128
+ - `landscape`
129
+ - `city`
130
+ - `beach`
131
+ - `mountain`
132
+ - `concept_prompt`:
133
+ - `<concept prompt> scene`
134
+ - `<concept prompt> landscape`
135
+ - Experiment with various values for lora dropouts, enabling/disabling fp16 and 8bit-Adam
136
+ """
137
+ )
138
+ with gr.Box():
139
+ gr.Markdown("Training Parameters")
140
+ num_training_steps = gr.Number(label="Number of Training Steps", value=1000, precision=0)
141
+ learning_rate = gr.Number(label="Learning Rate", value=0.0001)
142
+ gradient_checkpointing = gr.Checkbox(label="Whether to use gradient checkpointing", value=True)
143
+ train_text_encoder = gr.Checkbox(label="Train Text Encoder", value=True)
144
+ with_prior_preservation = gr.Checkbox(label="Prior Preservation", value=True)
145
+ class_prompt = gr.Textbox(
146
+ label="Class Prompt", max_lines=1, placeholder='Example: "a photo of object"'
147
+ )
148
+ num_class_images = gr.Number(label="Number of class images to use", value=50, precision=0)
149
+ prior_loss_weight = gr.Number(label="Prior Loss Weight", value=1.0, precision=1)
150
+ # use_lora = gr.Checkbox(label="Whether to use LoRA", value=True)
151
+ lora_r = gr.Number(label="LoRA Rank for unet", value=4, precision=0)
152
+ lora_alpha = gr.Number(
153
+ label="LoRA Alpha for unet. scaling factor = lora_r/lora_alpha", value=4, precision=0
154
+ )
155
+ lora_dropout = gr.Number(label="lora dropout", value=0.00)
156
+ lora_bias = gr.Dropdown(
157
+ choices=["none", "all", "lora_only"],
158
+ value="none",
159
+ label="LoRA Bias for unet. This enables bias params to be trainable based on the bias type",
160
+ visible=True,
161
+ )
162
+ lora_text_encoder_r = gr.Number(label="LoRA Rank for CLIP", value=4, precision=0)
163
+ lora_text_encoder_alpha = gr.Number(
164
+ label="LoRA Alpha for CLIP. scaling factor = lora_r/lora_alpha", value=4, precision=0
165
+ )
166
+ lora_text_encoder_dropout = gr.Number(label="lora dropout for CLIP", value=0.00)
167
+ lora_text_encoder_bias = gr.Dropdown(
168
+ choices=["none", "all", "lora_only"],
169
+ value="none",
170
+ label="LoRA Bias for CLIP. This enables bias params to be trainable based on the bias type",
171
+ visible=True,
172
+ )
173
+ gradient_accumulation = gr.Number(label="Number of Gradient Accumulation", value=1, precision=0)
174
+ fp16 = gr.Checkbox(label="FP16", value=True)
175
+ use_8bit_adam = gr.Checkbox(label="Use 8bit Adam", value=True)
176
+ gr.Markdown(
177
+ """
178
+ - It will take about 20-30 minutes to train for 1000 steps with a T4 GPU.
179
+ - You may want to try a small number of steps first, like 1, to see if everything works fine in your environment.
180
+ - Note that your trained models will be deleted when the second training is started. You can upload your trained model in the "Upload" tab.
181
+ """
182
+ )
183
+
184
+ run_button = gr.Button("Start Training")
185
+ with gr.Box():
186
+ with gr.Row():
187
+ check_status_button = gr.Button("Check Training Status")
188
+ with gr.Column():
189
+ with gr.Box():
190
+ gr.Markdown("Message")
191
+ training_status = gr.Markdown()
192
+ output_files = gr.Files(label="Trained Weight Files and Configs")
193
+
194
+ run_button.click(fn=pipe.clear)
195
+
196
+ run_button.click(
197
+ fn=trainer.run,
198
+ inputs=[
199
+ base_model,
200
+ resolution,
201
+ num_training_steps,
202
+ concept_images,
203
+ concept_prompt,
204
+ learning_rate,
205
+ gradient_accumulation,
206
+ fp16,
207
+ use_8bit_adam,
208
+ gradient_checkpointing,
209
+ train_text_encoder,
210
+ with_prior_preservation,
211
+ prior_loss_weight,
212
+ class_prompt,
213
+ num_class_images,
214
+ lora_r,
215
+ lora_alpha,
216
+ lora_bias,
217
+ lora_dropout,
218
+ lora_text_encoder_r,
219
+ lora_text_encoder_alpha,
220
+ lora_text_encoder_bias,
221
+ lora_text_encoder_dropout,
222
+ ],
223
+ outputs=[
224
+ training_status,
225
+ output_files,
226
+ ],
227
+ queue=False,
228
+ )
229
+ check_status_button.click(fn=trainer.check_if_running, inputs=None, outputs=training_status, queue=False)
230
+ check_status_button.click(fn=update_output_files, inputs=None, outputs=output_files, queue=False)
231
+ return demo
232
+
233
+
234
+ def find_weight_files() -> List[str]:
235
+ curr_dir = pathlib.Path(__file__).parent
236
+ paths = sorted(curr_dir.rglob("*.pt"))
237
+ return [path.relative_to(curr_dir).as_posix() for path in paths]
238
+
239
+
240
+ def reload_lora_weight_list() -> dict:
241
+ return gr.update(choices=find_weight_files())
242
+
243
+
244
+ def create_inference_demo(pipe: InferencePipeline) -> gr.Blocks:
245
+ with gr.Blocks() as demo:
246
+ with gr.Row():
247
+ with gr.Column():
248
+ base_model = gr.Dropdown(
249
+ choices=[
250
+ "CompVis/stable-diffusion-v1-4",
251
+ "runwayml/stable-diffusion-v1-5",
252
+ "stabilityai/stable-diffusion-2-1-base",
253
+ ],
254
+ value="runwayml/stable-diffusion-v1-5",
255
+ label="Base Model",
256
+ visible=True,
257
+ )
258
+ reload_button = gr.Button("Reload Weight List")
259
+ lora_weight_name = gr.Dropdown(
260
+ choices=find_weight_files(), value="lora/lora_disney.pt", label="LoRA Weight File"
261
+ )
262
+ prompt = gr.Textbox(label="Prompt", max_lines=1, placeholder='Example: "style of sks, baby lion"')
263
+ negative_prompt = gr.Textbox(
264
+ label="Negative Prompt", max_lines=1, placeholder='Example: "blurry, botched, low quality"'
265
+ )
266
+ seed = gr.Slider(label="Seed", minimum=0, maximum=100000, step=1, value=1)
267
+ with gr.Accordion("Other Parameters", open=False):
268
+ num_steps = gr.Slider(label="Number of Steps", minimum=0, maximum=1000, step=1, value=50)
269
+ guidance_scale = gr.Slider(label="CFG Scale", minimum=0, maximum=50, step=0.1, value=7)
270
+
271
+ run_button = gr.Button("Generate")
272
+
273
+ gr.Markdown(
274
+ """
275
+ - After training, you can press "Reload Weight List" button to load your trained model names.
276
+ - Few repos to refer for ideas:
277
+ - https://huggingface.co/smangrul/smangrul
278
+ - https://huggingface.co/smangrul/painting-in-the-style-of-smangrul
279
+ - https://huggingface.co/smangrul/erenyeager
280
+ """
281
+ )
282
+ with gr.Column():
283
+ result = gr.Image(label="Result")
284
+
285
+ reload_button.click(fn=reload_lora_weight_list, inputs=None, outputs=lora_weight_name)
286
+ prompt.submit(
287
+ fn=pipe.run,
288
+ inputs=[
289
+ base_model,
290
+ lora_weight_name,
291
+ prompt,
292
+ negative_prompt,
293
+ seed,
294
+ num_steps,
295
+ guidance_scale,
296
+ ],
297
+ outputs=result,
298
+ queue=False,
299
+ )
300
+ run_button.click(
301
+ fn=pipe.run,
302
+ inputs=[
303
+ base_model,
304
+ lora_weight_name,
305
+ prompt,
306
+ negative_prompt,
307
+ seed,
308
+ num_steps,
309
+ guidance_scale,
310
+ ],
311
+ outputs=result,
312
+ queue=False,
313
+ )
314
+ seed.change(
315
+ fn=pipe.run,
316
+ inputs=[
317
+ base_model,
318
+ lora_weight_name,
319
+ prompt,
320
+ negative_prompt,
321
+ seed,
322
+ num_steps,
323
+ guidance_scale,
324
+ ],
325
+ outputs=result,
326
+ queue=False,
327
+ )
328
+ return demo
329
+
330
+
331
+ def create_upload_demo() -> gr.Blocks:
332
+ with gr.Blocks() as demo:
333
+ model_name = gr.Textbox(label="Model Name")
334
+ hf_token = gr.Textbox(label="Hugging Face Token (with write permission)")
335
+ upload_button = gr.Button("Upload")
336
+ with gr.Box():
337
+ gr.Markdown("Message")
338
+ result = gr.Markdown()
339
+ gr.Markdown(
340
+ """
341
+ - You can upload your trained model to your private Model repo (i.e. https://huggingface.co/{your_username}/{model_name}).
342
+ - You can find your Hugging Face token [here](https://huggingface.co/settings/tokens).
343
+ """
344
+ )
345
+
346
+ upload_button.click(fn=upload, inputs=[model_name, hf_token], outputs=result)
347
+
348
+ return demo
349
+
350
+
351
+ pipe = InferencePipeline()
352
+ trainer = Trainer()
353
+
354
+ with gr.Blocks(css="style.css") as demo:
355
+ if os.getenv("IS_SHARED_UI"):
356
+ show_warning(SHARED_UI_WARNING)
357
+ if not torch.cuda.is_available():
358
+ show_warning(CUDA_NOT_AVAILABLE_WARNING)
359
+
360
+ gr.Markdown(TITLE)
361
+ gr.Markdown(DESCRIPTION)
362
+
363
+ with gr.Tabs():
364
+ with gr.TabItem("Train"):
365
+ create_training_demo(trainer, pipe)
366
+ with gr.TabItem("Test"):
367
+ create_inference_demo(pipe)
368
+ with gr.TabItem("Upload"):
369
+ create_upload_demo()
370
+
371
+ demo.queue(default_enabled=False).launch(share=True)
inference.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import gc
4
+ import json
5
+ import pathlib
6
+ import sys
7
+
8
+ import gradio as gr
9
+ import PIL.Image
10
+ import torch
11
+ from diffusers import StableDiffusionPipeline
12
+ from peft import LoraModel, LoraConfig, set_peft_model_state_dict
13
+
14
+
15
+ class InferencePipeline:
16
+ def __init__(self):
17
+ self.pipe = None
18
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19
+ self.weight_path = None
20
+
21
+ def clear(self) -> None:
22
+ self.weight_path = None
23
+ del self.pipe
24
+ self.pipe = None
25
+ torch.cuda.empty_cache()
26
+ gc.collect()
27
+
28
+ @staticmethod
29
+ def get_lora_weight_path(name: str) -> pathlib.Path:
30
+ curr_dir = pathlib.Path(__file__).parent
31
+ return curr_dir / name, curr_dir / f'{name.replace(".pt", "_config.json")}'
32
+
33
+ def load_and_set_lora_ckpt(self, pipe, weight_path, config_path, dtype):
34
+ with open(config_path, "r") as f:
35
+ lora_config = json.load(f)
36
+ lora_checkpoint_sd = torch.load(weight_path, map_location=self.device)
37
+ unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k}
38
+ text_encoder_lora_ds = {
39
+ k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k
40
+ }
41
+
42
+ unet_config = LoraConfig(**lora_config["peft_config"])
43
+ pipe.unet = LoraModel(unet_config, pipe.unet)
44
+ set_peft_model_state_dict(pipe.unet, unet_lora_ds)
45
+
46
+ if "text_encoder_peft_config" in lora_config:
47
+ text_encoder_config = LoraConfig(**lora_config["text_encoder_peft_config"])
48
+ pipe.text_encoder = LoraModel(text_encoder_config, pipe.text_encoder)
49
+ set_peft_model_state_dict(pipe.text_encoder, text_encoder_lora_ds)
50
+
51
+ if dtype in (torch.float16, torch.bfloat16):
52
+ pipe.unet.half()
53
+ pipe.text_encoder.half()
54
+
55
+ pipe.to(self.device)
56
+ return pipe
57
+
58
+ def load_pipe(self, model_id: str, lora_filename: str) -> None:
59
+ weight_path, config_path = self.get_lora_weight_path(lora_filename)
60
+ if weight_path == self.weight_path:
61
+ return
62
+
63
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(self.device)
64
+ pipe = pipe.to(self.device)
65
+ pipe = self.load_and_set_lora_ckpt(pipe, weight_path, config_path, torch.float16)
66
+ self.pipe = pipe
67
+
68
+ def run(
69
+ self,
70
+ base_model: str,
71
+ lora_weight_name: str,
72
+ prompt: str,
73
+ negative_prompt: str,
74
+ seed: int,
75
+ n_steps: int,
76
+ guidance_scale: float,
77
+ ) -> PIL.Image.Image:
78
+ if not torch.cuda.is_available():
79
+ raise gr.Error("CUDA is not available.")
80
+
81
+ self.load_pipe(base_model, lora_weight_name)
82
+
83
+ generator = torch.Generator(device=self.device).manual_seed(seed)
84
+ out = self.pipe(
85
+ prompt,
86
+ num_inference_steps=n_steps,
87
+ guidance_scale=guidance_scale,
88
+ generator=generator,
89
+ negative_prompt=negative_prompt if negative_prompt else None,
90
+ ) # type: ignore
91
+ return out.images[0]
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ peft
4
+ datasets
5
+ git+https://github.com/huggingface/accelerate
6
+ git+https://github.com/huggingface/diffusers
7
+ git+https://github.com/huggingface/transformers
8
+ tqdm
9
+ ftfy
10
+ Pillow
11
+ bitsandbytes
12
+ gradio
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
train_dreambooth.py ADDED
@@ -0,0 +1,1005 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import gc
3
+ import hashlib
4
+ import itertools
5
+ import json
6
+ import logging
7
+ import math
8
+ import os
9
+ import threading
10
+ import warnings
11
+ from pathlib import Path
12
+ from typing import Optional
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ import torch.utils.checkpoint
17
+ import transformers
18
+ from accelerate import Accelerator
19
+ from accelerate.logging import get_logger
20
+ from accelerate.utils import set_seed
21
+ from torch.utils.data import Dataset
22
+ from transformers import AutoTokenizer, PretrainedConfig
23
+
24
+ import datasets
25
+ import diffusers
26
+ import psutil
27
+ from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
28
+ from diffusers.optimization import get_scheduler
29
+ from diffusers.utils import check_min_version
30
+ from diffusers.utils.import_utils import is_xformers_available
31
+ from huggingface_hub import HfFolder, Repository, whoami
32
+ from peft import LoraConfig, LoraModel, get_peft_model_state_dict
33
+ from PIL import Image
34
+ from torchvision import transforms
35
+ from tqdm.auto import tqdm
36
+
37
+
38
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
39
+ check_min_version("0.10.0.dev0")
40
+
41
+ logger = get_logger(__name__)
42
+
43
+ UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] # , "ff.net.0.proj"]
44
+ TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"]
45
+
46
+
47
+ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
48
+ text_encoder_config = PretrainedConfig.from_pretrained(
49
+ pretrained_model_name_or_path,
50
+ subfolder="text_encoder",
51
+ revision=revision,
52
+ )
53
+ model_class = text_encoder_config.architectures[0]
54
+
55
+ if model_class == "CLIPTextModel":
56
+ from transformers import CLIPTextModel
57
+
58
+ return CLIPTextModel
59
+ elif model_class == "RobertaSeriesModelWithTransformation":
60
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
61
+
62
+ return RobertaSeriesModelWithTransformation
63
+ else:
64
+ raise ValueError(f"{model_class} is not supported.")
65
+
66
+
67
+ def parse_args(input_args=None):
68
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
69
+ parser.add_argument(
70
+ "--pretrained_model_name_or_path",
71
+ type=str,
72
+ default=None,
73
+ required=True,
74
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
75
+ )
76
+ parser.add_argument(
77
+ "--revision",
78
+ type=str,
79
+ default=None,
80
+ required=False,
81
+ help="Revision of pretrained model identifier from huggingface.co/models.",
82
+ )
83
+ parser.add_argument(
84
+ "--tokenizer_name",
85
+ type=str,
86
+ default=None,
87
+ help="Pretrained tokenizer name or path if not the same as model_name",
88
+ )
89
+ parser.add_argument(
90
+ "--instance_data_dir",
91
+ type=str,
92
+ default=None,
93
+ required=True,
94
+ help="A folder containing the training data of instance images.",
95
+ )
96
+ parser.add_argument(
97
+ "--class_data_dir",
98
+ type=str,
99
+ default=None,
100
+ required=False,
101
+ help="A folder containing the training data of class images.",
102
+ )
103
+ parser.add_argument(
104
+ "--instance_prompt",
105
+ type=str,
106
+ default=None,
107
+ required=True,
108
+ help="The prompt with identifier specifying the instance",
109
+ )
110
+ parser.add_argument(
111
+ "--class_prompt",
112
+ type=str,
113
+ default=None,
114
+ help="The prompt to specify images in the same class as provided instance images.",
115
+ )
116
+ parser.add_argument(
117
+ "--with_prior_preservation",
118
+ default=False,
119
+ action="store_true",
120
+ help="Flag to add prior preservation loss.",
121
+ )
122
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
123
+ parser.add_argument(
124
+ "--num_class_images",
125
+ type=int,
126
+ default=100,
127
+ help=(
128
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
129
+ " class_data_dir, additional images will be sampled with class_prompt."
130
+ ),
131
+ )
132
+ parser.add_argument(
133
+ "--output_dir",
134
+ type=str,
135
+ default="text-inversion-model",
136
+ help="The output directory where the model predictions and checkpoints will be written.",
137
+ )
138
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
139
+ parser.add_argument(
140
+ "--resolution",
141
+ type=int,
142
+ default=512,
143
+ help=(
144
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
145
+ " resolution"
146
+ ),
147
+ )
148
+ parser.add_argument(
149
+ "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
150
+ )
151
+ parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
152
+
153
+ # lora args
154
+ parser.add_argument("--use_lora", action="store_true", help="Whether to use Lora for parameter efficient tuning")
155
+ parser.add_argument("--lora_r", type=int, default=8, help="Lora rank, only used if use_lora is True")
156
+ parser.add_argument("--lora_alpha", type=int, default=32, help="Lora alpha, only used if use_lora is True")
157
+ parser.add_argument("--lora_dropout", type=float, default=0.0, help="Lora dropout, only used if use_lora is True")
158
+ parser.add_argument(
159
+ "--lora_bias",
160
+ type=str,
161
+ default="none",
162
+ help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora is True",
163
+ )
164
+ parser.add_argument(
165
+ "--lora_text_encoder_r",
166
+ type=int,
167
+ default=8,
168
+ help="Lora rank for text encoder, only used if `use_lora` and `train_text_encoder` are True",
169
+ )
170
+ parser.add_argument(
171
+ "--lora_text_encoder_alpha",
172
+ type=int,
173
+ default=32,
174
+ help="Lora alpha for text encoder, only used if `use_lora` and `train_text_encoder` are True",
175
+ )
176
+ parser.add_argument(
177
+ "--lora_text_encoder_dropout",
178
+ type=float,
179
+ default=0.0,
180
+ help="Lora dropout for text encoder, only used if `use_lora` and `train_text_encoder` are True",
181
+ )
182
+ parser.add_argument(
183
+ "--lora_text_encoder_bias",
184
+ type=str,
185
+ default="none",
186
+ help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True",
187
+ )
188
+
189
+ parser.add_argument(
190
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
191
+ )
192
+ parser.add_argument(
193
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
194
+ )
195
+ parser.add_argument("--num_train_epochs", type=int, default=1)
196
+ parser.add_argument(
197
+ "--max_train_steps",
198
+ type=int,
199
+ default=None,
200
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
201
+ )
202
+ parser.add_argument(
203
+ "--checkpointing_steps",
204
+ type=int,
205
+ default=500,
206
+ help=(
207
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
208
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
209
+ " training using `--resume_from_checkpoint`."
210
+ ),
211
+ )
212
+ parser.add_argument(
213
+ "--resume_from_checkpoint",
214
+ type=str,
215
+ default=None,
216
+ help=(
217
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
218
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
219
+ ),
220
+ )
221
+ parser.add_argument(
222
+ "--gradient_accumulation_steps",
223
+ type=int,
224
+ default=1,
225
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
226
+ )
227
+ parser.add_argument(
228
+ "--gradient_checkpointing",
229
+ action="store_true",
230
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
231
+ )
232
+ parser.add_argument(
233
+ "--learning_rate",
234
+ type=float,
235
+ default=5e-6,
236
+ help="Initial learning rate (after the potential warmup period) to use.",
237
+ )
238
+ parser.add_argument(
239
+ "--scale_lr",
240
+ action="store_true",
241
+ default=False,
242
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
243
+ )
244
+ parser.add_argument(
245
+ "--lr_scheduler",
246
+ type=str,
247
+ default="constant",
248
+ help=(
249
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
250
+ ' "constant", "constant_with_warmup"]'
251
+ ),
252
+ )
253
+ parser.add_argument(
254
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
255
+ )
256
+ parser.add_argument(
257
+ "--lr_num_cycles",
258
+ type=int,
259
+ default=1,
260
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
261
+ )
262
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
263
+ parser.add_argument(
264
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
265
+ )
266
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
267
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
268
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
269
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
270
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
271
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
272
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
273
+ parser.add_argument(
274
+ "--hub_model_id",
275
+ type=str,
276
+ default=None,
277
+ help="The name of the repository to keep in sync with the local `output_dir`.",
278
+ )
279
+ parser.add_argument(
280
+ "--logging_dir",
281
+ type=str,
282
+ default="logs",
283
+ help=(
284
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
285
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
286
+ ),
287
+ )
288
+ parser.add_argument(
289
+ "--allow_tf32",
290
+ action="store_true",
291
+ help=(
292
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
293
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
294
+ ),
295
+ )
296
+ parser.add_argument(
297
+ "--report_to",
298
+ type=str,
299
+ default="tensorboard",
300
+ help=(
301
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
302
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
303
+ ),
304
+ )
305
+ parser.add_argument(
306
+ "--mixed_precision",
307
+ type=str,
308
+ default=None,
309
+ choices=["no", "fp16", "bf16"],
310
+ help=(
311
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
312
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
313
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
314
+ ),
315
+ )
316
+ parser.add_argument(
317
+ "--prior_generation_precision",
318
+ type=str,
319
+ default=None,
320
+ choices=["no", "fp32", "fp16", "bf16"],
321
+ help=(
322
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
323
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
324
+ ),
325
+ )
326
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
327
+ parser.add_argument(
328
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
329
+ )
330
+
331
+ if input_args is not None:
332
+ args = parser.parse_args(input_args)
333
+ else:
334
+ args = parser.parse_args()
335
+
336
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
337
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
338
+ args.local_rank = env_local_rank
339
+
340
+ if args.with_prior_preservation:
341
+ if args.class_data_dir is None:
342
+ raise ValueError("You must specify a data directory for class images.")
343
+ if args.class_prompt is None:
344
+ raise ValueError("You must specify prompt for class images.")
345
+ else:
346
+ # logger is not available yet
347
+ if args.class_data_dir is not None:
348
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
349
+ if args.class_prompt is not None:
350
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
351
+
352
+ return args
353
+
354
+
355
+ # Converting Bytes to Megabytes
356
+ def b2mb(x):
357
+ return int(x / 2**20)
358
+
359
+
360
+ # This context manager is used to track the peak memory usage of the process
361
+ class TorchTracemalloc:
362
+ def __enter__(self):
363
+ gc.collect()
364
+ torch.cuda.empty_cache()
365
+ torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero
366
+ self.begin = torch.cuda.memory_allocated()
367
+ self.process = psutil.Process()
368
+
369
+ self.cpu_begin = self.cpu_mem_used()
370
+ self.peak_monitoring = True
371
+ peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
372
+ peak_monitor_thread.daemon = True
373
+ peak_monitor_thread.start()
374
+ return self
375
+
376
+ def cpu_mem_used(self):
377
+ """get resident set size memory for the current process"""
378
+ return self.process.memory_info().rss
379
+
380
+ def peak_monitor_func(self):
381
+ self.cpu_peak = -1
382
+
383
+ while True:
384
+ self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak)
385
+
386
+ # can't sleep or will not catch the peak right (this comment is here on purpose)
387
+ # time.sleep(0.001) # 1msec
388
+
389
+ if not self.peak_monitoring:
390
+ break
391
+
392
+ def __exit__(self, *exc):
393
+ self.peak_monitoring = False
394
+
395
+ gc.collect()
396
+ torch.cuda.empty_cache()
397
+ self.end = torch.cuda.memory_allocated()
398
+ self.peak = torch.cuda.max_memory_allocated()
399
+ self.used = b2mb(self.end - self.begin)
400
+ self.peaked = b2mb(self.peak - self.begin)
401
+
402
+ self.cpu_end = self.cpu_mem_used()
403
+ self.cpu_used = b2mb(self.cpu_end - self.cpu_begin)
404
+ self.cpu_peaked = b2mb(self.cpu_peak - self.cpu_begin)
405
+ # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}")
406
+
407
+
408
+ def print_trainable_parameters(model):
409
+ """
410
+ Prints the number of trainable parameters in the model.
411
+ """
412
+ trainable_params = 0
413
+ all_param = 0
414
+ for _, param in model.named_parameters():
415
+ all_param += param.numel()
416
+ if param.requires_grad:
417
+ trainable_params += param.numel()
418
+ print(
419
+ f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
420
+ )
421
+
422
+
423
+ class DreamBoothDataset(Dataset):
424
+ """
425
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
426
+ It pre-processes the images and the tokenizes prompts.
427
+ """
428
+
429
+ def __init__(
430
+ self,
431
+ instance_data_root,
432
+ instance_prompt,
433
+ tokenizer,
434
+ class_data_root=None,
435
+ class_prompt=None,
436
+ size=512,
437
+ center_crop=False,
438
+ ):
439
+ self.size = size
440
+ self.center_crop = center_crop
441
+ self.tokenizer = tokenizer
442
+
443
+ self.instance_data_root = Path(instance_data_root)
444
+ if not self.instance_data_root.exists():
445
+ raise ValueError("Instance images root doesn't exists.")
446
+
447
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
448
+ self.num_instance_images = len(self.instance_images_path)
449
+ self.instance_prompt = instance_prompt
450
+ self._length = self.num_instance_images
451
+
452
+ if class_data_root is not None:
453
+ self.class_data_root = Path(class_data_root)
454
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
455
+ self.class_images_path = list(self.class_data_root.iterdir())
456
+ self.num_class_images = len(self.class_images_path)
457
+ self._length = max(self.num_class_images, self.num_instance_images)
458
+ self.class_prompt = class_prompt
459
+ else:
460
+ self.class_data_root = None
461
+
462
+ self.image_transforms = transforms.Compose(
463
+ [
464
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
465
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
466
+ transforms.ToTensor(),
467
+ transforms.Normalize([0.5], [0.5]),
468
+ ]
469
+ )
470
+
471
+ def __len__(self):
472
+ return self._length
473
+
474
+ def __getitem__(self, index):
475
+ example = {}
476
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
477
+ if not instance_image.mode == "RGB":
478
+ instance_image = instance_image.convert("RGB")
479
+ example["instance_images"] = self.image_transforms(instance_image)
480
+ example["instance_prompt_ids"] = self.tokenizer(
481
+ self.instance_prompt,
482
+ truncation=True,
483
+ padding="max_length",
484
+ max_length=self.tokenizer.model_max_length,
485
+ return_tensors="pt",
486
+ ).input_ids
487
+
488
+ if self.class_data_root:
489
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
490
+ if not class_image.mode == "RGB":
491
+ class_image = class_image.convert("RGB")
492
+ example["class_images"] = self.image_transforms(class_image)
493
+ example["class_prompt_ids"] = self.tokenizer(
494
+ self.class_prompt,
495
+ truncation=True,
496
+ padding="max_length",
497
+ max_length=self.tokenizer.model_max_length,
498
+ return_tensors="pt",
499
+ ).input_ids
500
+
501
+ return example
502
+
503
+
504
+ def collate_fn(examples, with_prior_preservation=False):
505
+ input_ids = [example["instance_prompt_ids"] for example in examples]
506
+ pixel_values = [example["instance_images"] for example in examples]
507
+
508
+ # Concat class and instance examples for prior preservation.
509
+ # We do this to avoid doing two forward passes.
510
+ if with_prior_preservation:
511
+ input_ids += [example["class_prompt_ids"] for example in examples]
512
+ pixel_values += [example["class_images"] for example in examples]
513
+
514
+ pixel_values = torch.stack(pixel_values)
515
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
516
+
517
+ input_ids = torch.cat(input_ids, dim=0)
518
+
519
+ batch = {
520
+ "input_ids": input_ids,
521
+ "pixel_values": pixel_values,
522
+ }
523
+ return batch
524
+
525
+
526
+ class PromptDataset(Dataset):
527
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
528
+
529
+ def __init__(self, prompt, num_samples):
530
+ self.prompt = prompt
531
+ self.num_samples = num_samples
532
+
533
+ def __len__(self):
534
+ return self.num_samples
535
+
536
+ def __getitem__(self, index):
537
+ example = {}
538
+ example["prompt"] = self.prompt
539
+ example["index"] = index
540
+ return example
541
+
542
+
543
+ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
544
+ if token is None:
545
+ token = HfFolder.get_token()
546
+ if organization is None:
547
+ username = whoami(token)["name"]
548
+ return f"{username}/{model_id}"
549
+ else:
550
+ return f"{organization}/{model_id}"
551
+
552
+
553
+ def main(args):
554
+ logging_dir = Path(args.output_dir, args.logging_dir)
555
+
556
+ accelerator = Accelerator(
557
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
558
+ mixed_precision=args.mixed_precision,
559
+ log_with=args.report_to,
560
+ logging_dir=logging_dir,
561
+ )
562
+
563
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
564
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
565
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
566
+ if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
567
+ raise ValueError(
568
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
569
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
570
+ )
571
+
572
+ # Make one log on every process with the configuration for debugging.
573
+ logging.basicConfig(
574
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
575
+ datefmt="%m/%d/%Y %H:%M:%S",
576
+ level=logging.INFO,
577
+ )
578
+ logger.info(accelerator.state, main_process_only=False)
579
+ if accelerator.is_local_main_process:
580
+ datasets.utils.logging.set_verbosity_warning()
581
+ transformers.utils.logging.set_verbosity_warning()
582
+ diffusers.utils.logging.set_verbosity_info()
583
+ else:
584
+ datasets.utils.logging.set_verbosity_error()
585
+ transformers.utils.logging.set_verbosity_error()
586
+ diffusers.utils.logging.set_verbosity_error()
587
+
588
+ # If passed along, set the training seed now.
589
+ if args.seed is not None:
590
+ set_seed(args.seed)
591
+
592
+ # Generate class images if prior preservation is enabled.
593
+ if args.with_prior_preservation:
594
+ class_images_dir = Path(args.class_data_dir)
595
+ if not class_images_dir.exists():
596
+ class_images_dir.mkdir(parents=True)
597
+ cur_class_images = len(list(class_images_dir.iterdir()))
598
+
599
+ if cur_class_images < args.num_class_images:
600
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
601
+ if args.prior_generation_precision == "fp32":
602
+ torch_dtype = torch.float32
603
+ elif args.prior_generation_precision == "fp16":
604
+ torch_dtype = torch.float16
605
+ elif args.prior_generation_precision == "bf16":
606
+ torch_dtype = torch.bfloat16
607
+ pipeline = DiffusionPipeline.from_pretrained(
608
+ args.pretrained_model_name_or_path,
609
+ torch_dtype=torch_dtype,
610
+ safety_checker=None,
611
+ revision=args.revision,
612
+ )
613
+ pipeline.set_progress_bar_config(disable=True)
614
+
615
+ num_new_images = args.num_class_images - cur_class_images
616
+ logger.info(f"Number of class images to sample: {num_new_images}.")
617
+
618
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
619
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
620
+
621
+ sample_dataloader = accelerator.prepare(sample_dataloader)
622
+ pipeline.to(accelerator.device)
623
+
624
+ for example in tqdm(
625
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
626
+ ):
627
+ images = pipeline(example["prompt"]).images
628
+
629
+ for i, image in enumerate(images):
630
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
631
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
632
+ image.save(image_filename)
633
+
634
+ del pipeline
635
+ if torch.cuda.is_available():
636
+ torch.cuda.empty_cache()
637
+
638
+ # Handle the repository creation
639
+ if accelerator.is_main_process:
640
+ if args.push_to_hub:
641
+ if args.hub_model_id is None:
642
+ repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
643
+ else:
644
+ repo_name = args.hub_model_id
645
+ repo = Repository(args.output_dir, clone_from=repo_name) # noqa: F841
646
+
647
+ with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
648
+ if "step_*" not in gitignore:
649
+ gitignore.write("step_*\n")
650
+ if "epoch_*" not in gitignore:
651
+ gitignore.write("epoch_*\n")
652
+ elif args.output_dir is not None:
653
+ os.makedirs(args.output_dir, exist_ok=True)
654
+
655
+ # Load the tokenizer
656
+ if args.tokenizer_name:
657
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
658
+ elif args.pretrained_model_name_or_path:
659
+ tokenizer = AutoTokenizer.from_pretrained(
660
+ args.pretrained_model_name_or_path,
661
+ subfolder="tokenizer",
662
+ revision=args.revision,
663
+ use_fast=False,
664
+ )
665
+
666
+ # import correct text encoder class
667
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
668
+
669
+ # Load scheduler and models
670
+ noise_scheduler = DDPMScheduler(
671
+ beta_start=0.00085,
672
+ beta_end=0.012,
673
+ beta_schedule="scaled_linear",
674
+ num_train_timesteps=1000,
675
+ ) # DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
676
+ text_encoder = text_encoder_cls.from_pretrained(
677
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
678
+ )
679
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
680
+ unet = UNet2DConditionModel.from_pretrained(
681
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
682
+ )
683
+
684
+ if args.use_lora:
685
+ config = LoraConfig(
686
+ r=args.lora_r,
687
+ lora_alpha=args.lora_alpha,
688
+ target_modules=UNET_TARGET_MODULES,
689
+ lora_dropout=args.lora_dropout,
690
+ bias=args.lora_bias,
691
+ )
692
+ unet = LoraModel(config, unet)
693
+ print_trainable_parameters(unet)
694
+ print(unet)
695
+
696
+ vae.requires_grad_(False)
697
+ if not args.train_text_encoder:
698
+ text_encoder.requires_grad_(False)
699
+ elif args.train_text_encoder and args.use_lora:
700
+ config = LoraConfig(
701
+ r=args.lora_text_encoder_r,
702
+ lora_alpha=args.lora_text_encoder_alpha,
703
+ target_modules=TEXT_ENCODER_TARGET_MODULES,
704
+ lora_dropout=args.lora_text_encoder_dropout,
705
+ bias=args.lora_text_encoder_bias,
706
+ )
707
+ text_encoder = LoraModel(config, text_encoder)
708
+ print_trainable_parameters(text_encoder)
709
+ print(text_encoder)
710
+
711
+ if args.enable_xformers_memory_efficient_attention:
712
+ if is_xformers_available():
713
+ unet.enable_xformers_memory_efficient_attention()
714
+ else:
715
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
716
+
717
+ if args.gradient_checkpointing:
718
+ unet.enable_gradient_checkpointing()
719
+ # below fails when using lora so commenting it out
720
+ if args.train_text_encoder and not args.use_lora:
721
+ text_encoder.gradient_checkpointing_enable()
722
+
723
+ # Enable TF32 for faster training on Ampere GPUs,
724
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
725
+ if args.allow_tf32:
726
+ torch.backends.cuda.matmul.allow_tf32 = True
727
+
728
+ if args.scale_lr:
729
+ args.learning_rate = (
730
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
731
+ )
732
+
733
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
734
+ if args.use_8bit_adam:
735
+ try:
736
+ import bitsandbytes as bnb
737
+ except ImportError:
738
+ raise ImportError(
739
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
740
+ )
741
+
742
+ optimizer_class = bnb.optim.AdamW8bit
743
+ else:
744
+ optimizer_class = torch.optim.AdamW
745
+
746
+ # Optimizer creation
747
+ params_to_optimize = (
748
+ itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
749
+ )
750
+ optimizer = optimizer_class(
751
+ params_to_optimize,
752
+ lr=args.learning_rate,
753
+ betas=(args.adam_beta1, args.adam_beta2),
754
+ weight_decay=args.adam_weight_decay,
755
+ eps=args.adam_epsilon,
756
+ )
757
+
758
+ # Dataset and DataLoaders creation:
759
+ train_dataset = DreamBoothDataset(
760
+ instance_data_root=args.instance_data_dir,
761
+ instance_prompt=args.instance_prompt,
762
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
763
+ class_prompt=args.class_prompt,
764
+ tokenizer=tokenizer,
765
+ size=args.resolution,
766
+ center_crop=args.center_crop,
767
+ )
768
+
769
+ train_dataloader = torch.utils.data.DataLoader(
770
+ train_dataset,
771
+ batch_size=args.train_batch_size,
772
+ shuffle=True,
773
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
774
+ num_workers=1,
775
+ )
776
+
777
+ # Scheduler and math around the number of training steps.
778
+ overrode_max_train_steps = False
779
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
780
+ if args.max_train_steps is None:
781
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
782
+ overrode_max_train_steps = True
783
+
784
+ lr_scheduler = get_scheduler(
785
+ args.lr_scheduler,
786
+ optimizer=optimizer,
787
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
788
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
789
+ num_cycles=args.lr_num_cycles,
790
+ power=args.lr_power,
791
+ )
792
+
793
+ # Prepare everything with our `accelerator`.
794
+ if args.train_text_encoder:
795
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
796
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
797
+ )
798
+ else:
799
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
800
+ unet, optimizer, train_dataloader, lr_scheduler
801
+ )
802
+
803
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
804
+ # as these models are only used for inference, keeping weights in full precision is not required.
805
+ weight_dtype = torch.float32
806
+ if accelerator.mixed_precision == "fp16":
807
+ weight_dtype = torch.float16
808
+ elif accelerator.mixed_precision == "bf16":
809
+ weight_dtype = torch.bfloat16
810
+
811
+ # Move vae and text_encoder to device and cast to weight_dtype
812
+ vae.to(accelerator.device, dtype=weight_dtype)
813
+ if not args.train_text_encoder:
814
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
815
+
816
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
817
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
818
+ if overrode_max_train_steps:
819
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
820
+ # Afterwards we recalculate our number of training epochs
821
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
822
+
823
+ # We need to initialize the trackers we use, and also store our configuration.
824
+ # The trackers initializes automatically on the main process.
825
+ if accelerator.is_main_process:
826
+ accelerator.init_trackers("dreambooth", config=vars(args))
827
+
828
+ # Train!
829
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
830
+
831
+ logger.info("***** Running training *****")
832
+ logger.info(f" Num examples = {len(train_dataset)}")
833
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
834
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
835
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
836
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
837
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
838
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
839
+ global_step = 0
840
+ first_epoch = 0
841
+
842
+ # Potentially load in the weights and states from a previous save
843
+ if args.resume_from_checkpoint:
844
+ if args.resume_from_checkpoint != "latest":
845
+ path = os.path.basename(args.resume_from_checkpoint)
846
+ else:
847
+ # Get the mos recent checkpoint
848
+ dirs = os.listdir(args.output_dir)
849
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
850
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
851
+ path = dirs[-1]
852
+ accelerator.print(f"Resuming from checkpoint {path}")
853
+ accelerator.load_state(os.path.join(args.output_dir, path))
854
+ global_step = int(path.split("-")[1])
855
+
856
+ resume_global_step = global_step * args.gradient_accumulation_steps
857
+ first_epoch = resume_global_step // num_update_steps_per_epoch
858
+ resume_step = resume_global_step % num_update_steps_per_epoch
859
+
860
+ # Only show the progress bar once on each machine.
861
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
862
+ progress_bar.set_description("Steps")
863
+
864
+ for epoch in range(first_epoch, args.num_train_epochs):
865
+ unet.train()
866
+ if args.train_text_encoder:
867
+ text_encoder.train()
868
+ with TorchTracemalloc() as tracemalloc:
869
+ for step, batch in enumerate(train_dataloader):
870
+ # Skip steps until we reach the resumed step
871
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
872
+ if step % args.gradient_accumulation_steps == 0:
873
+ progress_bar.update(1)
874
+ continue
875
+
876
+ with accelerator.accumulate(unet):
877
+ # Convert images to latent space
878
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
879
+ latents = latents * 0.18215
880
+
881
+ # Sample noise that we'll add to the latents
882
+ noise = torch.randn_like(latents)
883
+ bsz = latents.shape[0]
884
+ # Sample a random timestep for each image
885
+ timesteps = torch.randint(
886
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
887
+ )
888
+ timesteps = timesteps.long()
889
+
890
+ # Add noise to the latents according to the noise magnitude at each timestep
891
+ # (this is the forward diffusion process)
892
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
893
+
894
+ # Get the text embedding for conditioning
895
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
896
+
897
+ # Predict the noise residual
898
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
899
+
900
+ # Get the target for loss depending on the prediction type
901
+ if noise_scheduler.config.prediction_type == "epsilon":
902
+ target = noise
903
+ elif noise_scheduler.config.prediction_type == "v_prediction":
904
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
905
+ else:
906
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
907
+
908
+ if args.with_prior_preservation:
909
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
910
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
911
+ target, target_prior = torch.chunk(target, 2, dim=0)
912
+
913
+ # Compute instance loss
914
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
915
+
916
+ # Compute prior loss
917
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
918
+
919
+ # Add the prior loss to the instance loss.
920
+ loss = loss + args.prior_loss_weight * prior_loss
921
+ else:
922
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
923
+
924
+ accelerator.backward(loss)
925
+ if accelerator.sync_gradients:
926
+ params_to_clip = (
927
+ itertools.chain(unet.parameters(), text_encoder.parameters())
928
+ if args.train_text_encoder
929
+ else unet.parameters()
930
+ )
931
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
932
+ optimizer.step()
933
+ lr_scheduler.step()
934
+ optimizer.zero_grad()
935
+
936
+ # Checks if the accelerator has performed an optimization step behind the scenes
937
+ if accelerator.sync_gradients:
938
+ progress_bar.update(1)
939
+ global_step += 1
940
+
941
+ # if global_step % args.checkpointing_steps == 0:
942
+ # if accelerator.is_main_process:
943
+ # save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
944
+ # accelerator.save_state(save_path)
945
+ # logger.info(f"Saved state to {save_path}")
946
+
947
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
948
+ progress_bar.set_postfix(**logs)
949
+ accelerator.log(logs, step=global_step)
950
+
951
+ if global_step >= args.max_train_steps:
952
+ break
953
+ # Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage
954
+ accelerator.print("GPU Memory before entering the train : {}".format(b2mb(tracemalloc.begin)))
955
+ accelerator.print("GPU Memory consumed at the end of the train (end-begin): {}".format(tracemalloc.used))
956
+ accelerator.print("GPU Peak Memory consumed during the train (max-begin): {}".format(tracemalloc.peaked))
957
+ accelerator.print(
958
+ "GPU Total Peak Memory consumed during the train (max): {}".format(
959
+ tracemalloc.peaked + b2mb(tracemalloc.begin)
960
+ )
961
+ )
962
+
963
+ accelerator.print("CPU Memory before entering the train : {}".format(b2mb(tracemalloc.cpu_begin)))
964
+ accelerator.print("CPU Memory consumed at the end of the train (end-begin): {}".format(tracemalloc.cpu_used))
965
+ accelerator.print("CPU Peak Memory consumed during the train (max-begin): {}".format(tracemalloc.cpu_peaked))
966
+ accelerator.print(
967
+ "CPU Total Peak Memory consumed during the train (max): {}".format(
968
+ tracemalloc.cpu_peaked + b2mb(tracemalloc.cpu_begin)
969
+ )
970
+ )
971
+
972
+ # Create the pipeline using using the trained modules and save it.
973
+ accelerator.wait_for_everyone()
974
+ if accelerator.is_main_process:
975
+ if args.use_lora:
976
+ lora_config = {}
977
+ state_dict = get_peft_model_state_dict(unet, state_dict=accelerator.get_state_dict(unet))
978
+ lora_config["peft_config"] = unet.get_peft_config_as_dict(inference=True)
979
+ if args.train_text_encoder:
980
+ text_encoder_state_dict = get_peft_model_state_dict(
981
+ text_encoder, state_dict=accelerator.get_state_dict(text_encoder)
982
+ )
983
+ text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()}
984
+ state_dict.update(text_encoder_state_dict)
985
+ lora_config["text_encoder_peft_config"] = text_encoder.get_peft_config_as_dict(inference=True)
986
+
987
+ accelerator.print(state_dict)
988
+ accelerator.save(state_dict, os.path.join(args.output_dir, f"{args.instance_prompt}_lora.pt"))
989
+ with open(os.path.join(args.output_dir, f"{args.instance_prompt}_lora_config.json"), "w") as f:
990
+ json.dump(lora_config, f)
991
+ else:
992
+ pipeline = DiffusionPipeline.from_pretrained(
993
+ args.pretrained_model_name_or_path,
994
+ unet=accelerator.unwrap_model(unet),
995
+ text_encoder=accelerator.unwrap_model(text_encoder),
996
+ revision=args.revision,
997
+ )
998
+ pipeline.save_pretrained(args.output_dir)
999
+
1000
+ accelerator.end_training()
1001
+
1002
+
1003
+ if __name__ == "__main__":
1004
+ args = parse_args()
1005
+ main(args)
trainer.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import pathlib
5
+ import shlex
6
+ import shutil
7
+ import subprocess
8
+
9
+ import gradio as gr
10
+ import PIL.Image
11
+ import torch
12
+
13
+
14
+ def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
15
+ w, h = image.size
16
+ if w == h:
17
+ return image
18
+ elif w > h:
19
+ new_image = PIL.Image.new(image.mode, (w, w), (0, 0, 0))
20
+ new_image.paste(image, (0, (w - h) // 2))
21
+ return new_image
22
+ else:
23
+ new_image = PIL.Image.new(image.mode, (h, h), (0, 0, 0))
24
+ new_image.paste(image, ((h - w) // 2, 0))
25
+ return new_image
26
+
27
+
28
+ class Trainer:
29
+ def __init__(self):
30
+ self.is_running = False
31
+ self.is_running_message = "Another training is in progress."
32
+
33
+ self.output_dir = pathlib.Path("results")
34
+ self.instance_data_dir = self.output_dir / "training_data"
35
+
36
+ def check_if_running(self) -> dict:
37
+ if self.is_running:
38
+ return gr.update(value=self.is_running_message)
39
+ else:
40
+ return gr.update(value="No training is running.")
41
+
42
+ def cleanup_dirs(self) -> None:
43
+ shutil.rmtree(self.output_dir, ignore_errors=True)
44
+
45
+ def prepare_dataset(self, concept_images: list, resolution: int) -> None:
46
+ self.instance_data_dir.mkdir(parents=True)
47
+ for i, temp_path in enumerate(concept_images):
48
+ image = PIL.Image.open(temp_path.name)
49
+ image = pad_image(image)
50
+ image = image.resize((resolution, resolution))
51
+ image = image.convert("RGB")
52
+ out_path = self.instance_data_dir / f"{i:03d}.jpg"
53
+ image.save(out_path, format="JPEG", quality=100)
54
+
55
+ def run(
56
+ self,
57
+ base_model: str,
58
+ resolution_s: str,
59
+ n_steps: int,
60
+ concept_images: list | None,
61
+ concept_prompt: str,
62
+ learning_rate: float,
63
+ gradient_accumulation: int,
64
+ fp16: bool,
65
+ use_8bit_adam: bool,
66
+ gradient_checkpointing: bool,
67
+ train_text_encoder: bool,
68
+ with_prior_preservation: bool,
69
+ prior_loss_weight: float,
70
+ class_prompt: str,
71
+ num_class_images: int,
72
+ lora_r: int,
73
+ lora_alpha: int,
74
+ lora_bias: str,
75
+ lora_dropout: float,
76
+ lora_text_encoder_r: int,
77
+ lora_text_encoder_alpha: int,
78
+ lora_text_encoder_bias: str,
79
+ lora_text_encoder_dropout: float,
80
+ ) -> tuple[dict, list[pathlib.Path]]:
81
+ if not torch.cuda.is_available():
82
+ raise gr.Error("CUDA is not available.")
83
+
84
+ if self.is_running:
85
+ return gr.update(value=self.is_running_message), []
86
+
87
+ if concept_images is None:
88
+ raise gr.Error("You need to upload images.")
89
+ if not concept_prompt:
90
+ raise gr.Error("The concept prompt is missing.")
91
+
92
+ resolution = int(resolution_s)
93
+
94
+ self.cleanup_dirs()
95
+ self.prepare_dataset(concept_images, resolution)
96
+
97
+ command = f"""
98
+ accelerate launch train_dreambooth.py \
99
+ --pretrained_model_name_or_path={base_model} \
100
+ --instance_data_dir={self.instance_data_dir} \
101
+ --output_dir={self.output_dir} \
102
+ --train_text_encoder \
103
+ --instance_prompt="{concept_prompt}" \
104
+ --resolution={resolution} \
105
+ --gradient_accumulation_steps={gradient_accumulation} \
106
+ --learning_rate={learning_rate} \
107
+ --max_train_steps={n_steps} \
108
+ --train_batch_size=1 \
109
+ --lr_scheduler=constant \
110
+ --lr_warmup_steps=0 \
111
+ --num_class_images={num_class_images} \
112
+ """
113
+ if train_text_encoder:
114
+ command += f" --train_text_encoder"
115
+ if with_prior_preservation:
116
+ command += f""" --with_prior_preservation \
117
+ --prior_loss_weight={prior_loss_weight} \
118
+ --class_prompt="{class_prompt}" \
119
+ --class_data_dir={self.output_dir / 'class_data'}
120
+ """
121
+
122
+ command += f""" --use_lora \
123
+ --lora_r={lora_r} \
124
+ --lora_alpha={lora_alpha} \
125
+ --lora_bias={lora_bias} \
126
+ --lora_dropout={lora_dropout}
127
+ """
128
+
129
+ if train_text_encoder:
130
+ command += f""" --lora_text_encoder_r={lora_text_encoder_r} \
131
+ --lora_text_encoder_alpha={lora_text_encoder_alpha} \
132
+ --lora_text_encoder_bias={lora_text_encoder_bias} \
133
+ --lora_text_encoder_dropout={lora_text_encoder_dropout}
134
+ """
135
+ if fp16:
136
+ command += " --mixed_precision fp16"
137
+ if use_8bit_adam:
138
+ command += " --use_8bit_adam"
139
+ if gradient_checkpointing:
140
+ command += " --gradient_checkpointing"
141
+
142
+ with open(self.output_dir / "train.sh", "w") as f:
143
+ command_s = " ".join(command.split())
144
+ f.write(command_s)
145
+
146
+ self.is_running = True
147
+ res = subprocess.run(shlex.split(command))
148
+ self.is_running = False
149
+
150
+ if res.returncode == 0:
151
+ result_message = "Training Completed!"
152
+ else:
153
+ result_message = "Training Failed!"
154
+ weight_paths = sorted(self.output_dir.glob("*.pt"))
155
+ config_paths = sorted(self.output_dir.glob("*.json"))
156
+ return gr.update(value=result_message), weight_paths + config_paths
uploader.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import HfApi
3
+
4
+
5
+ def upload(model_name: str, hf_token: str) -> None:
6
+ api = HfApi(token=hf_token)
7
+ user_name = api.whoami()["name"]
8
+ model_id = f"{user_name}/{model_name}"
9
+ try:
10
+ api.create_repo(model_id, repo_type="model", private=True)
11
+ api.upload_folder(repo_id=model_id, folder_path="results", path_in_repo="results", repo_type="model")
12
+ url = f"https://huggingface.co/{model_id}"
13
+ message = f"Your model was successfully uploaded to [{url}]({url})."
14
+ except Exception as e:
15
+ message = str(e)
16
+
17
+ return gr.update(value=message, visible=True)