fffiloni commited on
Commit
716d755
1 Parent(s): 4d8bece

Update train_dreambooth_lora_sdxl.py

Browse files
Files changed (1) hide show
  1. train_dreambooth_lora_sdxl.py +33 -55
train_dreambooth_lora_sdxl.py CHANGED
@@ -13,7 +13,6 @@
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
  # See the License for the specific language governing permissions and
15
 
16
- import gradio as gr
17
  import argparse
18
  import gc
19
  import hashlib
@@ -59,14 +58,14 @@ from diffusers.utils.import_utils import is_xformers_available
59
 
60
 
61
  # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
62
- check_min_version("0.21.0.dev0")
63
 
64
  logger = get_logger(__name__)
65
 
66
  def save_tempo_model_card(
67
  repo_id: str, dataset_id=str, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None, last_checkpoint=str
68
  ):
69
-
70
  yaml = f"""
71
  ---
72
  base_model: {base_model}
@@ -84,24 +83,17 @@ datasets:
84
  """
85
  model_card = f"""
86
  # LoRA DreamBooth - {repo_id}
87
-
88
  ## MODEL IS CURRENTLY TRAINING ...
89
  Last checkpoint saved: {last_checkpoint}
90
-
91
- These are LoRA adaption weights for {base_model}.
92
-
93
- The weights is currently trained on the concept prompt:
94
  ```
95
  {prompt}
96
- ```
97
  Use this keyword to trigger your custom model in your prompts.
98
-
99
  LoRA for the text encoder was enabled: {train_text_encoder}.
100
-
101
  Special VAE used for training: {vae_path}.
102
-
103
  ## Usage
104
-
105
  Make sure to upgrade diffusers to >= 0.19.0:
106
  ```
107
  pip install diffusers --upgrade
@@ -114,18 +106,28 @@ To just use the base model, you can run:
114
  ```python
115
  import torch
116
  from diffusers import DiffusionPipeline, AutoencoderKL
 
117
  vae = AutoencoderKL.from_pretrained('{vae_path}', torch_dtype=torch.float16)
118
  pipe = DiffusionPipeline.from_pretrained(
119
  "stabilityai/stable-diffusion-xl-base-1.0",
120
  vae=vae, torch_dtype=torch.float16, variant="fp16",
121
  use_safetensors=True
122
  )
123
- pipe.to("cuda")
124
  # This is where you load your trained weights
125
- pipe.load_lora_weights('{repo_id}')
126
-
 
 
 
 
 
127
  prompt = "A majestic {prompt} jumping from a big stone at night"
128
- image = pipe(prompt=prompt, num_inference_steps=50).images[0]
 
 
 
 
129
  ```
130
  """
131
  with open(os.path.join(repo_folder, "README.md"), "w") as f:
@@ -156,62 +158,44 @@ datasets:
156
  """
157
  model_card = f"""
158
  # LoRA DreamBooth - {repo_id}
159
-
160
  These are LoRA adaption weights for {base_model} trained on @fffiloni's SD-XL trainer.
161
-
162
  The weights were trained on the concept prompt:
163
  ```
164
  {prompt}
165
  ```
166
  Use this keyword to trigger your custom model in your prompts.
167
-
168
  LoRA for the text encoder was enabled: {train_text_encoder}.
169
-
170
  Special VAE used for training: {vae_path}.
171
-
172
  ## Usage
173
-
174
  Make sure to upgrade diffusers to >= 0.19.0:
175
  ```
176
  pip install diffusers --upgrade
177
  ```
178
-
179
  In addition make sure to install transformers, safetensors, accelerate as well as the invisible watermark:
180
  ```
181
  pip install invisible_watermark transformers accelerate safetensors
182
  ```
183
-
184
  To just use the base model, you can run:
