adamelliotfields commited on
Commit
9769856
1 Parent(s): 1b15230

Rewrite loading and inference

Browse files
Files changed (2) hide show
  1. lib/inference.py +68 -118
  2. lib/loader.py +242 -260
lib/inference.py CHANGED
@@ -5,153 +5,112 @@ from datetime import datetime
5
  import torch
6
  from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType
7
  from compel.prompt_parser import PromptParser
8
- from huggingface_hub.utils import HFValidationError, RepositoryNotFoundError
9
- from spaces import GPU
10
 
11
- from .config import Config
12
- from .loader import Loader
13
  from .logger import Logger
14
- from .utils import (
15
- annotate_image,
16
- clear_cuda_cache,
17
- resize_image,
18
- safe_progress,
19
- timer,
20
- )
21
-
22
-
23
- # Dynamic signature for the GPU duration function
24
- def gpu_duration(**kwargs):
25
- loading = 20
26
- duration = 10
27
- width = kwargs.get("width", 512)
28
- height = kwargs.get("height", 512)
29
- scale = kwargs.get("scale", 1)
30
- num_images = kwargs.get("num_images", 1)
31
- size = width * height
32
- if size > 500_000:
33
- duration += 5
34
- if scale == 4:
35
- duration += 5
36
- return loading + (duration * num_images)
37
-
38
-
39
- # Request GPU when deployed to Hugging Face
40
- @GPU(duration=gpu_duration)
41
  def generate(
42
- positive_prompt,
43
  negative_prompt="",
44
- image_prompt=None,
45
- control_image_prompt=None,
46
- ip_image_prompt=None,
47
  seed=None,
48
- model="Lykon/dreamshaper-8",
49
- scheduler="DDIM",
50
- annotator="canny",
51
  width=512,
52
  height=512,
53
  guidance_scale=6.0,
54
  inference_steps=40,
55
  denoising_strength=0.8,
56
- deepcache=1,
57
  scale=1,
58
  num_images=1,
59
- karras=False,
60
- ip_face=False,
61
- Error=Exception,
62
- Info=None,
63
- progress=None,
64
  ):
65
- start = time.perf_counter()
66
- log = Logger("generate")
67
- log.info(f"Generating {num_images} image{'s' if num_images > 1 else ''}...")
68
-
69
- if Config.ZERO_GPU:
70
- safe_progress(progress, 100, 100, "ZeroGPU init")
71
-
72
  if not torch.cuda.is_available():
73
  raise Error("CUDA not available")
74
 
75
- # https://pytorch.org/docs/stable/generated/torch.manual_seed.html
76
- if seed is None or seed < 0:
77
- seed = int(datetime.now().timestamp() * 1_000_000) % (2**64)
78
 
79
- CURRENT_STEP = 0
80
- CURRENT_IMAGE = 1
 
81
 
82
- KIND = "img2img" if image_prompt is not None else "txt2img"
83
- KIND = f"controlnet_{KIND}" if control_image_prompt is not None else KIND
84
 
85
  EMBEDDINGS_TYPE = ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED
86
 
87
  FAST_NEGATIVE = "<fast_negative>" in negative_prompt
88
 
89
- if ip_image_prompt:
90
- IP_ADAPTER = "full-face" if ip_face else "plus"
91
  else:
92
- IP_ADAPTER = ""
93
-
94
- # Custom progress bar for multiple images
95
- def callback_on_step_end(pipeline, step, timestep, latents):
96
- nonlocal CURRENT_STEP, CURRENT_IMAGE
97
- if progress is not None:
98
- # calculate total steps for img2img based on denoising strength
99
- strength = denoising_strength if KIND == "img2img" else 1
100
- total_steps = min(int(inference_steps * strength), inference_steps)
101
- CURRENT_STEP = step + 1
102
- progress(
103
- (CURRENT_STEP, total_steps),
104
- desc=f"Generating image {CURRENT_IMAGE}/{num_images}",
105
- )
106
- return latents
107
 
108
- loader = Loader()
 
 
109
  loader.load(
110
  KIND,
111
- IP_ADAPTER,
112
  model,
113
  scheduler,
114
- annotator,
115
- deepcache,
116
  scale,
117
- karras,
118
- progress,
119
  )
120
 
121
- if loader.pipe is None:
122
- raise Error(f"Error loading {model}")
123
-
124
- pipe = loader.pipe
125
  upscaler = loader.upscaler
126
 
 
 
 
 
127
  # Load fast negative embedding
128
  if FAST_NEGATIVE:
129
  embeddings_dir = os.path.abspath(
130
  os.path.join(os.path.dirname(__file__), "..", "embeddings")
131
  )
132
- pipe.load_textual_inversion(
133
  pretrained_model_name_or_path=f"{embeddings_dir}/fast_negative.pt",
134
  token="<fast_negative>",
135
  )
136
 
137
  # Embed prompts with weights
