Spaces:
Running
on
Zero
Running
on
Zero
adamelliotfields
commited on
Commit
•
163a3a9
1
Parent(s):
933318d
Loading and inferencing improvements
Browse files- lib/config.py +43 -5
- lib/inference.py +62 -90
- lib/loader.py +30 -28
- lib/upscaler.py +5 -7
lib/config.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
from types import SimpleNamespace
|
2 |
|
3 |
from diffusers import (
|
@@ -10,7 +12,38 @@ from diffusers import (
|
|
10 |
StableDiffusionXLPipeline,
|
11 |
)
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
Config = SimpleNamespace(
|
|
|
|
|
|
|
14 |
MONO_FONTS=["monospace"],
|
15 |
SANS_FONTS=[
|
16 |
"sans-serif",
|
@@ -23,6 +56,11 @@ Config = SimpleNamespace(
|
|
23 |
"txt2img": StableDiffusionXLPipeline,
|
24 |
"img2img": StableDiffusionXLImg2ImgPipeline,
|
25 |
},
|
|
|
|
|
|
|
|
|
|
|
26 |
MODEL="segmind/Segmind-Vega",
|
27 |
MODELS=[
|
28 |
"cagliostrolab/animagine-xl-3.1",
|
@@ -49,13 +87,13 @@ Config = SimpleNamespace(
|
|
49 |
"Euler": EulerDiscreteScheduler,
|
50 |
"Euler a": EulerAncestralDiscreteScheduler,
|
51 |
},
|
52 |
-
STYLE="
|
53 |
-
WIDTH=
|
54 |
-
HEIGHT=
|
55 |
NUM_IMAGES=1,
|
56 |
SEED=-1,
|
57 |
-
GUIDANCE_SCALE=
|
58 |
-
INFERENCE_STEPS=
|
59 |
DEEPCACHE_INTERVAL=1,
|
60 |
SCALE=1,
|
61 |
SCALES=[1, 2, 4],
|
|
|
1 |
+
import os
|
2 |
+
from importlib import import_module
|
3 |
from types import SimpleNamespace
|
4 |
|
5 |
from diffusers import (
|
|
|
12 |
StableDiffusionXLPipeline,
|
13 |
)
|
14 |
|
15 |
+
# improved GPU handling and progress bars; set before importing spaces
|
16 |
+
os.environ["ZEROGPU_V2"] = "true"
|
17 |
+
|
18 |
+
_sdxl_refiner_files = [
|
19 |
+
"scheduler/scheduler_config.json",
|
20 |
+
"text_encoder_2/config.json",
|
21 |
+
"text_encoder_2/model.fp16.safetensors",
|
22 |
+
"tokenizer_2/merges.txt",
|
23 |
+
"tokenizer_2/special_tokens_map.json",
|
24 |
+
"tokenizer_2/tokenizer_config.json",
|
25 |
+
"tokenizer_2/vocab.json",
|
26 |
+
"unet/config.json",
|
27 |
+
"unet/diffusion_pytorch_model.fp16.safetensors",
|
28 |
+
"vae/config.json",
|
29 |
+
"vae/diffusion_pytorch_model.fp16.safetensors",
|
30 |
+
"model_index.json",
|
31 |
+
]
|
32 |
+
|
33 |
+
_sdxl_files = [
|
34 |
+
*_sdxl_refiner_files,
|
35 |
+
"text_encoder/config.json",
|
36 |
+
"text_encoder/model.fp16.safetensors",
|
37 |
+
"tokenizer/merges.txt",
|
38 |
+
"tokenizer/special_tokens_map.json",
|
39 |
+
"tokenizer/tokenizer_config.json",
|
40 |
+
"tokenizer/vocab.json",
|
41 |
+
]
|
42 |
+
|
43 |
Config = SimpleNamespace(
|
44 |
+
HF_TOKEN=os.environ.get("HF_TOKEN", None),
|
45 |
+
CIVIT_TOKEN=os.environ.get("CIVIT_TOKEN", None),
|
46 |
+
ZERO_GPU=import_module("spaces").config.Config.zero_gpu,
|
47 |
MONO_FONTS=["monospace"],
|
48 |
SANS_FONTS=[
|
49 |
"sans-serif",
|
|
|
56 |
"txt2img": StableDiffusionXLPipeline,
|
57 |
"img2img": StableDiffusionXLImg2ImgPipeline,
|
58 |
},
|
59 |
+
HF_MODELS={
|
60 |
+
"segmind/Segmind-Vega": [*_sdxl_files],
|
61 |
+
"stabilityai/stable-diffusion-xl-base-1.0": [*_sdxl_files, "vae_1_0/config.json"],
|
62 |
+
"stabilityai/stable-diffusion-xl-refiner-1.0": [*_sdxl_refiner_files],
|
63 |
+
},
|
64 |
MODEL="segmind/Segmind-Vega",
|
65 |
MODELS=[
|
66 |
"cagliostrolab/animagine-xl-3.1",
|
|
|
87 |
"Euler": EulerDiscreteScheduler,
|
88 |
"Euler a": EulerAncestralDiscreteScheduler,
|
89 |
},
|
90 |
+
STYLE="enhance",
|
91 |
+
WIDTH=1024,
|
92 |
+
HEIGHT=1024,
|
93 |
NUM_IMAGES=1,
|
94 |
SEED=-1,
|
95 |
+
GUIDANCE_SCALE=7.5,
|
96 |
+
INFERENCE_STEPS=40,
|
97 |
DEEPCACHE_INTERVAL=1,
|
98 |
SCALE=1,
|
99 |
SCALES=[1, 2, 4],
|
lib/inference.py
CHANGED
@@ -1,77 +1,51 @@
|
|
1 |
-
import functools
|
2 |
-
import inspect
|
3 |
-
import json
|
4 |
import re
|
5 |
import time
|
6 |
from datetime import datetime
|
7 |
from itertools import product
|
8 |
-
from typing import Callable, TypeVar
|
9 |
|
10 |
-
import anyio
|
11 |
-
import spaces
|
12 |
import torch
|
13 |
-
from anyio import Semaphore
|
14 |
from compel import Compel, ReturnedEmbeddingsType
|
15 |
from compel.prompt_parser import PromptParser
|
16 |
-
from
|
17 |
|
|
|
18 |
from .loader import Loader
|
|
|
19 |
|
20 |
-
__import__("warnings").filterwarnings("ignore", category=FutureWarning, module="transformers")
|
21 |
-
__import__("transformers").logging.set_verbosity_error()
|
22 |
|
23 |
-
|
24 |
-
P = ParamSpec("P")
|
25 |
-
|
26 |
-
MAX_CONCURRENT_THREADS = 1
|
27 |
-
MAX_THREADS_GUARD = Semaphore(MAX_CONCURRENT_THREADS)
|
28 |
-
|
29 |
-
with open("./data/styles.json") as f:
|
30 |
-
STYLES = json.load(f)
|
31 |
-
|
32 |
-
|
33 |
-
# like the original but supports args and kwargs instead of a dict
|
34 |
-
# https://github.com/huggingface/huggingface-inference-toolkit/blob/0.2.0/src/huggingface_inference_toolkit/async_utils.py
|
35 |
-
async def async_call(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
|
36 |
-
async with MAX_THREADS_GUARD:
|
37 |
-
sig = inspect.signature(fn)
|
38 |
-
bound_args = sig.bind(*args, **kwargs)
|
39 |
-
bound_args.apply_defaults()
|
40 |
-
partial_fn = functools.partial(fn, **bound_args.arguments)
|
41 |
-
return await anyio.to_thread.run_sync(partial_fn)
|
42 |
-
|
43 |
-
|
44 |
-
# parse prompts with arrays
|
45 |
-
def parse_prompt(prompt: str) -> list[str]:
|
46 |
arrays = re.findall(r"\[\[(.*?)\]\]", prompt)
|
47 |
|
48 |
if not arrays:
|
49 |
return [prompt]
|
50 |
|
51 |
-
tokens = [item.split(",") for item in arrays]
|
52 |
-
combinations = list(product(*tokens))
|
53 |
-
prompts = []
|
54 |
|
|
|
|
|
55 |
for combo in combinations:
|
56 |
current_prompt = prompt
|
57 |
for i, token in enumerate(combo):
|
58 |
current_prompt = current_prompt.replace(f"[[{arrays[i]}]]", token.strip(), 1)
|
59 |
prompts.append(current_prompt)
|
60 |
-
|
61 |
return prompts
|
62 |
|
63 |
|
64 |
-
def apply_style(
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
return
|
|
|
|
|
|
|
75 |
|
76 |
|
77 |
# max 60s per image
|
@@ -97,7 +71,7 @@ def gpu_duration(**kwargs):
|
|
97 |
return loading + (duration * num_images)
|
98 |
|
99 |
|
100 |
-
@
|
101 |
def generate(
|
102 |
positive_prompt,
|
103 |
negative_prompt="",
|
@@ -114,53 +88,51 @@ def generate(
|
|
114 |
num_images=1,
|
115 |
use_karras=False,
|
116 |
use_refiner=False,
|
117 |
-
Info: Callable[[str], None] = None,
|
118 |
Error=Exception,
|
|
|
119 |
progress=None,
|
120 |
):
|
121 |
if not torch.cuda.is_available():
|
122 |
-
raise Error("
|
123 |
|
124 |
# https://pytorch.org/docs/stable/generated/torch.manual_seed.html
|
125 |
if seed is None or seed < 0:
|
126 |
-
seed = int(datetime.now().timestamp() *
|
127 |
|
128 |
KIND = "txt2img"
|
129 |
CURRENT_STEP = 0
|
130 |
CURRENT_IMAGE = 1
|
131 |
EMBEDDINGS_TYPE = ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED
|
132 |
|
133 |
-
|
134 |
-
TQDM = False
|
135 |
-
progress((0, inference_steps), desc=f"Generating image 1/{num_images}")
|
136 |
-
else:
|
137 |
-
TQDM = True
|
138 |
-
|
139 |
def callback_on_step_end(pipeline, step, timestep, latents):
|
140 |
nonlocal CURRENT_IMAGE, CURRENT_STEP
|
141 |
|
142 |
-
if progress is None:
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
CURRENT_STEP += 1
|
155 |
-
|
156 |
-
progress(
|
157 |
-
(CURRENT_STEP, total_steps),
|
158 |
-
desc=f"{'Refining' if refining else 'Generating'} image {CURRENT_IMAGE}/{num_images}",
|
159 |
-
)
|
160 |
|
|
|
|
|
|
|
|
|
161 |
return latents
|
162 |
|
163 |
start = time.perf_counter()
|
|
|
|
|
|
|
|
|
|
|
164 |
loader = Loader()
|
165 |
loader.load(
|
166 |
KIND,
|
@@ -170,11 +142,11 @@ def generate(
|
|
170 |
scale,
|
171 |
use_karras,
|
172 |
use_refiner,
|
173 |
-
|
174 |
)
|
175 |
|
176 |
if loader.pipe is None:
|
177 |
-
raise Error(f"
|
178 |
|
179 |
pipe = loader.pipe
|
180 |
refiner = loader.refiner
|
@@ -205,21 +177,21 @@ def generate(
|
|
205 |
|
206 |
images = []
|
207 |
current_seed = seed
|
208 |
-
|
209 |
for i in range(num_images):
|
210 |
-
# seeded generator for each iteration
|
211 |
generator = torch.Generator(device=pipe.device).manual_seed(current_seed)
|
212 |
|
213 |
try:
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
|
|
|
|
221 |
except PromptParser.ParsingException:
|
222 |
-
raise Error("
|
223 |
|
224 |
# refiner expects latents; upscaler expects numpy array
|
225 |
pipe_output_type = "pil"
|
@@ -272,12 +244,12 @@ def generate(
|
|
272 |
if scale > 1:
|
273 |
image = upscaler.predict(image)
|
274 |
images.append((image, str(current_seed)))
|
|
|
275 |
except Exception as e:
|
276 |
-
raise Error(f"
|
277 |
finally:
|
278 |
CURRENT_STEP = 0
|
279 |
CURRENT_IMAGE += 1
|
280 |
-
current_seed += 1
|
281 |
|
282 |
diff = time.perf_counter() - start
|
283 |
if Info:
|
|
|
|
|
|
|
|
|
1 |
import re
|
2 |
import time
|
3 |
from datetime import datetime
|
4 |
from itertools import product
|
|
|
5 |
|
|
|
|
|
6 |
import torch
|
|
|
7 |
from compel import Compel, ReturnedEmbeddingsType
|
8 |
from compel.prompt_parser import PromptParser
|
9 |
+
from spaces import GPU
|
10 |
|
11 |
+
from .config import Config
|
12 |
from .loader import Loader
|
13 |
+
from .utils import load_json
|
14 |
|
|
|
|
|
15 |
|
16 |
+
def parse_prompt_with_arrays(prompt: str) -> list[str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
arrays = re.findall(r"\[\[(.*?)\]\]", prompt)
|
18 |
|
19 |
if not arrays:
|
20 |
return [prompt]
|
21 |
|
22 |
+
tokens = [item.split(",") for item in arrays] # [("a", "b"), (1, 2)]
|
23 |
+
combinations = list(product(*tokens)) # [("a", 1), ("a", 2), ("b", 1), ("b", 2)]
|
|
|
24 |
|
25 |
+
# find all the arrays in the prompt and replace them with tokens
|
26 |
+
prompts = []
|
27 |
for combo in combinations:
|
28 |
current_prompt = prompt
|
29 |
for i, token in enumerate(combo):
|
30 |
current_prompt = current_prompt.replace(f"[[{arrays[i]}]]", token.strip(), 1)
|
31 |
prompts.append(current_prompt)
|
|
|
32 |
return prompts
|
33 |
|
34 |
|
35 |
+
def apply_style(positive_prompt, negative_prompt, style_id="none"):
|
36 |
+
if style_id.lower() == "none":
|
37 |
+
return (positive_prompt, negative_prompt)
|
38 |
+
|
39 |
+
styles = load_json("./data/styles.json")
|
40 |
+
style = styles.get(style_id)
|
41 |
+
if style is None:
|
42 |
+
return (positive_prompt, negative_prompt)
|
43 |
+
|
44 |
+
style_base = style.get("_base", {})
|
45 |
+
return (
|
46 |
+
style.get("positive").format(prompt=positive_prompt, _base=style_base.get("positive")).strip(),
|
47 |
+
style.get("negative").format(prompt=negative_prompt, _base=style_base.get("negative")).strip(),
|
48 |
+
)
|
49 |
|
50 |
|
51 |
# max 60s per image
|
|
|
71 |
return loading + (duration * num_images)
|
72 |
|
73 |
|
74 |
+
@GPU(duration=gpu_duration)
|
75 |
def generate(
|
76 |
positive_prompt,
|
77 |
negative_prompt="",
|
|
|
88 |
num_images=1,
|
89 |
use_karras=False,
|
90 |
use_refiner=False,
|
|
|
91 |
Error=Exception,
|
92 |
+
Info=None,
|
93 |
progress=None,
|
94 |
):
|
95 |
if not torch.cuda.is_available():
|
96 |
+
raise Error("CUDA not available")
|
97 |
|
98 |
# https://pytorch.org/docs/stable/generated/torch.manual_seed.html
|
99 |
if seed is None or seed < 0:
|
100 |
+
seed = int(datetime.now().timestamp() * 1e6) % (2**64)
|
101 |
|
102 |
KIND = "txt2img"
|
103 |
CURRENT_STEP = 0
|
104 |
CURRENT_IMAGE = 1
|
105 |
EMBEDDINGS_TYPE = ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED
|
106 |
|
107 |
+
# custom progress bar for multiple images
|
|
|
|
|
|
|
|
|
|
|
108 |
def callback_on_step_end(pipeline, step, timestep, latents):
|
109 |
nonlocal CURRENT_IMAGE, CURRENT_STEP
|
110 |
|
111 |
+
if progress is not None:
|
112 |
+
# calculate total steps for img2img based on denoising strength
|
113 |
+
strength = 1
|
114 |
+
total_steps = min(int(inference_steps * strength), inference_steps)
|
115 |
+
|
116 |
+
# if steps are different we're in the refiner
|
117 |
+
refining = False
|
118 |
+
if CURRENT_STEP == step:
|
119 |
+
CURRENT_STEP = step + 1
|
120 |
+
else:
|
121 |
+
refining = True
|
122 |
+
CURRENT_STEP += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
+
progress(
|
125 |
+
(CURRENT_STEP, total_steps),
|
126 |
+
desc=f"{'Refining' if refining else 'Generating'} image {CURRENT_IMAGE}/{num_images}",
|
127 |
+
)
|
128 |
return latents
|
129 |
|
130 |
start = time.perf_counter()
|
131 |
+
print(f"Generating {num_images} image{'s' if num_images > 1 else ''}")
|
132 |
+
|
133 |
+
if Config.ZERO_GPU and progress is not None:
|
134 |
+
progress((100, 100), desc="ZeroGPU init")
|
135 |
+
|
136 |
loader = Loader()
|
137 |
loader.load(
|
138 |
KIND,
|
|
|
142 |
scale,
|
143 |
use_karras,
|
144 |
use_refiner,
|
145 |
+
progress,
|
146 |
)
|
147 |
|
148 |
if loader.pipe is None:
|
149 |
+
raise Error(f"Error loading {model}")
|
150 |
|
151 |
pipe = loader.pipe
|
152 |
refiner = loader.refiner
|
|
|
177 |
|
178 |
images = []
|
179 |
current_seed = seed
|
|
|
180 |
for i in range(num_images):
|
|
|
181 |
generator = torch.Generator(device=pipe.device).manual_seed(current_seed)
|
182 |
|
183 |
try:
|
184 |
+
positive_prompts = parse_prompt_with_arrays(positive_prompt)
|
185 |
+
index = i % len(positive_prompts)
|
186 |
+
positive_styled, negative_styled = apply_style(positive_prompts[index], negative_prompt, style)
|
187 |
+
|
188 |
+
if negative_styled.startswith("(), "):
|
189 |
+
negative_styled = negative_styled[4:]
|
190 |
+
|
191 |
+
conditioning_1, pooled_1 = compel_1([positive_styled, negative_styled])
|
192 |
+
conditioning_2, pooled_2 = compel_2([positive_styled, negative_styled])
|
193 |
except PromptParser.ParsingException:
|
194 |
+
raise Error("Invalid prompt")
|
195 |
|
196 |
# refiner expects latents; upscaler expects numpy array
|
197 |
pipe_output_type = "pil"
|
|
|
244 |
if scale > 1:
|
245 |
image = upscaler.predict(image)
|
246 |
images.append((image, str(current_seed)))
|
247 |
+
current_seed += 1
|
248 |
except Exception as e:
|
249 |
+
raise Error(f"{e}")
|
250 |
finally:
|
251 |
CURRENT_STEP = 0
|
252 |
CURRENT_IMAGE += 1
|
|
|
253 |
|
254 |
diff = time.perf_counter() - start
|
255 |
if Info:
|
lib/loader.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
import gc
|
2 |
from threading import Lock
|
3 |
-
from warnings import filterwarnings
|
4 |
|
5 |
import torch
|
6 |
from DeepCache import DeepCacheSDHelper
|
@@ -9,10 +8,6 @@ from diffusers.models import AutoencoderKL
|
|
9 |
from .config import Config
|
10 |
from .upscaler import RealESRGAN
|
11 |
|
12 |
-
__import__("diffusers").logging.set_verbosity_error()
|
13 |
-
filterwarnings("ignore", category=FutureWarning, module="torch")
|
14 |
-
filterwarnings("ignore", category=FutureWarning, module="diffusers")
|
15 |
-
|
16 |
|
17 |
class Loader:
|
18 |
_instance = None
|
@@ -33,7 +28,6 @@ class Loader:
|
|
33 |
gc.collect()
|
34 |
torch.cuda.empty_cache()
|
35 |
torch.cuda.ipc_collect()
|
36 |
-
torch.cuda.reset_max_memory_allocated()
|
37 |
torch.cuda.reset_peak_memory_stats()
|
38 |
torch.cuda.synchronize()
|
39 |
|
@@ -44,8 +38,18 @@ class Loader:
|
|
44 |
return True
|
45 |
return False
|
46 |
|
47 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
to_unload = []
|
|
|
|
|
49 |
if self._should_unload_pipeline(model):
|
50 |
to_unload.append("model")
|
51 |
to_unload.append("pipe")
|
@@ -55,7 +59,7 @@ class Loader:
|
|
55 |
for component in to_unload:
|
56 |
setattr(self, component, None)
|
57 |
|
58 |
-
def _load_pipeline(self, kind, model,
|
59 |
pipeline = Config.PIPELINES[kind]
|
60 |
if self.pipe is None:
|
61 |
try:
|
@@ -81,9 +85,9 @@ class Loader:
|
|
81 |
if not isinstance(self.pipe, pipeline):
|
82 |
self.pipe = pipeline.from_pipe(self.pipe).to("cuda")
|
83 |
if self.pipe is not None:
|
84 |
-
self.pipe.set_progress_bar_config(disable=not
|
85 |
|
86 |
-
def _load_refiner(self, refiner,
|
87 |
if refiner and self.refiner is None:
|
88 |
model = Config.REFINER_MODEL
|
89 |
pipeline = Config.PIPELINES["img2img"]
|
@@ -95,7 +99,7 @@ class Loader:
|
|
95 |
self.refiner = None
|
96 |
return
|
97 |
if self.refiner is not None:
|
98 |
-
self.refiner.set_progress_bar_config(disable=not
|
99 |
|
100 |
def _load_upscaler(self, scale=1):
|
101 |
if scale == 2 and self.upscaler_2x is None:
|
@@ -117,29 +121,27 @@ class Loader:
|
|
117 |
|
118 |
def _load_deepcache(self, interval=1):
|
119 |
pipe_has_deepcache = hasattr(self.pipe, "deepcache")
|
|
|
|
|
120 |
if pipe_has_deepcache and self.pipe.deepcache.params["cache_interval"] == interval:
|
121 |
return
|
122 |
-
|
123 |
-
|
124 |
-
else:
|
125 |
-
self.pipe.deepcache = DeepCacheSDHelper(pipe=self.pipe)
|
126 |
self.pipe.deepcache.set_params(cache_interval=interval)
|
127 |
self.pipe.deepcache.enable()
|
128 |
|
129 |
if self.refiner is not None:
|
130 |
refiner_has_deepcache = hasattr(self.refiner, "deepcache")
|
|
|
|
|
131 |
if refiner_has_deepcache and self.refiner.deepcache.params["cache_interval"] == interval:
|
132 |
return
|
133 |
-
|
134 |
-
|
135 |
-
else:
|
136 |
-
self.refiner.deepcache = DeepCacheSDHelper(pipe=self.refiner)
|
137 |
self.refiner.deepcache.set_params(cache_interval=interval)
|
138 |
self.refiner.deepcache.enable()
|
139 |
|
140 |
-
def load(self, kind, model, scheduler, deepcache, scale, karras, refiner,
|
141 |
-
model_lower = model.lower()
|
142 |
-
|
143 |
scheduler_kwargs = {
|
144 |
"beta_start": 0.00085,
|
145 |
"beta_end": 0.012,
|
@@ -156,7 +158,7 @@ class Loader:
|
|
156 |
scheduler_kwargs["clip_sample"] = False
|
157 |
scheduler_kwargs["set_alpha_to_one"] = False
|
158 |
|
159 |
-
if
|
160 |
variant = "fp16"
|
161 |
else:
|
162 |
variant = None
|
@@ -170,8 +172,8 @@ class Loader:
|
|
170 |
"vae": AutoencoderKL.from_pretrained(Config.VAE_MODEL, torch_dtype=dtype),
|
171 |
}
|
172 |
|
173 |
-
self._unload(model)
|
174 |
-
self._load_pipeline(kind, model,
|
175 |
|
176 |
# error loading model
|
177 |
if self.pipe is None:
|
@@ -184,7 +186,7 @@ class Loader:
|
|
184 |
)
|
185 |
|
186 |
# same model, different scheduler
|
187 |
-
if self.model.lower() ==
|
188 |
if not same_scheduler:
|
189 |
print(f"Switching to {scheduler}...")
|
190 |
if not same_karras:
|
@@ -207,6 +209,6 @@ class Loader:
|
|
207 |
"text_encoder_2": self.pipe.text_encoder_2,
|
208 |
}
|
209 |
|
210 |
-
self._load_refiner(refiner,
|
211 |
-
self._load_upscaler(scale)
|
212 |
self._load_deepcache(deepcache)
|
|
|
|
1 |
import gc
|
2 |
from threading import Lock
|
|
|
3 |
|
4 |
import torch
|
5 |
from DeepCache import DeepCacheSDHelper
|
|
|
8 |
from .config import Config
|
9 |
from .upscaler import RealESRGAN
|
10 |
|
|
|
|
|
|
|
|
|
11 |
|
12 |
class Loader:
|
13 |
_instance = None
|
|
|
28 |
gc.collect()
|
29 |
torch.cuda.empty_cache()
|
30 |
torch.cuda.ipc_collect()
|
|
|
31 |
torch.cuda.reset_peak_memory_stats()
|
32 |
torch.cuda.synchronize()
|
33 |
|
|
|
38 |
return True
|
39 |
return False
|
40 |
|
41 |
+
def _unload_deepcache(self):
|
42 |
+
if self.pipe.deepcache is None:
|
43 |
+
return
|
44 |
+
print("Unloading DeepCache")
|
45 |
+
self.pipe.deepcache.disable()
|
46 |
+
delattr(self.pipe, "deepcache")
|
47 |
+
|
48 |
+
# don't unload refiner
|
49 |
+
def _unload(self, model, deepcache):
|
50 |
to_unload = []
|
51 |
+
if self._should_unload_deepcache(deepcache):
|
52 |
+
self._unload_deepcache()
|
53 |
if self._should_unload_pipeline(model):
|
54 |
to_unload.append("model")
|
55 |
to_unload.append("pipe")
|
|
|
59 |
for component in to_unload:
|
60 |
setattr(self, component, None)
|
61 |
|
62 |
+
def _load_pipeline(self, kind, model, progress, **kwargs):
|
63 |
pipeline = Config.PIPELINES[kind]
|
64 |
if self.pipe is None:
|
65 |
try:
|
|
|
85 |
if not isinstance(self.pipe, pipeline):
|
86 |
self.pipe = pipeline.from_pipe(self.pipe).to("cuda")
|
87 |
if self.pipe is not None:
|
88 |
+
self.pipe.set_progress_bar_config(disable=progress is not None)
|
89 |
|
90 |
+
def _load_refiner(self, refiner, progress, **kwargs):
|
91 |
if refiner and self.refiner is None:
|
92 |
model = Config.REFINER_MODEL
|
93 |
pipeline = Config.PIPELINES["img2img"]
|
|
|
99 |
self.refiner = None
|
100 |
return
|
101 |
if self.refiner is not None:
|
102 |
+
self.refiner.set_progress_bar_config(disable=progress is not None)
|
103 |
|
104 |
def _load_upscaler(self, scale=1):
|
105 |
if scale == 2 and self.upscaler_2x is None:
|
|
|
121 |
|
122 |
def _load_deepcache(self, interval=1):
|
123 |
pipe_has_deepcache = hasattr(self.pipe, "deepcache")
|
124 |
+
if not pipe_has_deepcache and interval == 1:
|
125 |
+
return
|
126 |
if pipe_has_deepcache and self.pipe.deepcache.params["cache_interval"] == interval:
|
127 |
return
|
128 |
+
print("Loading DeepCache")
|
129 |
+
self.pipe.deepcache = DeepCacheSDHelper(pipe=self.pipe)
|
|
|
|
|
130 |
self.pipe.deepcache.set_params(cache_interval=interval)
|
131 |
self.pipe.deepcache.enable()
|
132 |
|
133 |
if self.refiner is not None:
|
134 |
refiner_has_deepcache = hasattr(self.refiner, "deepcache")
|
135 |
+
if not refiner_has_deepcache and interval == 1:
|
136 |
+
return
|
137 |
if refiner_has_deepcache and self.refiner.deepcache.params["cache_interval"] == interval:
|
138 |
return
|
139 |
+
print("Loading DeepCache for refiner")
|
140 |
+
self.refiner.deepcache = DeepCacheSDHelper(pipe=self.refiner)
|
|
|
|
|
141 |
self.refiner.deepcache.set_params(cache_interval=interval)
|
142 |
self.refiner.deepcache.enable()
|
143 |
|
144 |
+
def load(self, kind, model, scheduler, deepcache, scale, karras, refiner, progress):
|
|
|
|
|
145 |
scheduler_kwargs = {
|
146 |
"beta_start": 0.00085,
|
147 |
"beta_end": 0.012,
|
|
|
158 |
scheduler_kwargs["clip_sample"] = False
|
159 |
scheduler_kwargs["set_alpha_to_one"] = False
|
160 |
|
161 |
+
if model.lower() not in Config.MODEL_CHECKPOINTS.keys():
|
162 |
variant = "fp16"
|
163 |
else:
|
164 |
variant = None
|
|
|
172 |
"vae": AutoencoderKL.from_pretrained(Config.VAE_MODEL, torch_dtype=dtype),
|
173 |
}
|
174 |
|
175 |
+
self._unload(model, deepcache)
|
176 |
+
self._load_pipeline(kind, model, progress, **pipe_kwargs)
|
177 |
|
178 |
# error loading model
|
179 |
if self.pipe is None:
|
|
|
186 |
)
|
187 |
|
188 |
# same model, different scheduler
|
189 |
+
if self.model.lower() == model.lower():
|
190 |
if not same_scheduler:
|
191 |
print(f"Switching to {scheduler}...")
|
192 |
if not same_karras:
|
|
|
209 |
"text_encoder_2": self.pipe.text_encoder_2,
|
210 |
}
|
211 |
|
212 |
+
self._load_refiner(refiner, progress, **refiner_kwargs)
|
|
|
213 |
self._load_deepcache(deepcache)
|
214 |
+
self._load_upscaler(scale)
|
lib/upscaler.py
CHANGED
@@ -55,17 +55,15 @@ HF_MODELS = {
|
|
55 |
|
56 |
|
57 |
def pad_reflect(image, pad_size):
|
58 |
-
# fmt: off
|
59 |
image_size = image.shape
|
60 |
height, width = image_size[:2]
|
61 |
new_image = np.zeros([height + pad_size * 2, width + pad_size * 2, image_size[2]]).astype(np.uint8)
|
62 |
new_image[pad_size:-pad_size, pad_size:-pad_size, :] = image
|
63 |
-
new_image[0:pad_size, pad_size:-pad_size, :] = np.flip(image[0:pad_size, :, :], axis=0)
|
64 |
-
new_image[-pad_size:, pad_size:-pad_size, :] = np.flip(image[-pad_size:, :, :], axis=0)
|
65 |
-
new_image[:, 0:pad_size, :] = np.flip(new_image[:, pad_size : pad_size * 2, :], axis=1)
|
66 |
new_image[:, -pad_size:, :] = np.flip(new_image[:, -pad_size * 2 : -pad_size, :], axis=1) # right
|
67 |
return new_image
|
68 |
-
# fmt: on
|
69 |
|
70 |
|
71 |
def unpad_image(image, pad_size):
|
@@ -279,9 +277,8 @@ class RealESRGAN:
|
|
279 |
self.model.load_state_dict(loadnet, strict=True)
|
280 |
self.model.eval().to(device=self.device)
|
281 |
|
282 |
-
@torch.
|
283 |
def predict(self, lr_image, batch_size=4, patches_size=192, padding=24, pad_size=15):
|
284 |
-
scale = self.scale
|
285 |
if not isinstance(lr_image, np.ndarray):
|
286 |
lr_image = np.array(lr_image)
|
287 |
if lr_image.min() < 0.0:
|
@@ -302,6 +299,7 @@ class RealESRGAN:
|
|
302 |
for i in range(batch_size, image.shape[0], batch_size):
|
303 |
res = torch.cat((res, self.model(image[i : i + batch_size])), 0)
|
304 |
|
|
|
305 |
sr_image = einops.rearrange(res.clamp(0, 1), "b c h w -> b h w c").cpu().numpy()
|
306 |
padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,)
|
307 |
scaled_image_shape = tuple(np.multiply(lr_image.shape[0:2], scale)) + (3,)
|
|
|
55 |
|
56 |
|
57 |
def pad_reflect(image, pad_size):
|
|
|
58 |
image_size = image.shape
|
59 |
height, width = image_size[:2]
|
60 |
new_image = np.zeros([height + pad_size * 2, width + pad_size * 2, image_size[2]]).astype(np.uint8)
|
61 |
new_image[pad_size:-pad_size, pad_size:-pad_size, :] = image
|
62 |
+
new_image[0:pad_size, pad_size:-pad_size, :] = np.flip(image[0:pad_size, :, :], axis=0) # # top
|
63 |
+
new_image[-pad_size:, pad_size:-pad_size, :] = np.flip(image[-pad_size:, :, :], axis=0) # # bottom
|
64 |
+
new_image[:, 0:pad_size, :] = np.flip(new_image[:, pad_size : pad_size * 2, :], axis=1) # # left
|
65 |
new_image[:, -pad_size:, :] = np.flip(new_image[:, -pad_size * 2 : -pad_size, :], axis=1) # right
|
66 |
return new_image
|
|
|
67 |
|
68 |
|
69 |
def unpad_image(image, pad_size):
|
|
|
277 |
self.model.load_state_dict(loadnet, strict=True)
|
278 |
self.model.eval().to(device=self.device)
|
279 |
|
280 |
+
@torch.autocast("cuda")
|
281 |
def predict(self, lr_image, batch_size=4, patches_size=192, padding=24, pad_size=15):
|
|
|
282 |
if not isinstance(lr_image, np.ndarray):
|
283 |
lr_image = np.array(lr_image)
|
284 |
if lr_image.min() < 0.0:
|
|
|
299 |
for i in range(batch_size, image.shape[0], batch_size):
|
300 |
res = torch.cat((res, self.model(image[i : i + batch_size])), 0)
|
301 |
|
302 |
+
scale = self.scale
|
303 |
sr_image = einops.rearrange(res.clamp(0, 1), "b c h w -> b h w c").cpu().numpy()
|
304 |
padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,)
|
305 |
scaled_image_shape = tuple(np.multiply(lr_image.shape[0:2], scale)) + (3,)
|