nsfwalex commited on
Commit
8197e41
1 Parent(s): 8119835

Update inference_manager.py

Browse files
Files changed (1) hide show
  1. inference_manager.py +7 -3
inference_manager.py CHANGED
@@ -151,9 +151,13 @@ class InferenceManager:
151
  vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), torch_dtype=torch.bfloat16)
152
  pipe = StableDiffusionPipeline.from_pretrained(ckpt_dir, vae=vae, torch_dtype=torch.bfloat16, use_safetensors=True)
153
  else:
154
-
155
- #vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), torch_dtype=torch.bfloat16)
156
- vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.bfloat16)
 
 
 
 
157
  print(ckpt_dir)
158
  pipe = DiffusionPipeline.from_pretrained(
159
  ckpt_dir,
 
151
  vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), torch_dtype=torch.bfloat16)
152
  pipe = StableDiffusionPipeline.from_pretrained(ckpt_dir, vae=vae, torch_dtype=torch.bfloat16, use_safetensors=True)
153
  else:
154
+ use_vae = cfg.get("vae", "")
155
+ if not use_vae:
156
+ vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), torch_dtype=torch.bfloat16)
157
+ elif use_vae == "tae":
158
+ vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.bfloat16)
159
+ else:
160
+ vae = AutoencoderTiny.from_pretrained(use_vae, torch_dtype=torch.bfloat16)
161
  print(ckpt_dir)
162
  pipe = DiffusionPipeline.from_pretrained(
163
  ckpt_dir,