adamelliotfields commited on
Commit
163a3a9
1 Parent(s): 933318d

Loading and inferencing improvements

Browse files
Files changed (4) hide show
  1. lib/config.py +43 -5
  2. lib/inference.py +62 -90
  3. lib/loader.py +30 -28
  4. lib/upscaler.py +5 -7
lib/config.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from types import SimpleNamespace
2
 
3
  from diffusers import (
@@ -10,7 +12,38 @@ from diffusers import (
10
  StableDiffusionXLPipeline,
11
  )
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  Config = SimpleNamespace(
 
 
 
14
  MONO_FONTS=["monospace"],
15
  SANS_FONTS=[
16
  "sans-serif",
@@ -23,6 +56,11 @@ Config = SimpleNamespace(
23
  "txt2img": StableDiffusionXLPipeline,
24
  "img2img": StableDiffusionXLImg2ImgPipeline,
25
  },
 
 
 
 
 
26
  MODEL="segmind/Segmind-Vega",
27
  MODELS=[
28
  "cagliostrolab/animagine-xl-3.1",
@@ -49,13 +87,13 @@ Config = SimpleNamespace(
49
  "Euler": EulerDiscreteScheduler,
50
  "Euler a": EulerAncestralDiscreteScheduler,
51
  },
52
- STYLE="sai-enhance",
53
- WIDTH=896,
54
- HEIGHT=1152,
55
  NUM_IMAGES=1,
56
  SEED=-1,
57
- GUIDANCE_SCALE=6,
58
- INFERENCE_STEPS=35,
59
  DEEPCACHE_INTERVAL=1,
60
  SCALE=1,
61
  SCALES=[1, 2, 4],
 
1
+ import os
2
+ from importlib import import_module
3
  from types import SimpleNamespace
4
 
5
  from diffusers import (
 
12
  StableDiffusionXLPipeline,
13
  )
14
 
15
+ # improved GPU handling and progress bars; set before importing spaces
16
+ os.environ["ZEROGPU_V2"] = "true"
17
+
18
+ _sdxl_refiner_files = [
19
+ "scheduler/scheduler_config.json",
20
+ "text_encoder_2/config.json",
21
+ "text_encoder_2/model.fp16.safetensors",
22
+ "tokenizer_2/merges.txt",
23
+ "tokenizer_2/special_tokens_map.json",
24
+ "tokenizer_2/tokenizer_config.json",
25
+ "tokenizer_2/vocab.json",
26
+ "unet/config.json",
27
+ "unet/diffusion_pytorch_model.fp16.safetensors",
28
+ "vae/config.json",
29
+ "vae/diffusion_pytorch_model.fp16.safetensors",
30
+ "model_index.json",
31
+ ]
32
+
33
+ _sdxl_files = [
34
+ *_sdxl_refiner_files,
35
+ "text_encoder/config.json",
36
+ "text_encoder/model.fp16.safetensors",
37
+ "tokenizer/merges.txt",
38
+ "tokenizer/special_tokens_map.json",
39
+ "tokenizer/tokenizer_config.json",
40
+ "tokenizer/vocab.json",
41
+ ]
42
+
43
  Config = SimpleNamespace(
44
+ HF_TOKEN=os.environ.get("HF_TOKEN", None),
45
+ CIVIT_TOKEN=os.environ.get("CIVIT_TOKEN", None),
46
+ ZERO_GPU=import_module("spaces").config.Config.zero_gpu,
47
  MONO_FONTS=["monospace"],
48
  SANS_FONTS=[
49
  "sans-serif",
 
56
  "txt2img": StableDiffusionXLPipeline,
57
  "img2img": StableDiffusionXLImg2ImgPipeline,
58
  },
59
+ HF_MODELS={
60
+ "segmind/Segmind-Vega": [*_sdxl_files],
61
+ "stabilityai/stable-diffusion-xl-base-1.0": [*_sdxl_files, "vae_1_0/config.json"],
62
+ "stabilityai/stable-diffusion-xl-refiner-1.0": [*_sdxl_refiner_files],
63
+ },
64
  MODEL="segmind/Segmind-Vega",
65
  MODELS=[
66
  "cagliostrolab/animagine-xl-3.1",
 
87
  "Euler": EulerDiscreteScheduler,
88
  "Euler a": EulerAncestralDiscreteScheduler,
89
  },
90
+ STYLE="enhance",
91
+ WIDTH=1024,
92
+ HEIGHT=1024,
93
  NUM_IMAGES=1,
94
  SEED=-1,
95
+ GUIDANCE_SCALE=7.5,
96
+ INFERENCE_STEPS=40,
97
  DEEPCACHE_INTERVAL=1,
98
  SCALE=1,
99
  SCALES=[1, 2, 4],
lib/inference.py CHANGED
@@ -1,77 +1,51 @@
1
- import functools
2
- import inspect
3
- import json
4
  import re
5
  import time
6
  from datetime import datetime
7
  from itertools import product
8
- from typing import Callable, TypeVar
9
 
10
- import anyio
11
- import spaces
12
  import torch
13
- from anyio import Semaphore
14
  from compel import Compel, ReturnedEmbeddingsType
15
  from compel.prompt_parser import PromptParser
16
- from typing_extensions import ParamSpec
17
 
 
18
  from .loader import Loader
 
19
 
20
- __import__("warnings").filterwarnings("ignore", category=FutureWarning, module="transformers")
21
- __import__("transformers").logging.set_verbosity_error()
22
 
23
- T = TypeVar("T")
24
- P = ParamSpec("P")
25
-
26
- MAX_CONCURRENT_THREADS = 1
27
- MAX_THREADS_GUARD = Semaphore(MAX_CONCURRENT_THREADS)
28
-
29
- with open("./data/styles.json") as f:
30
- STYLES = json.load(f)
31
-
32
-
33
- # like the original but supports args and kwargs instead of a dict
34
- # https://github.com/huggingface/huggingface-inference-toolkit/blob/0.2.0/src/huggingface_inference_toolkit/async_utils.py
35
- async def async_call(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
36
- async with MAX_THREADS_GUARD:
37
- sig = inspect.signature(fn)
38
- bound_args = sig.bind(*args, **kwargs)
39
- bound_args.apply_defaults()
40
- partial_fn = functools.partial(fn, **bound_args.arguments)
41
- return await anyio.to_thread.run_sync(partial_fn)
42
-
43
-
44
- # parse prompts with arrays
45
- def parse_prompt(prompt: str) -> list[str]:
46
  arrays = re.findall(r"\[\[(.*?)\]\]", prompt)
47
 
48
  if not arrays:
49
  return [prompt]
50
 
51
- tokens = [item.split(",") for item in arrays]
52
- combinations = list(product(*tokens))
53
- prompts = []
54
 
 
 
55
  for combo in combinations:
56
  current_prompt = prompt
57
  for i, token in enumerate(combo):
58
  current_prompt = current_prompt.replace(f"[[{arrays[i]}]]", token.strip(), 1)
59
  prompts.append(current_prompt)
60
-
61
  return prompts
62
 
63
 
64
- def apply_style(prompt, style_id, negative=False):
65
- global STYLES
66
- if not style_id or style_id == "None":
67
- return prompt
68
- for style in STYLES:
69
- if style["id"] == style_id:
70
- if negative:
71
- return prompt + " . " + style["negative_prompt"]
72
- else:
73
- return style["prompt"].format(prompt=prompt)
74
- return prompt
 
 
 
75
 
76
 
77
  # max 60s per image
@@ -97,7 +71,7 @@ def gpu_duration(**kwargs):
97
  return loading + (duration * num_images)
98
 
99
 
100
- @spaces.GPU(duration=gpu_duration)
101
  def generate(
102
  positive_prompt,
103
  negative_prompt="",
@@ -114,53 +88,51 @@ def generate(
114
  num_images=1,
115
  use_karras=False,
116
  use_refiner=False,
117
- Info: Callable[[str], None] = None,
118
  Error=Exception,
 
119
  progress=None,
120
  ):
121
  if not torch.cuda.is_available():
122
- raise Error("RuntimeError: CUDA not available")
123
 
124
  # https://pytorch.org/docs/stable/generated/torch.manual_seed.html
125
  if seed is None or seed < 0:
126
- seed = int(datetime.now().timestamp() * 1_000_000) % (2**64)
127
 
128
  KIND = "txt2img"
129
  CURRENT_STEP = 0
130
  CURRENT_IMAGE = 1
131
  EMBEDDINGS_TYPE = ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED
132
 
133
- if progress is not None:
134
- TQDM = False
135
- progress((0, inference_steps), desc=f"Generating image 1/{num_images}")
136
- else:
137
- TQDM = True
138
-
139
  def callback_on_step_end(pipeline, step, timestep, latents):
140
  nonlocal CURRENT_IMAGE, CURRENT_STEP
141
 
142
- if progress is None:
143
- return latents
144
-
145
- strength = 1
146
- total_steps = min(int(inference_steps * strength), inference_steps)
147
-
148
- # if steps are different we're in the refiner
149
- refining = False
150
- if CURRENT_STEP == step:
151
- CURRENT_STEP = step + 1
152
- else:
153
- refining = True
154
- CURRENT_STEP += 1
155
-
156
- progress(
157
- (CURRENT_STEP, total_steps),
158
- desc=f"{'Refining' if refining else 'Generating'} image {CURRENT_IMAGE}/{num_images}",
159
- )
160
 
 
 
 
 
161
  return latents
162
 
163
  start = time.perf_counter()
 
 
 
 
 
164
  loader = Loader()
165
  loader.load(
166
  KIND,
@@ -170,11 +142,11 @@ def generate(
170
  scale,
171
  use_karras,
172
  use_refiner,
173
- TQDM,
174
  )
175
 
176
  if loader.pipe is None:
177
- raise Error(f"RuntimeError: Error loading {model}")
178
 
179
  pipe = loader.pipe
180
  refiner = loader.refiner
@@ -205,21 +177,21 @@ def generate(
205
 
206
  images = []
207
  current_seed = seed
208
-
209
  for i in range(num_images):
210
- # seeded generator for each iteration
211
  generator = torch.Generator(device=pipe.device).manual_seed(current_seed)
212
 
213
  try:
214
- styled_negative_prompt = apply_style(negative_prompt, style, negative=True)
215
- all_positive_prompts = parse_prompt(positive_prompt)
216
- prompt_index = i % len(all_positive_prompts)
217
- prompt = all_positive_prompts[prompt_index]
218
- styled_prompt = apply_style(prompt, style)
219
- conditioning_1, pooled_1 = compel_1([styled_prompt, styled_negative_prompt])
220
- conditioning_2, pooled_2 = compel_2([styled_prompt, styled_negative_prompt])
 
 
221
  except PromptParser.ParsingException:
222
- raise Error("ValueError: Invalid prompt")
223
 
224
  # refiner expects latents; upscaler expects numpy array
225
  pipe_output_type = "pil"
@@ -272,12 +244,12 @@ def generate(
272
  if scale > 1:
273
  image = upscaler.predict(image)
274
  images.append((image, str(current_seed)))
 
275
  except Exception as e:
276
- raise Error(f"RuntimeError: {e}")
277
  finally:
278
  CURRENT_STEP = 0
279
  CURRENT_IMAGE += 1
280
- current_seed += 1
281
 
282
  diff = time.perf_counter() - start
283
  if Info:
 
 
 
 
1
  import re
2
  import time
3
  from datetime import datetime
4
  from itertools import product
 
5
 
 
 
6
  import torch
 
7
  from compel import Compel, ReturnedEmbeddingsType
8
  from compel.prompt_parser import PromptParser
9
+ from spaces import GPU
10
 
11
+ from .config import Config
12
  from .loader import Loader
13
+ from .utils import load_json
14
 
 
 
15
 
16
+ def parse_prompt_with_arrays(prompt: str) -> list[str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  arrays = re.findall(r"\[\[(.*?)\]\]", prompt)
18
 
19
  if not arrays:
20
  return [prompt]
21
 
22
+ tokens = [item.split(",") for item in arrays] # [("a", "b"), (1, 2)]
23
+ combinations = list(product(*tokens)) # [("a", 1), ("a", 2), ("b", 1), ("b", 2)]
 
24
 
25
+ # find all the arrays in the prompt and replace them with tokens
26
+ prompts = []
27
  for combo in combinations:
28
  current_prompt = prompt
29
  for i, token in enumerate(combo):
30
  current_prompt = current_prompt.replace(f"[[{arrays[i]}]]", token.strip(), 1)
31
  prompts.append(current_prompt)
 
32
  return prompts
33
 
34
 
35
+ def apply_style(positive_prompt, negative_prompt, style_id="none"):
36
+ if style_id.lower() == "none":
37
+ return (positive_prompt, negative_prompt)
38
+
39
+ styles = load_json("./data/styles.json")
40
+ style = styles.get(style_id)
41
+ if style is None:
42
+ return (positive_prompt, negative_prompt)
43
+
44
+ style_base = style.get("_base", {})
45
+ return (
46
+ style.get("positive").format(prompt=positive_prompt, _base=style_base.get("positive")).strip(),
47
+ style.get("negative").format(prompt=negative_prompt, _base=style_base.get("negative")).strip(),
48
+ )
49
 
50
 
51
  # max 60s per image
 
71
  return loading + (duration * num_images)
72
 
73
 
74
+ @GPU(duration=gpu_duration)
75
  def generate(
76
  positive_prompt,
77
  negative_prompt="",
 
88
  num_images=1,
89
  use_karras=False,
90
  use_refiner=False,
 
91
  Error=Exception,
92
+ Info=None,
93
  progress=None,
94
  ):
95
  if not torch.cuda.is_available():
96
+ raise Error("CUDA not available")
97
 
98
  # https://pytorch.org/docs/stable/generated/torch.manual_seed.html
99
  if seed is None or seed < 0:
100
+ seed = int(datetime.now().timestamp() * 1e6) % (2**64)
101
 
102
  KIND = "txt2img"
103
  CURRENT_STEP = 0
104
  CURRENT_IMAGE = 1
105
  EMBEDDINGS_TYPE = ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED
106
 
107
+ # custom progress bar for multiple images
 
 
 
 
 
108
  def callback_on_step_end(pipeline, step, timestep, latents):
109
  nonlocal CURRENT_IMAGE, CURRENT_STEP
110
 
111
+ if progress is not None:
112
+ # calculate total steps for img2img based on denoising strength
113
+ strength = 1
114
+ total_steps = min(int(inference_steps * strength), inference_steps)
115
+
116
+ # if steps are different we're in the refiner
117
+ refining = False
118
+ if CURRENT_STEP == step:
119
+ CURRENT_STEP = step + 1
120
+ else:
121
+ refining = True
122
+ CURRENT_STEP += 1
 
 
 
 
 
 
123
 
124
+ progress(
125
+ (CURRENT_STEP, total_steps),
126
+ desc=f"{'Refining' if refining else 'Generating'} image {CURRENT_IMAGE}/{num_images}",
127
+ )
128
  return latents
129
 
130
  start = time.perf_counter()
131
+ print(f"Generating {num_images} image{'s' if num_images > 1 else ''}")
132
+
133
+ if Config.ZERO_GPU and progress is not None:
134
+ progress((100, 100), desc="ZeroGPU init")
135
+
136
  loader = Loader()
137
  loader.load(
138
  KIND,
 
142
  scale,
143
  use_karras,
144
  use_refiner,
145
+ progress,
146
  )
147
 
148
  if loader.pipe is None:
149
+ raise Error(f"Error loading {model}")
150
 
151
  pipe = loader.pipe
152
  refiner = loader.refiner
 
177
 
178
  images = []
179
  current_seed = seed
 
180
  for i in range(num_images):
 
181
  generator = torch.Generator(device=pipe.device).manual_seed(current_seed)
182
 
183
  try:
184
+ positive_prompts = parse_prompt_with_arrays(positive_prompt)
185
+ index = i % len(positive_prompts)
186
+ positive_styled, negative_styled = apply_style(positive_prompts[index], negative_prompt, style)
187
+
188
+ if negative_styled.startswith("(), "):
189
+ negative_styled = negative_styled[4:]
190
+
191
+ conditioning_1, pooled_1 = compel_1([positive_styled, negative_styled])
192
+ conditioning_2, pooled_2 = compel_2([positive_styled, negative_styled])
193
  except PromptParser.ParsingException:
194
+ raise Error("Invalid prompt")
195
 
196
  # refiner expects latents; upscaler expects numpy array
197
  pipe_output_type = "pil"
 
244
  if scale > 1:
245
  image = upscaler.predict(image)
246
  images.append((image, str(current_seed)))
247
+ current_seed += 1
248
  except Exception as e:
249
+ raise Error(f"{e}")
250
  finally:
251
  CURRENT_STEP = 0
252
  CURRENT_IMAGE += 1
 
253
 
254
  diff = time.perf_counter() - start
255
  if Info:
lib/loader.py CHANGED
@@ -1,6 +1,5 @@
1
  import gc
2
  from threading import Lock
3
- from warnings import filterwarnings
4
 
5
  import torch
6
  from DeepCache import DeepCacheSDHelper
@@ -9,10 +8,6 @@ from diffusers.models import AutoencoderKL
9
  from .config import Config
10
  from .upscaler import RealESRGAN
11
 
12
- __import__("diffusers").logging.set_verbosity_error()
13
- filterwarnings("ignore", category=FutureWarning, module="torch")
14
- filterwarnings("ignore", category=FutureWarning, module="diffusers")
15
-
16
 
17
  class Loader:
18
  _instance = None
@@ -33,7 +28,6 @@ class Loader:
33
  gc.collect()
34
  torch.cuda.empty_cache()
35
  torch.cuda.ipc_collect()
36
- torch.cuda.reset_max_memory_allocated()
37
  torch.cuda.reset_peak_memory_stats()
38
  torch.cuda.synchronize()
39
 
@@ -44,8 +38,18 @@ class Loader:
44
  return True
45
  return False
46
 
47
- def _unload(self, model):
 
 
 
 
 
 
 
 
48
  to_unload = []
 
 
49
  if self._should_unload_pipeline(model):
50
  to_unload.append("model")
51
  to_unload.append("pipe")
@@ -55,7 +59,7 @@ class Loader:
55
  for component in to_unload:
56
  setattr(self, component, None)
57
 
58
- def _load_pipeline(self, kind, model, tqdm, **kwargs):
59
  pipeline = Config.PIPELINES[kind]
60
  if self.pipe is None:
61
  try:
@@ -81,9 +85,9 @@ class Loader:
81
  if not isinstance(self.pipe, pipeline):
82
  self.pipe = pipeline.from_pipe(self.pipe).to("cuda")
83
  if self.pipe is not None:
84
- self.pipe.set_progress_bar_config(disable=not tqdm)
85
 
86
- def _load_refiner(self, refiner, tqdm, **kwargs):
87
  if refiner and self.refiner is None:
88
  model = Config.REFINER_MODEL
89
  pipeline = Config.PIPELINES["img2img"]
@@ -95,7 +99,7 @@ class Loader:
95
  self.refiner = None
96
  return
97
  if self.refiner is not None:
98
- self.refiner.set_progress_bar_config(disable=not tqdm)
99
 
100
  def _load_upscaler(self, scale=1):
101
  if scale == 2 and self.upscaler_2x is None:
@@ -117,29 +121,27 @@ class Loader:
117
 
118
  def _load_deepcache(self, interval=1):
119
  pipe_has_deepcache = hasattr(self.pipe, "deepcache")
 
 
120
  if pipe_has_deepcache and self.pipe.deepcache.params["cache_interval"] == interval:
121
  return
122
- if pipe_has_deepcache:
123
- self.pipe.deepcache.disable()
124
- else:
125
- self.pipe.deepcache = DeepCacheSDHelper(pipe=self.pipe)
126
  self.pipe.deepcache.set_params(cache_interval=interval)
127
  self.pipe.deepcache.enable()
128
 
129
  if self.refiner is not None:
130
  refiner_has_deepcache = hasattr(self.refiner, "deepcache")
 
 
131
  if refiner_has_deepcache and self.refiner.deepcache.params["cache_interval"] == interval:
132
  return
133
- if refiner_has_deepcache:
134
- self.refiner.deepcache.disable()
135
- else:
136
- self.refiner.deepcache = DeepCacheSDHelper(pipe=self.refiner)
137
  self.refiner.deepcache.set_params(cache_interval=interval)
138
  self.refiner.deepcache.enable()
139
 
140
- def load(self, kind, model, scheduler, deepcache, scale, karras, refiner, tqdm):
141
- model_lower = model.lower()
142
-
143
  scheduler_kwargs = {
144
  "beta_start": 0.00085,
145
  "beta_end": 0.012,
@@ -156,7 +158,7 @@ class Loader:
156
  scheduler_kwargs["clip_sample"] = False
157
  scheduler_kwargs["set_alpha_to_one"] = False
158
 
159
- if model_lower not in Config.MODEL_CHECKPOINTS.keys():
160
  variant = "fp16"
161
  else:
162
  variant = None
@@ -170,8 +172,8 @@ class Loader:
170
  "vae": AutoencoderKL.from_pretrained(Config.VAE_MODEL, torch_dtype=dtype),
171
  }
172
 
173
- self._unload(model)
174
- self._load_pipeline(kind, model, tqdm, **pipe_kwargs)
175
 
176
  # error loading model
177
  if self.pipe is None:
@@ -184,7 +186,7 @@ class Loader:
184
  )
185
 
186
  # same model, different scheduler
187
- if self.model.lower() == model_lower:
188
  if not same_scheduler:
189
  print(f"Switching to {scheduler}...")
190
  if not same_karras:
@@ -207,6 +209,6 @@ class Loader:
207
  "text_encoder_2": self.pipe.text_encoder_2,
208
  }
209
 
210
- self._load_refiner(refiner, tqdm, **refiner_kwargs)
211
- self._load_upscaler(scale)
212
  self._load_deepcache(deepcache)
 
 
1
  import gc
2
  from threading import Lock
 
3
 
4
  import torch
5
  from DeepCache import DeepCacheSDHelper
 
8
  from .config import Config
9
  from .upscaler import RealESRGAN
10
 
 
 
 
 
11
 
12
  class Loader:
13
  _instance = None
 
28
  gc.collect()
29
  torch.cuda.empty_cache()
30
  torch.cuda.ipc_collect()
 
31
  torch.cuda.reset_peak_memory_stats()
32
  torch.cuda.synchronize()
33
 
 
38
  return True
39
  return False
40
 
41
+ def _unload_deepcache(self):
42
+ if self.pipe.deepcache is None:
43
+ return
44
+ print("Unloading DeepCache")
45
+ self.pipe.deepcache.disable()
46
+ delattr(self.pipe, "deepcache")
47
+
48
+ # don't unload refiner
49
+ def _unload(self, model, deepcache):
50
  to_unload = []
51
+ if self._should_unload_deepcache(deepcache):
52
+ self._unload_deepcache()
53
  if self._should_unload_pipeline(model):
54
  to_unload.append("model")
55
  to_unload.append("pipe")
 
59
  for component in to_unload:
60
  setattr(self, component, None)
61
 
62
+ def _load_pipeline(self, kind, model, progress, **kwargs):
63
  pipeline = Config.PIPELINES[kind]
64
  if self.pipe is None:
65
  try:
 
85
  if not isinstance(self.pipe, pipeline):
86
  self.pipe = pipeline.from_pipe(self.pipe).to("cuda")
87
  if self.pipe is not None:
88
+ self.pipe.set_progress_bar_config(disable=progress is not None)
89
 
90
+ def _load_refiner(self, refiner, progress, **kwargs):
91
  if refiner and self.refiner is None:
92
  model = Config.REFINER_MODEL
93
  pipeline = Config.PIPELINES["img2img"]
 
99
  self.refiner = None
100
  return
101
  if self.refiner is not None:
102
+ self.refiner.set_progress_bar_config(disable=progress is not None)
103
 
104
  def _load_upscaler(self, scale=1):
105
  if scale == 2 and self.upscaler_2x is None:
 
121
 
122
  def _load_deepcache(self, interval=1):
123
  pipe_has_deepcache = hasattr(self.pipe, "deepcache")
124
+ if not pipe_has_deepcache and interval == 1:
125
+ return
126
  if pipe_has_deepcache and self.pipe.deepcache.params["cache_interval"] == interval:
127
  return
128
+ print("Loading DeepCache")
129
+ self.pipe.deepcache = DeepCacheSDHelper(pipe=self.pipe)
 
 
130
  self.pipe.deepcache.set_params(cache_interval=interval)
131
  self.pipe.deepcache.enable()
132
 
133
  if self.refiner is not None:
134
  refiner_has_deepcache = hasattr(self.refiner, "deepcache")
135
+ if not refiner_has_deepcache and interval == 1:
136
+ return
137
  if refiner_has_deepcache and self.refiner.deepcache.params["cache_interval"] == interval:
138
  return
139
+ print("Loading DeepCache for refiner")
140
+ self.refiner.deepcache = DeepCacheSDHelper(pipe=self.refiner)
 
 
141
  self.refiner.deepcache.set_params(cache_interval=interval)
142
  self.refiner.deepcache.enable()
143
 
144
+ def load(self, kind, model, scheduler, deepcache, scale, karras, refiner, progress):
 
 
145
  scheduler_kwargs = {
146
  "beta_start": 0.00085,
147
  "beta_end": 0.012,
 
158
  scheduler_kwargs["clip_sample"] = False
159
  scheduler_kwargs["set_alpha_to_one"] = False
160
 
161
+ if model.lower() not in Config.MODEL_CHECKPOINTS.keys():
162
  variant = "fp16"
163
  else:
164
  variant = None
 
172
  "vae": AutoencoderKL.from_pretrained(Config.VAE_MODEL, torch_dtype=dtype),
173
  }
174
 
175
+ self._unload(model, deepcache)
176
+ self._load_pipeline(kind, model, progress, **pipe_kwargs)
177
 
178
  # error loading model
179
  if self.pipe is None:
 
186
  )
187
 
188
  # same model, different scheduler
189
+ if self.model.lower() == model.lower():
190
  if not same_scheduler:
191
  print(f"Switching to {scheduler}...")
192
  if not same_karras:
 
209
  "text_encoder_2": self.pipe.text_encoder_2,
210
  }
211
 
212
+ self._load_refiner(refiner, progress, **refiner_kwargs)
 
213
  self._load_deepcache(deepcache)
214
+ self._load_upscaler(scale)
lib/upscaler.py CHANGED
@@ -55,17 +55,15 @@ HF_MODELS = {
55
 
56
 
57
  def pad_reflect(image, pad_size):
58
- # fmt: off
59
  image_size = image.shape
60
  height, width = image_size[:2]
61
  new_image = np.zeros([height + pad_size * 2, width + pad_size * 2, image_size[2]]).astype(np.uint8)
62
  new_image[pad_size:-pad_size, pad_size:-pad_size, :] = image
63
- new_image[0:pad_size, pad_size:-pad_size, :] = np.flip(image[0:pad_size, :, :], axis=0) # top
64
- new_image[-pad_size:, pad_size:-pad_size, :] = np.flip(image[-pad_size:, :, :], axis=0) # bottom
65
- new_image[:, 0:pad_size, :] = np.flip(new_image[:, pad_size : pad_size * 2, :], axis=1) # left
66
  new_image[:, -pad_size:, :] = np.flip(new_image[:, -pad_size * 2 : -pad_size, :], axis=1) # right
67
  return new_image
68
- # fmt: on
69
 
70
 
71
  def unpad_image(image, pad_size):
@@ -279,9 +277,8 @@ class RealESRGAN:
279
  self.model.load_state_dict(loadnet, strict=True)
280
  self.model.eval().to(device=self.device)
281
 
282
- @torch.cuda.amp.autocast()
283
  def predict(self, lr_image, batch_size=4, patches_size=192, padding=24, pad_size=15):
284
- scale = self.scale
285
  if not isinstance(lr_image, np.ndarray):
286
  lr_image = np.array(lr_image)
287
  if lr_image.min() < 0.0:
@@ -302,6 +299,7 @@ class RealESRGAN:
302
  for i in range(batch_size, image.shape[0], batch_size):
303
  res = torch.cat((res, self.model(image[i : i + batch_size])), 0)
304
 
 
305
  sr_image = einops.rearrange(res.clamp(0, 1), "b c h w -> b h w c").cpu().numpy()
306
  padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,)
307
  scaled_image_shape = tuple(np.multiply(lr_image.shape[0:2], scale)) + (3,)
 
55
 
56
 
57
  def pad_reflect(image, pad_size):
 
58
  image_size = image.shape
59
  height, width = image_size[:2]
60
  new_image = np.zeros([height + pad_size * 2, width + pad_size * 2, image_size[2]]).astype(np.uint8)
61
  new_image[pad_size:-pad_size, pad_size:-pad_size, :] = image
62
+ new_image[0:pad_size, pad_size:-pad_size, :] = np.flip(image[0:pad_size, :, :], axis=0) # # top
63
+ new_image[-pad_size:, pad_size:-pad_size, :] = np.flip(image[-pad_size:, :, :], axis=0) # # bottom
64
+ new_image[:, 0:pad_size, :] = np.flip(new_image[:, pad_size : pad_size * 2, :], axis=1) # # left
65
  new_image[:, -pad_size:, :] = np.flip(new_image[:, -pad_size * 2 : -pad_size, :], axis=1) # right
66
  return new_image
 
67
 
68
 
69
  def unpad_image(image, pad_size):
 
277
  self.model.load_state_dict(loadnet, strict=True)
278
  self.model.eval().to(device=self.device)
279
 
280
+ @torch.autocast("cuda")
281
  def predict(self, lr_image, batch_size=4, patches_size=192, padding=24, pad_size=15):
 
282
  if not isinstance(lr_image, np.ndarray):
283
  lr_image = np.array(lr_image)
284
  if lr_image.min() < 0.0:
 
299
  for i in range(batch_size, image.shape[0], batch_size):
300
  res = torch.cat((res, self.model(image[i : i + batch_size])), 0)
301
 
302
+ scale = self.scale
303
  sr_image = einops.rearrange(res.clamp(0, 1), "b c h w -> b h w c").cpu().numpy()
304
  padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,)
305
  scaled_image_shape = tuple(np.multiply(lr_image.shape[0:2], scale)) + (3,)