LogicGoInfotechSpaces commited on
Commit
807fb92
·
1 Parent(s): da9ea7d

Add memory optimizations for SDXL pipeline - Enable VAE slicing and tiling - Enable attention slicing for UNet and ControlNet - Use sequential CPU offloading for pipeline - Keep BLIP model on CPU to save GPU memory - Add torch.no_grad() and cache clearing - Reduce guidance scale for lower memory usage

Browse files
Files changed (1) hide show
  1. app/main_sdxl.py +44 -24
app/main_sdxl.py CHANGED
@@ -232,37 +232,48 @@ async def startup_event():
232
  # Load diffusion components
233
  logger.info("Loading VAE...")
234
  vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae")
 
 
 
235
 
236
  logger.info("Loading UNet...")
237
  unet = UNet2DConditionModel.from_config(base_model_path, subfolder="unet")
238
  unet.load_state_dict(load_file(hf_hub_download("ByteDance/SDXL-Lightning", safetensors_ckpt)))
 
 
239
 
240
  logger.info("Loading ControlNet...")
241
  controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=weight_dtype)
 
 
242
 
243
  logger.info("Creating pipeline...")
244
  pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
245
- base_model_path, vae=vae, unet=unet, controlnet=controlnet
246
  )
247
- pipe.to(device, dtype=weight_dtype)
248
  pipe.safety_checker = None
249
 
250
- # Load BLIP captioning model
251
- logger.info("Loading BLIP captioning model...")
252
- # Try large first, fallback to base
253
- caption_model_name = "blip-image-captioning-large"
 
 
 
 
 
 
 
254
  try:
255
  processor = BlipProcessor.from_pretrained(f"Salesforce/{caption_model_name}")
256
  caption_model = BlipForConditionalGeneration.from_pretrained(
257
  f"Salesforce/{caption_model_name}", torch_dtype=weight_dtype
258
- ).to(device)
 
 
259
  except Exception as e:
260
- logger.warning(f"Failed to load large model, trying base: {e}")
261
- caption_model_name = "blip-image-captioning-base"
262
- processor = BlipProcessor.from_pretrained(f"Salesforce/{caption_model_name}")
263
- caption_model = BlipForConditionalGeneration.from_pretrained(
264
- f"Salesforce/{caption_model_name}", torch_dtype=weight_dtype
265
- ).to(device)
266
 
267
  logger.info("✅ All models loaded successfully!")
268
  model_load_error = None
@@ -381,10 +392,13 @@ def colorize_image_sdxl(
381
  original_size = image.size
382
  control_image = image.convert("L").convert("RGB").resize((512, 512))
383
 
384
- # Image captioning
385
  input_text = settings.CAPTION_PREFIX
386
- inputs = processor(control_image, input_text, return_tensors="pt").to(device, dtype=weight_dtype)
387
- caption_ids = caption_model.generate(**inputs)
 
 
 
388
  caption = processor.decode(caption_ids[0], skip_special_tokens=True)
389
  caption = remove_unlikely_words(caption)
390
 
@@ -394,14 +408,20 @@ def colorize_image_sdxl(
394
  else:
395
  final_prompt = caption
396
 
397
- # Inference
398
- result = pipe(
399
- prompt=final_prompt,
400
- negative_prompt=negative_prompt or settings.NEGATIVE_PROMPT,
401
- num_inference_steps=num_inference_steps,
402
- generator=torch.manual_seed(seed),
403
- image=control_image
404
- )
 
 
 
 
 
 
405
 
406
  colorized = apply_color(control_image, result.images[0]).resize(original_size)
407
  return colorized, caption
 
232
  # Load diffusion components
233
  logger.info("Loading VAE...")
234
  vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae")
235
+ # Enable VAE slicing for memory efficiency
236
+ vae.enable_slicing()
237
+ vae.enable_tiling()
238
 
239
  logger.info("Loading UNet...")
240
  unet = UNet2DConditionModel.from_config(base_model_path, subfolder="unet")
241
  unet.load_state_dict(load_file(hf_hub_download("ByteDance/SDXL-Lightning", safetensors_ckpt)))
242
+ # Enable attention slicing for memory efficiency
243
+ unet.set_attention_slice("max")
244
 
245
  logger.info("Loading ControlNet...")
246
  controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=weight_dtype)
247
+ # Enable attention slicing for ControlNet
248
+ controlnet.set_attention_slice("max")
249
 
250
  logger.info("Creating pipeline...")
251
  pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
252
+ base_model_path, vae=vae, unet=unet, controlnet=controlnet, torch_dtype=weight_dtype
253
  )
 
254
  pipe.safety_checker = None
255
 
256
+ # Enable sequential CPU offloading to reduce memory usage
257
+ logger.info("Enabling CPU offloading for memory efficiency...")
258
+ pipe.enable_sequential_cpu_offload()
259
+ # Alternative: use model CPU offload (moves entire model to CPU when not in use)
260
+ # pipe.enable_model_cpu_offload()
261
+
262
+ logger.info("Memory optimizations enabled")
263
+
264
+ # Load BLIP captioning model (use base to save memory)
265
+ logger.info("Loading BLIP captioning model (using base model for memory efficiency)...")
266
+ caption_model_name = "blip-image-captioning-base"
267
  try:
268
  processor = BlipProcessor.from_pretrained(f"Salesforce/{caption_model_name}")
269
  caption_model = BlipForConditionalGeneration.from_pretrained(
270
  f"Salesforce/{caption_model_name}", torch_dtype=weight_dtype
271
+ )
272
+ # Keep BLIP on CPU and move to device only during inference
273
+ caption_model.eval()
274
  except Exception as e:
275
+ logger.error(f"Failed to load BLIP model: {e}")
276
+ raise
 
 
 
 
277
 
278
  logger.info("✅ All models loaded successfully!")
279
  model_load_error = None
 
392
  original_size = image.size
393
  control_image = image.convert("L").convert("RGB").resize((512, 512))
394
 
395
+ # Image captioning - keep BLIP on CPU to save memory
396
  input_text = settings.CAPTION_PREFIX
397
+ # Use CPU for BLIP to save GPU memory
398
+ blip_device = torch.device("cpu")
399
+ inputs = processor(control_image, input_text, return_tensors="pt").to(blip_device)
400
+ with torch.no_grad():
401
+ caption_ids = caption_model.generate(**inputs, max_length=50, num_beams=3)
402
  caption = processor.decode(caption_ids[0], skip_special_tokens=True)
403
  caption = remove_unlikely_words(caption)
404
 
 
408
  else:
409
  final_prompt = caption
410
 
411
+ # Inference with memory-efficient settings
412
+ with torch.no_grad():
413
+ result = pipe(
414
+ prompt=final_prompt,
415
+ negative_prompt=negative_prompt or settings.NEGATIVE_PROMPT,
416
+ num_inference_steps=num_inference_steps,
417
+ generator=torch.manual_seed(seed),
418
+ image=control_image,
419
+ guidance_scale=7.5, # Lower guidance scale uses less memory
420
+ )
421
+
422
+ # Clear cache after inference
423
+ if torch.cuda.is_available():
424
+ torch.cuda.empty_cache()
425
 
426
  colorized = apply_color(control_image, result.images[0]).resize(original_size)
427
  return colorized, caption