138
  compel = Compel(
139
- device=pipe.device,
140
- tokenizer=pipe.tokenizer,
141
  truncate_long_prompts=False,
142
- text_encoder=pipe.text_encoder,
143
  returned_embeddings_type=EMBEDDINGS_TYPE,
144
- dtype_for_device_getter=lambda _: pipe.dtype,
145
- textual_inversion_manager=DiffusersTextualInversionManager(pipe),
146
  )
147
 
 
 
 
 
 
148
  images = []
149
  current_seed = seed
150
- safe_progress(progress, 0, num_images, f"Generating image 0/{num_images}")
151
 
152
  for i in range(num_images):
153
  try:
154
- generator = torch.Generator(device=pipe.device).manual_seed(current_seed)
155
  positive_embeds, negative_embeds = compel.pad_conditioning_tensors_to_same_length(
156
  [compel(positive_prompt), compel(negative_prompt)]
157
  )
@@ -169,53 +128,44 @@ def generate(
169
  "output_type": "np" if scale > 1 else "pil",
170
  }
171
 
172
- if progress is not None:
173
- kwargs["callback_on_step_end"] = callback_on_step_end
174
-
175
- # Resizing so the initial latents are the same size as the generated image
176
- if KIND == "img2img":
177
  kwargs["strength"] = denoising_strength
178
- kwargs["image"] = resize_image(image_prompt, (width, height))
179
 
180
  if KIND == "controlnet_txt2img":
181
- kwargs["image"] = annotate_image(control_image_prompt, annotator)
182
 
183
  if KIND == "controlnet_img2img":
184
- kwargs["control_image"] = annotate_image(control_image_prompt, annotator)
185
 
186
- if IP_ADAPTER:
187
- kwargs["ip_adapter_image"] = resize_image(ip_image_prompt)
 
188
 
189
  try:
190
- image = pipe(**kwargs).images[0]
191
- images.append((image, str(current_seed)))
192
  current_seed += 1
193
  finally:
194
  if FAST_NEGATIVE:
195
- pipe.unload_textual_inversion()
196
-
197
- CURRENT_STEP = 0
198
- CURRENT_IMAGE += 1
199
 
200
  # Upscale
201
  if scale > 1:
202
- msg = f"Upscaling {scale}x"
203
- with timer(msg, logger=log.info):
204
- safe_progress(progress, 0, num_images, desc=msg)
205
  for i, image in enumerate(images):
206
  image = upscaler.predict(image[0])
207
- images[i] = image
208
- safe_progress(progress, i + 1, num_images, desc=msg)
209
-
210
- # Flush memory after generating
211
- clear_cuda_cache()
212
 
213
  end = time.perf_counter()
214
  msg = f"Generating {len(images)} image{'s' if len(images) > 1 else ''} took {end - start:.2f}s"
215
  log.info(msg)
216
 
217
- # Alert if notifier provided
218
  if Info:
219
  Info(msg)
220
 
 
 
 
221
  return images
 
5
  import torch
6
  from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType
7
  from compel.prompt_parser import PromptParser
8
+ from gradio import Error, Info, Progress
9
+ from spaces import GPU, config
10
 
11
+ from .loader import get_loader
 
12
  from .logger import Logger
13
+ from .utils import annotate_image, cuda_collect, resize_image, timer
14
+
15
+
16
+ @GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def generate(
18
+ positive_prompt="",
19
  negative_prompt="",
20
+ image_input=None,
21
+ controlnet_input=None,
22
+ ip_adapter_input=None,
23
  seed=None,
24
+ model="XpucT/Reliberate",
25
+ scheduler="UniPC",
26
+ controlnet_annotator="canny",
27
  width=512,
28
  height=512,
29
  guidance_scale=6.0,
30
  inference_steps=40,
31
  denoising_strength=0.8,
32
+ deepcache_interval=1,
33
  scale=1,
34
  num_images=1,
35
+ use_karras=False,
36
+ use_ip_adapter_face=False,
37
+ _=Progress(track_tqdm=True),
 
 
38
  ):
 
 
 
 
 
 
 
39
  if not torch.cuda.is_available():
40
  raise Error("CUDA not available")
41
 
42
+ if positive_prompt.strip() == "":
43
+ raise Error("You must enter a prompt")
 
44
 
45
+ start = time.perf_counter()
46
+ log = Logger("generate")
47
+ log.info(f"Generating {num_images} image{'s' if num_images > 1 else ''}...")
48
 
49
+ KIND = "img2img" if image_input is not None else "txt2img"
50
+ KIND = f"controlnet_{KIND}" if controlnet_input is not None else KIND
51
 
52
  EMBEDDINGS_TYPE = ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED
53
 
54
  FAST_NEGATIVE = "<fast_negative>" in negative_prompt
55
 
56
+ if ip_adapter_input:
57
+ IP_KIND = "full-face" if use_ip_adapter_face else "plus"
58
  else:
59
+ IP_KIND = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ # ZeroGPU is serverless so you want ephemeral instances
62
+ # You want a singleton on localhost so the pipeline stays in memory
63
+ loader = get_loader(singleton=not config.Config.zero_gpu)
64
  loader.load(
65
  KIND,
66
+ IP_KIND,
67
  model,
68
  scheduler,
69
+ controlnet_annotator,
70
+ deepcache_interval,
71
  scale,
72
+ use_karras,
 
73
  )
