Spaces:
Running
on
Zero
Running
on
Zero
adamelliotfields
commited on
Commit
•
b00d4fe
1
Parent(s):
6ad0411
Simplify loading and inference
Browse files- app.py +23 -24
- lib/__init__.py +0 -2
- lib/config.py +11 -24
- lib/inference.py +37 -101
- lib/loader.py +40 -51
- lib/utils.py +8 -21
app.py
CHANGED
@@ -1,11 +1,20 @@
|
|
1 |
import argparse
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
import gradio as gr
|
|
|
4 |
|
5 |
from lib import (
|
6 |
Config,
|
7 |
-
|
8 |
-
download_repo_files,
|
9 |
generate,
|
10 |
read_file,
|
11 |
read_json,
|
@@ -60,24 +69,6 @@ random_prompt_js = f"""
|
|
60 |
}}
|
61 |
"""
|
62 |
|
63 |
-
|
64 |
-
# Transform the raw inputs before generation
|
65 |
-
def generate_fn(*args, progress=gr.Progress(track_tqdm=True)):
|
66 |
-
if len(args) > 0:
|
67 |
-
prompt = args[0]
|
68 |
-
else:
|
69 |
-
prompt = None
|
70 |
-
if prompt is None or prompt.strip() == "":
|
71 |
-
raise gr.Error("You must enter a prompt")
|
72 |
-
try:
|
73 |
-
# if Config.ZERO_GPU:
|
74 |
-
# progress((0, 100), desc="ZeroGPU init")
|
75 |
-
images = generate(*args, Error=gr.Error, Info=gr.Info, progress=progress)
|
76 |
-
except RuntimeError:
|
77 |
-
raise gr.Error("Error: Please try again")
|
78 |
-
return images
|
79 |
-
|
80 |
-
|
81 |
with gr.Blocks(
|
82 |
head=read_file("./partials/head.html"),
|
83 |
css="./app.css",
|
@@ -244,10 +235,10 @@ with gr.Blocks(
|
|
244 |
label="Scale",
|
245 |
)
|
246 |
seed = gr.Number(
|
247 |
-
value=Config.SEED,
|
248 |
-
label="Seed",
|
249 |
minimum=-1,
|
250 |
maximum=(2**64) - 1,
|
|
|
|
|
251 |
)
|
252 |
with gr.Row():
|
253 |
use_karras = gr.Checkbox(
|
@@ -293,7 +284,7 @@ with gr.Blocks(
|
|
293 |
# Generate images
|
294 |
gr.on(
|
295 |
triggers=[generate_btn.click, prompt.submit],
|
296 |
-
fn=
|
297 |
api_name="generate",
|
298 |
outputs=[output_images],
|
299 |
inputs=[
|
@@ -321,8 +312,16 @@ if __name__ == "__main__":
|
|
321 |
args = parser.parse_args()
|
322 |
|
323 |
# disable_progress_bars()
|
|
|
324 |
for repo_id, allow_patterns in Config.HF_REPOS.items():
|
325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
|
327 |
# https://www.gradio.app/docs/gradio/interface#interface-queue
|
328 |
demo.queue(default_concurrency_limit=1).launch(
|
|
|
1 |
import argparse
|
2 |
+
import os
|
3 |
+
from importlib.util import find_spec
|
4 |
+
|
5 |
+
# Improved GPU handling and progress bars
|
6 |
+
os.environ["ZEROGPU_V2"] = "1"
|
7 |
+
|
8 |
+
# Use Rust-based downloader
|
9 |
+
if find_spec("hf_transfer"):
|
10 |
+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
11 |
|
12 |
import gradio as gr
|
13 |
+
from huggingface_hub._snapshot_download import snapshot_download
|
14 |
|
15 |
from lib import (
|
16 |
Config,
|
17 |
+
disable_progress_bars,
|
|
|
18 |
generate,
|
19 |
read_file,
|
20 |
read_json,
|
|
|
69 |
}}
|
70 |
"""
|
71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
with gr.Blocks(
|
73 |
head=read_file("./partials/head.html"),
|
74 |
css="./app.css",
|
|
|
235 |
label="Scale",
|
236 |
)
|
237 |
seed = gr.Number(
|
|
|
|
|
238 |
minimum=-1,
|
239 |
maximum=(2**64) - 1,
|
240 |
+
label="Seed",
|
241 |
+
value=-1,
|
242 |
)
|
243 |
with gr.Row():
|
244 |
use_karras = gr.Checkbox(
|
|
|
284 |
# Generate images
|
285 |
gr.on(
|
286 |
triggers=[generate_btn.click, prompt.submit],
|
287 |
+
fn=generate,
|
288 |
api_name="generate",
|
289 |
outputs=[output_images],
|
290 |
inputs=[
|
|
|
312 |
args = parser.parse_args()
|
313 |
|
314 |
# disable_progress_bars()
|
315 |
+
token = os.environ.get("HF_TOKEN", None)
|
316 |
for repo_id, allow_patterns in Config.HF_REPOS.items():
|
317 |
+
snapshot_download(
|
318 |
+
repo_id=repo_id,
|
319 |
+
repo_type="model",
|
320 |
+
revision="main",
|
321 |
+
token=token,
|
322 |
+
allow_patterns=allow_patterns,
|
323 |
+
ignore_patterns=None,
|
324 |
+
)
|
325 |
|
326 |
# https://www.gradio.app/docs/gradio/interface#interface-queue
|
327 |
demo.queue(default_concurrency_limit=1).launch(
|
lib/__init__.py
CHANGED
@@ -2,7 +2,6 @@ from .config import Config
|
|
2 |
from .inference import generate
|
3 |
from .utils import (
|
4 |
disable_progress_bars,
|
5 |
-
download_repo_files,
|
6 |
read_file,
|
7 |
read_json,
|
8 |
)
|
@@ -10,7 +9,6 @@ from .utils import (
|
|
10 |
__all__ = [
|
11 |
"Config",
|
12 |
"disable_progress_bars",
|
13 |
-
"download_repo_files",
|
14 |
"generate",
|
15 |
"read_file",
|
16 |
"read_json",
|
|
|
2 |
from .inference import generate
|
3 |
from .utils import (
|
4 |
disable_progress_bars,
|
|
|
5 |
read_file,
|
6 |
read_json,
|
7 |
)
|
|
|
9 |
__all__ = [
|
10 |
"Config",
|
11 |
"disable_progress_bars",
|
|
|
12 |
"generate",
|
13 |
"read_file",
|
14 |
"read_json",
|
lib/config.py
CHANGED
@@ -1,6 +1,3 @@
|
|
1 |
-
import os
|
2 |
-
from importlib import import_module
|
3 |
-
from importlib.util import find_spec
|
4 |
from types import SimpleNamespace
|
5 |
from warnings import filterwarnings
|
6 |
|
@@ -16,13 +13,6 @@ from diffusers import (
|
|
16 |
from diffusers.utils import logging as diffusers_logging
|
17 |
from transformers import logging as transformers_logging
|
18 |
|
19 |
-
# Improved GPU handling and progress bars; set before importing spaces
|
20 |
-
os.environ["ZEROGPU_V2"] = "1"
|
21 |
-
|
22 |
-
# Use Rust-based downloader; errors if enabled and not installed
|
23 |
-
if find_spec("hf_transfer"):
|
24 |
-
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
25 |
-
|
26 |
filterwarnings("ignore", category=FutureWarning, module="diffusers")
|
27 |
filterwarnings("ignore", category=FutureWarning, module="transformers")
|
28 |
|
@@ -60,20 +50,10 @@ _sdxl_files_with_vae = [*_sdxl_files, "vae_1_0/config.json"]
|
|
60 |
|
61 |
# Using namespace instead of dataclass for simplicity
|
62 |
Config = SimpleNamespace(
|
63 |
-
HF_TOKEN=os.environ.get("HF_TOKEN", None),
|
64 |
-
ZERO_GPU=import_module("spaces").config.Config.zero_gpu,
|
65 |
PIPELINES={
|
66 |
"txt2img": StableDiffusionXLPipeline,
|
67 |
"img2img": StableDiffusionXLImg2ImgPipeline,
|
68 |
},
|
69 |
-
MODEL="segmind/Segmind-Vega",
|
70 |
-
MODELS=[
|
71 |
-
"cyberdelia/CyberRealsticXL",
|
72 |
-
"fluently/Fluently-XL-Final",
|
73 |
-
"segmind/Segmind-Vega",
|
74 |
-
"SG161222/RealVisXL_V5.0",
|
75 |
-
"stabilityai/stable-diffusion-xl-base-1.0",
|
76 |
-
],
|
77 |
HF_REPOS={
|
78 |
"ai-forever/Real-ESRGAN": ["RealESRGAN_x2.pth", "RealESRGAN_x4.pth"],
|
79 |
"cyberdelia/CyberRealsticXL": ["CyberRealisticXLPlay_V1.0.safetensors"],
|
@@ -84,10 +64,18 @@ Config = SimpleNamespace(
|
|
84 |
"stabilityai/stable-diffusion-xl-base-1.0": _sdxl_files_with_vae,
|
85 |
"stabilityai/stable-diffusion-xl-refiner-1.0": _sdxl_refiner_files,
|
86 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
SINGLE_FILE_MODELS=[
|
88 |
-
"cyberdelia/
|
89 |
-
"fluently/
|
90 |
-
"
|
91 |
],
|
92 |
VAE_MODEL="madebyollin/sdxl-vae-fp16-fix",
|
93 |
REFINER_MODEL="stabilityai/stable-diffusion-xl-refiner-1.0",
|
@@ -102,7 +90,6 @@ Config = SimpleNamespace(
|
|
102 |
WIDTH=1024,
|
103 |
HEIGHT=1024,
|
104 |
NUM_IMAGES=1,
|
105 |
-
SEED=-1,
|
106 |
GUIDANCE_SCALE=6,
|
107 |
INFERENCE_STEPS=40,
|
108 |
DEEPCACHE_INTERVAL=1,
|
|
|
|
|
|
|
|
|
1 |
from types import SimpleNamespace
|
2 |
from warnings import filterwarnings
|
3 |
|
|
|
13 |
from diffusers.utils import logging as diffusers_logging
|
14 |
from transformers import logging as transformers_logging
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
filterwarnings("ignore", category=FutureWarning, module="diffusers")
|
17 |
filterwarnings("ignore", category=FutureWarning, module="transformers")
|
18 |
|
|
|
50 |
|
51 |
# Using namespace instead of dataclass for simplicity
|
52 |
Config = SimpleNamespace(
|
|
|
|
|
53 |
PIPELINES={
|
54 |
"txt2img": StableDiffusionXLPipeline,
|
55 |
"img2img": StableDiffusionXLImg2ImgPipeline,
|
56 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
HF_REPOS={
|
58 |
"ai-forever/Real-ESRGAN": ["RealESRGAN_x2.pth", "RealESRGAN_x4.pth"],
|
59 |
"cyberdelia/CyberRealsticXL": ["CyberRealisticXLPlay_V1.0.safetensors"],
|
|
|
64 |
"stabilityai/stable-diffusion-xl-base-1.0": _sdxl_files_with_vae,
|
65 |
"stabilityai/stable-diffusion-xl-refiner-1.0": _sdxl_refiner_files,
|
66 |
},
|
67 |
+
MODEL="segmind/Segmind-Vega",
|
68 |
+
MODELS=[
|
69 |
+
"cyberdelia/CyberRealsticXL",
|
70 |
+
"fluently/Fluently-XL-Final",
|
71 |
+
"segmind/Segmind-Vega",
|
72 |
+
"SG161222/RealVisXL_V5.0",
|
73 |
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
74 |
+
],
|
75 |
SINGLE_FILE_MODELS=[
|
76 |
+
"cyberdelia/CyberRealsticXL",
|
77 |
+
"fluently/Fluently-XL-Final",
|
78 |
+
"SG161222/RealVisXL_V5.0",
|
79 |
],
|
80 |
VAE_MODEL="madebyollin/sdxl-vae-fp16-fix",
|
81 |
REFINER_MODEL="stabilityai/stable-diffusion-xl-refiner-1.0",
|
|
|
90 |
WIDTH=1024,
|
91 |
HEIGHT=1024,
|
92 |
NUM_IMAGES=1,
|
|
|
93 |
GUIDANCE_SCALE=6,
|
94 |
INFERENCE_STEPS=40,
|
95 |
DEEPCACHE_INTERVAL=1,
|
lib/inference.py
CHANGED
@@ -4,40 +4,17 @@ from datetime import datetime
|
|
4 |
import torch
|
5 |
from compel import Compel, ReturnedEmbeddingsType
|
6 |
from compel.prompt_parser import PromptParser
|
|
|
7 |
from spaces import GPU
|
8 |
|
9 |
-
from .config import Config
|
10 |
from .loader import Loader
|
11 |
from .logger import Logger
|
12 |
-
from .utils import cuda_collect,
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
def gpu_duration(**kwargs):
|
17 |
-
loading = 15
|
18 |
-
duration = 15
|
19 |
-
width = kwargs.get("width", 1024)
|
20 |
-
height = kwargs.get("height", 1024)
|
21 |
-
scale = kwargs.get("scale", 1)
|
22 |
-
num_images = kwargs.get("num_images", 1)
|
23 |
-
use_refiner = kwargs.get("use_refiner", False)
|
24 |
-
size = width * height
|
25 |
-
if use_refiner:
|
26 |
-
loading += 10
|
27 |
-
if size > 1_100_000:
|
28 |
-
duration += 5
|
29 |
-
if size > 1_600_000:
|
30 |
-
duration += 5
|
31 |
-
if scale == 2:
|
32 |
-
duration += 5
|
33 |
-
if scale == 4:
|
34 |
-
duration += 10
|
35 |
-
return loading + (duration * num_images)
|
36 |
-
|
37 |
-
|
38 |
-
@GPU(duration=gpu_duration)
|
39 |
def generate(
|
40 |
-
positive_prompt,
|
41 |
negative_prompt="",
|
42 |
seed=None,
|
43 |
model="stabilityai/stable-diffusion-xl-base-1.0",
|
@@ -51,50 +28,21 @@ def generate(
|
|
51 |
num_images=1,
|
52 |
use_karras=False,
|
53 |
use_refiner=False,
|
54 |
-
|
55 |
-
Info=None,
|
56 |
-
progress=None,
|
57 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
KIND = "txt2img"
|
59 |
-
CURRENT_STEP = 0
|
60 |
-
CURRENT_IMAGE = 1
|
61 |
EMBEDDINGS_TYPE = ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED
|
62 |
|
63 |
start = time.perf_counter()
|
64 |
log = Logger("generate")
|
65 |
log.info(f"Generating {num_images} image{'s' if num_images > 1 else ''}...")
|
66 |
|
67 |
-
if Config.ZERO_GPU:
|
68 |
-
safe_progress(progress, 100, 100, "ZeroGPU init")
|
69 |
-
|
70 |
-
if not torch.cuda.is_available():
|
71 |
-
raise Error("CUDA not available")
|
72 |
-
|
73 |
-
# https://pytorch.org/docs/stable/generated/torch.manual_seed.html
|
74 |
-
if seed is None or seed < 0:
|
75 |
-
seed = int(datetime.now().timestamp() * 1e6) % (2**64)
|
76 |
-
|
77 |
-
# custom progress bar for multiple images
|
78 |
-
def callback_on_step_end(pipeline, step, timestep, latents):
|
79 |
-
nonlocal CURRENT_IMAGE, CURRENT_STEP
|
80 |
-
if progress is not None:
|
81 |
-
# calculate total steps for img2img based on denoising strength
|
82 |
-
strength = 1
|
83 |
-
total_steps = min(int(inference_steps * strength), inference_steps)
|
84 |
-
|
85 |
-
# if steps are different we're in the refiner
|
86 |
-
refining = False
|
87 |
-
if CURRENT_STEP == step:
|
88 |
-
CURRENT_STEP = step + 1
|
89 |
-
else:
|
90 |
-
refining = True
|
91 |
-
CURRENT_STEP += 1
|
92 |
-
progress(
|
93 |
-
(CURRENT_STEP, total_steps),
|
94 |
-
desc=f"{'Refining' if refining else 'Generating'} image {CURRENT_IMAGE}/{num_images}",
|
95 |
-
)
|
96 |
-
return latents
|
97 |
-
|
98 |
loader = Loader()
|
99 |
loader.load(
|
100 |
KIND,
|
@@ -111,10 +59,11 @@ def generate(
|
|
111 |
pipeline = loader.pipeline
|
112 |
upscaler = loader.upscaler
|
113 |
|
|
|
114 |
if pipeline is None:
|
115 |
raise Error(f"Error loading {model}")
|
116 |
|
117 |
-
#
|
118 |
compel_1 = Compel(
|
119 |
text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2],
|
120 |
tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2],
|
@@ -132,9 +81,13 @@ def generate(
|
|
132 |
device=pipeline.device,
|
133 |
)
|
134 |
|
|
|
|
|
|
|
|
|
|
|
135 |
images = []
|
136 |
current_seed = seed
|
137 |
-
safe_progress(progress, 0, num_images, f"Generating image 0/{num_images}")
|
138 |
|
139 |
for i in range(num_images):
|
140 |
try:
|
@@ -144,23 +97,14 @@ def generate(
|
|
144 |
except PromptParser.ParsingException:
|
145 |
raise Error("Invalid prompt")
|
146 |
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
if use_refiner:
|
151 |
-
pipe_output_type = "latent"
|
152 |
-
if scale > 1:
|
153 |
-
refiner_output_type = "np"
|
154 |
-
else:
|
155 |
-
if scale > 1:
|
156 |
-
pipe_output_type = "np"
|
157 |
-
|
158 |
-
pipe_kwargs = {
|
159 |
"width": width,
|
160 |
"height": height,
|
161 |
"denoising_end": 0.8 if use_refiner else None,
|
162 |
"generator": generator,
|
163 |
-
"output_type":
|
164 |
"guidance_scale": guidance_scale,
|
165 |
"num_inference_steps": inference_steps,
|
166 |
"prompt_embeds": conditioning_1[0:1],
|
@@ -181,39 +125,31 @@ def generate(
|
|
181 |
"negative_pooled_prompt_embeds": pooled_2[1:2],
|
182 |
}
|
183 |
|
184 |
-
|
185 |
-
pipe_kwargs["callback_on_step_end"] = callback_on_step_end
|
186 |
-
refiner_kwargs["callback_on_step_end"] = callback_on_step_end
|
187 |
|
188 |
-
|
189 |
-
image =
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
finally:
|
196 |
-
CURRENT_STEP = 0
|
197 |
-
CURRENT_IMAGE += 1
|
198 |
|
199 |
# Upscale
|
200 |
if scale > 1:
|
201 |
-
|
202 |
-
with timer(msg):
|
203 |
-
safe_progress(progress, 0, num_images, desc=msg)
|
204 |
for i, image in enumerate(images):
|
205 |
-
|
206 |
-
images[i]
|
207 |
-
|
208 |
|
209 |
-
# Flush
|
210 |
cuda_collect()
|
211 |
|
212 |
end = time.perf_counter()
|
213 |
msg = f"Generated {len(images)} image{'s' if len(images) > 1 else ''} in {end - start:.2f}s"
|
214 |
log.info(msg)
|
215 |
|
216 |
-
# Alert if notifier provided
|
217 |
if Info:
|
218 |
Info(msg)
|
219 |
|
|
|
4 |
import torch
|
5 |
from compel import Compel, ReturnedEmbeddingsType
|
6 |
from compel.prompt_parser import PromptParser
|
7 |
+
from gradio import Error, Info, Progress
|
8 |
from spaces import GPU
|
9 |
|
|
|
10 |
from .loader import Loader
|
11 |
from .logger import Logger
|
12 |
+
from .utils import cuda_collect, get_output_types, timer
|
13 |
+
|
14 |
+
|
15 |
+
@GPU
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
def generate(
|
17 |
+
positive_prompt="",
|
18 |
negative_prompt="",
|
19 |
seed=None,
|
20 |
model="stabilityai/stable-diffusion-xl-base-1.0",
|
|
|
28 |
num_images=1,
|
29 |
use_karras=False,
|
30 |
use_refiner=False,
|
31 |
+
progress=Progress(track_tqdm=True),
|
|
|
|
|
32 |
):
|
33 |
+
if not torch.cuda.is_available():
|
34 |
+
raise Error("CUDA not available")
|
35 |
+
|
36 |
+
if positive_prompt.strip() == "":
|
37 |
+
raise Error("You must enter a prompt")
|
38 |
+
|
39 |
KIND = "txt2img"
|
|
|
|
|
40 |
EMBEDDINGS_TYPE = ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED
|
41 |
|
42 |
start = time.perf_counter()
|
43 |
log = Logger("generate")
|
44 |
log.info(f"Generating {num_images} image{'s' if num_images > 1 else ''}...")
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
loader = Loader()
|
47 |
loader.load(
|
48 |
KIND,
|
|
|
59 |
pipeline = loader.pipeline
|
60 |
upscaler = loader.upscaler
|
61 |
|
62 |
+
# Probably a typo in the config
|
63 |
if pipeline is None:
|
64 |
raise Error(f"Error loading {model}")
|
65 |
|
66 |
+
# Prompt embeddings for base and refiner
|
67 |
compel_1 = Compel(
|
68 |
text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2],
|
69 |
tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2],
|
|
|
81 |
device=pipeline.device,
|
82 |
)
|
83 |
|
84 |
+
# https://pytorch.org/docs/stable/generated/torch.manual_seed.html
|
85 |
+
if seed is None or seed < 0:
|
86 |
+
seed = int(datetime.now().timestamp() * 1e6) % (2**64)
|
87 |
+
|
88 |
+
# Increment the seed after each iteration
|
89 |
images = []
|
90 |
current_seed = seed
|
|
|
91 |
|
92 |
for i in range(num_images):
|
93 |
try:
|
|
|
97 |
except PromptParser.ParsingException:
|
98 |
raise Error("Invalid prompt")
|
99 |
|
100 |
+
pipeline_output_type, refiner_output_type = get_output_types(scale, use_refiner)
|
101 |
+
|
102 |
+
pipeline_kwargs = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
"width": width,
|
104 |
"height": height,
|
105 |
"denoising_end": 0.8 if use_refiner else None,
|
106 |
"generator": generator,
|
107 |
+
"output_type": pipeline_output_type,
|
108 |
"guidance_scale": guidance_scale,
|
109 |
"num_inference_steps": inference_steps,
|
110 |
"prompt_embeds": conditioning_1[0:1],
|
|
|
125 |
"negative_pooled_prompt_embeds": pooled_2[1:2],
|
126 |
}
|
127 |
|
128 |
+
image = pipeline(**pipeline_kwargs).images[0]
|
|
|
|
|
129 |
|
130 |
+
if use_refiner:
|
131 |
+
refiner_kwargs["image"] = image
|
132 |
+
image = refiner(**refiner_kwargs).images[0]
|
133 |
+
|
134 |
+
# Use a tuple so gallery images get captions
|
135 |
+
images.append((image, str(current_seed)))
|
136 |
+
current_seed += 1
|
|
|
|
|
|
|
137 |
|
138 |
# Upscale
|
139 |
if scale > 1:
|
140 |
+
with timer(f"Upscaling {num_images} images {scale}x", logger=log.info):
|
|
|
|
|
141 |
for i, image in enumerate(images):
|
142 |
+
image = upscaler.predict(image[0])
|
143 |
+
seed = images[i][1]
|
144 |
+
images[i] = (image, seed)
|
145 |
|
146 |
+
# Flush cache after generating
|
147 |
cuda_collect()
|
148 |
|
149 |
end = time.perf_counter()
|
150 |
msg = f"Generated {len(images)} image{'s' if len(images) > 1 else ''} in {end - start:.2f}s"
|
151 |
log.info(msg)
|
152 |
|
|
|
153 |
if Info:
|
154 |
Info(msg)
|
155 |
|
lib/loader.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1 |
-
# import gc
|
2 |
-
|
3 |
import torch
|
4 |
from DeepCache import DeepCacheSDHelper
|
5 |
from diffusers.models import AutoencoderKL
|
@@ -33,7 +31,7 @@ class Loader:
|
|
33 |
return False
|
34 |
|
35 |
def should_unload_pipeline(self, model=""):
|
36 |
-
return self.pipeline is not None and self.model
|
37 |
|
38 |
def should_load_refiner(self, use_refiner=False):
|
39 |
return self.refiner is None and use_refiner
|
@@ -53,8 +51,6 @@ class Loader:
|
|
53 |
return self.pipeline is None
|
54 |
|
55 |
def unload(self, model, use_refiner, deepcache_interval, scale):
|
56 |
-
needs_gc = False
|
57 |
-
|
58 |
if self.should_unload_deepcache(deepcache_interval):
|
59 |
self.log.info("Disabling DeepCache")
|
60 |
self.pipeline.deepcache.disable()
|
@@ -64,37 +60,41 @@ class Loader:
|
|
64 |
delattr(self.refiner, "deepcache")
|
65 |
|
66 |
if self.should_unload_refiner(use_refiner):
|
67 |
-
|
68 |
-
|
69 |
-
self.refiner = None
|
70 |
-
needs_gc = True
|
71 |
|
72 |
if self.should_unload_upscaler(scale):
|
73 |
-
|
74 |
-
|
75 |
-
self.upscaler = None
|
76 |
-
needs_gc = True
|
77 |
|
78 |
if self.should_unload_pipeline(model):
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
# gc.collect()
|
93 |
-
|
94 |
-
def load_refiner(self, refiner_kwargs={}, progress=None):
|
95 |
model = Config.REFINER_MODEL
|
96 |
try:
|
97 |
-
with timer(f"Loading {model}"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
Pipeline = Config.PIPELINES["img2img"]
|
99 |
self.refiner = Pipeline.from_pretrained(model, **refiner_kwargs).to("cuda")
|
100 |
except Exception as e:
|
@@ -107,7 +107,7 @@ class Loader:
|
|
107 |
def load_upscaler(self, scale=1):
|
108 |
if self.should_load_upscaler(scale):
|
109 |
try:
|
110 |
-
with timer(f"Loading {scale}x upscaler"):
|
111 |
self.upscaler = RealESRGAN(scale, device=self.pipeline.device)
|
112 |
self.upscaler.load_weights()
|
113 |
except Exception as e:
|
@@ -125,7 +125,7 @@ class Loader:
|
|
125 |
self.refiner.deepcache.set_params(cache_interval=interval)
|
126 |
self.refiner.deepcache.enable()
|
127 |
|
128 |
-
def load(self, kind, model, scheduler, deepcache_interval, scale, use_karras, use_refiner, progress):
|
129 |
scheduler_kwargs = {
|
130 |
"beta_start": 0.00085,
|
131 |
"beta_end": 0.012,
|
@@ -141,13 +141,13 @@ class Loader:
|
|
141 |
scheduler_kwargs["clip_sample"] = False
|
142 |
scheduler_kwargs["set_alpha_to_one"] = False
|
143 |
|
144 |
-
if model
|
145 |
variant = "fp16"
|
146 |
else:
|
147 |
variant = None
|
148 |
|
149 |
dtype = torch.float16
|
150 |
-
|
151 |
"variant": variant,
|
152 |
"torch_dtype": dtype,
|
153 |
"add_watermarker": False,
|
@@ -161,16 +161,16 @@ class Loader:
|
|
161 |
Scheduler = Config.SCHEDULERS[scheduler]
|
162 |
|
163 |
try:
|
164 |
-
with timer(f"Loading {model}"):
|
165 |
self.model = model
|
166 |
-
if model
|
167 |
checkpoint = Config.HF_REPOS[model][0]
|
168 |
self.pipeline = Pipeline.from_single_file(
|
169 |
f"https://huggingface.co/{model}/{checkpoint}",
|
170 |
-
**
|
171 |
).to("cuda")
|
172 |
else:
|
173 |
-
self.pipeline = Pipeline.from_pretrained(model, **
|
174 |
except Exception as e:
|
175 |
self.log.error(f"Error loading {model}: {e}")
|
176 |
self.model = None
|
@@ -190,7 +190,7 @@ class Loader:
|
|
190 |
or self.pipeline.scheduler.config.use_karras_sigmas == use_karras
|
191 |
)
|
192 |
|
193 |
-
if self.model
|
194 |
if not same_scheduler:
|
195 |
self.log.info(f"Enabling {scheduler}")
|
196 |
if not same_karras:
|
@@ -201,18 +201,7 @@ class Loader:
|
|
201 |
self.refiner.scheduler = self.pipeline.scheduler
|
202 |
|
203 |
if self.should_load_refiner(use_refiner):
|
204 |
-
|
205 |
-
"variant": "fp16",
|
206 |
-
"torch_dtype": dtype,
|
207 |
-
"add_watermarker": False,
|
208 |
-
"requires_aesthetics_score": True,
|
209 |
-
"force_zeros_for_empty_prompt": False,
|
210 |
-
"vae": self.pipeline.vae,
|
211 |
-
"scheduler": self.pipeline.scheduler,
|
212 |
-
"tokenizer_2": self.pipeline.tokenizer_2,
|
213 |
-
"text_encoder_2": self.pipeline.text_encoder_2,
|
214 |
-
}
|
215 |
-
self.load_refiner(refiner_kwargs, progress)
|
216 |
|
217 |
if self.should_load_deepcache(deepcache_interval):
|
218 |
self.load_deepcache(deepcache_interval)
|
|
|
|
|
|
|
1 |
import torch
|
2 |
from DeepCache import DeepCacheSDHelper
|
3 |
from diffusers.models import AutoencoderKL
|
|
|
31 |
return False
|
32 |
|
33 |
def should_unload_pipeline(self, model=""):
|
34 |
+
return self.pipeline is not None and self.model != model
|
35 |
|
36 |
def should_load_refiner(self, use_refiner=False):
|
37 |
return self.refiner is None and use_refiner
|
|
|
51 |
return self.pipeline is None
|
52 |
|
53 |
def unload(self, model, use_refiner, deepcache_interval, scale):
|
|
|
|
|
54 |
if self.should_unload_deepcache(deepcache_interval):
|
55 |
self.log.info("Disabling DeepCache")
|
56 |
self.pipeline.deepcache.disable()
|
|
|
60 |
delattr(self.refiner, "deepcache")
|
61 |
|
62 |
if self.should_unload_refiner(use_refiner):
|
63 |
+
self.log.info("Unloading refiner")
|
64 |
+
self.refiner = None
|
|
|
|
|
65 |
|
66 |
if self.should_unload_upscaler(scale):
|
67 |
+
self.log.info("Unloading upscaler")
|
68 |
+
self.upscaler = None
|
|
|
|
|
69 |
|
70 |
if self.should_unload_pipeline(model):
|
71 |
+
self.log.info(f"Unloading {self.model}")
|
72 |
+
if self.refiner:
|
73 |
+
self.refiner.vae = None
|
74 |
+
self.refiner.scheduler = None
|
75 |
+
self.refiner.tokenizer_2 = None
|
76 |
+
self.refiner.text_encoder_2 = None
|
77 |
+
self.pipeline = None
|
78 |
+
self.model = None
|
79 |
+
|
80 |
+
# Flush cache
|
81 |
+
cuda_collect()
|
82 |
+
|
83 |
+
def load_refiner(self, progress=None):
|
|
|
|
|
|
|
84 |
model = Config.REFINER_MODEL
|
85 |
try:
|
86 |
+
with timer(f"Loading {model}", logger=self.log.info):
|
87 |
+
refiner_kwargs = {
|
88 |
+
"variant": "fp16",
|
89 |
+
"torch_dtype": self.pipeline.dtype,
|
90 |
+
"add_watermarker": False,
|
91 |
+
"requires_aesthetics_score": True,
|
92 |
+
"force_zeros_for_empty_prompt": False,
|
93 |
+
"vae": self.pipeline.vae,
|
94 |
+
"scheduler": self.pipeline.scheduler,
|
95 |
+
"tokenizer_2": self.pipeline.tokenizer_2,
|
96 |
+
"text_encoder_2": self.pipeline.text_encoder_2,
|
97 |
+
}
|
98 |
Pipeline = Config.PIPELINES["img2img"]
|
99 |
self.refiner = Pipeline.from_pretrained(model, **refiner_kwargs).to("cuda")
|
100 |
except Exception as e:
|
|
|
107 |
def load_upscaler(self, scale=1):
|
108 |
if self.should_load_upscaler(scale):
|
109 |
try:
|
110 |
+
with timer(f"Loading {scale}x upscaler", logger=self.log.info):
|
111 |
self.upscaler = RealESRGAN(scale, device=self.pipeline.device)
|
112 |
self.upscaler.load_weights()
|
113 |
except Exception as e:
|
|
|
125 |
self.refiner.deepcache.set_params(cache_interval=interval)
|
126 |
self.refiner.deepcache.enable()
|
127 |
|
128 |
+
def load(self, kind, model, scheduler, deepcache_interval, scale, use_karras, use_refiner, progress=None):
|
129 |
scheduler_kwargs = {
|
130 |
"beta_start": 0.00085,
|
131 |
"beta_end": 0.012,
|
|
|
141 |
scheduler_kwargs["clip_sample"] = False
|
142 |
scheduler_kwargs["set_alpha_to_one"] = False
|
143 |
|
144 |
+
if model not in Config.SINGLE_FILE_MODELS:
|
145 |
variant = "fp16"
|
146 |
else:
|
147 |
variant = None
|
148 |
|
149 |
dtype = torch.float16
|
150 |
+
pipeline_kwargs = {
|
151 |
"variant": variant,
|
152 |
"torch_dtype": dtype,
|
153 |
"add_watermarker": False,
|
|
|
161 |
Scheduler = Config.SCHEDULERS[scheduler]
|
162 |
|
163 |
try:
|
164 |
+
with timer(f"Loading {model}", logger=self.log.info):
|
165 |
self.model = model
|
166 |
+
if model in Config.SINGLE_FILE_MODELS:
|
167 |
checkpoint = Config.HF_REPOS[model][0]
|
168 |
self.pipeline = Pipeline.from_single_file(
|
169 |
f"https://huggingface.co/{model}/{checkpoint}",
|
170 |
+
**pipeline_kwargs,
|
171 |
).to("cuda")
|
172 |
else:
|
173 |
+
self.pipeline = Pipeline.from_pretrained(model, **pipeline_kwargs).to("cuda")
|
174 |
except Exception as e:
|
175 |
self.log.error(f"Error loading {model}: {e}")
|
176 |
self.model = None
|
|
|
190 |
or self.pipeline.scheduler.config.use_karras_sigmas == use_karras
|
191 |
)
|
192 |
|
193 |
+
if self.model == model:
|
194 |
if not same_scheduler:
|
195 |
self.log.info(f"Enabling {scheduler}")
|
196 |
if not same_karras:
|
|
|
201 |
self.refiner.scheduler = self.pipeline.scheduler
|
202 |
|
203 |
if self.should_load_refiner(use_refiner):
|
204 |
+
self.load_refiner(progress)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
|
206 |
if self.should_load_deepcache(deepcache_interval):
|
207 |
self.load_deepcache(deepcache_interval)
|
lib/utils.py
CHANGED
@@ -5,8 +5,6 @@ from contextlib import contextmanager
|
|
5 |
|
6 |
import torch
|
7 |
from diffusers.utils import logging as diffusers_logging
|
8 |
-
from huggingface_hub._snapshot_download import snapshot_download
|
9 |
-
from huggingface_hub.utils import are_progress_bars_disabled
|
10 |
from transformers import logging as transformers_logging
|
11 |
|
12 |
|
@@ -45,9 +43,14 @@ def enable_progress_bars():
|
|
45 |
diffusers_logging.enable_progress_bar()
|
46 |
|
47 |
|
48 |
-
def
|
49 |
-
if
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
|
53 |
def cuda_collect():
|
@@ -56,19 +59,3 @@ def cuda_collect():
|
|
56 |
torch.cuda.ipc_collect()
|
57 |
torch.cuda.reset_peak_memory_stats()
|
58 |
torch.cuda.synchronize()
|
59 |
-
|
60 |
-
|
61 |
-
def download_repo_files(repo_id, allow_patterns, token=None):
|
62 |
-
was_disabled = are_progress_bars_disabled()
|
63 |
-
enable_progress_bars()
|
64 |
-
snapshot_path = snapshot_download(
|
65 |
-
repo_id=repo_id,
|
66 |
-
repo_type="model",
|
67 |
-
revision="main",
|
68 |
-
token=token,
|
69 |
-
allow_patterns=allow_patterns,
|
70 |
-
ignore_patterns=None,
|
71 |
-
)
|
72 |
-
if was_disabled:
|
73 |
-
disable_progress_bars()
|
74 |
-
return snapshot_path
|
|
|
5 |
|
6 |
import torch
|
7 |
from diffusers.utils import logging as diffusers_logging
|
|
|
|
|
8 |
from transformers import logging as transformers_logging
|
9 |
|
10 |
|
|
|
43 |
diffusers_logging.enable_progress_bar()
|
44 |
|
45 |
|
46 |
+
def get_output_types(scale=1, use_refiner=False):
|
47 |
+
if use_refiner:
|
48 |
+
pipeline_type = "latent"
|
49 |
+
refiner_type = "np" if scale > 1 else "pil"
|
50 |
+
else:
|
51 |
+
refiner_type = "pil"
|
52 |
+
pipeline_type = "np" if scale > 1 else "pil"
|
53 |
+
return (pipeline_type, refiner_type)
|
54 |
|
55 |
|
56 |
def cuda_collect():
|
|
|
59 |
torch.cuda.ipc_collect()
|
60 |
torch.cuda.reset_peak_memory_stats()
|
61 |
torch.cuda.synchronize()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|