Commit
·
807fb92
1
Parent(s):
da9ea7d
Add memory optimizations for SDXL pipeline - Enable VAE slicing and tiling - Enable attention slicing for UNet and ControlNet - Use sequential CPU offloading for pipeline - Keep BLIP model on CPU to save GPU memory - Add torch.no_grad() and cache clearing - Reduce guidance scale for lower memory usage
Browse files- app/main_sdxl.py +44 -24
app/main_sdxl.py
CHANGED
|
@@ -232,37 +232,48 @@ async def startup_event():
|
|
| 232 |
# Load diffusion components
|
| 233 |
logger.info("Loading VAE...")
|
| 234 |
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae")
|
|
|
|
|
|
|
|
|
|
| 235 |
|
| 236 |
logger.info("Loading UNet...")
|
| 237 |
unet = UNet2DConditionModel.from_config(base_model_path, subfolder="unet")
|
| 238 |
unet.load_state_dict(load_file(hf_hub_download("ByteDance/SDXL-Lightning", safetensors_ckpt)))
|
|
|
|
|
|
|
| 239 |
|
| 240 |
logger.info("Loading ControlNet...")
|
| 241 |
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=weight_dtype)
|
|
|
|
|
|
|
| 242 |
|
| 243 |
logger.info("Creating pipeline...")
|
| 244 |
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
|
| 245 |
-
base_model_path, vae=vae, unet=unet, controlnet=controlnet
|
| 246 |
)
|
| 247 |
-
pipe.to(device, dtype=weight_dtype)
|
| 248 |
pipe.safety_checker = None
|
| 249 |
|
| 250 |
-
#
|
| 251 |
-
logger.info("
|
| 252 |
-
|
| 253 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
try:
|
| 255 |
processor = BlipProcessor.from_pretrained(f"Salesforce/{caption_model_name}")
|
| 256 |
caption_model = BlipForConditionalGeneration.from_pretrained(
|
| 257 |
f"Salesforce/{caption_model_name}", torch_dtype=weight_dtype
|
| 258 |
-
)
|
|
|
|
|
|
|
| 259 |
except Exception as e:
|
| 260 |
-
logger.
|
| 261 |
-
|
| 262 |
-
processor = BlipProcessor.from_pretrained(f"Salesforce/{caption_model_name}")
|
| 263 |
-
caption_model = BlipForConditionalGeneration.from_pretrained(
|
| 264 |
-
f"Salesforce/{caption_model_name}", torch_dtype=weight_dtype
|
| 265 |
-
).to(device)
|
| 266 |
|
| 267 |
logger.info("✅ All models loaded successfully!")
|
| 268 |
model_load_error = None
|
|
@@ -381,10 +392,13 @@ def colorize_image_sdxl(
|
|
| 381 |
original_size = image.size
|
| 382 |
control_image = image.convert("L").convert("RGB").resize((512, 512))
|
| 383 |
|
| 384 |
-
# Image captioning
|
| 385 |
input_text = settings.CAPTION_PREFIX
|
| 386 |
-
|
| 387 |
-
|
|
|
|
|
|
|
|
|
|
| 388 |
caption = processor.decode(caption_ids[0], skip_special_tokens=True)
|
| 389 |
caption = remove_unlikely_words(caption)
|
| 390 |
|
|
@@ -394,14 +408,20 @@ def colorize_image_sdxl(
|
|
| 394 |
else:
|
| 395 |
final_prompt = caption
|
| 396 |
|
| 397 |
-
# Inference
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
|
| 406 |
colorized = apply_color(control_image, result.images[0]).resize(original_size)
|
| 407 |
return colorized, caption
|
|
|
|
| 232 |
# Load diffusion components
|
| 233 |
logger.info("Loading VAE...")
|
| 234 |
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae")
|
| 235 |
+
# Enable VAE slicing for memory efficiency
|
| 236 |
+
vae.enable_slicing()
|
| 237 |
+
vae.enable_tiling()
|
| 238 |
|
| 239 |
logger.info("Loading UNet...")
|
| 240 |
unet = UNet2DConditionModel.from_config(base_model_path, subfolder="unet")
|
| 241 |
unet.load_state_dict(load_file(hf_hub_download("ByteDance/SDXL-Lightning", safetensors_ckpt)))
|
| 242 |
+
# Enable attention slicing for memory efficiency
|
| 243 |
+
unet.set_attention_slice("max")
|
| 244 |
|
| 245 |
logger.info("Loading ControlNet...")
|
| 246 |
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=weight_dtype)
|
| 247 |
+
# Enable attention slicing for ControlNet
|
| 248 |
+
controlnet.set_attention_slice("max")
|
| 249 |
|
| 250 |
logger.info("Creating pipeline...")
|
| 251 |
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
|
| 252 |
+
base_model_path, vae=vae, unet=unet, controlnet=controlnet, torch_dtype=weight_dtype
|
| 253 |
)
|
|
|
|
| 254 |
pipe.safety_checker = None
|
| 255 |
|
| 256 |
+
# Enable sequential CPU offloading to reduce memory usage
|
| 257 |
+
logger.info("Enabling CPU offloading for memory efficiency...")
|
| 258 |
+
pipe.enable_sequential_cpu_offload()
|
| 259 |
+
# Alternative: use model CPU offload (moves entire model to CPU when not in use)
|
| 260 |
+
# pipe.enable_model_cpu_offload()
|
| 261 |
+
|
| 262 |
+
logger.info("Memory optimizations enabled")
|
| 263 |
+
|
| 264 |
+
# Load BLIP captioning model (use base to save memory)
|
| 265 |
+
logger.info("Loading BLIP captioning model (using base model for memory efficiency)...")
|
| 266 |
+
caption_model_name = "blip-image-captioning-base"
|
| 267 |
try:
|
| 268 |
processor = BlipProcessor.from_pretrained(f"Salesforce/{caption_model_name}")
|
| 269 |
caption_model = BlipForConditionalGeneration.from_pretrained(
|
| 270 |
f"Salesforce/{caption_model_name}", torch_dtype=weight_dtype
|
| 271 |
+
)
|
| 272 |
+
# Keep BLIP on CPU and move to device only during inference
|
| 273 |
+
caption_model.eval()
|
| 274 |
except Exception as e:
|
| 275 |
+
logger.error(f"Failed to load BLIP model: {e}")
|
| 276 |
+
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
|
| 278 |
logger.info("✅ All models loaded successfully!")
|
| 279 |
model_load_error = None
|
|
|
|
| 392 |
original_size = image.size
|
| 393 |
control_image = image.convert("L").convert("RGB").resize((512, 512))
|
| 394 |
|
| 395 |
+
# Image captioning - keep BLIP on CPU to save memory
|
| 396 |
input_text = settings.CAPTION_PREFIX
|
| 397 |
+
# Use CPU for BLIP to save GPU memory
|
| 398 |
+
blip_device = torch.device("cpu")
|
| 399 |
+
inputs = processor(control_image, input_text, return_tensors="pt").to(blip_device)
|
| 400 |
+
with torch.no_grad():
|
| 401 |
+
caption_ids = caption_model.generate(**inputs, max_length=50, num_beams=3)
|
| 402 |
caption = processor.decode(caption_ids[0], skip_special_tokens=True)
|
| 403 |
caption = remove_unlikely_words(caption)
|
| 404 |
|
|
|
|
| 408 |
else:
|
| 409 |
final_prompt = caption
|
| 410 |
|
| 411 |
+
# Inference with memory-efficient settings
|
| 412 |
+
with torch.no_grad():
|
| 413 |
+
result = pipe(
|
| 414 |
+
prompt=final_prompt,
|
| 415 |
+
negative_prompt=negative_prompt or settings.NEGATIVE_PROMPT,
|
| 416 |
+
num_inference_steps=num_inference_steps,
|
| 417 |
+
generator=torch.manual_seed(seed),
|
| 418 |
+
image=control_image,
|
| 419 |
+
guidance_scale=7.5, # Lower guidance scale uses less memory
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
# Clear cache after inference
|
| 423 |
+
if torch.cuda.is_available():
|
| 424 |
+
torch.cuda.empty_cache()
|
| 425 |
|
| 426 |
colorized = apply_color(control_image, result.images[0]).resize(original_size)
|
| 427 |
return colorized, caption
|