adamelliotfields commited on
Commit
7e19bd9
1 Parent(s): 7b8e908

Loader improvements

Browse files
Files changed (1) hide show
  1. lib/loader.py +106 -74
lib/loader.py CHANGED
@@ -18,9 +18,12 @@ from .upscaler import RealESRGAN
18
 
19
  __import__("warnings").filterwarnings("ignore", category=FutureWarning, module="diffusers")
20
 
 
 
 
 
 
21
 
22
- # inspired by ComfyUI
23
- # https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/model_management.py
24
  class Loader:
25
  _instance = None
26
 
@@ -32,40 +35,69 @@ class Loader:
32
  cls._instance.ip_adapter = None
33
  return cls._instance
34
 
35
- def _load_upscaler(self, device=None, scale=4):
36
- same_scale = self.upscaler is not None and self.upscaler.scale == scale
37
- if scale == 1:
38
- self.upscaler = None
39
- if scale > 1 and not same_scale:
40
- self.upscaler = RealESRGAN(device=device, scale=scale)
41
- self.upscaler.load_weights()
42
 
43
- def _load_deepcache(self, interval=1):
44
- has_deepcache = hasattr(self.pipe, "deepcache")
45
- if has_deepcache and self.pipe.deepcache.params["cache_interval"] == interval:
46
- return
47
- if has_deepcache:
48
- self.pipe.deepcache.disable()
49
- else:
50
- self.pipe.deepcache = DeepCacheSDHelper(pipe=self.pipe)
51
- self.pipe.deepcache.set_params(cache_interval=interval)
52
- self.pipe.deepcache.enable()
53
 
