multimodalart HF staff commited on
Commit
9c25812
1 Parent(s): c35f883

Remove text encoder deletion

Browse files
Files changed (1) hide show
  1. train_dreambooth.py +2 -2
train_dreambooth.py CHANGED
@@ -663,8 +663,8 @@ def run_training(args_imported):
663
  train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, collate_fn=lambda x: x, shuffle=True)
664
 
665
  del vae
666
- if not args.train_text_encoder:
667
- del text_encoder
668
  if torch.cuda.is_available():
669
  torch.cuda.empty_cache()
670
 
 
663
  train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, collate_fn=lambda x: x, shuffle=True)
664
 
665
  del vae
666
+ #if not args.train_text_encoder:
667
+ # del text_encoder
668
  if torch.cuda.is_available():
669
  torch.cuda.empty_cache()
670