74
 
75
+ pipeline = loader.pipeline
 
 
 
76
  upscaler = loader.upscaler
77
 
78
+ # Probably a typo in the config
79
+ if pipeline is None:
80
+ raise Error(f"Error loading {model}")
81
+
82
  # Load fast negative embedding
83
  if FAST_NEGATIVE:
84
  embeddings_dir = os.path.abspath(
85
  os.path.join(os.path.dirname(__file__), "..", "embeddings")
86
  )
87
+ pipeline.load_textual_inversion(
88
  pretrained_model_name_or_path=f"{embeddings_dir}/fast_negative.pt",
89
  token="<fast_negative>",
90
  )
91
 
92
  # Embed prompts with weights
93
  compel = Compel(
94
+ device=pipeline.device,
95
+ tokenizer=pipeline.tokenizer,
96
  truncate_long_prompts=False,
97
+ text_encoder=pipeline.text_encoder,
98
  returned_embeddings_type=EMBEDDINGS_TYPE,
99
+ dtype_for_device_getter=lambda _: pipeline.dtype,
100
+ textual_inversion_manager=DiffusersTextualInversionManager(pipeline),
101
  )
102
 
103
+ # https://pytorch.org/docs/stable/generated/torch.manual_seed.html
104
+ if seed is None or seed < 0:
105
+ seed = int(datetime.now().timestamp() * 1_000_000) % (2**64)
106
+
107
+ # Increment the seed after each iteration
108
  images = []
109
  current_seed = seed
 
110
 
111
  for i in range(num_images):
112
  try:
113
+ generator = torch.Generator(device=pipeline.device).manual_seed(current_seed)
114
  positive_embeds, negative_embeds = compel.pad_conditioning_tensors_to_same_length(
115
  [compel(positive_prompt), compel(negative_prompt)]
116
  )
 
128
  "output_type": "np" if scale > 1 else "pil",
129
  }
130
 
131
+ if KIND == "img2img" or KIND == "controlnet_img2img":
 
 
 
 
132
  kwargs["strength"] = denoising_strength
133
+ kwargs["image"] = resize_image(image_input, (width, height))
134
 
135
  if KIND == "controlnet_txt2img":
136
+ kwargs["image"] = annotate_image(controlnet_input, controlnet_annotator)
137
 
138
  if KIND == "controlnet_img2img":
139
+ kwargs["control_image"] = annotate_image(controlnet_input, controlnet_annotator)
140
 
141
+ if IP_KIND:
142
+ # No size means preserve aspect ratio
143
+ kwargs["ip_adapter_image"] = resize_image(ip_adapter_input)
144
 
145
  try:
146
+ image = pipeline(**kwargs).images[0]
147
+ images.append((image, str(current_seed))) # tuple with seed for gallery caption
148
  current_seed += 1
149
  finally:
150
  if FAST_NEGATIVE:
151
+ pipeline.unload_textual_inversion()
 
 
 
152
 
153
  # Upscale
154
  if scale > 1:
155
+ with timer(f"Upscaling {num_images} images {scale}x", logger=log.info):
 
 
156
  for i, image in enumerate(images):
157
  image = upscaler.predict(image[0])
158
+ seed = images[i][1]
159
+ images[i] = (image, seed) # tuple again
 
 
 
160
 
161
  end = time.perf_counter()
162
  msg = f"Generating {len(images)} image{'s' if len(images) > 1 else ''} took {end - start:.2f}s"
163
  log.info(msg)
164
 
 
165
  if Info:
166
  Info(msg)
167
 
168
+ # Flush cache before returning
169
+ cuda_collect()
170
+
171
  return images
lib/loader.py CHANGED
@@ -1,6 +1,3 @@
1
- import gc
2
- from threading import Lock
3
-
4
  import torch
5
  from DeepCache import DeepCacheSDHelper
6
  from diffusers import ControlNetModel
@@ -9,328 +6,313 @@ from diffusers.models.attention_processor import AttnProcessor2_0, IPAdapterAttn
9
  from .config import Config
10
  from .logger import Logger
11
  from .upscaler import RealESRGAN
12
- from .utils import clear_cuda_cache, safe_progress, timer
13
 
14
 
15
  class Loader:
16
- _instance = None
17
- _lock = Lock()
18
-
19
- def __new__(cls):
20
- with cls._lock:
21
- if cls._instance is None:
22
- cls._instance = super().__new__(cls)
23
- cls._instance.pipe = None
24
- cls._instance.model = None
25
- cls._instance.upscaler = None
26
- cls._instance.controlnet = None
27
- cls._instance.ip_adapter = None
28
- cls._instance.log = Logger("Loader")
29
- return cls._instance
30
-
31
- def _should_unload_upscaler(self, scale=1):
32
- if self.upscaler is not None and self.upscaler.scale != scale:
33
- return True
34
- return False
35
-
36
- def _should_unload_deepcache(self, interval=1):
37
- has_deepcache = hasattr(self.pipe, "deepcache")
38
- if has_deepcache and interval == 1:
 
 
 
 
 
 
 
 
 
 
 
39
  return True
40
- if has_deepcache and self.pipe.deepcache.params["cache_interval"] != interval:
 
41
  return True
42
  return False
43
 
44
- def _should_unload_ip_adapter(self, model="", ip_adapter=""):
45
- # unload if model changed
46
- if self.model and self.model.lower() != model.lower():
 
47
  return True
48
- if self.ip_adapter and not ip_adapter:
 
49
  return True
50
  return False
51
 
52
- def _should_unload_controlnet(self, kind="", controlnet=""):
53
  if self.controlnet is None:
54
  return False
55
- if self.controlnet.lower() != controlnet.lower():
56
  return True
57
- if not kind.startswith("controlnet_"):
58
  return True
59
  return False
60
 
61
- def _should_unload_pipeline(self, kind="", model="", controlnet=""):
62
- if self.pipe is None:
63
  return False
64
- if self.model.lower() != model.lower():
65
- return True
66
- if kind == "txt2img" and not isinstance(self.pipe, Config.PIPELINES["txt2img"]):
67
- return True
68
- if kind == "img2img" and not isinstance(self.pipe, Config.PIPELINES["img2img"]):
69
- return True
70
- if kind == "controlnet_txt2img" and not isinstance(
71
- self.pipe,
72
- Config.PIPELINES["controlnet_txt2img"],
73
- ):
74
- return True
75
- if kind == "controlnet_img2img" and not isinstance(
76
- self.pipe,
77
- Config.PIPELINES["controlnet_img2img"],
78
- ):
79
- return True
80
- if self._should_unload_controlnet(kind, controlnet):
81
  return True
82
  return False
83
 
84
- def _unload_upscaler(self):
85
- if self.upscaler is not None:
86
- with timer(f"Unloading {self.upscaler.scale}x upscaler", logger=self.log.info):
87
- self.upscaler.to("cpu")
88
-
89
- def _unload_deepcache(self):
90
- if self.pipe.deepcache is not None:
91
- self.log.info("Disabling DeepCache")
92
- self.pipe.deepcache.disable()
93
- delattr(self.pipe, "deepcache")
94
-
95
  # Copied from https://github.com/huggingface/diffusers/blob/v0.28.0/src/diffusers/loaders/ip_adapter.py#L300
96
- def _unload_ip_adapter(self):
97
- if self.ip_adapter is not None:
98
- with timer("Unloading IP-Adapter", logger=self.log.info):
99
- if not isinstance(self.pipe, Config.PIPELINES["img2img"]):
100
- self.pipe.image_encoder = None
101
- self.pipe.register_to_config(image_encoder=[None, None])
102
- self.pipe.feature_extractor = None
103
- self.pipe.unet.encoder_hid_proj = None
104
- self.pipe.unet.config.encoder_hid_dim_type = None
105
- self.pipe.register_to_config(feature_extractor=[None, None])
106
- attn_procs = {}
107
- for name, value in self.pipe.unet.attn_processors.items():
108
- attn_processor_class = AttnProcessor2_0() # raises if not torch 2
109
- attn_procs[name] = (
110
- attn_processor_class
111
- if isinstance(value, IPAdapterAttnProcessor2_0)
112
- else value.__class__()
113
- )
114
- self.pipe.unet.set_attn_processor(attn_procs)
115
-
116
- def _unload_pipeline(self):
117
- if self.pipe is not None:
118
- with timer(f"Unloading {self.model}", logger=self.log.info):
119
- self.pipe.to("cpu")
120
-
121
- def _unload(
 
122
  self,
123
- kind="",
 
124
  model="",
125
- controlnet="",
126
- ip_adapter="",
127
- deepcache=1,
128
  scale=1,
129
  ):
130
- to_unload = []
131
- if self._should_unload_deepcache(deepcache): # remove deepcache first
132
- self._unload_deepcache()
133
-
134
- if self._should_unload_upscaler(scale):
135
- self._unload_upscaler()
136
- to_unload.append("upscaler")
137
-
138
- if self._should_unload_ip_adapter(model, ip_adapter):
139
- self._unload_ip_adapter()
140
- to_unload.append("ip_adapter")
141
-
142
- if self._should_unload_controlnet(kind, controlnet):
143
- to_unload.append("controlnet")
144
-
145
- if self._should_unload_pipeline(kind, model, controlnet):
146
- self._unload_pipeline()
147
- to_unload.append("model")
148
- to_unload.append("pipe")
149
-
150
- # Flush cache and run garbage collector
151
- clear_cuda_cache()
152
- for component in to_unload:
153
- setattr(self, component, None)
154
- gc.collect()
155
-
156
- def _should_load_upscaler(self, scale=1):
157
- if self.upscaler is None and scale > 1:
 
158
  return True
159
  return False
160
 
161
- def _should_load_deepcache(self, interval=1):
162
- has_deepcache = hasattr(self.pipe, "deepcache")
163
- if not has_deepcache and interval != 1:
 
 
 
 
 
 
 
 
 
 
164
  return True
