fffiloni commited on
Commit
8336b37
1 Parent(s): 032504d

logs current steps + auto push checkpoint step to repo

Browse files
Files changed (1) hide show
  1. train_dreambooth_lora_sdxl.py +103 -7
train_dreambooth_lora_sdxl.py CHANGED
@@ -13,6 +13,7 @@
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
  # See the License for the specific language governing permissions and
15
 
 
16
  import argparse
17
  import gc
18
  import hashlib
@@ -62,6 +63,73 @@ check_min_version("0.21.0.dev0")
62
 
63
  logger = get_logger(__name__)
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  def save_model_card(
67
  repo_id: str, images=None, dataset_id=str, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None
@@ -92,9 +160,9 @@ datasets:
92
  These are LoRA adaption weights for {base_model}.
93
 
94
  The weights were trained on the concept prompt:
95
-
96
- `{prompt}`
97
-
98
  Use this keyword to trigger your custom model in your prompts.
99
 
100
  LoRA for the text encoder was enabled: {train_text_encoder}.
@@ -126,11 +194,11 @@ pipe = DiffusionPipeline.from_pretrained(
126
  use_safetensors=True
127
  )
128
 
 
 
129
  # This is where you load your trained weights
130
  pipe.load_lora_weights('{repo_id}')
131
 
132
- pipe.to("cuda")
133
-
134
  prompt = "A majestic {prompt} jumping from a big stone at night"
135
 
136
  image = pipe(prompt=prompt, num_inference_steps=50).images[0]
@@ -1067,6 +1135,7 @@ def main(args):
1067
  accelerator.init_trackers("dreambooth-lora-sd-xl", config=vars(args))
1068
 
1069
  # Train!
 
1070
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1071
 
1072
  logger.info("***** Running training *****")
@@ -1110,6 +1179,9 @@ def main(args):
1110
  progress_bar.set_description("Steps")
1111
 
1112
  for epoch in range(first_epoch, args.num_train_epochs):
 
 
 
1113
  unet.train()
1114
  if args.train_text_encoder:
1115
  text_encoder_one.train()
@@ -1118,7 +1190,7 @@ def main(args):
1118
  # Skip steps until we reach the resumed step
1119
  if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
1120
  if step % args.gradient_accumulation_steps == 0:
1121
- progress_bar.update(1)
1122
  continue
1123
 
1124
  with accelerator.accumulate(unet):
@@ -1211,6 +1283,8 @@ def main(args):
1211
 
1212
  # Checks if the accelerator has performed an optimization step behind the scenes
1213
  if accelerator.sync_gradients:
 
 
1214
  progress_bar.update(1)
1215
  global_step += 1
1216
 
@@ -1240,10 +1314,32 @@ def main(args):
1240
  accelerator.save_state(save_path)
1241
  logger.info(f"Saved state to {save_path}")
1242
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1243
  logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1244
  progress_bar.set_postfix(**logs)
1245
  accelerator.log(logs, step=global_step)
1246
-
 
 
1247
  if global_step >= args.max_train_steps:
1248
  break
1249
 
 
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
  # See the License for the specific language governing permissions and
15
 
16
+ import gradio as gr
17
  import argparse
18
  import gc
19
  import hashlib
 
63
 
64
  logger = get_logger(__name__)
65
 
66
+ def save_tempo_model_card(
67
+ repo_id: str, dataset_id=str, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None, last_checkpoint=str
68
+ ):
69
+
70
+ yaml = f"""
71
+ ---
72
+ base_model: {base_model}
73
+ instance_prompt: {prompt}
74
+ tags:
75
+ - stable-diffusion-xl
76
+ - stable-diffusion-xl-diffusers
77
+ - text-to-image
78
+ - diffusers
79
+ - lora
80
+ inference: false
81
+ datasets:
82
+ - {dataset_id}
83
+ ---
84
+ """
85
+ model_card = f"""
86
+ # LoRA DreamBooth - {repo_id}
87
+
88
+ ## MODEL IS CURRENTLY TRAINING ...
89
+ Last checkpoint saved: {last_checkpoint}
90
+
91
+ These are LoRA adaption weights for {base_model}.
92
+
93
+ The weights were trained on the concept prompt:
94
+ ```
95
+ {prompt}
96
+ ```
97
+ Use this keyword to trigger your custom model in your prompts.
98
+
99
+ LoRA for the text encoder was enabled: {train_text_encoder}.
100
+
101
+ Special VAE used for training: {vae_path}.
102
+
103
+ ## Usage
104
+
105
+ Make sure to upgrade diffusers to >= 0.19.0:
106
+ ```
107
+ pip install diffusers --upgrade
108
+ ```
109
+ In addition make sure to install transformers, safetensors, accelerate as well as the invisible watermark:
110
+ ```
111
+ pip install invisible_watermark transformers accelerate safetensors
112
+ ```
113
+ To just use the base model, you can run:
114
+ ```python
115
+ import torch
116
+ from diffusers import DiffusionPipeline, AutoencoderKL
117
+ vae = AutoencoderKL.from_pretrained('{vae_path}', torch_dtype=torch.float16)
118
+ pipe = DiffusionPipeline.from_pretrained(
119
+ "stabilityai/stable-diffusion-xl-base-1.0",
120
+ vae=vae, torch_dtype=torch.float16, variant="fp16",
121
+ use_safetensors=True
122
+ )
123
+ pipe.to("cuda")
124
+ # This is where you load your trained weights
125
+ pipe.load_lora_weights('{repo_id}')
126
+
127
+ prompt = "A majestic {prompt} jumping from a big stone at night"
128
+ image = pipe(prompt=prompt, num_inference_steps=50).images[0]
129
+ ```
130
+ """
131
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
132
+ f.write(yaml + model_card)
133
 
134
  def save_model_card(
135
  repo_id: str, images=None, dataset_id=str, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None
 
160
  These are LoRA adaption weights for {base_model}.
161
 
162
  The weights were trained on the concept prompt:
163
+ ```
164
+ {prompt}
165
+ ```
166
  Use this keyword to trigger your custom model in your prompts.
167
 
168
  LoRA for the text encoder was enabled: {train_text_encoder}.
 
194
  use_safetensors=True
195
  )
196
 
197
+ pipe.to("cuda")
198
+
199
  # This is where you load your trained weights
200
  pipe.load_lora_weights('{repo_id}')
201
 
 
 
202
  prompt = "A majestic {prompt} jumping from a big stone at night"
203
 
204
  image = pipe(prompt=prompt, num_inference_steps=50).images[0]
 
1135
  accelerator.init_trackers("dreambooth-lora-sd-xl", config=vars(args))
1136
 
1137
  # Train!
1138
+ gr.Info("Training Starts now")
1139
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1140
 
1141
  logger.info("***** Running training *****")
 
1179
  progress_bar.set_description("Steps")
1180
 
1181
  for epoch in range(first_epoch, args.num_train_epochs):
1182
+ # Print a message for each epoch
1183
+ print(f"Epoch {epoch}: Training in progress...")
1184
+
1185
  unet.train()
1186
  if args.train_text_encoder:
1187
  text_encoder_one.train()
 
1190
  # Skip steps until we reach the resumed step
1191
  if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
1192
  if step % args.gradient_accumulation_steps == 0:
1193
+ progress_bar.update(1)
1194
  continue
1195
 
1196
  with accelerator.accumulate(unet):
 
1283
 
1284
  # Checks if the accelerator has performed an optimization step behind the scenes
1285
  if accelerator.sync_gradients:
1286
+ # Print a message for each step
1287
+ print(f"Step {global_step}/{args.max_train_steps}: Done")
1288
  progress_bar.update(1)
1289
  global_step += 1
1290
 
 
1314
  accelerator.save_state(save_path)
1315
  logger.info(f"Saved state to {save_path}")
1316
 
1317
+ gr.Info(f"Saving checkpoint-{global_step} to {repo_id}")
1318
+ save_tempo_model_card(
1319
+ repo_id,
1320
+ dataset_id=args.dataset_id,
1321
+ base_model=args.pretrained_model_name_or_path,
1322
+ train_text_encoder=args.train_text_encoder,
1323
+ prompt=args.instance_prompt,
1324
+ repo_folder=args.output_dir,
1325
+ vae_path=args.pretrained_vae_model_name_or_path,
1326
+ last_checkpoint = f"checkpoint-{global_step}"
1327
+ )
1328
+
1329
+ upload_folder(
1330
+ repo_id=repo_id,
1331
+ folder_path=args.output_dir,
1332
+ commit_message=f"saving checkpoint-{global_step}",
1333
+ ignore_patterns=["step_*", "epoch_*"],
1334
+ token=args.hub_token
1335
+ )
1336
+
1337
  logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1338
  progress_bar.set_postfix(**logs)
1339
  accelerator.log(logs, step=global_step)
1340
+
1341
+
1342
+
1343
  if global_step >= args.max_train_steps:
1344
  break
1345