54
- def _load_freeu(self, freeu=False):
55
- # https://github.com/huggingface/diffusers/blob/v0.30.0/src/diffusers/models/unets/unet_2d_condition.py
56
- block = self.pipe.unet.up_blocks[0]
57
- attrs = ["b1", "b2", "s1", "s2"]
58
- has_freeu = all(getattr(block, attr, None) is not None for attr in attrs)
59
- if has_freeu and not freeu:
60
- print("Disabling FreeU...")
61
- self.pipe.disable_freeu()
62
- elif not has_freeu and freeu:
63
- # https://github.com/ChenyangSi/FreeU
64
- print("Enabling FreeU...")
65
- self.pipe.enable_freeu(b1=1.5, b2=1.6, s1=0.9, s2=0.2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  def _load_ip_adapter(self, ip_adapter=None):
68
- if self.ip_adapter is None and self.ip_adapter != ip_adapter:
69
  print(f"Loading IP Adapter: {ip_adapter}...")
70
  self.pipe.load_ip_adapter(
71
  "h94/IP-Adapter",
@@ -76,27 +108,19 @@ class Loader:
76
  self.pipe.set_ip_adapter_scale(0.5)
77
  self.ip_adapter = ip_adapter
78
 
79
- if self.ip_adapter is not None and ip_adapter is None:
80
- print("Unloading IP Adapter...")
81
- if not isinstance(self.pipe, StableDiffusionImg2ImgPipeline):
82
- self.pipe.image_encoder = None
83
- self.pipe.register_to_config(image_encoder=[None, None])
84
-
85
- self.pipe.feature_extractor = None
86
- self.pipe.unet.encoder_hid_proj = None
87
- self.pipe.unet.config.encoder_hid_dim_type = None
88
- self.pipe.register_to_config(feature_extractor=[None, None])
89
-
90
- attn_procs = {}
91
- for name, value in self.pipe.unet.attn_processors.items():
92
- attn_processor_class = AttnProcessor2_0() # raises if not torch 2
93
- attn_procs[name] = (
94
- attn_processor_class
95
- if isinstance(value, IPAdapterAttnProcessor2_0)
96
- else value.__class__()
97
- )
98
- self.pipe.unet.set_attn_processor(attn_procs)
99
- self.pipe.ip_adapter = None
100
 
101
  def _load_vae(self, taesd=False, model_name=None, variant=None):
102
  vae_type = type(self.pipe.vae)
@@ -127,16 +151,29 @@ class Loader:
127
  model=model,
128
  )
129
 
130
- def _load_pipeline(self, kind, model, device, **kwargs):
131
- pipelines = {
132
- "txt2img": StableDiffusionPipeline,
133
- "img2img": StableDiffusionImg2ImgPipeline,
134
- }
135
- if self.pipe is None:
136
- self.pipe = pipelines[kind].from_pretrained(model, **kwargs).to(device)
137
- if not isinstance(self.pipe, pipelines[kind]):
138
- self.pipe = pipelines[kind].from_pipe(self.pipe).to(device)
139
- self.ip_adapter = None
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  def load(
142
  self,
@@ -153,6 +190,7 @@ class Loader:
153
  dtype,
154
  ):
155
  model_lower = model.lower()
 
156
 
157
  schedulers = {
158
  "DDIM": DDIMScheduler,
@@ -197,33 +235,27 @@ class Loader:
197
  "variant": variant,
198
  }
199
 
200
- if self.pipe is None:
201
- print(f"Loading {model_lower} with {'Tiny' if taesd else 'KL'} VAE...")
202
 
203
- self._load_pipeline(kind, model_lower, device, **pipe_kwargs)
204
- model_name = self.pipe.config._name_or_path
205
- same_model = model_name.lower() == model_lower
206
  same_scheduler = isinstance(self.pipe.scheduler, schedulers[scheduler])
207
  same_karras = (
208
  not hasattr(self.pipe.scheduler.config, "use_karras_sigmas")
209
  or self.pipe.scheduler.config.use_karras_sigmas == karras
210
  )
211
 
212
- if same_model:
 
213
  if not same_scheduler:
214
  print(f"Switching to {scheduler}...")
215
  if not same_karras:
216
  print(f"{'Enabling' if karras else 'Disabling'} Karras sigmas...")
217
  if not same_scheduler or not same_karras:
218
  self.pipe.scheduler = schedulers[scheduler](**scheduler_kwargs)
219
- else:
220
- self.pipe = None
221
- self._load_pipeline(kind, model_lower, device, **pipe_kwargs)
222
 
 
223
  self._load_ip_adapter(ip_adapter)
224
  self._load_vae(taesd, model_lower, variant)
225
  self._load_freeu(freeu)
226
  self._load_deepcache(deepcache)
227
- self._load_upscaler(device, scale)
228
- torch.cuda.empty_cache()
229
  return self.pipe, self.upscaler
 
18
 
19
  __import__("warnings").filterwarnings("ignore", category=FutureWarning, module="diffusers")
20
 
21
+ PIPELINES = {
22
+ "txt2img": StableDiffusionPipeline,
23
+ "img2img": StableDiffusionImg2ImgPipeline,
24
+ }
25
+
26
 
 
 
27
  class Loader:
28
  _instance = None
29
 
 
35
  cls._instance.ip_adapter = None
36
  return cls._instance
37
 
38
+ def _should_unload_upscaler(self, scale=1):
39
+ return self.upscaler is not None and scale == 1
 
 
 
 
 
40
 
41
+ def _should_unload_ip_adapter(self, ip_adapter=None):
42
+ return self.ip_adapter is not None and ip_adapter is None
 
 
 
 
 
 
 
 
43
 
44
+ def _should_unload_pipeline(self, kind="", model=""):
45
+ if self.pipe is None:
46
+ return False
47
+ if self.pipe.config._name_or_path.lower() != model.lower():
48
+ return True
49
+ if kind == "txt2img" and not isinstance(self.pipe, StableDiffusionPipeline):
50
+ return True # txt2img -> img2img
51
+ if kind == "img2img" and not isinstance(self.pipe, StableDiffusionImg2ImgPipeline):
52
+ return True # img2img -> txt2img
53
+ return False
54
+
55
+ def _unload_ip_adapter(self):
56
+ print("Unloading IP Adapter...")
57
+ if not isinstance(self.pipe, StableDiffusionImg2ImgPipeline):
58
+ self.pipe.image_encoder = None
59
+ self.pipe.register_to_config(image_encoder=[None, None])
60
+
61
+ self.pipe.feature_extractor = None
62
+ self.pipe.unet.encoder_hid_proj = None
63
+ self.pipe.unet.config.encoder_hid_dim_type = None
64
+ self.pipe.register_to_config(feature_extractor=[None, None])
65
+
66
+ attn_procs = {}
67
+ for name, value in self.pipe.unet.attn_processors.items():
68
+ attn_processor_class = AttnProcessor2_0() # raises if not torch 2
69
+ attn_procs[name] = (
70
+ attn_processor_class
71
+ if isinstance(value, IPAdapterAttnProcessor2_0)
72
+ else value.__class__()
73
+ )
74
+ self.pipe.unet.set_attn_processor(attn_procs)
75
+
76
+ def _unload(self, kind="", model="", ip_adapter=None, scale=1):
77
+ to_unload = []
78
+
79
+ if self._should_unload_upscaler(scale):
80
+ to_unload.append("upscaler")
81
+
82
+ if self._should_unload_ip_adapter(ip_adapter):
83
+ self._unload_ip_adapter()
84
+ to_unload.append("ip_adapter")
85
+
86
+ if self._should_unload_pipeline(kind, model):
87
+ to_unload.append("pipe")
88
+
89
+ for component in to_unload:
90
+ if hasattr(self, component):
91
+ delattr(self, component)
92
+
93
+ torch.cuda.empty_cache()
94
+ torch.cuda.ipc_collect()
95
+
96
+ for component in to_unload:
97
+ setattr(self, component, None)
98
 
99
  def _load_ip_adapter(self, ip_adapter=None):
100
+ if self.ip_adapter is None and ip_adapter is not None:
101
  print(f"Loading IP Adapter: {ip_adapter}...")
102
  self.pipe.load_ip_adapter(
103
  "h94/IP-Adapter",
 
108
  self.pipe.set_ip_adapter_scale(0.5)
109
  self.ip_adapter = ip_adapter
110
 
111
+ def _load_upscaler(self, device=None, scale=1):
112
+ if scale > 1 and self.upscaler is None:
113
+ print(f"Loading {scale}x upscaler...")
114
+ self.upscaler = RealESRGAN(device=device, scale=scale)
115
+ self.upscaler.load_weights()
116
+
117
+ def _load_pipeline(self, kind, model, taesd, device, **kwargs):
118
+ pipeline = PIPELINES[kind]
119
+ if self.pipe is None:
120
+ print(f"Loading {model.lower()} with {'Tiny' if taesd else 'KL'} VAE...")
121
+ self.pipe = pipeline.from_pretrained(model, **kwargs).to(device)
122
+ if not isinstance(self.pipe, pipeline):
123
+ self.pipe = pipeline.from_pipe(self.pipe).to(device)
 
 
 
 
 
 
 
 
124
 
125
  def _load_vae(self, taesd=False, model_name=None, variant=None):
126
  vae_type = type(self.pipe.vae)
 
151
  model=model,
152
  )
153
 
154
+ def _load_deepcache(self, interval=1):
155
+ has_deepcache = hasattr(self.pipe, "deepcache")
156
+ if has_deepcache and self.pipe.deepcache.params["cache_interval"] == interval:
157
+ return
158
+ if has_deepcache:
159
+ self.pipe.deepcache.disable()
160
+ else:
161
+ self.pipe.deepcache = DeepCacheSDHelper(pipe=self.pipe)
162
+ self.pipe.deepcache.set_params(cache_interval=interval)
163
+ self.pipe.deepcache.enable()
164
+
165
+ def _load_freeu(self, freeu=False):
166
+ # https://github.com/huggingface/diffusers/blob/v0.30.0/src/diffusers/models/unets/unet_2d_condition.py
167
+ block = self.pipe.unet.up_blocks[0]
168
+ attrs = ["b1", "b2", "s1", "s2"]
169
+ has_freeu = all(getattr(block, attr, None) is not None for attr in attrs)
170
+ if has_freeu and not freeu:
171
+ print("Disabling FreeU...")
172
+ self.pipe.disable_freeu()
173
+ elif not has_freeu and freeu:
174
+ # https://github.com/ChenyangSi/FreeU
175
+ print("Enabling FreeU...")
176
+ self.pipe.enable_freeu(b1=1.5, b2=1.6, s1=0.9, s2=0.2)
177
 
178
  def load(
179
  self,
 
190
  dtype,
191
  ):
192
  model_lower = model.lower()
193
+ model_name = self.pipe.config._name_or_path.lower() if self.pipe is not None else ""
194
 
195
  schedulers = {
196
  "DDIM": DDIMScheduler,
 
235
  "variant": variant,
236
  }
237
 
238
+ self._unload(kind, model, ip_adapter, scale)
239
+ self._load_pipeline(kind, model, taesd, device, **pipe_kwargs)
240
 
 
 
 
241
  same_scheduler = isinstance(self.pipe.scheduler, schedulers[scheduler])
242
  same_karras = (
243
  not hasattr(self.pipe.scheduler.config, "use_karras_sigmas")
244
  or self.pipe.scheduler.config.use_karras_sigmas == karras
245
  )
246
 
247
+ # same model, different scheduler
248
+ if model_name == model_lower:
249
  if not same_scheduler:
250
  print(f"Switching to {scheduler}...")
251
  if not same_karras:
252
  print(f"{'Enabling' if karras else 'Disabling'} Karras sigmas...")
253
  if not same_scheduler or not same_karras:
254
  self.pipe.scheduler = schedulers[scheduler](**scheduler_kwargs)
 
 
 
255
 
256
+ self._load_upscaler(device, scale)
257
  self._load_ip_adapter(ip_adapter)
258
  self._load_vae(taesd, model_lower, variant)
259
  self._load_freeu(freeu)
260
  self._load_deepcache(deepcache)
 
 
261
  return self.pipe, self.upscaler