Spaces:
Runtime error
Runtime error
logs current steps + auto push checkpoint step to repo
Browse files- train_dreambooth_lora_sdxl.py +103 -7
train_dreambooth_lora_sdxl.py
CHANGED
@@ -13,6 +13,7 @@
|
|
13 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
# See the License for the specific language governing permissions and
|
15 |
|
|
|
16 |
import argparse
|
17 |
import gc
|
18 |
import hashlib
|
@@ -62,6 +63,73 @@ check_min_version("0.21.0.dev0")
|
|
62 |
|
63 |
logger = get_logger(__name__)
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
def save_model_card(
|
67 |
repo_id: str, images=None, dataset_id=str, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None
|
@@ -92,9 +160,9 @@ datasets:
|
|
92 |
These are LoRA adaption weights for {base_model}.
|
93 |
|
94 |
The weights were trained on the concept prompt:
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
Use this keyword to trigger your custom model in your prompts.
|
99 |
|
100 |
LoRA for the text encoder was enabled: {train_text_encoder}.
|
@@ -126,11 +194,11 @@ pipe = DiffusionPipeline.from_pretrained(
|
|
126 |
use_safetensors=True
|
127 |
)
|
128 |
|
|
|
|
|
129 |
# This is where you load your trained weights
|
130 |
pipe.load_lora_weights('{repo_id}')
|
131 |
|
132 |
-
pipe.to("cuda")
|
133 |
-
|
134 |
prompt = "A majestic {prompt} jumping from a big stone at night"
|
135 |
|
136 |
image = pipe(prompt=prompt, num_inference_steps=50).images[0]
|
@@ -1067,6 +1135,7 @@ def main(args):
|
|
1067 |
accelerator.init_trackers("dreambooth-lora-sd-xl", config=vars(args))
|
1068 |
|
1069 |
# Train!
|
|
|
1070 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
1071 |
|
1072 |
logger.info("***** Running training *****")
|
@@ -1110,6 +1179,9 @@ def main(args):
|
|
1110 |
progress_bar.set_description("Steps")
|
1111 |
|
1112 |
for epoch in range(first_epoch, args.num_train_epochs):
|
|
|
|
|
|
|
1113 |
unet.train()
|
1114 |
if args.train_text_encoder:
|
1115 |
text_encoder_one.train()
|
@@ -1118,7 +1190,7 @@ def main(args):
|
|
1118 |
# Skip steps until we reach the resumed step
|
1119 |
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
1120 |
if step % args.gradient_accumulation_steps == 0:
|
1121 |
-
progress_bar.update(1)
|
1122 |
continue
|
1123 |
|
1124 |
with accelerator.accumulate(unet):
|
@@ -1211,6 +1283,8 @@ def main(args):
|
|
1211 |
|
1212 |
# Checks if the accelerator has performed an optimization step behind the scenes
|
1213 |
if accelerator.sync_gradients:
|
|
|
|
|
1214 |
progress_bar.update(1)
|
1215 |
global_step += 1
|
1216 |
|
@@ -1240,10 +1314,32 @@ def main(args):
|
|
1240 |
accelerator.save_state(save_path)
|
1241 |
logger.info(f"Saved state to {save_path}")
|
1242 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1243 |
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
1244 |
progress_bar.set_postfix(**logs)
|
1245 |
accelerator.log(logs, step=global_step)
|
1246 |
-
|
|
|
|
|
1247 |
if global_step >= args.max_train_steps:
|
1248 |
break
|
1249 |
|
|
|
13 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
# See the License for the specific language governing permissions and
|
15 |
|
16 |
+
import gradio as gr
|
17 |
import argparse
|
18 |
import gc
|
19 |
import hashlib
|
|
|
63 |
|
64 |
logger = get_logger(__name__)
|
65 |
|
66 |
+
def save_tempo_model_card(
|
67 |
+
repo_id: str, dataset_id=str, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None, last_checkpoint=str
|
68 |
+
):
|
69 |
+
|
70 |
+
yaml = f"""
|
71 |
+
---
|
72 |
+
base_model: {base_model}
|
73 |
+
instance_prompt: {prompt}
|
74 |
+
tags:
|
75 |
+
- stable-diffusion-xl
|
76 |
+
- stable-diffusion-xl-diffusers
|
77 |
+
- text-to-image
|
78 |
+
- diffusers
|
79 |
+
- lora
|
80 |
+
inference: false
|
81 |
+
datasets:
|
82 |
+
- {dataset_id}
|
83 |
+
---
|
84 |
+
"""
|
85 |
+
model_card = f"""
|
86 |
+
# LoRA DreamBooth - {repo_id}
|
87 |
+
|
88 |
+
## MODEL IS CURRENTLY TRAINING ...
|
89 |
+
Last checkpoint saved: {last_checkpoint}
|
90 |
+
|
91 |
+
These are LoRA adaption weights for {base_model}.
|
92 |
+
|
93 |
+
The weights were trained on the concept prompt:
|
94 |
+
```
|
95 |
+
{prompt}
|
96 |
+
```
|
97 |
+
Use this keyword to trigger your custom model in your prompts.
|
98 |
+
|
99 |
+
LoRA for the text encoder was enabled: {train_text_encoder}.
|
100 |
+
|
101 |
+
Special VAE used for training: {vae_path}.
|
102 |
+
|
103 |
+
## Usage
|
104 |
+
|
105 |
+
Make sure to upgrade diffusers to >= 0.19.0:
|
106 |
+
```
|
107 |
+
pip install diffusers --upgrade
|
108 |
+
```
|
109 |
+
In addition make sure to install transformers, safetensors, accelerate as well as the invisible watermark:
|
110 |
+
```
|
111 |
+
pip install invisible_watermark transformers accelerate safetensors
|
112 |
+
```
|
113 |
+
To just use the base model, you can run:
|
114 |
+
```python
|
115 |
+
import torch
|
116 |
+
from diffusers import DiffusionPipeline, AutoencoderKL
|
117 |
+
vae = AutoencoderKL.from_pretrained('{vae_path}', torch_dtype=torch.float16)
|
118 |
+
pipe = DiffusionPipeline.from_pretrained(
|
119 |
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
120 |
+
vae=vae, torch_dtype=torch.float16, variant="fp16",
|
121 |
+
use_safetensors=True
|
122 |
+
)
|
123 |
+
pipe.to("cuda")
|
124 |
+
# This is where you load your trained weights
|
125 |
+
pipe.load_lora_weights('{repo_id}')
|
126 |
+
|
127 |
+
prompt = "A majestic {prompt} jumping from a big stone at night"
|
128 |
+
image = pipe(prompt=prompt, num_inference_steps=50).images[0]
|
129 |
+
```
|
130 |
+
"""
|
131 |
+
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
132 |
+
f.write(yaml + model_card)
|
133 |
|
134 |
def save_model_card(
|
135 |
repo_id: str, images=None, dataset_id=str, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None
|
|
|
160 |
These are LoRA adaption weights for {base_model}.
|
161 |
|
162 |
The weights were trained on the concept prompt:
|
163 |
+
```
|
164 |
+
{prompt}
|
165 |
+
```
|
166 |
Use this keyword to trigger your custom model in your prompts.
|
167 |
|
168 |
LoRA for the text encoder was enabled: {train_text_encoder}.
|
|
|
194 |
use_safetensors=True
|
195 |
)
|
196 |
|
197 |
+
pipe.to("cuda")
|
198 |
+
|
199 |
# This is where you load your trained weights
|
200 |
pipe.load_lora_weights('{repo_id}')
|
201 |
|
|
|
|
|
202 |
prompt = "A majestic {prompt} jumping from a big stone at night"
|
203 |
|
204 |
image = pipe(prompt=prompt, num_inference_steps=50).images[0]
|
|
|
1135 |
accelerator.init_trackers("dreambooth-lora-sd-xl", config=vars(args))
|
1136 |
|
1137 |
# Train!
|
1138 |
+
gr.Info("Training Starts now")
|
1139 |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
1140 |
|
1141 |
logger.info("***** Running training *****")
|
|
|
1179 |
progress_bar.set_description("Steps")
|
1180 |
|
1181 |
for epoch in range(first_epoch, args.num_train_epochs):
|
1182 |
+
# Print a message for each epoch
|
1183 |
+
print(f"Epoch {epoch}: Training in progress...")
|
1184 |
+
|
1185 |
unet.train()
|
1186 |
if args.train_text_encoder:
|
1187 |
text_encoder_one.train()
|
|
|
1190 |
# Skip steps until we reach the resumed step
|
1191 |
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
1192 |
if step % args.gradient_accumulation_steps == 0:
|
1193 |
+
progress_bar.update(1)
|
1194 |
continue
|
1195 |
|
1196 |
with accelerator.accumulate(unet):
|
|
|
1283 |
|
1284 |
# Checks if the accelerator has performed an optimization step behind the scenes
|
1285 |
if accelerator.sync_gradients:
|
1286 |
+
# Print a message for each step
|
1287 |
+
print(f"Step {global_step}/{args.max_train_steps}: Done")
|
1288 |
progress_bar.update(1)
|
1289 |
global_step += 1
|
1290 |
|
|
|
1314 |
accelerator.save_state(save_path)
|
1315 |
logger.info(f"Saved state to {save_path}")
|
1316 |
|
1317 |
+
gr.Info(f"Saving checkpoint-{global_step} to {repo_id}")
|
1318 |
+
save_tempo_model_card(
|
1319 |
+
repo_id,
|
1320 |
+
dataset_id=args.dataset_id,
|
1321 |
+
base_model=args.pretrained_model_name_or_path,
|
1322 |
+
train_text_encoder=args.train_text_encoder,
|
1323 |
+
prompt=args.instance_prompt,
|
1324 |
+
repo_folder=args.output_dir,
|
1325 |
+
vae_path=args.pretrained_vae_model_name_or_path,
|
1326 |
+
last_checkpoint = f"checkpoint-{global_step}"
|
1327 |
+
)
|
1328 |
+
|
1329 |
+
upload_folder(
|
1330 |
+
repo_id=repo_id,
|
1331 |
+
folder_path=args.output_dir,
|
1332 |
+
commit_message=f"saving checkpoint-{global_step}",
|
1333 |
+
ignore_patterns=["step_*", "epoch_*"],
|
1334 |
+
token=args.hub_token
|
1335 |
+
)
|
1336 |
+
|
1337 |
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
1338 |
progress_bar.set_postfix(**logs)
|
1339 |
accelerator.log(logs, step=global_step)
|
1340 |
+
|
1341 |
+
|
1342 |
+
|
1343 |
if global_step >= args.max_train_steps:
|
1344 |
break
|
1345 |
|