Update inference_manager.py
Browse files- inference_manager.py +7 -3
inference_manager.py
CHANGED
@@ -151,9 +151,13 @@ class InferenceManager:
|
|
151 |
vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), torch_dtype=torch.bfloat16)
|
152 |
pipe = StableDiffusionPipeline.from_pretrained(ckpt_dir, vae=vae, torch_dtype=torch.bfloat16, use_safetensors=True)
|
153 |
else:
|
154 |
-
|
155 |
-
|
156 |
-
|
|
|
|
|
|
|
|
|
157 |
print(ckpt_dir)
|
158 |
pipe = DiffusionPipeline.from_pretrained(
|
159 |
ckpt_dir,
|
|
|
151 |
vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), torch_dtype=torch.bfloat16)
|
152 |
pipe = StableDiffusionPipeline.from_pretrained(ckpt_dir, vae=vae, torch_dtype=torch.bfloat16, use_safetensors=True)
|
153 |
else:
|
154 |
+
use_vae = cfg.get("vae", "")
|
155 |
+
if not use_vae:
|
156 |
+
vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), torch_dtype=torch.bfloat16)
|
157 |
+
elif use_vae == "tae":
|
158 |
+
vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.bfloat16)
|
159 |
+
else:
|
160 |
+
vae = AutoencoderTiny.from_pretrained(use_vae, torch_dtype=torch.bfloat16)
|
161 |
print(ckpt_dir)
|
162 |
pipe = DiffusionPipeline.from_pretrained(
|
163 |
ckpt_dir,
|