Update inference_manager.py
Browse files- inference_manager.py +4 -2
inference_manager.py
CHANGED
@@ -18,6 +18,7 @@ import base64
|
|
18 |
import json
|
19 |
import jwt
|
20 |
import glob
|
|
|
21 |
|
22 |
#from onediffx import compile_pipe, save_pipe, load_pipe
|
23 |
|
@@ -153,7 +154,7 @@ class InferenceManager:
|
|
153 |
|
154 |
#vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), torch_dtype=torch.bfloat16)
|
155 |
vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.bfloat16)
|
156 |
-
|
157 |
pipe = DiffusionPipeline.from_pretrained(
|
158 |
ckpt_dir,
|
159 |
vae=vae,
|
@@ -287,7 +288,7 @@ class ModelManager:
|
|
287 |
return
|
288 |
|
289 |
for file_path in model_files:
|
290 |
-
model_name = self.get_model_name_from_url(file_path)
|
291 |
print(f"Initializing model: {model_name} from {file_path}")
|
292 |
try:
|
293 |
# Initialize InferenceManager for each model
|
@@ -324,6 +325,7 @@ class ModelManager:
|
|
324 |
print(f"Building pipeline with LoRAs for model {model_id}...")
|
325 |
return model.build_pipeline_with_lora(lora_list, sampler, new_pipeline)
|
326 |
except Exception as e:
|
|
|
327 |
print(f"Failed to build pipeline for model {model_id}: {e}")
|
328 |
return None
|
329 |
|
|
|
18 |
import json
|
19 |
import jwt
|
20 |
import glob
|
21 |
+
import traceback
|
22 |
|
23 |
#from onediffx import compile_pipe, save_pipe, load_pipe
|
24 |
|
|
|
154 |
|
155 |
#vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), torch_dtype=torch.bfloat16)
|
156 |
vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.bfloat16)
|
157 |
+
print(ckpt_dir)
|
158 |
pipe = DiffusionPipeline.from_pretrained(
|
159 |
ckpt_dir,
|
160 |
vae=vae,
|
|
|
288 |
return
|
289 |
|
290 |
for file_path in model_files:
|
291 |
+
model_name = self.get_model_name_from_url(file_path).split(".")[0]
|
292 |
print(f"Initializing model: {model_name} from {file_path}")
|
293 |
try:
|
294 |
# Initialize InferenceManager for each model
|
|
|
325 |
print(f"Building pipeline with LoRAs for model {model_id}...")
|
326 |
return model.build_pipeline_with_lora(lora_list, sampler, new_pipeline)
|
327 |
except Exception as e:
|
328 |
+
traceback.print_exc()
|
329 |
print(f"Failed to build pipeline for model {model_id}: {e}")
|
330 |
return None
|
331 |
|