Spaces:
Sleeping
Sleeping
Update core/image_generator.py
Browse files- core/image_generator.py +32 -33
core/image_generator.py
CHANGED
|
@@ -321,52 +321,36 @@ def download_model() -> Path:
|
|
| 321 |
# --------------------------------------------------------------
|
| 322 |
# MEMORY-SAFE PIPELINE MANAGER
|
| 323 |
# --------------------------------------------------------------
|
| 324 |
-
def unload_pipelines():
|
| 325 |
-
"""Unload
|
| 326 |
global pipe, img2img_pipe
|
| 327 |
-
print("[ImageGen] 🧹 Clearing
|
| 328 |
-
|
| 329 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
pipe = None
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
|
|
|
|
|
|
| 335 |
img2img_pipe = None
|
| 336 |
-
|
| 337 |
-
pass
|
| 338 |
gc.collect()
|
| 339 |
if torch.cuda.is_available():
|
| 340 |
torch.cuda.empty_cache()
|
| 341 |
print("[ImageGen] ✅ Memory cleared.")
|
| 342 |
|
| 343 |
|
| 344 |
-
def safe_load_pipeline(pipeline_cls, model_path: Path):
|
| 345 |
-
try:
|
| 346 |
-
# Try local-only load first
|
| 347 |
-
return pipeline_cls.from_single_file(
|
| 348 |
-
str(model_path),
|
| 349 |
-
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
| 350 |
-
local_files_only=True,
|
| 351 |
-
cache_dir=str(HF_CACHE_DIR) # <-- ensure all caches go to /tmp/hf_cache
|
| 352 |
-
)
|
| 353 |
-
except Exception as e:
|
| 354 |
-
print(f"[WARN] Local-only load failed ({e}). Retrying with network access...")
|
| 355 |
-
return pipeline_cls.from_single_file(
|
| 356 |
-
str(model_path),
|
| 357 |
-
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
| 358 |
-
local_files_only=False, # allow network download
|
| 359 |
-
cache_dir=str(HF_CACHE_DIR) # <-- must specify this
|
| 360 |
-
)
|
| 361 |
-
|
| 362 |
-
|
| 363 |
def load_pipeline():
|
| 364 |
-
"""Load text-to-image pipeline into RAM."""
|
| 365 |
global pipe
|
| 366 |
-
unload_pipelines() #
|
| 367 |
model_path = download_model()
|
| 368 |
print("[ImageGen] Loading main (txt2img) pipeline...")
|
| 369 |
-
|
| 370 |
pipe = safe_load_pipeline(StableDiffusionXLPipeline, model_path)
|
| 371 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 372 |
pipe.to(device)
|
|
@@ -376,6 +360,21 @@ def load_pipeline():
|
|
| 376 |
return pipe
|
| 377 |
|
| 378 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
def load_img2img_pipeline():
|
| 380 |
"""Load img2img pipeline into RAM."""
|
| 381 |
global img2img_pipe
|
|
|
|
| 321 |
# --------------------------------------------------------------
|
| 322 |
# MEMORY-SAFE PIPELINE MANAGER
|
| 323 |
# --------------------------------------------------------------
|
| 324 |
+
def unload_pipelines(target="all"):
|
| 325 |
+
"""Unload specific or all pipelines."""
|
| 326 |
global pipe, img2img_pipe
|
| 327 |
+
print("[ImageGen] 🧹 Clearing pipelines from memory...")
|
| 328 |
+
|
| 329 |
+
if target in ("pipe", "all"):
|
| 330 |
+
try:
|
| 331 |
+
del pipe
|
| 332 |
+
except:
|
| 333 |
+
pass
|
| 334 |
pipe = None
|
| 335 |
+
|
| 336 |
+
if target in ("img2img_pipe", "all"):
|
| 337 |
+
try:
|
| 338 |
+
del img2img_pipe
|
| 339 |
+
except:
|
| 340 |
+
pass
|
| 341 |
img2img_pipe = None
|
| 342 |
+
|
|
|
|
| 343 |
gc.collect()
|
| 344 |
if torch.cuda.is_available():
|
| 345 |
torch.cuda.empty_cache()
|
| 346 |
print("[ImageGen] ✅ Memory cleared.")
|
| 347 |
|
| 348 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
def load_pipeline():
|
|
|
|
| 350 |
global pipe
|
| 351 |
+
unload_pipelines(target="pipe") # only clear old txt2img
|
| 352 |
model_path = download_model()
|
| 353 |
print("[ImageGen] Loading main (txt2img) pipeline...")
|
|
|
|
| 354 |
pipe = safe_load_pipeline(StableDiffusionXLPipeline, model_path)
|
| 355 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 356 |
pipe.to(device)
|
|
|
|
| 360 |
return pipe
|
| 361 |
|
| 362 |
|
| 363 |
+
def load_img2img_pipeline():
|
| 364 |
+
global img2img_pipe
|
| 365 |
+
unload_pipelines(target="img2img_pipe") # only clear old img2img
|
| 366 |
+
model_path = download_model()
|
| 367 |
+
print("[ImageGen] Loading img2img pipeline...")
|
| 368 |
+
img2img_pipe = safe_load_pipeline(StableDiffusionXLImg2ImgPipeline, model_path)
|
| 369 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 370 |
+
img2img_pipe.to(device)
|
| 371 |
+
img2img_pipe.safety_checker = None
|
| 372 |
+
img2img_pipe.enable_attention_slicing()
|
| 373 |
+
print("[ImageGen] ✅ Img2Img pipeline ready.")
|
| 374 |
+
return img2img_pipe
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
|
| 378 |
def load_img2img_pipeline():
|
| 379 |
"""Load img2img pipeline into RAM."""
|
| 380 |
global img2img_pipe
|