Spaces:
Running
on
Zero
Running
on
Zero
adamelliotfields
commited on
Commit
•
9769856
1
Parent(s):
1b15230
Rewrite loading and inference
Browse files- lib/inference.py +68 -118
- lib/loader.py +242 -260
lib/inference.py
CHANGED
@@ -5,153 +5,112 @@ from datetime import datetime
|
|
5 |
import torch
|
6 |
from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType
|
7 |
from compel.prompt_parser import PromptParser
|
8 |
-
from
|
9 |
-
from spaces import GPU
|
10 |
|
11 |
-
from .
|
12 |
-
from .loader import Loader
|
13 |
from .logger import Logger
|
14 |
-
from .utils import
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
safe_progress,
|
19 |
-
timer,
|
20 |
-
)
|
21 |
-
|
22 |
-
|
23 |
-
# Dynamic signature for the GPU duration function
|
24 |
-
def gpu_duration(**kwargs):
|
25 |
-
loading = 20
|
26 |
-
duration = 10
|
27 |
-
width = kwargs.get("width", 512)
|
28 |
-
height = kwargs.get("height", 512)
|
29 |
-
scale = kwargs.get("scale", 1)
|
30 |
-
num_images = kwargs.get("num_images", 1)
|
31 |
-
size = width * height
|
32 |
-
if size > 500_000:
|
33 |
-
duration += 5
|
34 |
-
if scale == 4:
|
35 |
-
duration += 5
|
36 |
-
return loading + (duration * num_images)
|
37 |
-
|
38 |
-
|
39 |
-
# Request GPU when deployed to Hugging Face
|
40 |
-
@GPU(duration=gpu_duration)
|
41 |
def generate(
|
42 |
-
positive_prompt,
|
43 |
negative_prompt="",
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
seed=None,
|
48 |
-
model="
|
49 |
-
scheduler="
|
50 |
-
|
51 |
width=512,
|
52 |
height=512,
|
53 |
guidance_scale=6.0,
|
54 |
inference_steps=40,
|
55 |
denoising_strength=0.8,
|
56 |
-
|
57 |
scale=1,
|
58 |
num_images=1,
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
Info=None,
|
63 |
-
progress=None,
|
64 |
):
|
65 |
-
start = time.perf_counter()
|
66 |
-
log = Logger("generate")
|
67 |
-
log.info(f"Generating {num_images} image{'s' if num_images > 1 else ''}...")
|
68 |
-
|
69 |
-
if Config.ZERO_GPU:
|
70 |
-
safe_progress(progress, 100, 100, "ZeroGPU init")
|
71 |
-
|
72 |
if not torch.cuda.is_available():
|
73 |
raise Error("CUDA not available")
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
seed = int(datetime.now().timestamp() * 1_000_000) % (2**64)
|
78 |
|
79 |
-
|
80 |
-
|
|
|
81 |
|
82 |
-
KIND = "img2img" if
|
83 |
-
KIND = f"controlnet_{KIND}" if
|
84 |
|
85 |
EMBEDDINGS_TYPE = ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED
|
86 |
|
87 |
FAST_NEGATIVE = "<fast_negative>" in negative_prompt
|
88 |
|
89 |
-
if
|
90 |
-
|
91 |
else:
|
92 |
-
|
93 |
-
|
94 |
-
# Custom progress bar for multiple images
|
95 |
-
def callback_on_step_end(pipeline, step, timestep, latents):
|
96 |
-
nonlocal CURRENT_STEP, CURRENT_IMAGE
|
97 |
-
if progress is not None:
|
98 |
-
# calculate total steps for img2img based on denoising strength
|
99 |
-
strength = denoising_strength if KIND == "img2img" else 1
|
100 |
-
total_steps = min(int(inference_steps * strength), inference_steps)
|
101 |
-
CURRENT_STEP = step + 1
|
102 |
-
progress(
|
103 |
-
(CURRENT_STEP, total_steps),
|
104 |
-
desc=f"Generating image {CURRENT_IMAGE}/{num_images}",
|
105 |
-
)
|
106 |
-
return latents
|
107 |
|
108 |
-
|
|
|
|
|
109 |
loader.load(
|
110 |
KIND,
|
111 |
-
|
112 |
model,
|
113 |
scheduler,
|
114 |
-
|
115 |
-
|
116 |
scale,
|
117 |
-
|
118 |
-
progress,
|
119 |
)
|
120 |
|
121 |
-
|
122 |
-
raise Error(f"Error loading {model}")
|
123 |
-
|
124 |
-
pipe = loader.pipe
|
125 |
upscaler = loader.upscaler
|
126 |
|
|
|
|
|
|
|
|
|
127 |
# Load fast negative embedding
|
128 |
if FAST_NEGATIVE:
|
129 |
embeddings_dir = os.path.abspath(
|
130 |
os.path.join(os.path.dirname(__file__), "..", "embeddings")
|
131 |
)
|
132 |
-
|
133 |
pretrained_model_name_or_path=f"{embeddings_dir}/fast_negative.pt",
|
134 |
token="<fast_negative>",
|
135 |
)
|
136 |
|
137 |
# Embed prompts with weights
|
138 |
compel = Compel(
|
139 |
-
device=
|
140 |
-
tokenizer=
|
141 |
truncate_long_prompts=False,
|
142 |
-
text_encoder=
|
143 |
returned_embeddings_type=EMBEDDINGS_TYPE,
|
144 |
-
dtype_for_device_getter=lambda _:
|
145 |
-
textual_inversion_manager=DiffusersTextualInversionManager(
|
146 |
)
|
147 |
|
|
|
|
|
|
|
|
|
|
|
148 |
images = []
|
149 |
current_seed = seed
|
150 |
-
safe_progress(progress, 0, num_images, f"Generating image 0/{num_images}")
|
151 |
|
152 |
for i in range(num_images):
|
153 |
try:
|
154 |
-
generator = torch.Generator(device=
|
155 |
positive_embeds, negative_embeds = compel.pad_conditioning_tensors_to_same_length(
|
156 |
[compel(positive_prompt), compel(negative_prompt)]
|
157 |
)
|
@@ -169,53 +128,44 @@ def generate(
|
|
169 |
"output_type": "np" if scale > 1 else "pil",
|
170 |
}
|
171 |
|
172 |
-
if
|
173 |
-
kwargs["callback_on_step_end"] = callback_on_step_end
|
174 |
-
|
175 |
-
# Resizing so the initial latents are the same size as the generated image
|
176 |
-
if KIND == "img2img":
|
177 |
kwargs["strength"] = denoising_strength
|
178 |
-
kwargs["image"] = resize_image(
|
179 |
|
180 |
if KIND == "controlnet_txt2img":
|
181 |
-
kwargs["image"] = annotate_image(
|
182 |
|
183 |
if KIND == "controlnet_img2img":
|
184 |
-
kwargs["control_image"] = annotate_image(
|
185 |
|
186 |
-
if
|
187 |
-
|
|
|
188 |
|
189 |
try:
|
190 |
-
image =
|
191 |
-
images.append((image, str(current_seed)))
|
192 |
current_seed += 1
|
193 |
finally:
|
194 |
if FAST_NEGATIVE:
|
195 |
-
|
196 |
-
|
197 |
-
CURRENT_STEP = 0
|
198 |
-
CURRENT_IMAGE += 1
|
199 |
|
200 |
# Upscale
|
201 |
if scale > 1:
|
202 |
-
|
203 |
-
with timer(msg, logger=log.info):
|
204 |
-
safe_progress(progress, 0, num_images, desc=msg)
|
205 |
for i, image in enumerate(images):
|
206 |
image = upscaler.predict(image[0])
|
207 |
-
images[i]
|
208 |
-
|
209 |
-
|
210 |
-
# Flush memory after generating
|
211 |
-
clear_cuda_cache()
|
212 |
|
213 |
end = time.perf_counter()
|
214 |
msg = f"Generating {len(images)} image{'s' if len(images) > 1 else ''} took {end - start:.2f}s"
|
215 |
log.info(msg)
|
216 |
|
217 |
-
# Alert if notifier provided
|
218 |
if Info:
|
219 |
Info(msg)
|
220 |
|
|
|
|
|
|
|
221 |
return images
|
|
|
5 |
import torch
|
6 |
from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType
|
7 |
from compel.prompt_parser import PromptParser
|
8 |
+
from gradio import Error, Info, Progress
|
9 |
+
from spaces import GPU, config
|
10 |
|
11 |
+
from .loader import get_loader
|
|
|
12 |
from .logger import Logger
|
13 |
+
from .utils import annotate_image, cuda_collect, resize_image, timer
|
14 |
+
|
15 |
+
|
16 |
+
@GPU
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
def generate(
|
18 |
+
positive_prompt="",
|
19 |
negative_prompt="",
|
20 |
+
image_input=None,
|
21 |
+
controlnet_input=None,
|
22 |
+
ip_adapter_input=None,
|
23 |
seed=None,
|
24 |
+
model="XpucT/Reliberate",
|
25 |
+
scheduler="UniPC",
|
26 |
+
controlnet_annotator="canny",
|
27 |
width=512,
|
28 |
height=512,
|
29 |
guidance_scale=6.0,
|
30 |
inference_steps=40,
|
31 |
denoising_strength=0.8,
|
32 |
+
deepcache_interval=1,
|
33 |
scale=1,
|
34 |
num_images=1,
|
35 |
+
use_karras=False,
|
36 |
+
use_ip_adapter_face=False,
|
37 |
+
_=Progress(track_tqdm=True),
|
|
|
|
|
38 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
if not torch.cuda.is_available():
|
40 |
raise Error("CUDA not available")
|
41 |
|
42 |
+
if positive_prompt.strip() == "":
|
43 |
+
raise Error("You must enter a prompt")
|
|
|
44 |
|
45 |
+
start = time.perf_counter()
|
46 |
+
log = Logger("generate")
|
47 |
+
log.info(f"Generating {num_images} image{'s' if num_images > 1 else ''}...")
|
48 |
|
49 |
+
KIND = "img2img" if image_input is not None else "txt2img"
|
50 |
+
KIND = f"controlnet_{KIND}" if controlnet_input is not None else KIND
|
51 |
|
52 |
EMBEDDINGS_TYPE = ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED
|
53 |
|
54 |
FAST_NEGATIVE = "<fast_negative>" in negative_prompt
|
55 |
|
56 |
+
if ip_adapter_input:
|
57 |
+
IP_KIND = "full-face" if use_ip_adapter_face else "plus"
|
58 |
else:
|
59 |
+
IP_KIND = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
+
# ZeroGPU is serverless so you want ephemeral instances
|
62 |
+
# You want a singleton on localhost so the pipeline stays in memory
|
63 |
+
loader = get_loader(singleton=not config.Config.zero_gpu)
|
64 |
loader.load(
|
65 |
KIND,
|
66 |
+
IP_KIND,
|
67 |
model,
|
68 |
scheduler,
|
69 |
+
controlnet_annotator,
|
70 |
+
deepcache_interval,
|
71 |
scale,
|
72 |
+
use_karras,
|
|
|
73 |
)
|
74 |
|
75 |
+
pipeline = loader.pipeline
|
|
|
|
|
|
|
76 |
upscaler = loader.upscaler
|
77 |
|
78 |
+
# Probably a typo in the config
|
79 |
+
if pipeline is None:
|
80 |
+
raise Error(f"Error loading {model}")
|
81 |
+
|
82 |
# Load fast negative embedding
|
83 |
if FAST_NEGATIVE:
|
84 |
embeddings_dir = os.path.abspath(
|
85 |
os.path.join(os.path.dirname(__file__), "..", "embeddings")
|
86 |
)
|
87 |
+
pipeline.load_textual_inversion(
|
88 |
pretrained_model_name_or_path=f"{embeddings_dir}/fast_negative.pt",
|
89 |
token="<fast_negative>",
|
90 |
)
|
91 |
|
92 |
# Embed prompts with weights
|
93 |
compel = Compel(
|
94 |
+
device=pipeline.device,
|
95 |
+
tokenizer=pipeline.tokenizer,
|
96 |
truncate_long_prompts=False,
|
97 |
+
text_encoder=pipeline.text_encoder,
|
98 |
returned_embeddings_type=EMBEDDINGS_TYPE,
|
99 |
+
dtype_for_device_getter=lambda _: pipeline.dtype,
|
100 |
+
textual_inversion_manager=DiffusersTextualInversionManager(pipeline),
|
101 |
)
|
102 |
|
103 |
+
# https://pytorch.org/docs/stable/generated/torch.manual_seed.html
|
104 |
+
if seed is None or seed < 0:
|
105 |
+
seed = int(datetime.now().timestamp() * 1_000_000) % (2**64)
|
106 |
+
|
107 |
+
# Increment the seed after each iteration
|
108 |
images = []
|
109 |
current_seed = seed
|
|
|
110 |
|
111 |
for i in range(num_images):
|
112 |
try:
|
113 |
+
generator = torch.Generator(device=pipeline.device).manual_seed(current_seed)
|
114 |
positive_embeds, negative_embeds = compel.pad_conditioning_tensors_to_same_length(
|
115 |
[compel(positive_prompt), compel(negative_prompt)]
|
116 |
)
|
|
|
128 |
"output_type": "np" if scale > 1 else "pil",
|
129 |
}
|
130 |
|
131 |
+
if KIND == "img2img" or KIND == "controlnet_img2img":
|
|
|
|
|
|
|
|
|
132 |
kwargs["strength"] = denoising_strength
|
133 |
+
kwargs["image"] = resize_image(image_input, (width, height))
|
134 |
|
135 |
if KIND == "controlnet_txt2img":
|
136 |
+
kwargs["image"] = annotate_image(controlnet_input, controlnet_annotator)
|
137 |
|
138 |
if KIND == "controlnet_img2img":
|
139 |
+
kwargs["control_image"] = annotate_image(controlnet_input, controlnet_annotator)
|
140 |
|
141 |
+
if IP_KIND:
|
142 |
+
# No size means preserve aspect ratio
|
143 |
+
kwargs["ip_adapter_image"] = resize_image(ip_adapter_input)
|
144 |
|
145 |
try:
|
146 |
+
image = pipeline(**kwargs).images[0]
|
147 |
+
images.append((image, str(current_seed))) # tuple with seed for gallery caption
|
148 |
current_seed += 1
|
149 |
finally:
|
150 |
if FAST_NEGATIVE:
|
151 |
+
pipeline.unload_textual_inversion()
|
|
|
|
|
|
|
152 |
|
153 |
# Upscale
|
154 |
if scale > 1:
|
155 |
+
with timer(f"Upscaling {num_images} images {scale}x", logger=log.info):
|
|
|
|
|
156 |
for i, image in enumerate(images):
|
157 |
image = upscaler.predict(image[0])
|
158 |
+
seed = images[i][1]
|
159 |
+
images[i] = (image, seed) # tuple again
|
|
|
|
|
|
|
160 |
|
161 |
end = time.perf_counter()
|
162 |
msg = f"Generating {len(images)} image{'s' if len(images) > 1 else ''} took {end - start:.2f}s"
|
163 |
log.info(msg)
|
164 |
|
|
|
165 |
if Info:
|
166 |
Info(msg)
|
167 |
|
168 |
+
# Flush cache before returning
|
169 |
+
cuda_collect()
|
170 |
+
|
171 |
return images
|
lib/loader.py
CHANGED
@@ -1,6 +1,3 @@
|
|
1 |
-
import gc
|
2 |
-
from threading import Lock
|
3 |
-
|
4 |
import torch
|
5 |
from DeepCache import DeepCacheSDHelper
|
6 |
from diffusers import ControlNetModel
|
@@ -9,328 +6,313 @@ from diffusers.models.attention_processor import AttnProcessor2_0, IPAdapterAttn
|
|
9 |
from .config import Config
|
10 |
from .logger import Logger
|
11 |
from .upscaler import RealESRGAN
|
12 |
-
from .utils import
|
13 |
|
14 |
|
15 |
class Loader:
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
return True
|
40 |
-
if has_deepcache and self.
|
|
|
41 |
return True
|
42 |
return False
|
43 |
|
44 |
-
def
|
45 |
-
|
46 |
-
|
|
|
47 |
return True
|
48 |
-
if self.ip_adapter
|
|
|
49 |
return True
|
50 |
return False
|
51 |
|
52 |
-
def
|
53 |
if self.controlnet is None:
|
54 |
return False
|
55 |
-
if self.
|
56 |
return True
|
57 |
-
if not
|
58 |
return True
|
59 |
return False
|
60 |
|
61 |
-
def
|
62 |
-
if self.
|
63 |
return False
|
64 |
-
if self.model
|
65 |
-
return True
|
66 |
-
if kind == "txt2img" and not isinstance(self.pipe, Config.PIPELINES["txt2img"]):
|
67 |
-
return True
|
68 |
-
if kind == "img2img" and not isinstance(self.pipe, Config.PIPELINES["img2img"]):
|
69 |
-
return True
|
70 |
-
if kind == "controlnet_txt2img" and not isinstance(
|
71 |
-
self.pipe,
|
72 |
-
Config.PIPELINES["controlnet_txt2img"],
|
73 |
-
):
|
74 |
-
return True
|
75 |
-
if kind == "controlnet_img2img" and not isinstance(
|
76 |
-
self.pipe,
|
77 |
-
Config.PIPELINES["controlnet_img2img"],
|
78 |
-
):
|
79 |
-
return True
|
80 |
-
if self._should_unload_controlnet(kind, controlnet):
|
81 |
return True
|
82 |
return False
|
83 |
|
84 |
-
def _unload_upscaler(self):
|
85 |
-
if self.upscaler is not None:
|
86 |
-
with timer(f"Unloading {self.upscaler.scale}x upscaler", logger=self.log.info):
|
87 |
-
self.upscaler.to("cpu")
|
88 |
-
|
89 |
-
def _unload_deepcache(self):
|
90 |
-
if self.pipe.deepcache is not None:
|
91 |
-
self.log.info("Disabling DeepCache")
|
92 |
-
self.pipe.deepcache.disable()
|
93 |
-
delattr(self.pipe, "deepcache")
|
94 |
-
|
95 |
# Copied from https://github.com/huggingface/diffusers/blob/v0.28.0/src/diffusers/loaders/ip_adapter.py#L300
|
96 |
-
def
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
|
|
122 |
self,
|
123 |
-
|
|
|
124 |
model="",
|
125 |
-
|
126 |
-
|
127 |
-
deepcache=1,
|
128 |
scale=1,
|
129 |
):
|
130 |
-
|
131 |
-
|
132 |
-
self.
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
self.
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
def
|
157 |
-
|
|
|
158 |
return True
|
159 |
return False
|
160 |
|
161 |
-
def
|
162 |
-
|
163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
return True
|
165 |
-
if
|
166 |
return True
|
167 |
return False
|
168 |
|
169 |
-
def
|
170 |
-
if
|
171 |
return True
|
172 |
-
|
173 |
-
|
174 |
-
def _should_load_pipeline(self):
|
175 |
-
if self.pipe is None:
|
176 |
return True
|
177 |
return False
|
178 |
|
179 |
-
def
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
def
|
191 |
-
|
192 |
-
self.
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
|
|
|
|
|
|
211 |
self,
|
212 |
-
|
213 |
model,
|
214 |
-
progress,
|
215 |
**kwargs,
|
216 |
):
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
239 |
|
240 |
def load(
|
241 |
self,
|
242 |
-
|
243 |
-
|
244 |
model,
|
245 |
scheduler,
|
246 |
-
|
247 |
-
|
248 |
scale,
|
249 |
-
|
250 |
-
progress,
|
251 |
):
|
|
|
|
|
252 |
scheduler_kwargs = {
|
253 |
-
"beta_schedule": "scaled_linear",
|
254 |
-
"timestep_spacing": "leading",
|
255 |
"beta_start": 0.00085,
|
256 |
"beta_end": 0.012,
|
|
|
|
|
257 |
"steps_offset": 1,
|
258 |
}
|
259 |
|
260 |
-
if scheduler not in ["
|
261 |
-
scheduler_kwargs["use_karras_sigmas"] =
|
262 |
-
|
263 |
-
# https://github.com/huggingface/diffusers/blob/8a3f0c1/scripts/convert_original_stable_diffusion_to_diffusers.py#L939
|
264 |
-
if scheduler == "DDIM":
|
265 |
-
scheduler_kwargs["clip_sample"] = False
|
266 |
-
scheduler_kwargs["set_alpha_to_one"] = False
|
267 |
|
268 |
-
|
|
|
269 |
"safety_checker": None,
|
270 |
"requires_safety_checker": False,
|
271 |
-
"scheduler":
|
272 |
}
|
273 |
|
274 |
-
#
|
275 |
-
if model
|
276 |
-
|
277 |
else:
|
278 |
-
|
279 |
-
|
280 |
-
#
|
281 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
|
283 |
-
#
|
284 |
-
if
|
285 |
-
|
286 |
-
Config.ANNOTATORS[annotator],
|
287 |
-
torch_dtype=torch.float16,
|
288 |
-
variant="fp16",
|
289 |
-
)
|
290 |
-
self.controlnet = annotator
|
291 |
|
292 |
-
self.
|
293 |
-
|
294 |
|
295 |
-
|
296 |
-
|
297 |
-
return
|
298 |
|
299 |
-
|
300 |
-
|
301 |
-
not hasattr(self.pipe.scheduler.config, "use_karras_sigmas")
|
302 |
-
or self.pipe.scheduler.config.use_karras_sigmas == karras
|
303 |
-
)
|
304 |
|
305 |
-
|
306 |
-
|
307 |
-
if not same_scheduler:
|
308 |
-
self.log.info(f"Enabling {scheduler} scheduler")
|
309 |
-
if not same_karras:
|
310 |
-
self.log.info(f"{'Enabling' if karras else 'Disabling'} Karras sigmas")
|
311 |
-
if not same_scheduler or not same_karras:
|
312 |
-
self.pipe.scheduler = Config.SCHEDULERS[scheduler](**scheduler_kwargs)
|
313 |
-
|
314 |
-
CURRENT_STEP = 1
|
315 |
-
TOTAL_STEPS = sum(
|
316 |
-
[
|
317 |
-
self._should_load_deepcache(deepcache),
|
318 |
-
self._should_load_ip_adapter(ip_adapter),
|
319 |
-
self._should_load_upscaler(scale),
|
320 |
-
]
|
321 |
-
)
|
322 |
|
323 |
-
|
324 |
-
|
325 |
-
self._load_deepcache(deepcache)
|
326 |
-
safe_progress(progress, CURRENT_STEP, TOTAL_STEPS, desc)
|
327 |
-
CURRENT_STEP += 1
|
328 |
|
329 |
-
if self._should_load_ip_adapter(ip_adapter):
|
330 |
-
self._load_ip_adapter(ip_adapter)
|
331 |
-
safe_progress(progress, CURRENT_STEP, TOTAL_STEPS, desc)
|
332 |
-
CURRENT_STEP += 1
|
333 |
|
334 |
-
|
335 |
-
|
336 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
from DeepCache import DeepCacheSDHelper
|
3 |
from diffusers import ControlNetModel
|
|
|
6 |
from .config import Config
|
7 |
from .logger import Logger
|
8 |
from .upscaler import RealESRGAN
|
9 |
+
from .utils import timer
|
10 |
|
11 |
|
12 |
class Loader:
|
13 |
+
"""
|
14 |
+
A lazy-loading resource manager for Stable Diffusion pipelines. Lifecycles are managed by
|
15 |
+
comparing the current state with desired. Can be used as a singleton when created by the
|
16 |
+
`get_loader()` helper.
|
17 |
+
|
18 |
+
Usage:
|
19 |
+
loader = get_loader(singleton=True)
|
20 |
+
loader.load(
|
21 |
+
pipeline_id="controlnet_txt2img",
|
22 |
+
ip_adapter_model="full-face",
|
23 |
+
model="XpucT/Reliberate",
|
24 |
+
scheduler="UniPC",
|
25 |
+
controlnet_annotator="canny",
|
26 |
+
deepcache_interval=2,
|
27 |
+
scale=2,
|
28 |
+
use_karras=True
|
29 |
+
)
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self):
|
33 |
+
self.model = ""
|
34 |
+
self.pipeline = None
|
35 |
+
self.upscaler = None
|
36 |
+
self.controlnet = None
|
37 |
+
self.annotator = "" # controlnet annotator (canny)
|
38 |
+
self.ip_adapter = "" # ip-adapter kind (full-face or plus)
|
39 |
+
self.log = Logger("Loader")
|
40 |
+
|
41 |
+
def should_unload_upscaler(self, scale=1):
|
42 |
+
return self.upscaler is not None and self.upscaler.scale != scale
|
43 |
+
|
44 |
+
def should_unload_deepcache(self, cache_interval=1):
|
45 |
+
has_deepcache = hasattr(self.pipeline, "deepcache")
|
46 |
+
if has_deepcache and cache_interval == 1:
|
47 |
return True
|
48 |
+
if has_deepcache and self.pipeline.deepcache.params["cache_interval"] != cache_interval:
|
49 |
+
# Unload if interval is different so it can be reloaded
|
50 |
return True
|
51 |
return False
|
52 |
|
53 |
+
def should_unload_ip_adapter(self, ip_adapter_model=""):
|
54 |
+
if not self.ip_adapter:
|
55 |
+
return False
|
56 |
+
if not ip_adapter_model:
|
57 |
return True
|
58 |
+
if self.ip_adapter != ip_adapter_model:
|
59 |
+
# Unload if model is different so it can be reloaded
|
60 |
return True
|
61 |
return False
|
62 |
|
63 |
+
def should_unload_controlnet(self, pipeline_id="", annotator=""):
|
64 |
if self.controlnet is None:
|
65 |
return False
|
66 |
+
if self.annotator != annotator:
|
67 |
return True
|
68 |
+
if not pipeline_id.startswith("controlnet_"):
|
69 |
return True
|
70 |
return False
|
71 |
|
72 |
+
def should_unload_pipeline(self, model=""):
|
73 |
+
if self.pipeline is None:
|
74 |
return False
|
75 |
+
if self.model != model:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
return True
|
77 |
return False
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
# Copied from https://github.com/huggingface/diffusers/blob/v0.28.0/src/diffusers/loaders/ip_adapter.py#L300
|
80 |
+
def unload_ip_adapter(self):
|
81 |
+
# Remove the image encoder if text-to-image
|
82 |
+
if isinstance(self.pipeline, Config.PIPELINES["txt2img"]):
|
83 |
+
self.pipeline.image_encoder = None
|
84 |
+
self.pipeline.register_to_config(image_encoder=[None, None])
|
85 |
+
|
86 |
+
# Remove hidden projection layer added by IP-Adapter
|
87 |
+
self.pipeline.unet.encoder_hid_proj = None
|
88 |
+
self.pipeline.unet.config.encoder_hid_dim_type = None
|
89 |
+
|
90 |
+
# Remove the feature extractor
|
91 |
+
self.pipeline.feature_extractor = None
|
92 |
+
self.pipeline.register_to_config(feature_extractor=[None, None])
|
93 |
+
|
94 |
+
# Replace the custom attention processors with defaults
|
95 |
+
attn_procs = {}
|
96 |
+
for name, value in self.pipeline.unet.attn_processors.items():
|
97 |
+
attn_processor_class = AttnProcessor2_0() # raises if not torch 2
|
98 |
+
attn_procs[name] = (
|
99 |
+
attn_processor_class
|
100 |
+
if isinstance(value, IPAdapterAttnProcessor2_0)
|
101 |
+
else value.__class__()
|
102 |
+
)
|
103 |
+
self.pipeline.unet.set_attn_processor(attn_procs)
|
104 |
+
self.ip_adapter = ""
|
105 |
+
|
106 |
+
def unload_all(
|
107 |
self,
|
108 |
+
pipeline_id="",
|
109 |
+
ip_adapter_model="",
|
110 |
model="",
|
111 |
+
controlnet_annotator="",
|
112 |
+
deepcache_interval=1,
|
|
|
113 |
scale=1,
|
114 |
):
|
115 |
+
if self.should_unload_deepcache(deepcache_interval): # remove deepcache first
|
116 |
+
self.log.info("Disabling DeepCache")
|
117 |
+
self.pipeline.deepcache.disable()
|
118 |
+
delattr(self.pipeline, "deepcache")
|
119 |
+
|
120 |
+
if self.should_unload_ip_adapter(ip_adapter_model):
|
121 |
+
self.log.info("Unloading IP-Adapter")
|
122 |
+
self.unload_ip_adapter()
|
123 |
+
|
124 |
+
if self.should_unload_controlnet(pipeline_id, controlnet_annotator):
|
125 |
+
self.log.info("Unloading ControlNet")
|
126 |
+
self.controlnet = None
|
127 |
+
self.annotator = ""
|
128 |
+
|
129 |
+
if self.should_unload_upscaler(scale):
|
130 |
+
self.log.info("Unloading upscaler")
|
131 |
+
self.upscaler = None
|
132 |
+
|
133 |
+
if self.should_unload_pipeline(model):
|
134 |
+
self.log.info("Unloading pipeline")
|
135 |
+
self.pipeline = None
|
136 |
+
self.model = ""
|
137 |
+
|
138 |
+
def should_load_upscaler(self, scale=1):
|
139 |
+
return self.upscaler is None and scale > 1
|
140 |
+
|
141 |
+
def should_load_deepcache(self, cache_interval=1):
|
142 |
+
has_deepcache = hasattr(self.pipeline, "deepcache")
|
143 |
+
if not has_deepcache and cache_interval > 1:
|
144 |
return True
|
145 |
return False
|
146 |
|
147 |
+
def should_load_controlnet(self, pipeline_id=""):
|
148 |
+
return self.controlnet is None and pipeline_id.startswith("controlnet_")
|
149 |
+
|
150 |
+
def should_load_ip_adapter(self, ip_adapter_model=""):
|
151 |
+
has_ip_adapter = (
|
152 |
+
hasattr(self.pipeline.unet, "encoder_hid_proj")
|
153 |
+
and self.pipeline.unet.config.encoder_hid_dim_type == "ip_image_proj"
|
154 |
+
)
|
155 |
+
return not has_ip_adapter and ip_adapter_model != ""
|
156 |
+
|
157 |
+
def should_load_scheduler(self, cls, use_karras=False):
|
158 |
+
has_karras = hasattr(self.pipeline.scheduler.config, "use_karras_sigmas")
|
159 |
+
if not isinstance(self.pipeline.scheduler, cls):
|
160 |
return True
|
161 |
+
if has_karras and self.pipeline.scheduler.config.use_karras_sigmas != use_karras:
|
162 |
return True
|
163 |
return False
|
164 |
|
165 |
+
def should_load_pipeline(self, pipeline_id=""):
|
166 |
+
if self.pipeline is None:
|
167 |
return True
|
168 |
+
if not isinstance(self.pipeline, Config.PIPELINES[pipeline_id]):
|
|
|
|
|
|
|
169 |
return True
|
170 |
return False
|
171 |
|
172 |
+
def load_upscaler(self, scale=1):
|
173 |
+
with timer(f"Loading {scale}x upscaler", logger=self.log.info):
|
174 |
+
self.upscaler = RealESRGAN(scale, device=self.pipeline.device)
|
175 |
+
self.upscaler.load_weights()
|
176 |
+
|
177 |
+
def load_deepcache(self, cache_interval=1):
|
178 |
+
self.log.info(f"Enabling DeepCache interval {cache_interval}")
|
179 |
+
self.pipeline.deepcache = DeepCacheSDHelper(self.pipeline)
|
180 |
+
self.pipeline.deepcache.set_params(cache_interval=cache_interval)
|
181 |
+
self.pipeline.deepcache.enable()
|
182 |
+
|
183 |
+
def load_controlnet(self, controlnet_annotator):
|
184 |
+
with timer("Loading ControlNet", logger=self.log.info):
|
185 |
+
self.controlnet = ControlNetModel.from_pretrained(
|
186 |
+
Config.ANNOTATORS[controlnet_annotator],
|
187 |
+
variant="fp16",
|
188 |
+
torch_dtype=torch.float16,
|
189 |
+
)
|
190 |
+
self.annotator = controlnet_annotator
|
191 |
+
|
192 |
+
def load_ip_adapter(self, ip_adapter_model=""):
|
193 |
+
with timer("Loading IP-Adapter", logger=self.log.info):
|
194 |
+
self.pipeline.load_ip_adapter(
|
195 |
+
"h94/IP-Adapter",
|
196 |
+
subfolder="models",
|
197 |
+
weight_name=f"ip-adapter-{ip_adapter_model}_sd15.safetensors",
|
198 |
+
)
|
199 |
+
self.pipeline.set_ip_adapter_scale(0.5) # 50% works the best
|
200 |
+
self.ip_adapter = ip_adapter_model
|
201 |
+
|
202 |
+
def load_scheduler(self, cls, use_karras=False, **kwargs):
|
203 |
+
self.log.info(f"Loading {cls.__name__}{' with Karras' if use_karras else ''}")
|
204 |
+
self.pipeline.scheduler = cls(**kwargs)
|
205 |
+
|
206 |
+
def load_pipeline(
|
207 |
self,
|
208 |
+
pipeline_id,
|
209 |
model,
|
|
|
210 |
**kwargs,
|
211 |
):
|
212 |
+
Pipeline = Config.PIPELINES[pipeline_id]
|
213 |
+
|
214 |
+
# Load from scratch
|
215 |
+
if self.pipeline is None:
|
216 |
+
with timer(f"Loading {model} ({pipeline_id})", logger=self.log.info):
|
217 |
+
if self.controlnet is not None:
|
218 |
+
kwargs["controlnet"] = self.controlnet
|
219 |
+
if model in Config.SINGLE_FILE_MODELS:
|
220 |
+
checkpoint = Config.HF_REPOS[model][0]
|
221 |
+
self.pipeline = Pipeline.from_single_file(
|
222 |
+
f"https://huggingface.co/{model}/{checkpoint}",
|
223 |
+
**kwargs,
|
224 |
+
).to("cuda")
|
225 |
+
else:
|
226 |
+
self.pipeline = Pipeline.from_pretrained(model, **kwargs).to("cuda")
|
227 |
+
|
228 |
+
# Change to a different one
|
229 |
+
else:
|
230 |
+
with timer(f"Changing pipeline to {pipeline_id}", logger=self.log.info):
|
231 |
+
kwargs = {}
|
232 |
+
if self.controlnet is not None:
|
233 |
+
kwargs["controlnet"] = self.controlnet
|
234 |
+
self.pipeline = Pipeline.from_pipe(
|
235 |
+
self.pipeline,
|
236 |
+
**kwargs,
|
237 |
+
).to("cuda")
|
238 |
+
|
239 |
+
# Update model and disable terminal progress bars
|
240 |
+
self.model = model
|
241 |
+
self.pipeline.set_progress_bar_config(disable=True)
|
242 |
|
243 |
def load(
|
244 |
self,
|
245 |
+
pipeline_id,
|
246 |
+
ip_adapter_model,
|
247 |
model,
|
248 |
scheduler,
|
249 |
+
controlnet_annotator,
|
250 |
+
deepcache_interval,
|
251 |
scale,
|
252 |
+
use_karras,
|
|
|
253 |
):
|
254 |
+
Scheduler = Config.SCHEDULERS[scheduler]
|
255 |
+
|
256 |
scheduler_kwargs = {
|
|
|
|
|
257 |
"beta_start": 0.00085,
|
258 |
"beta_end": 0.012,
|
259 |
+
"beta_schedule": "scaled_linear",
|
260 |
+
"timestep_spacing": "leading",
|
261 |
"steps_offset": 1,
|
262 |
}
|
263 |
|
264 |
+
if scheduler not in ["Euler a"]:
|
265 |
+
scheduler_kwargs["use_karras_sigmas"] = use_karras
|
|
|
|
|
|
|
|
|
|
|
266 |
|
267 |
+
pipeline_kwargs = {
|
268 |
+
"torch_dtype": torch.float16, # defaults to fp32
|
269 |
"safety_checker": None,
|
270 |
"requires_safety_checker": False,
|
271 |
+
"scheduler": Scheduler(**scheduler_kwargs),
|
272 |
}
|
273 |
|
274 |
+
# Single-file models don't need a variant
|
275 |
+
if model not in Config.SINGLE_FILE_MODELS:
|
276 |
+
pipeline_kwargs["variant"] = "fp16"
|
277 |
else:
|
278 |
+
pipeline_kwargs["variant"] = None
|
279 |
+
|
280 |
+
# Prepare state for loading checks
|
281 |
+
self.unload_all(
|
282 |
+
pipeline_id,
|
283 |
+
ip_adapter_model,
|
284 |
+
model,
|
285 |
+
controlnet_annotator,
|
286 |
+
deepcache_interval,
|
287 |
+
scale,
|
288 |
+
)
|
289 |
|
290 |
+
# Load controlnet model before pipeline
|
291 |
+
if self.should_load_controlnet(pipeline_id):
|
292 |
+
self.load_controlnet(controlnet_annotator)
|
|
|
|
|
|
|
|
|
|
|
293 |
|
294 |
+
if self.should_load_pipeline(pipeline_id):
|
295 |
+
self.load_pipeline(pipeline_id, model, **pipeline_kwargs)
|
296 |
|
297 |
+
if self.should_load_scheduler(Scheduler, use_karras):
|
298 |
+
self.load_scheduler(Scheduler, use_karras, **scheduler_kwargs)
|
|
|
299 |
|
300 |
+
if self.should_load_deepcache(deepcache_interval):
|
301 |
+
self.load_deepcache(deepcache_interval)
|
|
|
|
|
|
|
302 |
|
303 |
+
if self.should_load_ip_adapter(ip_adapter_model):
|
304 |
+
self.load_ip_adapter(ip_adapter_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
305 |
|
306 |
+
if self.should_load_upscaler(scale):
|
307 |
+
self.load_upscaler(scale)
|
|
|
|
|
|
|
308 |
|
|
|
|
|
|
|
|
|
309 |
|
310 |
+
# Get a singleton or a new instance of the Loader
|
311 |
+
def get_loader(singleton=False):
|
312 |
+
if not singleton:
|
313 |
+
return Loader()
|
314 |
+
else:
|
315 |
+
if not hasattr(get_loader, "_instance"):
|
316 |
+
get_loader._instance = Loader()
|
317 |
+
assert isinstance(get_loader._instance, Loader)
|
318 |
+
return get_loader._instance
|