165
- if has_deepcache and self.pipe.deepcache.params["cache_interval"] != interval:
166
  return True
167
  return False
168
 
169
- def _should_load_ip_adapter(self, ip_adapter=""):
170
- if not self.ip_adapter and ip_adapter:
171
  return True
172
- return False
173
-
174
- def _should_load_pipeline(self):
175
- if self.pipe is None:
176
  return True
177
  return False
178
 
179
- def _load_upscaler(self, scale=1):
180
- if self._should_load_upscaler(scale):
181
- try:
182
- msg = f"Loading {scale}x upscaler"
183
- with timer(msg, logger=self.log.info):
184
- self.upscaler = RealESRGAN(scale, device=self.pipe.device)
185
- self.upscaler.load_weights()
186
- except Exception as e:
187
- self.log.error(f"Error loading {scale}x upscaler: {e}")
188
- self.upscaler = None
189
-
190
- def _load_deepcache(self, interval=1):
191
- if self._should_load_deepcache(interval):
192
- self.log.info("Enabling DeepCache")
193
- self.pipe.deepcache = DeepCacheSDHelper(self.pipe)
194
- self.pipe.deepcache.set_params(cache_interval=interval)
195
- self.pipe.deepcache.enable()
196
-
197
- def _load_ip_adapter(self, ip_adapter=""):
198
- if self._should_load_ip_adapter(ip_adapter):
199
- msg = "Loading IP-Adapter"
200
- with timer(msg, logger=self.log.info):
201
- self.pipe.load_ip_adapter(
202
- "h94/IP-Adapter",
203
- subfolder="models",
204
- weight_name=f"ip-adapter-{ip_adapter}_sd15.safetensors",
205
- )
206
- # 50% works the best
207
- self.pipe.set_ip_adapter_scale(0.5)
208
- self.ip_adapter = ip_adapter
209
-
210
- def _load_pipeline(
 
 
 
211
  self,
212
- kind,
213
  model,
214
- progress,
215
  **kwargs,
216
  ):
217
- pipeline = Config.PIPELINES[kind]
218
- if self._should_load_pipeline():
219
- try:
220
- with timer(f"Loading {model} ({kind})", logger=self.log.info):
221
- self.model = model
222
- if model.lower() in Config.MODEL_CHECKPOINTS.keys():
223
- self.pipe = pipeline.from_single_file(
224
- f"https://huggingface.co/{model}/{Config.MODEL_CHECKPOINTS[model.lower()]}",
225
- progress,
226
- **kwargs,
227
- ).to("cuda")
228
- else:
229
- self.pipe = pipeline.from_pretrained(model, progress, **kwargs).to("cuda")
230
- except Exception as e:
231
- self.log.error(f"Error loading {model}: {e}")
232
- self.model = None
233
- self.pipe = None
234
- return
235
- if not isinstance(self.pipe, pipeline):
236
- self.pipe = pipeline.from_pipe(self.pipe).to("cuda")
237
- if self.pipe is not None:
238
- self.pipe.set_progress_bar_config(disable=progress is not None)
 
 
 
 
 
 
 
 
239
 
240
  def load(
241
  self,
242
- kind,
243
- ip_adapter,
244
  model,
245
  scheduler,
246
- annotator,
247
- deepcache,
248
  scale,
249
- karras,
250
- progress,
251
  ):
 
 
252
  scheduler_kwargs = {
253
- "beta_schedule": "scaled_linear",
254
- "timestep_spacing": "leading",
255
  "beta_start": 0.00085,
256
  "beta_end": 0.012,
 
 
257
  "steps_offset": 1,
258
  }
259
 
260
- if scheduler not in ["DDIM", "Euler a", "PNDM"]:
261
- scheduler_kwargs["use_karras_sigmas"] = karras
262
-
263
- # https://github.com/huggingface/diffusers/blob/8a3f0c1/scripts/convert_original_stable_diffusion_to_diffusers.py#L939
264
- if scheduler == "DDIM":
265
- scheduler_kwargs["clip_sample"] = False
266
- scheduler_kwargs["set_alpha_to_one"] = False
267
 
268
- pipe_kwargs = {
 
269
  "safety_checker": None,
270
  "requires_safety_checker": False,
271
- "scheduler": Config.SCHEDULERS[scheduler](**scheduler_kwargs),
272
  }
273
 
274
- # diffusers fp16 variant
275
- if model.lower() not in Config.MODEL_CHECKPOINTS.keys():
276
- pipe_kwargs["variant"] = "fp16"
277
  else:
278
- pipe_kwargs["variant"] = None
279
-
280
- # converts to fp32 by default
281
- pipe_kwargs["torch_dtype"] = torch.float16
 
 
 
 
 
 
 
282
 
283
- # config maps the repo to the ID: canny -> lllyasviel/control_sd15_canny
284
- if kind.startswith("controlnet_"):
285
- pipe_kwargs["controlnet"] = ControlNetModel.from_pretrained(
286
- Config.ANNOTATORS[annotator],
287
- torch_dtype=torch.float16,
288
- variant="fp16",
289
- )
290
- self.controlnet = annotator
291
 
