multimodalart HF staff commited on
Commit
1235e6e
1 Parent(s): 173552f

Update training 2

Browse files
Files changed (2) hide show
  1. app.py +12 -4
  2. train_dreambooth.py +68 -11
app.py CHANGED
@@ -30,7 +30,7 @@ maximum_concepts = 3
30
 
31
  #Pre download the files
32
  model_v1 = snapshot_download(repo_id="multimodalart/sd-fine-tunable")
33
- #model_v2 = snapshot_download(repo_id="stabilityai/stable-diffusion-2")
34
  model_v2_512 = snapshot_download(repo_id="stabilityai/stable-diffusion-2-base")
35
  safety_checker = snapshot_download(repo_id="multimodalart/sd-sc")
36
 
@@ -171,6 +171,10 @@ def train(*inputs):
171
  Training_Steps=1400
172
 
173
  stptxt = int((Training_Steps*Train_text_encoder_for)/100)
 
 
 
 
174
  if (type_of_thing == "object" or type_of_thing == "style" or (type_of_thing == "person" and not experimental_face_improvement)):
175
  args_general = argparse.Namespace(
176
  image_captions_filename = True,
@@ -183,7 +187,7 @@ def train(*inputs):
183
  output_dir="output_model",
184
  instance_prompt="",
185
  seed=42,
186
- resolution=512,
187
  mixed_precision="fp16",
188
  train_batch_size=1,
189
  gradient_accumulation_steps=1,
@@ -192,6 +196,8 @@ def train(*inputs):
192
  lr_scheduler="polynomial",
193
  lr_warmup_steps = 0,
194
  max_train_steps=Training_Steps,
 
 
195
  )
196
  print("Starting single training...")
197
  lock_file = open("intraining.lock", "w")
@@ -211,7 +217,7 @@ def train(*inputs):
211
  prior_loss_weight=1.0,
212
  instance_prompt="",
213
  seed=42,
214
- resolution=512,
215
  mixed_precision="fp16",
216
  train_batch_size=1,
217
  gradient_accumulation_steps=1,
@@ -220,7 +226,9 @@ def train(*inputs):
220
  lr_scheduler="polynomial",
221
  lr_warmup_steps = 0,
222
  max_train_steps=Training_Steps,
223
- num_class_images=200,
 
 
224
  )
225
  print("Starting multi-training...")
226
  lock_file = open("intraining.lock", "w")
 
30
 
31
  #Pre download the files
32
  model_v1 = snapshot_download(repo_id="multimodalart/sd-fine-tunable")
33
+ model_v2 = snapshot_download(repo_id="stabilityai/stable-diffusion-2")
34
  model_v2_512 = snapshot_download(repo_id="stabilityai/stable-diffusion-2-base")
35
  safety_checker = snapshot_download(repo_id="multimodalart/sd-sc")
36
 
 
171
  Training_Steps=1400
172
 
173
  stptxt = int((Training_Steps*Train_text_encoder_for)/100)
174
+ #gradient_checkpointing = False if which_model == "v1-5" else True
175
+ gradient_checkpointing=False
176
+ resolution = 512 if which_model != "v2-768" else 768
177
+ cache_latents = True if which_model != "v1-5" else False
178
  if (type_of_thing == "object" or type_of_thing == "style" or (type_of_thing == "person" and not experimental_face_improvement)):
179
  args_general = argparse.Namespace(
180
  image_captions_filename = True,
 
187
  output_dir="output_model",
188
  instance_prompt="",
189
  seed=42,
190
+ resolution=resolution,
191
  mixed_precision="fp16",
192
  train_batch_size=1,
193
  gradient_accumulation_steps=1,
 
196
  lr_scheduler="polynomial",
197
  lr_warmup_steps = 0,
198
  max_train_steps=Training_Steps,
199
+ gradient_checkpointing=gradient_checkpointing,
200
+ cache_latents=cache_latents,
201
  )
202
  print("Starting single training...")
203
  lock_file = open("intraining.lock", "w")
 
217
  prior_loss_weight=1.0,
218
  instance_prompt="",
219
  seed=42,
220
+ resolution=resolution,
221
  mixed_precision="fp16",
222
  train_batch_size=1,
223
  gradient_accumulation_steps=1,
 
226
  lr_scheduler="polynomial",
227
  lr_warmup_steps = 0,
228
  max_train_steps=Training_Steps,
229
+ num_class_images=200,
230
+ gradient_checkpointing=gradient_checkpointing,
231
+ cache_latents=cache_latents,
232
  )
233
  print("Starting multi-training...")
234
  lock_file = open("intraining.lock", "w")
train_dreambooth.py CHANGED
@@ -235,6 +235,13 @@ def parse_args():
235
  help="Train only the unet",
236
  )
