Spaces:
Runtime error
Runtime error
Update train_dreambooth_lora.py
Browse files- train_dreambooth_lora.py +2 -2
train_dreambooth_lora.py
CHANGED
@@ -940,11 +940,11 @@ def main(args):
|
|
940 |
torch_dtype=weight_dtype,
|
941 |
)
|
942 |
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
943 |
-
pipeline = pipeline.to(
|
944 |
pipeline.set_progress_bar_config(disable=True)
|
945 |
|
946 |
# run inference
|
947 |
-
generator = torch.Generator(device=
|
948 |
prompt = args.num_validation_images * [args.validation_prompt]
|
949 |
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
|
950 |
|
|
|
940 |
torch_dtype=weight_dtype,
|
941 |
)
|
942 |
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
943 |
+
pipeline = pipeline.to(accelerator.device)
|
944 |
pipeline.set_progress_bar_config(disable=True)
|
945 |
|
946 |
# run inference
|
947 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
948 |
prompt = args.num_validation_images * [args.validation_prompt]
|
949 |
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
|
950 |
|