292
- self._unload(kind, model, annotator, ip_adapter, deepcache, scale)
293
- self._load_pipeline(kind, model, progress, **pipe_kwargs)
294
 
295
- # error loading model
296
- if self.pipe is None:
297
- return
298
 
299
- same_scheduler = isinstance(self.pipe.scheduler, Config.SCHEDULERS[scheduler])
300
- same_karras = (
301
- not hasattr(self.pipe.scheduler.config, "use_karras_sigmas")
302
- or self.pipe.scheduler.config.use_karras_sigmas == karras
303
- )
304
 
305
- # same model, different scheduler
306
- if self.model.lower() == model.lower():
307
- if not same_scheduler:
308
- self.log.info(f"Enabling {scheduler} scheduler")
309
- if not same_karras:
310
- self.log.info(f"{'Enabling' if karras else 'Disabling'} Karras sigmas")
311
- if not same_scheduler or not same_karras:
312
- self.pipe.scheduler = Config.SCHEDULERS[scheduler](**scheduler_kwargs)
313
-
314
- CURRENT_STEP = 1
315
- TOTAL_STEPS = sum(
316
- [
317
- self._should_load_deepcache(deepcache),
318
- self._should_load_ip_adapter(ip_adapter),
319
- self._should_load_upscaler(scale),
320
- ]
321
- )
322
 
323
- desc = "Configuring pipeline"
324
- if self._should_load_deepcache(deepcache):
325
- self._load_deepcache(deepcache)
326
- safe_progress(progress, CURRENT_STEP, TOTAL_STEPS, desc)
327
- CURRENT_STEP += 1
328
 
329
- if self._should_load_ip_adapter(ip_adapter):
330
- self._load_ip_adapter(ip_adapter)
331
- safe_progress(progress, CURRENT_STEP, TOTAL_STEPS, desc)
332
- CURRENT_STEP += 1
333
 
334
- if self._should_load_upscaler(scale):
335
- self._load_upscaler(scale)
336
- safe_progress(progress, CURRENT_STEP, TOTAL_STEPS, desc)
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from DeepCache import DeepCacheSDHelper
3
  from diffusers import ControlNetModel
 
6
  from .config import Config
7
  from .logger import Logger
8
  from .upscaler import RealESRGAN
9
+ from .utils import timer
10
 
11
 
12
  class Loader:
13
+ """
14
+ A lazy-loading resource manager for Stable Diffusion pipelines. Lifecycles are managed by
15
+ comparing the current state with desired. Can be used as a singleton when created by the
16
+ `get_loader()` helper.
17
+
18
+ Usage:
19
+ loader = get_loader(singleton=True)
20
+ loader.load(
21
+ pipeline_id="controlnet_txt2img",
22
+ ip_adapter_model="full-face",
23
+ model="XpucT/Reliberate",
24
+ scheduler="UniPC",
25
+ controlnet_annotator="canny",
26
+ deepcache_interval=2,
27
+ scale=2,
28
+ use_karras=True
29
+ )
30
+ """
31
+
32
+ def __init__(self):
33
+ self.model = ""
34
+ self.pipeline = None
35
+ self.upscaler = None
36
+ self.controlnet = None
37
+ self.annotator = "" # controlnet annotator (canny)
38
+ self.ip_adapter = "" # ip-adapter kind (full-face or plus)
39
+ self.log = Logger("Loader")
40
+
41
+ def should_unload_upscaler(self, scale=1):
42
+ return self.upscaler is not None and self.upscaler.scale != scale
43
+
44
+ def should_unload_deepcache(self, cache_interval=1):
45
+ has_deepcache = hasattr(self.pipeline, "deepcache")
46
+ if has_deepcache and cache_interval == 1:
47
  return True
48
+ if has_deepcache and self.pipeline.deepcache.params["cache_interval"] != cache_interval:
49
+ # Unload if interval is different so it can be reloaded
50
  return True
51
  return False
52
 
53
+ def should_unload_ip_adapter(self, ip_adapter_model=""):
54
+ if not self.ip_adapter:
55
+ return False
56
+ if not ip_adapter_model:
57
  return True
58
+ if self.ip_adapter != ip_adapter_model:
59
+ # Unload if model is different so it can be reloaded
60
  return True
61
  return False
62
 
63
+ def should_unload_controlnet(self, pipeline_id="", annotator=""):
64
  if self.controlnet is None:
65
  return False
66
+ if self.annotator != annotator:
67
  return True
68
+ if not pipeline_id.startswith("controlnet_"):
69
  return True
70
  return False
71
 
72
+ def should_unload_pipeline(self, model=""):
73
+ if self.pipeline is None:
74
  return False
75
+ if self.model != model:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  return True
77
  return False
78
 
 
 
 
 
 
 
 
 
 
 
 
79
  # Copied from https://github.com/huggingface/diffusers/blob/v0.28.0/src/diffusers/loaders/ip_adapter.py#L300