237
 
 
 
 
 
 
 
 
238
  parser.add_argument(
239
  "--Session_dir",
240
  type=str,
@@ -382,6 +389,16 @@ class PromptDataset(Dataset):
382
  example["index"] = index
383
  return example
384
 
 
 
 
 
 
 
 
 
 
 
385
 
386
  def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
387
  if token is None:
@@ -631,6 +648,28 @@ def run_training(args_imported):
631
  if not args.train_text_encoder:
632
  text_encoder.to(accelerator.device, dtype=weight_dtype)
633
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
634
  # We need to recalculate our total training steps as the size of the training dataloader may have changed.
635
  num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
636
  if overrode_max_train_steps:
@@ -669,8 +708,12 @@ def run_training(args_imported):
669
  for step, batch in enumerate(train_dataloader):
670
  with accelerator.accumulate(unet):
671
  # Convert images to latent space
672
- latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
673
- latents = latents * 0.18215
 
 
 
 
674
 
675
  # Sample noise that we'll add to the latents
676
  noise = torch.randn_like(latents)
@@ -684,26 +727,40 @@ def run_training(args_imported):
684
  noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
685
 
686
  # Get the text embedding for conditioning
687
- encoder_hidden_states = text_encoder(batch["input_ids"])[0]
 
 
 
 
 
 
688
 
689
  # Predict the noise residual
690
- noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
691
-
 
 
 
 
 
 
 
 
692
  if args.with_prior_preservation:
693
- # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
694
- noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
695
- noise, noise_prior = torch.chunk(noise, 2, dim=0)
696
 
697
  # Compute instance loss
698
- loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
699
 
700
  # Compute prior loss
701
- prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
702
 
703
  # Add the prior loss to the instance loss.
704
  loss = loss + args.prior_loss_weight * prior_loss
705
  else:
706
- loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
707
 
708
  accelerator.backward(loss)
709
  if accelerator.sync_gradients:
 
235
  help="Train only the unet",
236
  )
237
 
238
+ parser.add_argument(
239
+ "--cache_latents",
240
+ action="store_true",
241
+ default=False,
242
+ help="Train only the unet",
243
+ )
244
+
245
  parser.add_argument(
246
  "--Session_dir",
247
  type=str,
 
389
  example["index"] = index
390
  return example
391
 
392
+ class LatentsDataset(Dataset):
393
+ def __init__(self, latents_cache, text_encoder_cache):
394
+ self.latents_cache = latents_cache
395
+ self.text_encoder_cache = text_encoder_cache
396
+
397
+ def __len__(self):
398
+ return len(self.latents_cache)
399
+
400
+ def __getitem__(self, index):
401
+ return self.latents_cache[index], self.text_encoder_cache[index]
402
 
403
  def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
404
  if token is None:
 
648
  if not args.train_text_encoder:
649
  text_encoder.to(accelerator.device, dtype=weight_dtype)
650
 
651
+
652
+ if args.cache_latents:
653
+ latents_cache = []
654
+ text_encoder_cache = []
655
+ for batch in tqdm(train_dataloader, desc="Caching latents"):
656
+ with torch.no_grad():
657
+ batch["pixel_values"] = batch["pixel_values"].to(accelerator.device, non_blocking=True, dtype=weight_dtype)
658
+ batch["input_ids"] = batch["input_ids"].to(accelerator.device, non_blocking=True)
659
+ latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
660
+ if args.train_text_encoder:
661
+ text_encoder_cache.append(batch["input_ids"])
662
+ else:
663
+ text_encoder_cache.append(text_encoder(batch["input_ids"])[0])
664
+ train_dataset = LatentsDataset(latents_cache, text_encoder_cache)
665
+ train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, collate_fn=lambda x: x, shuffle=True)
666
+
667
+ del vae
668
+ if not args.train_text_encoder:
669
+ del text_encoder
670
+ if torch.cuda.is_available():
671
+ torch.cuda.empty_cache()
672
+
673
  # We need to recalculate our total training steps as the size of the training dataloader may have changed.
674
  num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
675
  if overrode_max_train_steps:
 
708
  for step, batch in enumerate(train_dataloader):
709
  with accelerator.accumulate(unet):
710
  # Convert images to latent space
711
+ with torch.no_grad():
712
+ if args.cache_latents:
713
+ latents = batch[0][0]
714
+ else:
715
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
716
+ latents = latents * 0.18215
717
 
718
  # Sample noise that we'll add to the latents
719
  noise = torch.randn_like(latents)
 
727
  noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
728
 
729
  # Get the text embedding for conditioning
730
+ if(args.cache_latents):
731
+ if args.train_text_encoder:
732
+ encoder_hidden_states = text_encoder(batch[0][1])[0]
733
+ else:
734
+ encoder_hidden_states = batch[0][1]
735
+ else:
736
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
737
 
738
  # Predict the noise residual
739
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
740
+
741
+ # Get the target for loss depending on the prediction type
742
+ if noise_scheduler.config.prediction_type == "epsilon":
743
+ target = noise
744
+ elif noise_scheduler.config.prediction_type == "v_prediction":
745
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
746
+ else:
747
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
748
+
749
  if args.with_prior_preservation:
750
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
751
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
752
+ target, target_prior = torch.chunk(target, 2, dim=0)
753
 
754
  # Compute instance loss
755
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
756
 
757
  # Compute prior loss
758
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
759
 
760
  # Add the prior loss to the instance loss.
761
  loss = loss + args.prior_loss_weight * prior_loss
762
  else:
763
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
764
 
765
  accelerator.backward(loss)
766
  if accelerator.sync_gradients: