karthikeya1212 commited on
Commit
079d316
·
verified ·
1 Parent(s): 845979b

Update core/image_generator.py

Browse files
Files changed (1) hide show
  1. core/image_generator.py +32 -33
core/image_generator.py CHANGED
@@ -321,52 +321,36 @@ def download_model() -> Path:
321
  # --------------------------------------------------------------
322
  # MEMORY-SAFE PIPELINE MANAGER
323
  # --------------------------------------------------------------
324
- def unload_pipelines():
325
- """Unload any existing pipelines from RAM and GPU."""
326
  global pipe, img2img_pipe
327
- print("[ImageGen] 🧹 Clearing old pipelines from memory...")
328
- try:
329
- del pipe
 
 
 
 
330
  pipe = None
331
- except:
332
- pass
333
- try:
334
- del img2img_pipe
 
 
335
  img2img_pipe = None
336
- except:
337
- pass
338
  gc.collect()
339
  if torch.cuda.is_available():
340
  torch.cuda.empty_cache()
341
  print("[ImageGen] ✅ Memory cleared.")
342
 
343
 
344
- def safe_load_pipeline(pipeline_cls, model_path: Path):
345
- try:
346
- # Try local-only load first
347
- return pipeline_cls.from_single_file(
348
- str(model_path),
349
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
350
- local_files_only=True,
351
- cache_dir=str(HF_CACHE_DIR) # <-- ensure all caches go to /tmp/hf_cache
352
- )
353
- except Exception as e:
354
- print(f"[WARN] Local-only load failed ({e}). Retrying with network access...")
355
- return pipeline_cls.from_single_file(
356
- str(model_path),
357
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
358
- local_files_only=False, # allow network download
359
- cache_dir=str(HF_CACHE_DIR) # <-- must specify this
360
- )
361
-
362
-
363
  def load_pipeline():
364
- """Load text-to-image pipeline into RAM."""
365
  global pipe
366
- unload_pipelines() # Ensure no duplicate pipelines in memory
367
  model_path = download_model()
368
  print("[ImageGen] Loading main (txt2img) pipeline...")
369
-
370
  pipe = safe_load_pipeline(StableDiffusionXLPipeline, model_path)
371
  device = "cuda" if torch.cuda.is_available() else "cpu"
372
  pipe.to(device)
@@ -376,6 +360,21 @@ def load_pipeline():
376
  return pipe
377
 
378
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
  def load_img2img_pipeline():
380
  """Load img2img pipeline into RAM."""
381
  global img2img_pipe
 
321
  # --------------------------------------------------------------
322
  # MEMORY-SAFE PIPELINE MANAGER
323
  # --------------------------------------------------------------
324
+ def unload_pipelines(target="all"):
325
+ """Unload specific or all pipelines."""
326
  global pipe, img2img_pipe
327
+ print("[ImageGen] 🧹 Clearing pipelines from memory...")
328
+
329
+ if target in ("pipe", "all"):
330
+ try:
331
+ del pipe
332
+ except:
333
+ pass
334
  pipe = None
335
+
336
+ if target in ("img2img_pipe", "all"):
337
+ try:
338
+ del img2img_pipe
339
+ except:
340
+ pass
341
  img2img_pipe = None
342
+
 
343
  gc.collect()
344
  if torch.cuda.is_available():
345
  torch.cuda.empty_cache()
346
  print("[ImageGen] ✅ Memory cleared.")
347
 
348
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  def load_pipeline():
 
350
  global pipe
351
+ unload_pipelines(target="pipe") # only clear old txt2img
352
  model_path = download_model()
353
  print("[ImageGen] Loading main (txt2img) pipeline...")
 
354
  pipe = safe_load_pipeline(StableDiffusionXLPipeline, model_path)
355
  device = "cuda" if torch.cuda.is_available() else "cpu"
356
  pipe.to(device)
 
360
  return pipe
361
 
362
 
363
+ def load_img2img_pipeline():
364
+ global img2img_pipe
365
+ unload_pipelines(target="img2img_pipe") # only clear old img2img
366
+ model_path = download_model()
367
+ print("[ImageGen] Loading img2img pipeline...")
368
+ img2img_pipe = safe_load_pipeline(StableDiffusionXLImg2ImgPipeline, model_path)
369
+ device = "cuda" if torch.cuda.is_available() else "cpu"
370
+ img2img_pipe.to(device)
371
+ img2img_pipe.safety_checker = None
372
+ img2img_pipe.enable_attention_slicing()
373
+ print("[ImageGen] ✅ Img2Img pipeline ready.")
374
+ return img2img_pipe
375
+
376
+
377
+
378
  def load_img2img_pipeline():
379
  """Load img2img pipeline into RAM."""
380
  global img2img_pipe