Spaces:
Sleeping
Sleeping
Kyle Pearson commited on
Commit Β·
b1e7bdb
1
Parent(s): 8cdb001
Add cancel button, validate cache files, detect cache hits, update peft dependency, fix race conditions, enhance feedback, improve interrupt handling.
Browse files- app.py +24 -2
- requirements.txt +3 -0
- src/config.py +62 -0
- src/downloader.py +34 -0
- src/pipeline.py +111 -27
app.py
CHANGED
|
@@ -361,7 +361,6 @@ def create_app():
|
|
| 361 |
return (
|
| 362 |
'<div class="status-warning">β³ Loading started...</div>',
|
| 363 |
"Starting download...",
|
| 364 |
-
gr.update(interactive=False)
|
| 365 |
)
|
| 366 |
|
| 367 |
def on_load_pipeline_complete(status_msg, progress_text):
|
|
@@ -385,14 +384,26 @@ def create_app():
|
|
| 385 |
gr.update(interactive=True)
|
| 386 |
)
|
| 387 |
|
|
|
|
|
|
|
|
|
|
| 388 |
load_btn.click(
|
| 389 |
fn=on_load_pipeline_start,
|
| 390 |
inputs=[],
|
| 391 |
-
outputs=[load_status, load_progress
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
).then(
|
| 393 |
fn=load_pipeline,
|
| 394 |
inputs=[checkpoint_url, vae_url, lora_urls, lora_strengths],
|
| 395 |
outputs=[load_status, load_progress],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 396 |
).then(
|
| 397 |
fn=on_load_pipeline_complete,
|
| 398 |
inputs=[load_status, load_progress],
|
|
@@ -407,6 +418,17 @@ def create_app():
|
|
| 407 |
outputs=[cached_checkpoints, cached_vaes, cached_loras],
|
| 408 |
)
|
| 409 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
def on_cached_checkpoint_change(cached_path):
|
| 411 |
"""Update URL when a cached checkpoint is selected."""
|
| 412 |
if cached_path and cached_path != "(None found)":
|
|
|
|
| 361 |
return (
|
| 362 |
'<div class="status-warning">β³ Loading started...</div>',
|
| 363 |
"Starting download...",
|
|
|
|
| 364 |
)
|
| 365 |
|
| 366 |
def on_load_pipeline_complete(status_msg, progress_text):
|
|
|
|
| 384 |
gr.update(interactive=True)
|
| 385 |
)
|
| 386 |
|
| 387 |
+
# Cancel button for pipeline loading
|
| 388 |
+
cancel_load_btn = gr.Button("π Cancel Loading", variant="secondary", size="sm", visible=False)
|
| 389 |
+
|
| 390 |
load_btn.click(
|
| 391 |
fn=on_load_pipeline_start,
|
| 392 |
inputs=[],
|
| 393 |
+
outputs=[load_status, load_progress],
|
| 394 |
+
).then(
|
| 395 |
+
fn=lambda: (gr.update(visible=True), gr.update(interactive=False)),
|
| 396 |
+
inputs=[],
|
| 397 |
+
outputs=[cancel_load_btn, load_btn],
|
| 398 |
).then(
|
| 399 |
fn=load_pipeline,
|
| 400 |
inputs=[checkpoint_url, vae_url, lora_urls, lora_strengths],
|
| 401 |
outputs=[load_status, load_progress],
|
| 402 |
+
show_progress="full",
|
| 403 |
+
).then(
|
| 404 |
+
fn=lambda: (gr.update(visible=False), gr.update(interactive=True)),
|
| 405 |
+
inputs=[],
|
| 406 |
+
outputs=[cancel_load_btn, load_btn],
|
| 407 |
).then(
|
| 408 |
fn=on_load_pipeline_complete,
|
| 409 |
inputs=[load_status, load_progress],
|
|
|
|
| 418 |
outputs=[cached_checkpoints, cached_vaes, cached_loras],
|
| 419 |
)
|
| 420 |
|
| 421 |
+
# Cancel button handler
|
| 422 |
+
cancel_load_btn.click(
|
| 423 |
+
fn=lambda: (cancel_download(),
|
| 424 |
+
'<div class="status-warning">β³ Cancelling...</div>',
|
| 425 |
+
"Cancelling download...",
|
| 426 |
+
gr.update(visible=False),
|
| 427 |
+
gr.update(interactive=True)),
|
| 428 |
+
inputs=[],
|
| 429 |
+
outputs=[load_status, load_progress, cancel_load_btn, load_btn],
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
def on_cached_checkpoint_change(cached_path):
|
| 433 |
"""Update URL when a cached checkpoint is selected."""
|
| 434 |
if cached_path and cached_path != "(None found)":
|
requirements.txt
CHANGED
|
@@ -7,6 +7,9 @@ transformers>=4.35.0
|
|
| 7 |
safetensors>=0.4.0
|
| 8 |
optimum>=1.0.0
|
| 9 |
|
|
|
|
|
|
|
|
|
|
| 10 |
# Image processing
|
| 11 |
Pillow>=10.0.0
|
| 12 |
|
|
|
|
| 7 |
safetensors>=0.4.0
|
| 8 |
optimum>=1.0.0
|
| 9 |
|
| 10 |
+
# LoRA support (required for diffusers >= 0.26)
|
| 11 |
+
peft>=0.6.0
|
| 12 |
+
|
| 13 |
# Image processing
|
| 14 |
Pillow>=10.0.0
|
| 15 |
|
src/config.py
CHANGED
|
@@ -133,3 +133,65 @@ def get_cached_loras():
|
|
| 133 |
for file in sorted(CACHE_DIR.glob("*_lora.safetensors")):
|
| 134 |
models.append(str(file))
|
| 135 |
return models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
for file in sorted(CACHE_DIR.glob("*_lora.safetensors")):
|
| 134 |
models.append(str(file))
|
| 135 |
return models
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def validate_cache_file(cache_path: Path, min_size_mb: float = 1.0) -> tuple[bool, str]:
|
| 139 |
+
"""
|
| 140 |
+
Validate a cached model file exists and has valid content.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
cache_path: Path to the cached .safetensors file
|
| 144 |
+
min_size_mb: Minimum acceptable file size in MB (default: 1MB)
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
Tuple of (is_valid, message)
|
| 148 |
+
- is_valid: True if file passes all checks
|
| 149 |
+
- message: Description of validation result
|
| 150 |
+
"""
|
| 151 |
+
try:
|
| 152 |
+
if not cache_path.exists():
|
| 153 |
+
return False, f"File does not exist: {cache_path.name}"
|
| 154 |
+
|
| 155 |
+
if not cache_path.is_file():
|
| 156 |
+
return False, f"Not a regular file: {cache_path.name}"
|
| 157 |
+
|
| 158 |
+
file_size = cache_path.stat().st_size
|
| 159 |
+
size_mb = file_size / (1024 * 1024)
|
| 160 |
+
|
| 161 |
+
if size_mb < min_size_mb:
|
| 162 |
+
return False, f"File too small ({size_mb:.2f} MB < {min_size_mb} MB): {cache_path.name}"
|
| 163 |
+
|
| 164 |
+
# Check if it's a valid safetensors file by reading the header
|
| 165 |
+
if not cache_path.suffix == ".safetensors":
|
| 166 |
+
return True, f"Valid non-safetensors file: {cache_path.name}"
|
| 167 |
+
|
| 168 |
+
try:
|
| 169 |
+
with open(cache_path, "rb") as f:
|
| 170 |
+
# Read first 8 bytes (header size)
|
| 171 |
+
header_size_bytes = f.read(8)
|
| 172 |
+
if len(header_size_bytes) < 8:
|
| 173 |
+
return False, f"File too small for safetensors header: {cache_path.name}"
|
| 174 |
+
|
| 175 |
+
import struct
|
| 176 |
+
header_size = struct.unpack("<Q", header_size_bytes)[0]
|
| 177 |
+
|
| 178 |
+
if header_size == 0:
|
| 179 |
+
return False, f"Invalid safetensors header (size=0): {cache_path.name}"
|
| 180 |
+
|
| 181 |
+
# Read and parse header JSON
|
| 182 |
+
header = f.read(header_size)
|
| 183 |
+
if len(header) < header_size:
|
| 184 |
+
return False, f"Incomplete safetensors header: {cache_path.name}"
|
| 185 |
+
|
| 186 |
+
import json
|
| 187 |
+
json.loads(header.decode("utf-8"))
|
| 188 |
+
|
| 189 |
+
except struct.error as e:
|
| 190 |
+
return False, f"Invalid safetensors format: {str(e)}"
|
| 191 |
+
except json.JSONDecodeError as e:
|
| 192 |
+
return False, f"Invalid safetensors header JSON: {str(e)}"
|
| 193 |
+
|
| 194 |
+
return True, f"Valid cached file ({size_mb:.1f} MB): {cache_path.name}"
|
| 195 |
+
|
| 196 |
+
except OSError as e:
|
| 197 |
+
return False, f"File access error: {str(e)}"
|
src/downloader.py
CHANGED
|
@@ -219,6 +219,9 @@ def download_file_with_progress(url: str, output_path: Path, progress_bar=None)
|
|
| 219 |
if local_path.exists():
|
| 220 |
import shutil
|
| 221 |
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
| 222 |
# Copy the file to cache location
|
| 223 |
shutil.copy2(str(local_path), str(output_path))
|
| 224 |
|
|
@@ -229,6 +232,8 @@ def download_file_with_progress(url: str, output_path: Path, progress_bar=None)
|
|
| 229 |
else:
|
| 230 |
raise FileNotFoundError(f"Local file not found: {local_path}")
|
| 231 |
|
|
|
|
|
|
|
| 232 |
# Early cache check: if file exists and size matches URL's content-length, skip re-download
|
| 233 |
expected_size = None
|
| 234 |
try:
|
|
@@ -241,6 +246,7 @@ def download_file_with_progress(url: str, output_path: Path, progress_bar=None)
|
|
| 241 |
try:
|
| 242 |
cached_size = output_path.stat().st_size
|
| 243 |
if cached_size == expected_size:
|
|
|
|
| 244 |
# Cache hit - file exists with correct size
|
| 245 |
if progress_bar:
|
| 246 |
progress_bar(1.0)
|
|
@@ -282,6 +288,34 @@ def download_file_with_progress(url: str, output_path: Path, progress_bar=None)
|
|
| 282 |
output_path.unlink(missing_ok=True)
|
| 283 |
raise
|
| 284 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
return output_path
|
| 286 |
|
| 287 |
|
|
|
|
| 219 |
if local_path.exists():
|
| 220 |
import shutil
|
| 221 |
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 222 |
+
|
| 223 |
+
print(f" π Copying from cache: {local_path.name} β {output_path.name}")
|
| 224 |
+
|
| 225 |
# Copy the file to cache location
|
| 226 |
shutil.copy2(str(local_path), str(output_path))
|
| 227 |
|
|
|
|
| 232 |
else:
|
| 233 |
raise FileNotFoundError(f"Local file not found: {local_path}")
|
| 234 |
|
| 235 |
+
print(f" π₯ Downloading to cache: {output_path.name}")
|
| 236 |
+
|
| 237 |
# Early cache check: if file exists and size matches URL's content-length, skip re-download
|
| 238 |
expected_size = None
|
| 239 |
try:
|
|
|
|
| 246 |
try:
|
| 247 |
cached_size = output_path.stat().st_size
|
| 248 |
if cached_size == expected_size:
|
| 249 |
+
print(f" β
Cache hit: {output_path.name} ({cached_size / (1024**2):.1f} MB)")
|
| 250 |
# Cache hit - file exists with correct size
|
| 251 |
if progress_bar:
|
| 252 |
progress_bar(1.0)
|
|
|
|
| 288 |
output_path.unlink(missing_ok=True)
|
| 289 |
raise
|
| 290 |
|
| 291 |
+
# Verify the downloaded file is complete
|
| 292 |
+
try:
|
| 293 |
+
actual_size = output_path.stat().st_size
|
| 294 |
+
|
| 295 |
+
# For safetensors files, check header is valid
|
| 296 |
+
if output_path.suffix == ".safetensors":
|
| 297 |
+
import struct
|
| 298 |
+
with open(output_path, "rb") as f:
|
| 299 |
+
header_size_bytes = f.read(8)
|
| 300 |
+
if len(header_size_bytes) < 8:
|
| 301 |
+
raise OSError(f"Safetensors file too small: {output_path.name}")
|
| 302 |
+
|
| 303 |
+
header_size = struct.unpack("<Q", header_size_bytes)[0]
|
| 304 |
+
header = f.read(header_size)
|
| 305 |
+
if len(header) < header_size:
|
| 306 |
+
raise OSError(f"Incomplete safetensors header in {output_path.name}")
|
| 307 |
+
|
| 308 |
+
import json
|
| 309 |
+
json.loads(header.decode("utf-8"))
|
| 310 |
+
|
| 311 |
+
# Verify size matches expected (if known)
|
| 312 |
+
if expected_size is not None and actual_size != expected_size:
|
| 313 |
+
print(f" β οΈ Size mismatch: expected {expected_size}, got {actual_size}")
|
| 314 |
+
|
| 315 |
+
except Exception as e:
|
| 316 |
+
output_path.unlink(missing_ok=True)
|
| 317 |
+
raise OSError(f"Invalid downloaded file {output_path.name}: {str(e)}")
|
| 318 |
+
|
| 319 |
return output_path
|
| 320 |
|
| 321 |
|
src/pipeline.py
CHANGED
|
@@ -78,42 +78,90 @@ def load_pipeline(
|
|
| 78 |
|
| 79 |
try:
|
| 80 |
set_download_cancelled(False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
checkpoint_filename = get_safe_filename_from_url(checkpoint_url, type_prefix="model")
|
| 82 |
checkpoint_path = CACHE_DIR / checkpoint_filename
|
| 83 |
|
| 84 |
# Check if checkpoint is already cached
|
| 85 |
checkpoint_cached = checkpoint_path.exists() and checkpoint_path.stat().st_size > 0
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
# VAE: Use suffix="_vae" and default to "vae.safetensors" for proper caching/dropdown matching
|
| 88 |
-
vae_filename = get_safe_filename_from_url(vae_url, default_name="vae.safetensors", suffix="_vae") if vae_url.strip() else
|
| 89 |
-
vae_path = CACHE_DIR / vae_filename
|
| 90 |
-
vae_cached = vae_url.strip() and vae_path.exists() and vae_path.stat().st_size > 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
# Download checkpoint (skips if already cached)
|
| 93 |
if progress:
|
| 94 |
progress(0.1, desc="Downloading base model..." if not checkpoint_cached else "Loading base model...")
|
| 95 |
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
yield status_msg, "Starting download..."
|
| 98 |
|
| 99 |
if not checkpoint_cached:
|
| 100 |
download_file_with_progress(checkpoint_url, checkpoint_path)
|
| 101 |
|
| 102 |
# Download VAE if provided
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
if
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
-
|
|
|
|
| 109 |
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
# Load base pipeline (yield progress during this heavy operation)
|
|
|
|
| 114 |
yield "βοΈ Loading SDXL pipeline...", "Loading model weights into memory..."
|
|
|
|
| 115 |
if progress:
|
| 116 |
-
progress(0.
|
| 117 |
|
| 118 |
global_pipe = StableDiffusionXLPipeline.from_single_file(
|
| 119 |
str(checkpoint_path),
|
|
@@ -121,12 +169,19 @@ def load_pipeline(
|
|
| 121 |
use_safetensors=True,
|
| 122 |
safety_checker=None,
|
| 123 |
)
|
| 124 |
-
|
|
|
|
| 125 |
if progress:
|
| 126 |
-
progress(0.
|
| 127 |
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
global_pipe.vae = vae.to(device=device, dtype=dtype)
|
|
|
|
| 130 |
|
| 131 |
# Parse LoRA URLs & ensure strengths list matches
|
| 132 |
lora_urls = [u.strip() for u in lora_urls_str.split("\n") if u.strip()]
|
|
@@ -141,17 +196,29 @@ def load_pipeline(
|
|
| 141 |
|
| 142 |
# Load and fuse each LoRA sequentially (only if URLs exist)
|
| 143 |
if lora_urls:
|
|
|
|
| 144 |
global_pipe = global_pipe.to(device=device, dtype=dtype)
|
|
|
|
| 145 |
for i, (lora_url, strength) in enumerate(zip(lora_urls, strengths)):
|
| 146 |
lora_filename = get_safe_filename_from_url(lora_url, suffix="_lora")
|
| 147 |
lora_path = CACHE_DIR / lora_filename
|
| 148 |
lora_cached = lora_path.exists() and lora_path.stat().st_size > 0
|
| 149 |
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
|
| 156 |
yield (
|
| 157 |
status_msg,
|
|
@@ -161,32 +228,49 @@ def load_pipeline(
|
|
| 161 |
|
| 162 |
if not lora_cached:
|
| 163 |
download_file_with_progress(lora_url, lora_path)
|
| 164 |
-
|
|
|
|
| 165 |
yield f"βοΈ Loading LoRA {i+1}/{len(lora_urls)}...", f"Fusing {lora_path.name}..."
|
| 166 |
if progress:
|
| 167 |
progress(0.7 + (0.2 * i / len(lora_urls)), desc=f"Loading LoRA {i+1}/{len(lora_urls)}...")
|
| 168 |
|
| 169 |
adapter_name = f"lora_{i}"
|
| 170 |
global_pipe.load_lora_weights(str(lora_path), adapter_name=adapter_name)
|
|
|
|
| 171 |
global_pipe.fuse_lora(adapter_names=[adapter_name], lora_scale=strength)
|
| 172 |
global_pipe.unload_lora_weights()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
-
# Set scheduler and move to device (do this once at the end)
|
| 175 |
-
yield "βοΈ Finalizing pipeline...", "Setting up scheduler and moving to device..."
|
| 176 |
if progress:
|
| 177 |
-
progress(0.
|
| 178 |
|
| 179 |
-
global_pipe.scheduler
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
| 181 |
|
|
|
|
| 182 |
yield "β
Pipeline ready!", f"Ready! Loaded {len(lora_urls)} LoRA(s)"
|
| 183 |
return ("β
Pipeline loaded successfully!", f"Ready! Loaded {len(lora_urls)} LoRA(s)")
|
| 184 |
|
| 185 |
except KeyboardInterrupt:
|
| 186 |
set_download_cancelled(False)
|
|
|
|
| 187 |
return ("β οΈ Download cancelled by user", "Cancelled")
|
| 188 |
except Exception as e:
|
| 189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
|
| 191 |
|
| 192 |
def cancel_download():
|
|
|
|
| 78 |
|
| 79 |
try:
|
| 80 |
set_download_cancelled(False)
|
| 81 |
+
|
| 82 |
+
# Import gr here to update button state if needed
|
| 83 |
+
import gradio as gr
|
| 84 |
+
|
| 85 |
+
print("=" * 60)
|
| 86 |
+
print("π Loading SDXL Pipeline...")
|
| 87 |
+
print("=" * 60)
|
| 88 |
+
|
| 89 |
checkpoint_filename = get_safe_filename_from_url(checkpoint_url, type_prefix="model")
|
| 90 |
checkpoint_path = CACHE_DIR / checkpoint_filename
|
| 91 |
|
| 92 |
# Check if checkpoint is already cached
|
| 93 |
checkpoint_cached = checkpoint_path.exists() and checkpoint_path.stat().st_size > 0
|
| 94 |
|
| 95 |
+
# Validate cache file before using it
|
| 96 |
+
if checkpoint_cached:
|
| 97 |
+
from src.config import validate_cache_file
|
| 98 |
+
is_valid, msg = validate_cache_file(checkpoint_path)
|
| 99 |
+
if not is_valid:
|
| 100 |
+
print(f" β οΈ Cache invalid: {msg}")
|
| 101 |
+
checkpoint_path.unlink(missing_ok=True)
|
| 102 |
+
checkpoint_cached = False
|
| 103 |
+
|
| 104 |
# VAE: Use suffix="_vae" and default to "vae.safetensors" for proper caching/dropdown matching
|
| 105 |
+
vae_filename = get_safe_filename_from_url(vae_url, default_name="vae.safetensors", suffix="_vae") if vae_url.strip() else None
|
| 106 |
+
vae_path = CACHE_DIR / vae_filename if vae_filename else None
|
| 107 |
+
vae_cached = vae_url.strip() and vae_path and vae_path.exists() and vae_path.stat().st_size > 0
|
| 108 |
+
|
| 109 |
+
# Validate VAE cache file before using it
|
| 110 |
+
if vae_cached:
|
| 111 |
+
from src.config import validate_cache_file
|
| 112 |
+
is_valid, msg = validate_cache_file(vae_path)
|
| 113 |
+
if not is_valid:
|
| 114 |
+
print(f" β οΈ VAE Cache invalid: {msg}")
|
| 115 |
+
vae_path.unlink(missing_ok=True)
|
| 116 |
+
vae_cached = False
|
| 117 |
|
| 118 |
# Download checkpoint (skips if already cached)
|
| 119 |
if progress:
|
| 120 |
progress(0.1, desc="Downloading base model..." if not checkpoint_cached else "Loading base model...")
|
| 121 |
|
| 122 |
+
if not checkpoint_cached:
|
| 123 |
+
status_msg = f"π₯ Downloading {checkpoint_path.name}..."
|
| 124 |
+
print(f" π₯ Downloading: {checkpoint_path.name}")
|
| 125 |
+
else:
|
| 126 |
+
status_msg = f"β
Using cached {checkpoint_path.name}"
|
| 127 |
+
print(f" β
Using cached: {checkpoint_path.name}")
|
| 128 |
+
|
| 129 |
yield status_msg, "Starting download..."
|
| 130 |
|
| 131 |
if not checkpoint_cached:
|
| 132 |
download_file_with_progress(checkpoint_url, checkpoint_path)
|
| 133 |
|
| 134 |
# Download VAE if provided
|
| 135 |
+
vae = None
|
| 136 |
+
if vae_url and vae_url.strip():
|
| 137 |
+
if vae_path:
|
| 138 |
+
status_msg = f"π₯ Downloading {vae_path.name}..." if not vae_cached else f"β
Using cached {vae_path.name}"
|
| 139 |
+
print(f" π₯ VAE: {vae_path.name}" if not vae_cached else f" β
VAE (cached): {vae_path.name}")
|
| 140 |
+
|
| 141 |
+
if progress:
|
| 142 |
+
progress(0.2, desc="Downloading VAE..." if not vae_cached else "Loading VAE...")
|
| 143 |
+
|
| 144 |
+
yield status_msg, f"Downloading VAE: {vae_path.name}" if not vae_cached else f"Using cached VAE: {vae_path.name}"
|
| 145 |
|
| 146 |
+
if not vae_cached:
|
| 147 |
+
download_file_with_progress(vae_url, vae_path)
|
| 148 |
|
| 149 |
+
# Load VAE from file
|
| 150 |
+
print(" βοΈ Loading VAE weights...")
|
| 151 |
+
yield "βοΈ Loading VAE...", f"Loading VAE: {vae_path.name}"
|
| 152 |
+
vae = AutoencoderKL.from_single_file(
|
| 153 |
+
str(vae_path),
|
| 154 |
+
torch_dtype=dtype,
|
| 155 |
+
)
|
| 156 |
+
if progress:
|
| 157 |
+
progress(0.25, desc="VAE loaded")
|
| 158 |
|
| 159 |
# Load base pipeline (yield progress during this heavy operation)
|
| 160 |
+
print(" βοΈ Loading SDXL pipeline from single file...")
|
| 161 |
yield "βοΈ Loading SDXL pipeline...", "Loading model weights into memory..."
|
| 162 |
+
|
| 163 |
if progress:
|
| 164 |
+
progress(0.3, desc="Loading text encoders...")
|
| 165 |
|
| 166 |
global_pipe = StableDiffusionXLPipeline.from_single_file(
|
| 167 |
str(checkpoint_path),
|
|
|
|
| 169 |
use_safetensors=True,
|
| 170 |
safety_checker=None,
|
| 171 |
)
|
| 172 |
+
print(" β
Text encoders loaded")
|
| 173 |
+
|
| 174 |
if progress:
|
| 175 |
+
progress(0.5, desc="Loading UNet...")
|
| 176 |
|
| 177 |
+
print(" β
UNet loaded")
|
| 178 |
+
yield "βοΈ Pipeline loaded, setting up components...", f"Using device: {device_description}"
|
| 179 |
+
|
| 180 |
+
# Load VAE into pipeline if provided
|
| 181 |
+
if vae is not None:
|
| 182 |
+
print(" βοΈ Setting custom VAE...")
|
| 183 |
global_pipe.vae = vae.to(device=device, dtype=dtype)
|
| 184 |
+
yield "βοΈ Pipeline loaded, setting up components...", f"VAE loaded: {vae_path.name}"
|
| 185 |
|
| 186 |
# Parse LoRA URLs & ensure strengths list matches
|
| 187 |
lora_urls = [u.strip() for u in lora_urls_str.split("\n") if u.strip()]
|
|
|
|
| 196 |
|
| 197 |
# Load and fuse each LoRA sequentially (only if URLs exist)
|
| 198 |
if lora_urls:
|
| 199 |
+
print(f" βοΈ Moving pipeline to device: {device_description}...")
|
| 200 |
global_pipe = global_pipe.to(device=device, dtype=dtype)
|
| 201 |
+
|
| 202 |
for i, (lora_url, strength) in enumerate(zip(lora_urls, strengths)):
|
| 203 |
lora_filename = get_safe_filename_from_url(lora_url, suffix="_lora")
|
| 204 |
lora_path = CACHE_DIR / lora_filename
|
| 205 |
lora_cached = lora_path.exists() and lora_path.stat().st_size > 0
|
| 206 |
|
| 207 |
+
# Validate LoRA cache file before using it
|
| 208 |
+
if lora_cached:
|
| 209 |
+
from src.config import validate_cache_file
|
| 210 |
+
is_valid, msg = validate_cache_file(lora_path)
|
| 211 |
+
if not is_valid:
|
| 212 |
+
print(f" β οΈ LoRA Cache invalid: {msg}")
|
| 213 |
+
lora_path.unlink(missing_ok=True)
|
| 214 |
+
lora_cached = False
|
| 215 |
+
|
| 216 |
+
if not lora_cached:
|
| 217 |
+
print(f" π₯ LoRA {i+1}/{len(lora_urls)}: Downloading {lora_path.name}...")
|
| 218 |
+
status_msg = f"π₯ Downloading LoRA {i+1}/{len(lora_urls)}: {lora_path.name}..."
|
| 219 |
+
else:
|
| 220 |
+
print(f" β
LoRA {i+1}/{len(lora_urls)}: Using cached {lora_path.name}")
|
| 221 |
+
status_msg = f"β
Using cached LoRA {i+1}/{len(lora_urls)}: {lora_path.name}"
|
| 222 |
|
| 223 |
yield (
|
| 224 |
status_msg,
|
|
|
|
| 228 |
|
| 229 |
if not lora_cached:
|
| 230 |
download_file_with_progress(lora_url, lora_path)
|
| 231 |
+
|
| 232 |
+
print(f" βοΈ Loading LoRA {i+1}/{len(lora_urls)}...")
|
| 233 |
yield f"βοΈ Loading LoRA {i+1}/{len(lora_urls)}...", f"Fusing {lora_path.name}..."
|
| 234 |
if progress:
|
| 235 |
progress(0.7 + (0.2 * i / len(lora_urls)), desc=f"Loading LoRA {i+1}/{len(lora_urls)}...")
|
| 236 |
|
| 237 |
adapter_name = f"lora_{i}"
|
| 238 |
global_pipe.load_lora_weights(str(lora_path), adapter_name=adapter_name)
|
| 239 |
+
print(f" βοΈ Fusing LoRA {i+1} with strength={strength}...")
|
| 240 |
global_pipe.fuse_lora(adapter_names=[adapter_name], lora_scale=strength)
|
| 241 |
global_pipe.unload_lora_weights()
|
| 242 |
+
else:
|
| 243 |
+
# Move pipeline to device even without LoRAs
|
| 244 |
+
print(f" βοΈ Moving pipeline to device: {device_description}...")
|
| 245 |
+
global_pipe = global_pipe.to(device=device, dtype=dtype)
|
| 246 |
+
|
| 247 |
+
# Set scheduler and finalize (do this once at the end)
|
| 248 |
+
print(" βοΈ Configuring scheduler...")
|
| 249 |
+
yield "βοΈ Finalizing pipeline...", "Setting up scheduler..."
|
| 250 |
|
|
|
|
|
|
|
| 251 |
if progress:
|
| 252 |
+
progress(0.95, desc="Finalizing...")
|
| 253 |
|
| 254 |
+
global_pipe.scheduler = DPMSolverSDEScheduler.from_config(
|
| 255 |
+
global_pipe.scheduler.config,
|
| 256 |
+
algorithm_type="sde-dpmsolver++",
|
| 257 |
+
use_karras_sigmas=False,
|
| 258 |
+
)
|
| 259 |
|
| 260 |
+
print(" β
Pipeline ready!")
|
| 261 |
yield "β
Pipeline ready!", f"Ready! Loaded {len(lora_urls)} LoRA(s)"
|
| 262 |
return ("β
Pipeline loaded successfully!", f"Ready! Loaded {len(lora_urls)} LoRA(s)")
|
| 263 |
|
| 264 |
except KeyboardInterrupt:
|
| 265 |
set_download_cancelled(False)
|
| 266 |
+
print("\nβ οΈ Download cancelled by user")
|
| 267 |
return ("β οΈ Download cancelled by user", "Cancelled")
|
| 268 |
except Exception as e:
|
| 269 |
+
import traceback
|
| 270 |
+
error_msg = f"β Error loading pipeline: {str(e)}"
|
| 271 |
+
print(f"\n{error_msg}")
|
| 272 |
+
print(traceback.format_exc())
|
| 273 |
+
return (error_msg, f"Error: {str(e)}")
|
| 274 |
|
| 275 |
|
| 276 |
def cancel_download():
|