185
-
186
  ```python
187
  import torch
188
  from diffusers import DiffusionPipeline, AutoencoderKL
189
-
190
  device = "cuda" if torch.cuda.is_available() else "cpu"
191
-
192
  vae = AutoencoderKL.from_pretrained('{vae_path}', torch_dtype=torch.float16)
193
-
194
  pipe = DiffusionPipeline.from_pretrained(
195
  "stabilityai/stable-diffusion-xl-base-1.0",
196
  vae=vae, torch_dtype=torch.float16, variant="fp16",
197
  use_safetensors=True
198
  )
199
-
200
  pipe.to(device)
201
-
202
  # This is where you load your trained weights
203
-
204
  specific_safetensors = "pytorch_lora_weights.safetensors"
205
  lora_scale = 0.9
206
-
207
  pipe.load_lora_weights(
208
  '{repo_id}',
209
  weight_name = specific_safetensors,
210
  # use_auth_token = True
211
  )
212
-
213
  prompt = "A majestic {prompt} jumping from a big stone at night"
214
-
215
  image = pipe(
216
  prompt=prompt,
217
  num_inference_steps=50,
@@ -809,7 +793,7 @@ def main(args):
809
 
810
  if args.push_to_hub:
811
  repo_id = create_repo(
812
- repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, private=True, token=args.hub_token
813
  ).repo_id
814
 
815
  # Load the tokenizers
@@ -1150,7 +1134,6 @@ def main(args):
1150
  accelerator.init_trackers("dreambooth-lora-sd-xl", config=vars(args))
1151
 
1152
  # Train!
1153
- gr.Info("Training Starts now")
1154
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1155
 
1156
  logger.info("***** Running training *****")
@@ -1180,34 +1163,34 @@ def main(args):
1180
  f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1181
  )
1182
  args.resume_from_checkpoint = None
 
1183
  else:
1184
  accelerator.print(f"Resuming from checkpoint {path}")
1185
  accelerator.load_state(os.path.join(args.output_dir, path))
1186
  global_step = int(path.split("-")[1])
1187
 
1188
- resume_global_step = global_step * args.gradient_accumulation_steps
1189
  first_epoch = global_step // num_update_steps_per_epoch
1190
- resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
1191
 
1192
- # Only show the progress bar once on each machine.
1193
- progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
1194
- progress_bar.set_description("Steps")
 
 
 
 
 
 
 
1195
 
1196
  for epoch in range(first_epoch, args.num_train_epochs):
1197
  # Print a message for each epoch
1198
  print(f"Epoch {epoch}: Training in progress...")
1199
-
1200
  unet.train()
1201
  if args.train_text_encoder:
1202
  text_encoder_one.train()
1203
  text_encoder_two.train()
1204
  for step, batch in enumerate(train_dataloader):
1205
- # Skip steps until we reach the resumed step
1206
- if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
1207
- if step % args.gradient_accumulation_steps == 0:
1208
- progress_bar.update(1)
1209
- continue
1210
-
1211
  with accelerator.accumulate(unet):
1212
  pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
1213
 
@@ -1329,7 +1312,6 @@ def main(args):
1329
  accelerator.save_state(save_path)
1330
  logger.info(f"Saved state to {save_path}")
1331
 
1332
- gr.Info(f"Saving checkpoint-{global_step} to {repo_id}")
1333
  save_tempo_model_card(
1334
  repo_id,
1335
  dataset_id=args.dataset_id,
@@ -1352,9 +1334,7 @@ def main(args):
1352
  logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1353
  progress_bar.set_postfix(**logs)
1354
  accelerator.log(logs, step=global_step)
1355
-
1356
-
1357
-
1358
  if global_step >= args.max_train_steps:
1359
  break
1360
 
@@ -1512,7 +1492,6 @@ def main(args):
1512
  prompt=args.instance_prompt,
1513
  repo_folder=args.output_dir,
1514
  vae_path=args.pretrained_vae_model_name_or_path,
1515
-
1516
  )
1517
  upload_folder(
1518
  repo_id=repo_id,
@@ -1524,7 +1503,6 @@ def main(args):
1524
 
1525
  accelerator.end_training()
1526
 
1527
-
1528
  if __name__ == "__main__":
1529
  args = parse_args()
1530
  main(args)
 
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
  # See the License for the specific language governing permissions and
15
 
 
16
  import argparse
17
  import gc
18
  import hashlib
 
58
 
59
 
60
  # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
61
+ check_min_version("0.22.0.dev0")
62
 
63
  logger = get_logger(__name__)
64
 
65
  def save_tempo_model_card(
66
  repo_id: str, dataset_id=str, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None, last_checkpoint=str
67
  ):
68
+
69
  yaml = f"""
70
  ---
71
  base_model: {base_model}
 
83
  """
84
  model_card = f"""
85
  # LoRA DreamBooth - {repo_id}
 
86
  ## MODEL IS CURRENTLY TRAINING ...
87
  Last checkpoint saved: {last_checkpoint}
88
+ These are LoRA adaption weights for {base_model} trained on @fffiloni's SD-XL trainer.
89
+ The weights were trained on the concept prompt:
 
 
90
  ```
91
  {prompt}
92
+ ```
93
  Use this keyword to trigger your custom model in your prompts.
 
94
  LoRA for the text encoder was enabled: {train_text_encoder}.
 
95
  Special VAE used for training: {vae_path}.
 
96
  ## Usage
 
97
  Make sure to upgrade diffusers to >= 0.19.0:
98
  ```
99
  pip install diffusers --upgrade
 
106
  ```python
107
  import torch
108
  from diffusers import DiffusionPipeline, AutoencoderKL
109
+ device = "cuda" if torch.cuda.is_available() else "cpu"
110
  vae = AutoencoderKL.from_pretrained('{vae_path}', torch_dtype=torch.float16)
111
  pipe = DiffusionPipeline.from_pretrained(
112
  "stabilityai/stable-diffusion-xl-base-1.0",
113
  vae=vae, torch_dtype=torch.float16, variant="fp16",
114
  use_safetensors=True
115
  )
116
+ pipe.to(device)
117
  # This is where you load your trained weights
118
+ specific_safetensors = "pytorch_lora_weights.safetensors"
119
+ lora_scale = 0.9
120
+ pipe.load_lora_weights(
121
+ '{repo_id}',
122
+ weight_name = specific_safetensors,
123
+ # use_auth_token = True
124
+ )
125
  prompt = "A majestic {prompt} jumping from a big stone at night"
126
+ image = pipe(
127
+ prompt=prompt,
128
+ num_inference_steps=50,
129
+ cross_attention_kwargs={{"scale": lora_scale}}
130
+ ).images[0]
131
  ```
