Update inference_manager.py
Browse files- inference_manager.py +23 -6
inference_manager.py
CHANGED
@@ -27,12 +27,29 @@ import re
|
|
27 |
import gradio as gr
|
28 |
import uuid
|
29 |
from PIL import Image
|
30 |
-
MAX_SEED =
|
31 |
#from onediffx import compile_pipe, save_pipe, load_pipe
|
32 |
|
33 |
HF_TOKEN = os.getenv('HF_TOKEN')
|
34 |
VAR_PUBLIC_KEY = os.getenv('PUBLIC_KEY')
|
35 |
DATASET_ID = 'nsfwalex/checkpoint_n_lora'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
class AuthHelper:
|
38 |
def load_public_key_from_file(self):
|
@@ -179,7 +196,7 @@ class InferenceManager:
|
|
179 |
pipe = StableDiffusionPipeline.from_pretrained(ckpt_dir, vae=vae, torch_dtype=torch.bfloat16, use_safetensors=True)
|
180 |
else:
|
181 |
use_vae = cfg.get("vae", "")
|
182 |
-
if not use_vae
|
183 |
vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), torch_dtype=torch.bfloat16)
|
184 |
elif use_vae == "tae":
|
185 |
vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.bfloat16)
|
@@ -193,7 +210,7 @@ class InferenceManager:
|
|
193 |
torch_dtype=torch.bfloat16,
|
194 |
use_safetensors=True,
|
195 |
#variant="fp16",
|
196 |
-
|
197 |
)
|
198 |
#pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
|
199 |
clip_skip = cfg.get("clip_skip", 1)
|
@@ -208,7 +225,7 @@ class InferenceManager:
|
|
208 |
ip_ckpt = self.ext_model_pathes.get("ip-adapter-faceid-sdxl", "")
|
209 |
if ip_ckpt:
|
210 |
print(f"loading ip adapter model...")
|
211 |
-
self.ip_adapter_faceid_pipeline = ipown.IPAdapterFaceIDXL(pipe, ip_ckpt, 'cuda')
|
212 |
else:
|
213 |
print("ip-adapter-faceid-sdxl not found, skip")
|
214 |
|
@@ -583,8 +600,8 @@ class ModelManager:
|
|
583 |
generator=generator,
|
584 |
num_images_per_prompt=1,
|
585 |
output_type="pil",
|
586 |
-
callback_on_step_end=callback_dynamic_cfg,
|
587 |
-
callback_on_step_end_tensor_inputs=['prompt_embeds', 'add_text_embeds', 'add_time_ids'],
|
588 |
).images
|
589 |
cost = round(time.time() - start, 2)
|
590 |
print(f"inference done in {cost}s")
|
|
|
27 |
import gradio as gr
|
28 |
import uuid
|
29 |
from PIL import Image
|
30 |
+
MAX_SEED = np.iinfo(np.int32).max
|
31 |
#from onediffx import compile_pipe, save_pipe, load_pipe
|
32 |
|
33 |
HF_TOKEN = os.getenv('HF_TOKEN')
|
34 |
VAR_PUBLIC_KEY = os.getenv('PUBLIC_KEY')
|
35 |
DATASET_ID = 'nsfwalex/checkpoint_n_lora'
|
36 |
+
scheduler_config = {
|
37 |
+
"num_train_timesteps": 1000,
|
38 |
+
"beta_start": 0.00085,
|
39 |
+
"beta_end": 0.012,
|
40 |
+
"beta_schedule": "scaled_linear",
|
41 |
+
"set_alpha_to_one": False,
|
42 |
+
"steps_offset": 1,
|
43 |
+
"prediction_type": "epsilon",
|
44 |
+
}
|
45 |
+
samplers = {
|
46 |
+
"Euler a": EulerAncestralDiscreteScheduler.from_config(scheduler_config),
|
47 |
+
"DPM++ SDE Karras": DPMSolverSDEScheduler.from_config(scheduler_config, use_karras_sigmas=True),
|
48 |
+
"DPM2 a": DPMSolverMultistepScheduler.from_config(scheduler_config),
|
49 |
+
"DPM++ SDE": DPMSolverSDEScheduler.from_config(scheduler_config),
|
50 |
+
"DPM++ 2M SDE": DPMSolverSDEScheduler.from_config(scheduler_config, use_2m=True),
|
51 |
+
"DPM++ 2S a": DPMSolverMultistepScheduler.from_config(scheduler_config, use_2s=True)
|
52 |
+
}
|
53 |
|
54 |
class AuthHelper:
|
55 |
def load_public_key_from_file(self):
|
|
|
196 |
pipe = StableDiffusionPipeline.from_pretrained(ckpt_dir, vae=vae, torch_dtype=torch.bfloat16, use_safetensors=True)
|
197 |
else:
|
198 |
use_vae = cfg.get("vae", "")
|
199 |
+
if not use_vae or True:#!TEST! default vae for test
|
200 |
vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), torch_dtype=torch.bfloat16)
|
201 |
elif use_vae == "tae":
|
202 |
vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.bfloat16)
|
|
|
210 |
torch_dtype=torch.bfloat16,
|
211 |
use_safetensors=True,
|
212 |
#variant="fp16",
|
213 |
+
custom_pipeline = "lpw_stable_diffusion_xl",
|
214 |
)
|
215 |
#pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
|
216 |
clip_skip = cfg.get("clip_skip", 1)
|
|
|
225 |
ip_ckpt = self.ext_model_pathes.get("ip-adapter-faceid-sdxl", "")
|
226 |
if ip_ckpt:
|
227 |
print(f"loading ip adapter model...")
|
228 |
+
self.ip_adapter_faceid_pipeline = ipown.IPAdapterFaceIDXL(pipe, ip_ckpt, 'cuda', torch_dtype=torch.bfloat16)
|
229 |
else:
|
230 |
print("ip-adapter-faceid-sdxl not found, skip")
|
231 |
|
|
|
600 |
generator=generator,
|
601 |
num_images_per_prompt=1,
|
602 |
output_type="pil",
|
603 |
+
#callback_on_step_end=callback_dynamic_cfg,
|
604 |
+
#callback_on_step_end_tensor_inputs=['prompt_embeds', 'add_text_embeds', 'add_time_ids'],
|
605 |
).images
|
606 |
cost = round(time.time() - start, 2)
|
607 |
print(f"inference done in {cost}s")
|