adamelliotfields commited on
Commit
b00d4fe
1 Parent(s): 6ad0411

Simplify loading and inference

Browse files
Files changed (6) hide show
  1. app.py +23 -24
  2. lib/__init__.py +0 -2
  3. lib/config.py +11 -24
  4. lib/inference.py +37 -101
  5. lib/loader.py +40 -51
  6. lib/utils.py +8 -21
app.py CHANGED
@@ -1,11 +1,20 @@
1
  import argparse
 
 
 
 
 
 
 
 
 
2
 
3
  import gradio as gr
 
4
 
5
  from lib import (
6
  Config,
7
- # disable_progress_bars,
8
- download_repo_files,
9
  generate,
10
  read_file,
11
  read_json,
@@ -60,24 +69,6 @@ random_prompt_js = f"""
60
  }}
61
  """
62
 
63
-
64
- # Transform the raw inputs before generation
65
- def generate_fn(*args, progress=gr.Progress(track_tqdm=True)):
66
- if len(args) > 0:
67
- prompt = args[0]
68
- else:
69
- prompt = None
70
- if prompt is None or prompt.strip() == "":
71
- raise gr.Error("You must enter a prompt")
72
- try:
73
- # if Config.ZERO_GPU:
74
- # progress((0, 100), desc="ZeroGPU init")
75
- images = generate(*args, Error=gr.Error, Info=gr.Info, progress=progress)
76
- except RuntimeError:
77
- raise gr.Error("Error: Please try again")
78
- return images
79
-
80
-
81
  with gr.Blocks(
82
  head=read_file("./partials/head.html"),
83
  css="./app.css",
@@ -244,10 +235,10 @@ with gr.Blocks(
244
  label="Scale",
245
  )
246
  seed = gr.Number(
247
- value=Config.SEED,
248
- label="Seed",
249
  minimum=-1,
250
  maximum=(2**64) - 1,
 
 
251
  )
252
  with gr.Row():
253
  use_karras = gr.Checkbox(
@@ -293,7 +284,7 @@ with gr.Blocks(
293
  # Generate images
294
  gr.on(
295
  triggers=[generate_btn.click, prompt.submit],
296
- fn=generate_fn,
297
  api_name="generate",
298
  outputs=[output_images],
299
  inputs=[
@@ -321,8 +312,16 @@ if __name__ == "__main__":
321
  args = parser.parse_args()
322
 
323
  # disable_progress_bars()
 
324
  for repo_id, allow_patterns in Config.HF_REPOS.items():
325
- download_repo_files(repo_id, allow_patterns, token=Config.HF_TOKEN)
 
 
 
 
 
 
 
326
 
327
  # https://www.gradio.app/docs/gradio/interface#interface-queue
328
  demo.queue(default_concurrency_limit=1).launch(
 
1
  import argparse
2
+ import os
3
+ from importlib.util import find_spec
4
+
5
+ # Improved GPU handling and progress bars
6
+ os.environ["ZEROGPU_V2"] = "1"
7
+
8
+ # Use Rust-based downloader
9
+ if find_spec("hf_transfer"):
10
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
11
 
12
  import gradio as gr
13
+ from huggingface_hub._snapshot_download import snapshot_download
14
 
15
  from lib import (
16
  Config,
17
+ disable_progress_bars,
 
18
  generate,
19
  read_file,
20
  read_json,
 
69
  }}
70
  """
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  with gr.Blocks(
73
  head=read_file("./partials/head.html"),
74
  css="./app.css",
 
235
  label="Scale",
236
  )
237
  seed = gr.Number(
 
 
238
  minimum=-1,
239
  maximum=(2**64) - 1,
240
+ label="Seed",
241
+ value=-1,
242
  )
243
  with gr.Row():
244
  use_karras = gr.Checkbox(
 
284
  # Generate images
285
  gr.on(
286
  triggers=[generate_btn.click, prompt.submit],
287
+ fn=generate,
288
  api_name="generate",
289
  outputs=[output_images],
290
  inputs=[
 
312
  args = parser.parse_args()
313
 
314
  # disable_progress_bars()
315
+ token = os.environ.get("HF_TOKEN", None)
316
  for repo_id, allow_patterns in Config.HF_REPOS.items():
317
+ snapshot_download(
318
+ repo_id=repo_id,
319
+ repo_type="model",
320
+ revision="main",
321
+ token=token,
322
+ allow_patterns=allow_patterns,
323
+ ignore_patterns=None,
324
+ )
325
 
326
  # https://www.gradio.app/docs/gradio/interface#interface-queue
327
  demo.queue(default_concurrency_limit=1).launch(
lib/__init__.py CHANGED
@@ -2,7 +2,6 @@ from .config import Config
2
  from .inference import generate
3
  from .utils import (
4
  disable_progress_bars,
5
- download_repo_files,
6
  read_file,
7
  read_json,
8
  )
@@ -10,7 +9,6 @@ from .utils import (
10
  __all__ = [
11
  "Config",
12
  "disable_progress_bars",
13
- "download_repo_files",
14
  "generate",
15
  "read_file",
16
  "read_json",
 
2
  from .inference import generate
3
  from .utils import (
4
  disable_progress_bars,
 
5
  read_file,
6
  read_json,
7
  )
 
9
  __all__ = [
10
  "Config",
11
  "disable_progress_bars",
 
12
  "generate",
13
  "read_file",
14
  "read_json",
lib/config.py CHANGED
@@ -1,6 +1,3 @@
1
- import os
2
- from importlib import import_module
3
- from importlib.util import find_spec
4
  from types import SimpleNamespace
5
  from warnings import filterwarnings
6
 
@@ -16,13 +13,6 @@ from diffusers import (
16
  from diffusers.utils import logging as diffusers_logging
17
  from transformers import logging as transformers_logging
18
 
19
- # Improved GPU handling and progress bars; set before importing spaces
20
- os.environ["ZEROGPU_V2"] = "1"
21
-
22
- # Use Rust-based downloader; errors if enabled and not installed
23
- if find_spec("hf_transfer"):
24
- os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
25
-
26
  filterwarnings("ignore", category=FutureWarning, module="diffusers")
27
  filterwarnings("ignore", category=FutureWarning, module="transformers")
28
 
@@ -60,20 +50,10 @@ _sdxl_files_with_vae = [*_sdxl_files, "vae_1_0/config.json"]
60
 
61
  # Using namespace instead of dataclass for simplicity
62
  Config = SimpleNamespace(
63
- HF_TOKEN=os.environ.get("HF_TOKEN", None),
64
- ZERO_GPU=import_module("spaces").config.Config.zero_gpu,
65
  PIPELINES={
66
  "txt2img": StableDiffusionXLPipeline,
67
  "img2img": StableDiffusionXLImg2ImgPipeline,
68
  },
69
- MODEL="segmind/Segmind-Vega",
70
- MODELS=[
71
- "cyberdelia/CyberRealsticXL",
72
- "fluently/Fluently-XL-Final",
73
- "segmind/Segmind-Vega",
74
- "SG161222/RealVisXL_V5.0",
75
- "stabilityai/stable-diffusion-xl-base-1.0",
76
- ],
77
  HF_REPOS={
78
  "ai-forever/Real-ESRGAN": ["RealESRGAN_x2.pth", "RealESRGAN_x4.pth"],
79
  "cyberdelia/CyberRealsticXL": ["CyberRealisticXLPlay_V1.0.safetensors"],
@@ -84,10 +64,18 @@ Config = SimpleNamespace(
84
  "stabilityai/stable-diffusion-xl-base-1.0": _sdxl_files_with_vae,
85
  "stabilityai/stable-diffusion-xl-refiner-1.0": _sdxl_refiner_files,
86
  },
 
 
 
 
 
 
 
 
87
  SINGLE_FILE_MODELS=[
88
- "cyberdelia/cyberrealsticxl",
89
- "fluently/fluently-xl-final",
90
- "sg161222/realvisxl_v5.0",
91
  ],
92
  VAE_MODEL="madebyollin/sdxl-vae-fp16-fix",
93
  REFINER_MODEL="stabilityai/stable-diffusion-xl-refiner-1.0",
@@ -102,7 +90,6 @@ Config = SimpleNamespace(
102
  WIDTH=1024,
103
  HEIGHT=1024,
104
  NUM_IMAGES=1,
105
- SEED=-1,
106
  GUIDANCE_SCALE=6,
107
  INFERENCE_STEPS=40,
108
  DEEPCACHE_INTERVAL=1,
 
 
 
 
1
  from types import SimpleNamespace
2
  from warnings import filterwarnings
3
 
 
13
  from diffusers.utils import logging as diffusers_logging
14
  from transformers import logging as transformers_logging
15
 
 
 
 
 
 
 
 
16
  filterwarnings("ignore", category=FutureWarning, module="diffusers")
17
  filterwarnings("ignore", category=FutureWarning, module="transformers")
18
 
 
50
 
51
  # Using namespace instead of dataclass for simplicity
52
  Config = SimpleNamespace(
 
 
53
  PIPELINES={
54
  "txt2img": StableDiffusionXLPipeline,
55
  "img2img": StableDiffusionXLImg2ImgPipeline,
56
  },
 
 
 
 
 
 
 
 
57
  HF_REPOS={
58
  "ai-forever/Real-ESRGAN": ["RealESRGAN_x2.pth", "RealESRGAN_x4.pth"],
59
  "cyberdelia/CyberRealsticXL": ["CyberRealisticXLPlay_V1.0.safetensors"],
 
64
  "stabilityai/stable-diffusion-xl-base-1.0": _sdxl_files_with_vae,
65
  "stabilityai/stable-diffusion-xl-refiner-1.0": _sdxl_refiner_files,
66
  },
67
+ MODEL="segmind/Segmind-Vega",
68
+ MODELS=[
69
+ "cyberdelia/CyberRealsticXL",
70
+ "fluently/Fluently-XL-Final",
71
+ "segmind/Segmind-Vega",
72
+ "SG161222/RealVisXL_V5.0",
73
+ "stabilityai/stable-diffusion-xl-base-1.0",
74
+ ],
75
  SINGLE_FILE_MODELS=[
76
+ "cyberdelia/CyberRealsticXL",
77
+ "fluently/Fluently-XL-Final",
78
+ "SG161222/RealVisXL_V5.0",
79
  ],
80
  VAE_MODEL="madebyollin/sdxl-vae-fp16-fix",
81
  REFINER_MODEL="stabilityai/stable-diffusion-xl-refiner-1.0",
 
90
  WIDTH=1024,
91
  HEIGHT=1024,
92
  NUM_IMAGES=1,
 
93
  GUIDANCE_SCALE=6,
94
  INFERENCE_STEPS=40,
95
  DEEPCACHE_INTERVAL=1,
lib/inference.py CHANGED
@@ -4,40 +4,17 @@ from datetime import datetime
4
  import torch
5
  from compel import Compel, ReturnedEmbeddingsType
6
  from compel.prompt_parser import PromptParser
 
7
  from spaces import GPU
8
 
9
- from .config import Config
10
  from .loader import Loader
11
  from .logger import Logger
12
- from .utils import cuda_collect, safe_progress, timer
13
-
14
-
15
- # Dynamic signature for the GPU duration function; max 60s per image
16
- def gpu_duration(**kwargs):
17
- loading = 15
18
- duration = 15
19
- width = kwargs.get("width", 1024)
20
- height = kwargs.get("height", 1024)
21
- scale = kwargs.get("scale", 1)
22
- num_images = kwargs.get("num_images", 1)
23
- use_refiner = kwargs.get("use_refiner", False)
24
- size = width * height
25
- if use_refiner:
26
- loading += 10
27
- if size > 1_100_000:
28
- duration += 5
29
- if size > 1_600_000:
30
- duration += 5
31
- if scale == 2:
32
- duration += 5
33
- if scale == 4:
34
- duration += 10
35
- return loading + (duration * num_images)
36
-
37
-
38
- @GPU(duration=gpu_duration)
39
  def generate(
40
- positive_prompt,
41
  negative_prompt="",
42
  seed=None,
43
  model="stabilityai/stable-diffusion-xl-base-1.0",
@@ -51,50 +28,21 @@ def generate(
51
  num_images=1,
52
  use_karras=False,
53
  use_refiner=False,
54
- Error=Exception,
55
- Info=None,
56
- progress=None,
57
  ):
 
 
 
 
 
 
58
  KIND = "txt2img"
59
- CURRENT_STEP = 0
60
- CURRENT_IMAGE = 1
61
  EMBEDDINGS_TYPE = ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED
62
 
63
  start = time.perf_counter()
64
  log = Logger("generate")
65
  log.info(f"Generating {num_images} image{'s' if num_images > 1 else ''}...")
66
 
67
- if Config.ZERO_GPU:
68
- safe_progress(progress, 100, 100, "ZeroGPU init")
69
-
70
- if not torch.cuda.is_available():
71
- raise Error("CUDA not available")
72
-
73
- # https://pytorch.org/docs/stable/generated/torch.manual_seed.html
74
- if seed is None or seed < 0:
75
- seed = int(datetime.now().timestamp() * 1e6) % (2**64)
76
-
77
- # custom progress bar for multiple images
78
- def callback_on_step_end(pipeline, step, timestep, latents):
79
- nonlocal CURRENT_IMAGE, CURRENT_STEP
80
- if progress is not None:
81
- # calculate total steps for img2img based on denoising strength
82
- strength = 1
83
- total_steps = min(int(inference_steps * strength), inference_steps)
84
-
85
- # if steps are different we're in the refiner
86
- refining = False
87
- if CURRENT_STEP == step:
88
- CURRENT_STEP = step + 1
89
- else:
90
- refining = True
91
- CURRENT_STEP += 1
92
- progress(
93
- (CURRENT_STEP, total_steps),
94
- desc=f"{'Refining' if refining else 'Generating'} image {CURRENT_IMAGE}/{num_images}",
95
- )
96
- return latents
97
-
98
  loader = Loader()
99
  loader.load(
100
  KIND,
@@ -111,10 +59,11 @@ def generate(
111
  pipeline = loader.pipeline
112
  upscaler = loader.upscaler
113
 
 
114
  if pipeline is None:
115
  raise Error(f"Error loading {model}")
116
 
117
- # prompt embeds for base and refiner
118
  compel_1 = Compel(
119
  text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2],
120
  tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2],
@@ -132,9 +81,13 @@ def generate(
132
  device=pipeline.device,
133
  )
134
 
 
 
 
 
 
135
  images = []
136
  current_seed = seed
137
- safe_progress(progress, 0, num_images, f"Generating image 0/{num_images}")
138
 
139
  for i in range(num_images):
140
  try:
@@ -144,23 +97,14 @@ def generate(
144
  except PromptParser.ParsingException:
145
  raise Error("Invalid prompt")
146
 
147
- # refiner expects latents; upscaler expects numpy array
148
- pipe_output_type = "pil"
149
- refiner_output_type = "pil"
150
- if use_refiner:
151
- pipe_output_type = "latent"
152
- if scale > 1:
153
- refiner_output_type = "np"
154
- else:
155
- if scale > 1:
156
- pipe_output_type = "np"
157
-
158
- pipe_kwargs = {
159
  "width": width,
160
  "height": height,
161
  "denoising_end": 0.8 if use_refiner else None,
162
  "generator": generator,
163
- "output_type": pipe_output_type,
164
  "guidance_scale": guidance_scale,
165
  "num_inference_steps": inference_steps,
166
  "prompt_embeds": conditioning_1[0:1],
@@ -181,39 +125,31 @@ def generate(
181
  "negative_pooled_prompt_embeds": pooled_2[1:2],
182
  }
183
 
184
- if progress is not None:
185
- pipe_kwargs["callback_on_step_end"] = callback_on_step_end
186
- refiner_kwargs["callback_on_step_end"] = callback_on_step_end
187
 
188
- try:
189
- image = pipeline(**pipe_kwargs).images[0]
190
- if use_refiner:
191
- refiner_kwargs["image"] = image
192
- image = refiner(**refiner_kwargs).images[0]
193
- images.append((image, str(current_seed)))
194
- current_seed += 1
195
- finally:
196
- CURRENT_STEP = 0
197
- CURRENT_IMAGE += 1
198
 
199
  # Upscale
200
  if scale > 1:
201
- msg = f"Upscaling {scale}x"
202
- with timer(msg):
203
- safe_progress(progress, 0, num_images, desc=msg)
204
  for i, image in enumerate(images):
205
- images = upscaler.predict(image[0])
206
- images[i] = image
207
- safe_progress(progress, i + 1, num_images, desc=msg)
208
 
209
- # Flush memory after generating
210
  cuda_collect()
211
 
212
  end = time.perf_counter()
213
  msg = f"Generated {len(images)} image{'s' if len(images) > 1 else ''} in {end - start:.2f}s"
214
  log.info(msg)
215
 
216
- # Alert if notifier provided
217
  if Info:
218
  Info(msg)
219
 
 
4
  import torch
5
  from compel import Compel, ReturnedEmbeddingsType
6
  from compel.prompt_parser import PromptParser
7
+ from gradio import Error, Info, Progress
8
  from spaces import GPU
9
 
 
10
  from .loader import Loader
11
  from .logger import Logger
12
+ from .utils import cuda_collect, get_output_types, timer
13
+
14
+
15
+ @GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def generate(
17
+ positive_prompt="",
18
  negative_prompt="",
19
  seed=None,
20
  model="stabilityai/stable-diffusion-xl-base-1.0",
 
28
  num_images=1,
29
  use_karras=False,
30
  use_refiner=False,
31
+ progress=Progress(track_tqdm=True),
 
 
32
  ):
33
+ if not torch.cuda.is_available():
34
+ raise Error("CUDA not available")
35
+
36
+ if positive_prompt.strip() == "":
37
+ raise Error("You must enter a prompt")
38
+
39
  KIND = "txt2img"
 
 
40
  EMBEDDINGS_TYPE = ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED
41
 
42
  start = time.perf_counter()
43
  log = Logger("generate")
44
  log.info(f"Generating {num_images} image{'s' if num_images > 1 else ''}...")
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  loader = Loader()
47
  loader.load(
48
  KIND,
 
59
  pipeline = loader.pipeline
60
  upscaler = loader.upscaler
61
 
62
+ # Probably a typo in the config
63
  if pipeline is None:
64
  raise Error(f"Error loading {model}")
65
 
66
+ # Prompt embeddings for base and refiner
67
  compel_1 = Compel(
68
  text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2],
69
  tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2],
 
81
  device=pipeline.device,
82
  )
83
 
84
+ # https://pytorch.org/docs/stable/generated/torch.manual_seed.html
85
+ if seed is None or seed < 0:
86
+ seed = int(datetime.now().timestamp() * 1e6) % (2**64)
87
+
88
+ # Increment the seed after each iteration
89
  images = []
90
  current_seed = seed
 
91
 
92
  for i in range(num_images):
93
  try:
 
97
  except PromptParser.ParsingException:
98
  raise Error("Invalid prompt")
99
 
100
+ pipeline_output_type, refiner_output_type = get_output_types(scale, use_refiner)
101
+
102
+ pipeline_kwargs = {
 
 
 
 
 
 
 
 
 
103
  "width": width,
104
  "height": height,
105
  "denoising_end": 0.8 if use_refiner else None,
106
  "generator": generator,
107
+ "output_type": pipeline_output_type,
108
  "guidance_scale": guidance_scale,
109
  "num_inference_steps": inference_steps,
110
  "prompt_embeds": conditioning_1[0:1],
 
125
  "negative_pooled_prompt_embeds": pooled_2[1:2],
126
  }
127
 
128
+ image = pipeline(**pipeline_kwargs).images[0]
 
 
129
 
130
+ if use_refiner:
131
+ refiner_kwargs["image"] = image
132
+ image = refiner(**refiner_kwargs).images[0]
133
+
134
+ # Use a tuple so gallery images get captions
135
+ images.append((image, str(current_seed)))
136
+ current_seed += 1
 
 
 
137
 
138
  # Upscale
139
  if scale > 1:
140
+ with timer(f"Upscaling {num_images} images {scale}x", logger=log.info):
 
 
141
  for i, image in enumerate(images):
142
+ image = upscaler.predict(image[0])
143
+ seed = images[i][1]
144
+ images[i] = (image, seed)
145
 
146
+ # Flush cache after generating
147
  cuda_collect()
148
 
149
  end = time.perf_counter()
150
  msg = f"Generated {len(images)} image{'s' if len(images) > 1 else ''} in {end - start:.2f}s"
151
  log.info(msg)
152
 
 
153
  if Info:
154
  Info(msg)
155
 
lib/loader.py CHANGED
@@ -1,5 +1,3 @@
1
- # import gc
2
-
3
  import torch
4
  from DeepCache import DeepCacheSDHelper
5
  from diffusers.models import AutoencoderKL
@@ -33,7 +31,7 @@ class Loader:
33
  return False
34
 
35
  def should_unload_pipeline(self, model=""):
36
- return self.pipeline is not None and self.model.lower() != model.lower()
37
 
38
  def should_load_refiner(self, use_refiner=False):
39
  return self.refiner is None and use_refiner
@@ -53,8 +51,6 @@ class Loader:
53
  return self.pipeline is None
54
 
55
  def unload(self, model, use_refiner, deepcache_interval, scale):
56
- needs_gc = False
57
-
58
  if self.should_unload_deepcache(deepcache_interval):
59
  self.log.info("Disabling DeepCache")
60
  self.pipeline.deepcache.disable()
@@ -64,37 +60,41 @@ class Loader:
64
  delattr(self.refiner, "deepcache")
65
 
66
  if self.should_unload_refiner(use_refiner):
67
- with timer("Unloading refiner"):
68
- self.refiner.to("cpu", silence_dtype_warnings=True)
69
- self.refiner = None
70
- needs_gc = True
71
 
72
  if self.should_unload_upscaler(scale):
73
- with timer(f"Unloading {self.upscaler.scale}x upscaler"):
74
- self.upscaler.to("cpu")
75
- self.upscaler = None
76
- needs_gc = True
77
 
78
  if self.should_unload_pipeline(model):
79
- with timer(f"Unloading {self.model}"):
80
- self.pipeline.to("cpu", silence_dtype_warnings=True)
81
- if self.refiner:
82
- self.refiner.vae = None
83
- self.refiner.scheduler = None
84
- self.refiner.tokenizer_2 = None
85
- self.refiner.text_encoder_2 = None
86
- self.pipeline = None
87
- self.model = None
88
- needs_gc = True
89
-
90
- if needs_gc:
91
- cuda_collect()
92
- # gc.collect()
93
-
94
- def load_refiner(self, refiner_kwargs={}, progress=None):
95
  model = Config.REFINER_MODEL
96
  try:
97
- with timer(f"Loading {model}"):
 
 
 
 
 
 
 
 
 
 
 
98
  Pipeline = Config.PIPELINES["img2img"]
99
  self.refiner = Pipeline.from_pretrained(model, **refiner_kwargs).to("cuda")
100
  except Exception as e:
@@ -107,7 +107,7 @@ class Loader:
107
  def load_upscaler(self, scale=1):
108
  if self.should_load_upscaler(scale):
109
  try:
110
- with timer(f"Loading {scale}x upscaler"):
111
  self.upscaler = RealESRGAN(scale, device=self.pipeline.device)
112
  self.upscaler.load_weights()
113
  except Exception as e:
@@ -125,7 +125,7 @@ class Loader:
125
  self.refiner.deepcache.set_params(cache_interval=interval)
126
  self.refiner.deepcache.enable()
127
 
128
- def load(self, kind, model, scheduler, deepcache_interval, scale, use_karras, use_refiner, progress):
129
  scheduler_kwargs = {
130
  "beta_start": 0.00085,
131
  "beta_end": 0.012,
@@ -141,13 +141,13 @@ class Loader:
141
  scheduler_kwargs["clip_sample"] = False
142
  scheduler_kwargs["set_alpha_to_one"] = False
143
 
144
- if model.lower() not in Config.SINGLE_FILE_MODELS:
145
  variant = "fp16"
146
  else:
147
  variant = None
148
 
149
  dtype = torch.float16
150
- pipe_kwargs = {
151
  "variant": variant,
152
  "torch_dtype": dtype,
153
  "add_watermarker": False,
@@ -161,16 +161,16 @@ class Loader:
161
  Scheduler = Config.SCHEDULERS[scheduler]
162
 
163
  try:
164
- with timer(f"Loading {model}"):
165
  self.model = model
166
- if model.lower() in Config.SINGLE_FILE_MODELS:
167
  checkpoint = Config.HF_REPOS[model][0]
168
  self.pipeline = Pipeline.from_single_file(
169
  f"https://huggingface.co/{model}/{checkpoint}",
170
- **pipe_kwargs,
171
  ).to("cuda")
172
  else:
173
- self.pipeline = Pipeline.from_pretrained(model, **pipe_kwargs).to("cuda")
174
  except Exception as e:
175
  self.log.error(f"Error loading {model}: {e}")
176
  self.model = None
@@ -190,7 +190,7 @@ class Loader:
190
  or self.pipeline.scheduler.config.use_karras_sigmas == use_karras
191
  )
192
 
193
- if self.model.lower() == model.lower():
194
  if not same_scheduler:
195
  self.log.info(f"Enabling {scheduler}")
196
  if not same_karras:
@@ -201,18 +201,7 @@ class Loader:
201
  self.refiner.scheduler = self.pipeline.scheduler
202
 
203
  if self.should_load_refiner(use_refiner):
204
- refiner_kwargs = {
205
- "variant": "fp16",
206
- "torch_dtype": dtype,
207
- "add_watermarker": False,
208
- "requires_aesthetics_score": True,
209
- "force_zeros_for_empty_prompt": False,
210
- "vae": self.pipeline.vae,
211
- "scheduler": self.pipeline.scheduler,
212
- "tokenizer_2": self.pipeline.tokenizer_2,
213
- "text_encoder_2": self.pipeline.text_encoder_2,
214
- }
215
- self.load_refiner(refiner_kwargs, progress)
216
 
217
  if self.should_load_deepcache(deepcache_interval):
218
  self.load_deepcache(deepcache_interval)
 
 
 
1
  import torch
2
  from DeepCache import DeepCacheSDHelper
3
  from diffusers.models import AutoencoderKL
 
31
  return False
32
 
33
  def should_unload_pipeline(self, model=""):
34
+ return self.pipeline is not None and self.model != model
35
 
36
  def should_load_refiner(self, use_refiner=False):
37
  return self.refiner is None and use_refiner
 
51
  return self.pipeline is None
52
 
53
  def unload(self, model, use_refiner, deepcache_interval, scale):
 
 
54
  if self.should_unload_deepcache(deepcache_interval):
55
  self.log.info("Disabling DeepCache")
56
  self.pipeline.deepcache.disable()
 
60
  delattr(self.refiner, "deepcache")
61
 
62
  if self.should_unload_refiner(use_refiner):
63
+ self.log.info("Unloading refiner")
64
+ self.refiner = None
 
 
65
 
66
  if self.should_unload_upscaler(scale):
67
+ self.log.info("Unloading upscaler")
68
+ self.upscaler = None
 
 
69
 
70
  if self.should_unload_pipeline(model):
71
+ self.log.info(f"Unloading {self.model}")
72
+ if self.refiner:
73
+ self.refiner.vae = None
74
+ self.refiner.scheduler = None
75
+ self.refiner.tokenizer_2 = None
76
+ self.refiner.text_encoder_2 = None
77
+ self.pipeline = None
78
+ self.model = None
79
+
80
+ # Flush cache
81
+ cuda_collect()
82
+
83
+ def load_refiner(self, progress=None):
 
 
 
84
  model = Config.REFINER_MODEL
85
  try:
86
+ with timer(f"Loading {model}", logger=self.log.info):
87
+ refiner_kwargs = {
88
+ "variant": "fp16",
89
+ "torch_dtype": self.pipeline.dtype,
90
+ "add_watermarker": False,
91
+ "requires_aesthetics_score": True,
92
+ "force_zeros_for_empty_prompt": False,
93
+ "vae": self.pipeline.vae,
94
+ "scheduler": self.pipeline.scheduler,
95
+ "tokenizer_2": self.pipeline.tokenizer_2,
96
+ "text_encoder_2": self.pipeline.text_encoder_2,
97
+ }
98
  Pipeline = Config.PIPELINES["img2img"]
99
  self.refiner = Pipeline.from_pretrained(model, **refiner_kwargs).to("cuda")
100
  except Exception as e:
 
107
  def load_upscaler(self, scale=1):
108
  if self.should_load_upscaler(scale):
109
  try:
110
+ with timer(f"Loading {scale}x upscaler", logger=self.log.info):
111
  self.upscaler = RealESRGAN(scale, device=self.pipeline.device)
112
  self.upscaler.load_weights()
113
  except Exception as e:
 
125
  self.refiner.deepcache.set_params(cache_interval=interval)
126
  self.refiner.deepcache.enable()
127
 
128
+ def load(self, kind, model, scheduler, deepcache_interval, scale, use_karras, use_refiner, progress=None):
129
  scheduler_kwargs = {
130
  "beta_start": 0.00085,
131
  "beta_end": 0.012,
 
141
  scheduler_kwargs["clip_sample"] = False
142
  scheduler_kwargs["set_alpha_to_one"] = False
143
 
144
+ if model not in Config.SINGLE_FILE_MODELS:
145
  variant = "fp16"
146
  else:
147
  variant = None
148
 
149
  dtype = torch.float16
150
+ pipeline_kwargs = {
151
  "variant": variant,
152
  "torch_dtype": dtype,
153
  "add_watermarker": False,
 
161
  Scheduler = Config.SCHEDULERS[scheduler]
162
 
163
  try:
164
+ with timer(f"Loading {model}", logger=self.log.info):
165
  self.model = model
166
+ if model in Config.SINGLE_FILE_MODELS:
167
  checkpoint = Config.HF_REPOS[model][0]
168
  self.pipeline = Pipeline.from_single_file(
169
  f"https://huggingface.co/{model}/{checkpoint}",
170
+ **pipeline_kwargs,
171
  ).to("cuda")
172
  else:
173
+ self.pipeline = Pipeline.from_pretrained(model, **pipeline_kwargs).to("cuda")
174
  except Exception as e:
175
  self.log.error(f"Error loading {model}: {e}")
176
  self.model = None
 
190
  or self.pipeline.scheduler.config.use_karras_sigmas == use_karras
191
  )
192
 
193
+ if self.model == model:
194
  if not same_scheduler:
195
  self.log.info(f"Enabling {scheduler}")
196
  if not same_karras:
 
201
  self.refiner.scheduler = self.pipeline.scheduler
202
 
203
  if self.should_load_refiner(use_refiner):
204
+ self.load_refiner(progress)
 
 
 
 
 
 
 
 
 
 
 
205
 
206
  if self.should_load_deepcache(deepcache_interval):
207
  self.load_deepcache(deepcache_interval)
lib/utils.py CHANGED
@@ -5,8 +5,6 @@ from contextlib import contextmanager
5
 
6
  import torch
7
  from diffusers.utils import logging as diffusers_logging
8
- from huggingface_hub._snapshot_download import snapshot_download
9
- from huggingface_hub.utils import are_progress_bars_disabled
10
  from transformers import logging as transformers_logging
11
 
12
 
@@ -45,9 +43,14 @@ def enable_progress_bars():
45
  diffusers_logging.enable_progress_bar()
46
 
47
 
48
- def safe_progress(progress, current=0, total=0, desc=""):
49
- if progress is not None:
50
- progress((current, total), desc=desc)
 
 
 
 
 
51
 
52
 
53
  def cuda_collect():
@@ -56,19 +59,3 @@ def cuda_collect():
56
  torch.cuda.ipc_collect()
57
  torch.cuda.reset_peak_memory_stats()
58
  torch.cuda.synchronize()
59
-
60
-
61
- def download_repo_files(repo_id, allow_patterns, token=None):
62
- was_disabled = are_progress_bars_disabled()
63
- enable_progress_bars()
64
- snapshot_path = snapshot_download(
65
- repo_id=repo_id,
66
- repo_type="model",
67
- revision="main",
68
- token=token,
69
- allow_patterns=allow_patterns,
70
- ignore_patterns=None,
71
- )
72
- if was_disabled:
73
- disable_progress_bars()
74
- return snapshot_path
 
5
 
6
  import torch
7
  from diffusers.utils import logging as diffusers_logging
 
 
8
  from transformers import logging as transformers_logging
9
 
10
 
 
43
  diffusers_logging.enable_progress_bar()
44
 
45
 
46
+ def get_output_types(scale=1, use_refiner=False):
47
+ if use_refiner:
48
+ pipeline_type = "latent"
49
+ refiner_type = "np" if scale > 1 else "pil"
50
+ else:
51
+ refiner_type = "pil"
52
+ pipeline_type = "np" if scale > 1 else "pil"
53
+ return (pipeline_type, refiner_type)
54
 
55
 
56
  def cuda_collect():
 
59
  torch.cuda.ipc_collect()
60
  torch.cuda.reset_peak_memory_stats()
61
  torch.cuda.synchronize()