adamelliotfields commited on
Commit
af35186
1 Parent(s): 80551a9

Add timer context manager

Browse files
Files changed (4) hide show
  1. lib/__init__.py +2 -0
  2. lib/inference.py +4 -2
  3. lib/loader.py +38 -35
  4. lib/utils.py +13 -0
lib/__init__.py CHANGED
@@ -10,6 +10,7 @@ from .utils import (
10
  enable_progress_bars,
11
  load_json,
12
  read_file,
 
13
  )
14
 
15
  __all__ = [
@@ -24,4 +25,5 @@ __all__ = [
24
  "generate",
25
  "load_json",
26
  "read_file",
 
27
  ]
 
10
  enable_progress_bars,
11
  load_json,
12
  read_file,
13
+ timer,
14
  )
15
 
16
  __all__ = [
 
25
  "generate",
26
  "load_json",
27
  "read_file",
28
+ "timer",
29
  ]
lib/inference.py CHANGED
@@ -251,7 +251,9 @@ def generate(
251
  loader.collect()
252
  gc.collect()
253
 
254
- diff = time.perf_counter() - start
 
 
255
  if Info:
256
- Info(f"Generated {len(images)} image{'s' if len(images) > 1 else ''} in {diff:.2f}s")
257
  return images
 
251
  loader.collect()
252
  gc.collect()
253
 
254
+ end = time.perf_counter()
255
+ msg = f"Generated {len(images)} image{'s' if len(images) > 1 else ''} in {end - start:.2f}s"
256
+ print(msg)
257
  if Info:
258
+ Info(msg)
259
  return images
lib/loader.py CHANGED
@@ -7,6 +7,7 @@ from diffusers.models import AutoencoderKL
7
 
8
  from .config import Config
9
  from .upscaler import RealESRGAN
 
10
 
11
 
12
  class Loader:
@@ -23,13 +24,15 @@ class Loader:
23
  cls._instance.upscaler = None
24
  return cls._instance
25
 
26
- def _should_offload_refiner(self, model=""):
 
27
  if self.refiner is None:
28
  return False
29
  if self.model and self.model.lower() != model.lower():
30
  return True
31
  return False
32
 
 
33
  def _should_unload_refiner(self, refiner=False):
34
  if self.refiner is None:
35
  return False
@@ -57,44 +60,45 @@ class Loader:
57
  return True
58
  return False
59
 
60
- def _offload_refiner(self):
61
  if self.refiner is not None:
62
- self.refiner.to("cpu", silence_dtype_warnings=True)
63
  self.refiner.vae = None
64
  self.refiner.scheduler = None
65
  self.refiner.tokenizer_2 = None
66
  self.refiner.text_encoder_2 = None
67
 
68
  def _unload_refiner(self):
69
- # already on CPU from offloading
70
- print("Unloading refiner")
 
71
 
72
  def _unload_upscaler(self):
73
- print(f"Unloading {self.upscaler.scale}x upscaler")
74
- self.upscaler.to("cpu")
 
75
 
76
  def _unload_deepcache(self):
77
  if self.pipe.deepcache is not None:
78
- print("Unloading DeepCache")
79
  self.pipe.deepcache.disable()
80
  delattr(self.pipe, "deepcache")
81
  if self.refiner is not None:
82
  if hasattr(self.refiner, "deepcache"):
83
- print("Unloading DeepCache for refiner")
84
  self.refiner.deepcache.disable()
85
  delattr(self.refiner, "deepcache")
86
 
87
  def _unload_pipeline(self):
88
- print(f"Unloading {self.model}")
89
- self.pipe.to("cpu", silence_dtype_warnings=True)
 
90
 
91
  def _unload(self, model, refiner, deepcache, scale):
92
  to_unload = []
93
  if self._should_unload_deepcache(deepcache): # remove deepcache first
94
  self._unload_deepcache()
95
 
96
- if self._should_offload_refiner(model):
97
- self._offload_refiner()
98
 
99
  if self._should_unload_refiner(refiner):
100
  self._unload_refiner()
@@ -119,8 +123,8 @@ class Loader:
119
  model = Config.REFINER_MODEL
120
  pipeline = Config.PIPELINES["img2img"]
121
  try:
122
- print(f"Loading {model}")
123
- self.refiner = pipeline.from_pretrained(model, **kwargs).to("cuda")
124
  except Exception as e:
125
  print(f"Error loading {model}: {e}")
126
  self.refiner = None
@@ -131,9 +135,9 @@ class Loader:
131
  def _load_upscaler(self, scale=1):
132
  if self.upscaler is None and scale > 1:
133
  try:
134
- print(f"Loading {scale}x upscaler")
135
- self.upscaler = RealESRGAN(scale, device=self.pipe.device)
136
- self.upscaler.load_weights()
137
  except Exception as e:
138
  print(f"Error loading {scale}x upscaler: {e}")
139
  self.upscaler = None
@@ -144,7 +148,7 @@ class Loader:
144
  return
145
  if pipe_has_deepcache and self.pipe.deepcache.params["cache_interval"] == interval:
146
  return
147
- print("Loading DeepCache")
148
  self.pipe.deepcache = DeepCacheSDHelper(pipe=self.pipe)
149
  self.pipe.deepcache.set_params(cache_interval=interval)
150
  self.pipe.deepcache.enable()
@@ -155,7 +159,6 @@ class Loader:
155
  return
156
  if refiner_has_deepcache and self.refiner.deepcache.params["cache_interval"] == interval:
157
  return
158
- print("Loading DeepCache for refiner")
159
  self.refiner.deepcache = DeepCacheSDHelper(pipe=self.refiner)
160
  self.refiner.deepcache.set_params(cache_interval=interval)
161
  self.refiner.deepcache.enable()
@@ -164,21 +167,21 @@ class Loader:
164
  pipeline = Config.PIPELINES[kind]
165
  if self.pipe is None:
166
  try:
167
- print(f"Loading {model}")
168
- self.model = model
169
- if model.lower() in Config.MODEL_CHECKPOINTS.keys():
170
- self.pipe = pipeline.from_single_file(
171
- f"https://huggingface.co/{model}/{Config.MODEL_CHECKPOINTS[model.lower()]}",
172
- **kwargs,
173
- ).to("cuda")
174
- else:
175
- self.pipe = pipeline.from_pretrained(model, **kwargs).to("cuda")
176
- if self.refiner is not None:
177
- self.refiner.vae = self.pipe.vae
178
- self.refiner.scheduler = self.pipe.scheduler
179
- self.refiner.tokenizer_2 = self.pipe.tokenizer_2
180
- self.refiner.text_encoder_2 = self.pipe.text_encoder_2
181
- self.refiner.to(self.pipe.device)
182
  except Exception as e:
183
  print(f"Error loading {model}: {e}")
184
  self.model = None
 
7
 
8
  from .config import Config
9
  from .upscaler import RealESRGAN
10
+ from .utils import timer
11
 
12
 
13
  class Loader:
 
24
  cls._instance.upscaler = None
25
  return cls._instance
26
 
27
+ # switching models
28
+ def _should_reset_refiner(self, model=""):
29
  if self.refiner is None:
30
  return False
31
  if self.model and self.model.lower() != model.lower():
32
  return True
33
  return False
34
 
35
+ # switching refiner
36
  def _should_unload_refiner(self, refiner=False):
37
  if self.refiner is None:
38
  return False
 
60
  return True
61
  return False
62
 
63
+ def _reset_refiner(self):
64
  if self.refiner is not None:
 
65
  self.refiner.vae = None
66
  self.refiner.scheduler = None
67
  self.refiner.tokenizer_2 = None
68
  self.refiner.text_encoder_2 = None
69
 
70
  def _unload_refiner(self):
71
+ if self.refiner is not None:
72
+ with timer("Unloading refiner"):
73
+ self.refiner.to("cpu", silence_dtype_warnings=True)
74
 
75
  def _unload_upscaler(self):
76
+ if self.upscaler is not None:
77
+ with timer(f"Unloading {self.upscaler.scale}x upscaler"):
78
+ self.upscaler.to("cpu")
79
 
80
  def _unload_deepcache(self):
81
  if self.pipe.deepcache is not None:
82
+ print("Disabling DeepCache")
83
  self.pipe.deepcache.disable()
84
  delattr(self.pipe, "deepcache")
85
  if self.refiner is not None:
86
  if hasattr(self.refiner, "deepcache"):
 
87
  self.refiner.deepcache.disable()
88
  delattr(self.refiner, "deepcache")
89
 
90
  def _unload_pipeline(self):
91
+ if self.pipe is not None:
92
+ with timer(f"Unloading {self.model}"):
93
+ self.pipe.to("cpu", silence_dtype_warnings=True)
94
 
95
  def _unload(self, model, refiner, deepcache, scale):
96
  to_unload = []
97
  if self._should_unload_deepcache(deepcache): # remove deepcache first
98
  self._unload_deepcache()
99
 
100
+ if self._should_reset_refiner(model):
101
+ self._reset_refiner()
102
 
103
  if self._should_unload_refiner(refiner):
104
  self._unload_refiner()
 
123
  model = Config.REFINER_MODEL
124
  pipeline = Config.PIPELINES["img2img"]
125
  try:
126
+ with timer(f"Loading {model}"):
127
+ self.refiner = pipeline.from_pretrained(model, **kwargs).to("cuda")
128
  except Exception as e:
129
  print(f"Error loading {model}: {e}")
130
  self.refiner = None
 
135
  def _load_upscaler(self, scale=1):
136
  if self.upscaler is None and scale > 1:
137
  try:
138
+ with timer(f"Loading {scale}x upscaler"):
139
+ self.upscaler = RealESRGAN(scale, device=self.pipe.device)
140
+ self.upscaler.load_weights()
141
  except Exception as e:
142
  print(f"Error loading {scale}x upscaler: {e}")
143
  self.upscaler = None
 
148
  return
149
  if pipe_has_deepcache and self.pipe.deepcache.params["cache_interval"] == interval:
150
  return
151
+ print("Enabling DeepCache")
152
  self.pipe.deepcache = DeepCacheSDHelper(pipe=self.pipe)
153
  self.pipe.deepcache.set_params(cache_interval=interval)
154
  self.pipe.deepcache.enable()
 
159
  return
160
  if refiner_has_deepcache and self.refiner.deepcache.params["cache_interval"] == interval:
161
  return
 
162
  self.refiner.deepcache = DeepCacheSDHelper(pipe=self.refiner)
163
  self.refiner.deepcache.set_params(cache_interval=interval)
164
  self.refiner.deepcache.enable()
 
167
  pipeline = Config.PIPELINES[kind]
168
  if self.pipe is None:
169
  try:
170
+ with timer(f"Loading {model}"):
171
+ self.model = model
172
+ if model.lower() in Config.MODEL_CHECKPOINTS.keys():
173
+ self.pipe = pipeline.from_single_file(
174
+ f"https://huggingface.co/{model}/{Config.MODEL_CHECKPOINTS[model.lower()]}",
175
+ **kwargs,
176
+ ).to("cuda")
177
+ else:
178
+ self.pipe = pipeline.from_pretrained(model, **kwargs).to("cuda")
179
+ if self.refiner is not None:
180
+ self.refiner.vae = self.pipe.vae
181
+ self.refiner.scheduler = self.pipe.scheduler
182
+ self.refiner.tokenizer_2 = self.pipe.tokenizer_2
183
+ self.refiner.text_encoder_2 = self.pipe.text_encoder_2
184
+ self.refiner.to(self.pipe.device)
185
  except Exception as e:
186
  print(f"Error loading {model}: {e}")
187
  self.model = None
lib/utils.py CHANGED
@@ -2,6 +2,8 @@ import functools
2
  import inspect
3
  import json
4
  import os
 
 
5
  from typing import Callable, TypeVar
6
 
7
  import anyio
@@ -20,6 +22,17 @@ MAX_CONCURRENT_THREADS = 1
20
  MAX_THREADS_GUARD = Semaphore(MAX_CONCURRENT_THREADS)
21
 
22
 
 
 
 
 
 
 
 
 
 
 
 
23
  @functools.lru_cache()
24
  def load_json(path: str) -> dict:
25
  with open(path, "r", encoding="utf-8") as file:
 
2
  import inspect
3
  import json
4
  import os
5
+ import time
6
+ from contextlib import contextmanager
7
  from typing import Callable, TypeVar
8
 
9
  import anyio
 
22
  MAX_THREADS_GUARD = Semaphore(MAX_CONCURRENT_THREADS)
23
 
24
 
25
+ @contextmanager
26
+ def timer(message="Operation", logger=print):
27
+ start = time.perf_counter()
28
+ logger(message)
29
+ try:
30
+ yield
31
+ finally:
32
+ end = time.perf_counter()
33
+ logger(f"{message} took {end - start:.2f}s")
34
+
35
+
36
  @functools.lru_cache()
37
  def load_json(path: str) -> dict:
38
  with open(path, "r", encoding="utf-8") as file: