adamelliotfields commited on
Commit
bb42c8d
1 Parent(s): 163a3a9

Fix deepcache loading

Browse files
Files changed (1) hide show
  1. lib/loader.py +35 -22
lib/loader.py CHANGED
@@ -38,12 +38,25 @@ class Loader:
38
  return True
39
  return False
40
 
 
 
 
 
 
 
 
 
41
  def _unload_deepcache(self):
42
  if self.pipe.deepcache is None:
43
  return
44
  print("Unloading DeepCache")
45
  self.pipe.deepcache.disable()
46
  delattr(self.pipe, "deepcache")
 
 
 
 
 
47
 
48
  # don't unload refiner
49
  def _unload(self, model, deepcache):
@@ -59,6 +72,28 @@ class Loader:
59
  for component in to_unload:
60
  setattr(self, component, None)
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def _load_pipeline(self, kind, model, progress, **kwargs):
63
  pipeline = Config.PIPELINES[kind]
64
  if self.pipe is None:
@@ -119,28 +154,6 @@ class Loader:
119
  print(f"Error loading 4x upscaler: {e}")
120
  self.upscaler_4x = None
121
 
122
- def _load_deepcache(self, interval=1):
123
- pipe_has_deepcache = hasattr(self.pipe, "deepcache")
124
- if not pipe_has_deepcache and interval == 1:
125
- return
126
- if pipe_has_deepcache and self.pipe.deepcache.params["cache_interval"] == interval:
127
- return
128
- print("Loading DeepCache")
129
- self.pipe.deepcache = DeepCacheSDHelper(pipe=self.pipe)
130
- self.pipe.deepcache.set_params(cache_interval=interval)
131
- self.pipe.deepcache.enable()
132
-
133
- if self.refiner is not None:
134
- refiner_has_deepcache = hasattr(self.refiner, "deepcache")
135
- if not refiner_has_deepcache and interval == 1:
136
- return
137
- if refiner_has_deepcache and self.refiner.deepcache.params["cache_interval"] == interval:
138
- return
139
- print("Loading DeepCache for refiner")
140
- self.refiner.deepcache = DeepCacheSDHelper(pipe=self.refiner)
141
- self.refiner.deepcache.set_params(cache_interval=interval)
142
- self.refiner.deepcache.enable()
143
-
144
  def load(self, kind, model, scheduler, deepcache, scale, karras, refiner, progress):
145
  scheduler_kwargs = {
146
  "beta_start": 0.00085,
 
38
  return True
39
  return False
40
 
41
+ def _should_unload_deepcache(self, interval=1):
42
+ has_deepcache = hasattr(self.pipe, "deepcache")
43
+ if has_deepcache and interval == 1:
44
+ return True
45
+ if has_deepcache and self.pipe.deepcache.params["cache_interval"] != interval:
46
+ return True
47
+ return False
48
+
49
  def _unload_deepcache(self):
50
  if self.pipe.deepcache is None:
51
  return
52
  print("Unloading DeepCache")
53
  self.pipe.deepcache.disable()
54
  delattr(self.pipe, "deepcache")
55
+ if self.refiner is not None:
56
+ if hasattr(self.refiner, "deepcache"):
57
+ print("Unloading DeepCache for refiner")
58
+ self.refiner.deepcache.disable()
59
+ delattr(self.refiner, "deepcache")
60
 
61
  # don't unload refiner
62
  def _unload(self, model, deepcache):
 
72
  for component in to_unload:
73
  setattr(self, component, None)
74
 
75
+ def _load_deepcache(self, interval=1):
76
+ pipe_has_deepcache = hasattr(self.pipe, "deepcache")
77
+ if not pipe_has_deepcache and interval == 1:
78
+ return
79
+ if pipe_has_deepcache and self.pipe.deepcache.params["cache_interval"] == interval:
80
+ return
81
+ print("Loading DeepCache")
82
+ self.pipe.deepcache = DeepCacheSDHelper(pipe=self.pipe)
83
+ self.pipe.deepcache.set_params(cache_interval=interval)
84
+ self.pipe.deepcache.enable()
85
+
86
+ if self.refiner is not None:
87
+ refiner_has_deepcache = hasattr(self.refiner, "deepcache")
88
+ if not refiner_has_deepcache and interval == 1:
89
+ return
90
+ if refiner_has_deepcache and self.refiner.deepcache.params["cache_interval"] == interval:
91
+ return
92
+ print("Loading DeepCache for refiner")
93
+ self.refiner.deepcache = DeepCacheSDHelper(pipe=self.refiner)
94
+ self.refiner.deepcache.set_params(cache_interval=interval)
95
+ self.refiner.deepcache.enable()
96
+
97
  def _load_pipeline(self, kind, model, progress, **kwargs):
98
  pipeline = Config.PIPELINES[kind]
99
  if self.pipe is None:
 
154
  print(f"Error loading 4x upscaler: {e}")
155
  self.upscaler_4x = None
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  def load(self, kind, model, scheduler, deepcache, scale, karras, refiner, progress):
158
  scheduler_kwargs = {
159
  "beta_start": 0.00085,