80
+ def unload_ip_adapter(self):
81
+ # Remove the image encoder if text-to-image
82
+ if isinstance(self.pipeline, Config.PIPELINES["txt2img"]):
83
+ self.pipeline.image_encoder = None
84
+ self.pipeline.register_to_config(image_encoder=[None, None])
85
+
86
+ # Remove hidden projection layer added by IP-Adapter
87
+ self.pipeline.unet.encoder_hid_proj = None
88
+ self.pipeline.unet.config.encoder_hid_dim_type = None
89
+
90
+ # Remove the feature extractor
91
+ self.pipeline.feature_extractor = None
92
+ self.pipeline.register_to_config(feature_extractor=[None, None])
93
+
94
+ # Replace the custom attention processors with defaults
95
+ attn_procs = {}
96
+ for name, value in self.pipeline.unet.attn_processors.items():
97
+ attn_processor_class = AttnProcessor2_0() # raises if not torch 2
98
+ attn_procs[name] = (
99
+ attn_processor_class
100
+ if isinstance(value, IPAdapterAttnProcessor2_0)
101
+ else value.__class__()
102
+ )
103
+ self.pipeline.unet.set_attn_processor(attn_procs)
104
+ self.ip_adapter = ""
105
+
106
+ def unload_all(
107
  self,
108
+ pipeline_id="",
109
+ ip_adapter_model="",
110
  model="",
111
+ controlnet_annotator="",
112
+ deepcache_interval=1,
 
113
  scale=1,
114
  ):
115
+ if self.should_unload_deepcache(deepcache_interval): # remove deepcache first
116
+ self.log.info("Disabling DeepCache")
117
+ self.pipeline.deepcache.disable()
118
+ delattr(self.pipeline, "deepcache")
119
+
120
+ if self.should_unload_ip_adapter(ip_adapter_model):
121
+ self.log.info("Unloading IP-Adapter")
122
+ self.unload_ip_adapter()
123
+
124
+ if self.should_unload_controlnet(pipeline_id, controlnet_annotator):
125
+ self.log.info("Unloading ControlNet")
126
+ self.controlnet = None
127
+ self.annotator = ""
128
+
129
+ if self.should_unload_upscaler(scale):
130
+ self.log.info("Unloading upscaler")
131
+ self.upscaler = None
132
+
133
+ if self.should_unload_pipeline(model):
134
+ self.log.info("Unloading pipeline")
135
+ self.pipeline = None
136
+ self.model = ""
137
+
138
+ def should_load_upscaler(self, scale=1):
139
+ return self.upscaler is None and scale > 1
140
+
141
+ def should_load_deepcache(self, cache_interval=1):
142
+ has_deepcache = hasattr(self.pipeline, "deepcache")
143
+ if not has_deepcache and cache_interval > 1:
144
  return True
145
  return False
146
 
147
+ def should_load_controlnet(self, pipeline_id=""):
148
+ return self.controlnet is None and pipeline_id.startswith("controlnet_")
149
+
150
+ def should_load_ip_adapter(self, ip_adapter_model=""):
151
+ has_ip_adapter = (
152
+ hasattr(self.pipeline.unet, "encoder_hid_proj")
153
+ and self.pipeline.unet.config.encoder_hid_dim_type == "ip_image_proj"
154
+ )
155
+ return not has_ip_adapter and ip_adapter_model != ""
156
+
157
+ def should_load_scheduler(self, cls, use_karras=False):
158
+ has_karras = hasattr(self.pipeline.scheduler.config, "use_karras_sigmas")
159
+ if not isinstance(self.pipeline.scheduler, cls):
160
  return True
161
+ if has_karras and self.pipeline.scheduler.config.use_karras_sigmas != use_karras:
162
  return True
163
  return False
164
 
165
+ def should_load_pipeline(self, pipeline_id=""):
166
+ if self.pipeline is None:
167
  return True
168
+ if not isinstance(self.pipeline, Config.PIPELINES[pipeline_id]):
 
 
 
169
  return True
170
  return False
171
 
172
+ def load_upscaler(self, scale=1):
173
+ with timer(f"Loading {scale}x upscaler", logger=self.log.info):
174
+ self.upscaler = RealESRGAN(scale, device=self.pipeline.device)
175
+ self.upscaler.load_weights()
176
+
177
+ def load_deepcache(self, cache_interval=1):
178
+ self.log.info(f"Enabling DeepCache interval {cache_interval}")
179
+ self.pipeline.deepcache = DeepCacheSDHelper(self.pipeline)
180
+ self.pipeline.deepcache.set_params(cache_interval=cache_interval)
181
+ self.pipeline.deepcache.enable()
182
+
183
+ def load_controlnet(self, controlnet_annotator):
184
+ with timer("Loading ControlNet", logger=self.log.info):
185
+ self.controlnet = ControlNetModel.from_pretrained(
186
+ Config.ANNOTATORS[controlnet_annotator],
187
+ variant="fp16",
188
+ torch_dtype=torch.float16,
189
+ )
190
+ self.annotator = controlnet_annotator
191
+
192
+ def load_ip_adapter(self, ip_adapter_model=""):
193
+ with timer("Loading IP-Adapter", logger=self.log.info):
194
+ self.pipeline.load_ip_adapter(
195
+ "h94/IP-Adapter",
196
+ subfolder="models",
197
+ weight_name=f"ip-adapter-{ip_adapter_model}_sd15.safetensors",
198
+ )
199
+ self.pipeline.set_ip_adapter_scale(0.5) # 50% works the best
200
+ self.ip_adapter = ip_adapter_model
201
+
202
+ def load_scheduler(self, cls, use_karras=False, **kwargs):
203
+ self.log.info(f"Loading {cls.__name__}{' with Karras' if use_karras else ''}")
204
+ self.pipeline.scheduler = cls(**kwargs)
205
+
206
+ def load_pipeline(
207
  self,
208
+ pipeline_id,
209
  model,
 
210
  **kwargs,
211
  ):