132
  """
133
  with open(os.path.join(repo_folder, "README.md"), "w") as f:
 
158
  """
159
  model_card = f"""
160
  # LoRA DreamBooth - {repo_id}
 
161
  These are LoRA adaption weights for {base_model} trained on @fffiloni's SD-XL trainer.
 
162
  The weights were trained on the concept prompt:
163
  ```
164
  {prompt}
165
  ```
166
  Use this keyword to trigger your custom model in your prompts.
 
167
  LoRA for the text encoder was enabled: {train_text_encoder}.
 
168
  Special VAE used for training: {vae_path}.
 
169
  ## Usage
 
170
  Make sure to upgrade diffusers to >= 0.19.0:
171
  ```
172
  pip install diffusers --upgrade
173
  ```
 
174
  In addition make sure to install transformers, safetensors, accelerate as well as the invisible watermark:
175
  ```
176
  pip install invisible_watermark transformers accelerate safetensors
177
  ```
 
178
  To just use the base model, you can run:
 
179
  ```python
180
  import torch
181
  from diffusers import DiffusionPipeline, AutoencoderKL
 
182
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
183
  vae = AutoencoderKL.from_pretrained('{vae_path}', torch_dtype=torch.float16)
 
184
  pipe = DiffusionPipeline.from_pretrained(
185
  "stabilityai/stable-diffusion-xl-base-1.0",
186
  vae=vae, torch_dtype=torch.float16, variant="fp16",
187
  use_safetensors=True
188
  )
 
189
  pipe.to(device)
 
190
  # This is where you load your trained weights
 
191
  specific_safetensors = "pytorch_lora_weights.safetensors"
192
  lora_scale = 0.9
 
193
  pipe.load_lora_weights(
194
  '{repo_id}',
195
  weight_name = specific_safetensors,
196
  # use_auth_token = True
197
  )
 
198
  prompt = "A majestic {prompt} jumping from a big stone at night"
 
199
  image = pipe(
200
  prompt=prompt,
201
  num_inference_steps=50,
 
793
 
794
  if args.push_to_hub:
795
  repo_id = create_repo(
796
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
797
  ).repo_id
798
 
799
  # Load the tokenizers
 
1134
  accelerator.init_trackers("dreambooth-lora-sd-xl", config=vars(args))
1135
 
1136
  # Train!
 
1137
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1138
 
1139
  logger.info("***** Running training *****")
 
1163
  f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1164
  )
1165
  args.resume_from_checkpoint = None
1166
+ initial_global_step = 0
1167
  else:
1168
  accelerator.print(f"Resuming from checkpoint {path}")
1169
  accelerator.load_state(os.path.join(args.output_dir, path))
1170
  global_step = int(path.split("-")[1])
1171
 
1172
+ initial_global_step = global_step
1173
  first_epoch = global_step // num_update_steps_per_epoch
 
1174
 
1175
+ else:
1176
+ initial_global_step = 0
1177
+
1178
+ progress_bar = tqdm(
1179
+ range(0, args.max_train_steps),
1180
+ initial=initial_global_step,
1181
+ desc="Steps",
1182
+ # Only show the progress bar once on each machine.
1183
+ disable=not accelerator.is_local_main_process,
1184
+ )
1185
 
1186
  for epoch in range(first_epoch, args.num_train_epochs):
1187
  # Print a message for each epoch
1188
  print(f"Epoch {epoch}: Training in progress...")
 
1189
  unet.train()
1190
  if args.train_text_encoder:
1191
  text_encoder_one.train()
1192
  text_encoder_two.train()
1193
  for step, batch in enumerate(train_dataloader):
 
 
 
 
 
 
1194
  with accelerator.accumulate(unet):
1195
  pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
1196
 
 
1312
  accelerator.save_state(save_path)
1313
  logger.info(f"Saved state to {save_path}")
1314
 
 
1315
  save_tempo_model_card(
1316
  repo_id,
1317
  dataset_id=args.dataset_id,
 
1334
  logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1335
  progress_bar.set_postfix(**logs)
1336
  accelerator.log(logs, step=global_step)
1337
+
 
 
1338
  if global_step >= args.max_train_steps:
1339
  break
1340
 
 
1492
  prompt=args.instance_prompt,
1493
  repo_folder=args.output_dir,
1494
  vae_path=args.pretrained_vae_model_name_or_path,
 
1495
  )
1496
  upload_folder(
1497
  repo_id=repo_id,
 
1503
 
1504
  accelerator.end_training()
1505
 
 
1506
  if __name__ == "__main__":
1507
  args = parse_args()
1508
  main(args)