multimodalart HF staff commited on
Commit
95ea872
1 Parent(s): 62ee77b

Enable xformers

Browse files
Files changed (1) hide show
  1. train_dreambooth.py +9 -1
train_dreambooth.py CHANGED
@@ -18,6 +18,7 @@ from accelerate import Accelerator
18
  from accelerate.logging import get_logger
19
  from accelerate.utils import set_seed
20
  from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
 
21
  from diffusers.optimization import get_scheduler
22
  from huggingface_hub import HfFolder, Repository, whoami
23
  from PIL import Image
@@ -533,7 +534,14 @@ def run_training(args_imported):
533
  text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
534
  vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
535
  unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
536
-
 
 
 
 
 
 
 
537
  vae.requires_grad_(False)
538
  if not args.train_text_encoder:
539
  text_encoder.requires_grad_(False)
18
  from accelerate.logging import get_logger
19
  from accelerate.utils import set_seed
20
  from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
21
+ from diffusers.utils.import_utils import is_xformers_available
22
  from diffusers.optimization import get_scheduler
23
  from huggingface_hub import HfFolder, Repository, whoami
24
  from PIL import Image
534
  text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
535
  vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
536
  unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
537
+ if is_xformers_available():
538
+ try:
539
+ print("Enabling memory efficient attention with xformers...")
540
+ unet.enable_xformers_memory_efficient_attention()
541
+ except Exception as e:
542
+ logger.warning(
543
+ f"Could not enable memory efficient attention. Make sure xformers is installed correctly and a GPU is available: {e}"
544
+ )
545
  vae.requires_grad_(False)
546
  if not args.train_text_encoder:
547
  text_encoder.requires_grad_(False)