harkov000 commited on
Commit
effdc75
1 Parent(s): 8c26bca

Update train_dreambooth_lora.py

Browse files
Files changed (1) hide show
  1. train_dreambooth_lora.py +2 -2
train_dreambooth_lora.py CHANGED
@@ -940,11 +940,11 @@ def main(args):
940
  torch_dtype=weight_dtype,
941
  )
942
  pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
943
- pipeline = pipeline.to('cpu')
944
  pipeline.set_progress_bar_config(disable=True)
945
 
946
  # run inference
947
- generator = torch.Generator(device='cpu').manual_seed(args.seed)
948
  prompt = args.num_validation_images * [args.validation_prompt]
949
  images = pipeline(prompt, num_inference_steps=25, generator=generator).images
950
 
 
940
  torch_dtype=weight_dtype,
941
  )
942
  pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
943
+ pipeline = pipeline.to(accelerator.device)
944
  pipeline.set_progress_bar_config(disable=True)
945
 
946
  # run inference
947
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
948
  prompt = args.num_validation_images * [args.validation_prompt]
949
  images = pipeline(prompt, num_inference_steps=25, generator=generator).images
950