212
+ Pipeline = Config.PIPELINES[pipeline_id]
213
+
214
+ # Load from scratch
215
+ if self.pipeline is None:
216
+ with timer(f"Loading {model} ({pipeline_id})", logger=self.log.info):
217
+ if self.controlnet is not None:
218
+ kwargs["controlnet"] = self.controlnet
219
+ if model in Config.SINGLE_FILE_MODELS:
220
+ checkpoint = Config.HF_REPOS[model][0]
221
+ self.pipeline = Pipeline.from_single_file(
222
+ f"https://huggingface.co/{model}/{checkpoint}",
223
+ **kwargs,
224
+ ).to("cuda")
225
+ else:
226
+ self.pipeline = Pipeline.from_pretrained(model, **kwargs).to("cuda")
227
+
228
+ # Change to a different one
229
+ else:
230
+ with timer(f"Changing pipeline to {pipeline_id}", logger=self.log.info):
231
+ kwargs = {}
232
+ if self.controlnet is not None:
233
+ kwargs["controlnet"] = self.controlnet
234
+ self.pipeline = Pipeline.from_pipe(
235
+ self.pipeline,
236
+ **kwargs,
237
+ ).to("cuda")
238
+
239
+ # Update model and disable terminal progress bars
240
+ self.model = model
241
+ self.pipeline.set_progress_bar_config(disable=True)
242
 
243
  def load(
244
  self,
245
+ pipeline_id,
246
+ ip_adapter_model,
247
  model,
248
  scheduler,
249
+ controlnet_annotator,
250
+ deepcache_interval,
251
  scale,
252
+ use_karras,
 
253
  ):
254
+ Scheduler = Config.SCHEDULERS[scheduler]
255
+
256
  scheduler_kwargs = {
 
 
257
  "beta_start": 0.00085,
258
  "beta_end": 0.012,
259
+ "beta_schedule": "scaled_linear",
260
+ "timestep_spacing": "leading",
261
  "steps_offset": 1,
262
  }
263
 
264
+ if scheduler not in ["Euler a"]:
265
+ scheduler_kwargs["use_karras_sigmas"] = use_karras
 
 
 
 
 
266
 
267
+ pipeline_kwargs = {
268
+ "torch_dtype": torch.float16, # defaults to fp32
269
  "safety_checker": None,
270
  "requires_safety_checker": False,
271
+ "scheduler": Scheduler(**scheduler_kwargs),
272
  }
273
 
274
+ # Single-file models don't need a variant
275
+ if model not in Config.SINGLE_FILE_MODELS:
276
+ pipeline_kwargs["variant"] = "fp16"
277
  else:
278
+ pipeline_kwargs["variant"] = None
279
+
280
+ # Prepare state for loading checks
281
+ self.unload_all(
282
+ pipeline_id,
283
+ ip_adapter_model,
284
+ model,
285
+ controlnet_annotator,
286
+ deepcache_interval,
287
+ scale,
288
+ )
289
 
290
+ # Load controlnet model before pipeline
291
+ if self.should_load_controlnet(pipeline_id):
292
+ self.load_controlnet(controlnet_annotator)
 
 
 
 
 
293
 
294
+ if self.should_load_pipeline(pipeline_id):
295
+ self.load_pipeline(pipeline_id, model, **pipeline_kwargs)
296
 
297
+ if self.should_load_scheduler(Scheduler, use_karras):
298
+ self.load_scheduler(Scheduler, use_karras, **scheduler_kwargs)
 
299
 
300
+ if self.should_load_deepcache(deepcache_interval):
301
+ self.load_deepcache(deepcache_interval)
 
 
 
302
 
303
+ if self.should_load_ip_adapter(ip_adapter_model):
304
+ self.load_ip_adapter(ip_adapter_model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
+ if self.should_load_upscaler(scale):
307
+ self.load_upscaler(scale)
 
 
 
308
 
 
 
 
 
309
 
310
+ # Get a singleton or a new instance of the Loader
311
+ def get_loader(singleton=False):
312
+ if not singleton:
313
+ return Loader()
314
+ else:
315
+ if not hasattr(get_loader, "_instance"):
316
+ get_loader._instance = Loader()
317
+ assert isinstance(get_loader._instance, Loader)
318
+ return get_loader._instance