Commit
·
da9ea7d
1
Parent(s):
779884f
Fix permission errors and environment variables for model loading - Use /tmp directory for writable model storage - Set OMP_NUM_THREADS and cache directories at startup - Add better error handling and permission checks
Browse files- app/main_sdxl.py +41 -10
app/main_sdxl.py
CHANGED
|
@@ -3,6 +3,15 @@ FastAPI application for Text-Guided Image Colorization using SDXL + ControlNet
|
|
| 3 |
Based on fffiloni/text-guided-image-colorization
|
| 4 |
"""
|
| 5 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
import io
|
| 7 |
import uuid
|
| 8 |
import logging
|
|
@@ -174,17 +183,39 @@ async def startup_event():
|
|
| 174 |
try:
|
| 175 |
logger.info("🔄 Loading SDXL + ControlNet colorization models...")
|
| 176 |
|
| 177 |
-
#
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
# Download controlnet model snapshot
|
| 181 |
try:
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
except Exception as e:
|
| 187 |
-
logger.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
|
| 189 |
# Device and precision setup
|
| 190 |
accelerator = Accelerator(mixed_precision="fp16")
|
|
@@ -196,7 +227,7 @@ async def startup_event():
|
|
| 196 |
# Pretrained paths
|
| 197 |
base_model_path = settings.BASE_MODEL_ID
|
| 198 |
safetensors_ckpt = settings.LIGHTNING_WEIGHTS
|
| 199 |
-
controlnet_path
|
| 200 |
|
| 201 |
# Load diffusion components
|
| 202 |
logger.info("Loading VAE...")
|
|
|
|
| 3 |
Based on fffiloni/text-guided-image-colorization
|
| 4 |
"""
|
| 5 |
import os
|
| 6 |
+
# Set environment variables BEFORE any imports
|
| 7 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
| 8 |
+
os.environ["HF_HOME"] = "/tmp/hf_cache"
|
| 9 |
+
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
|
| 10 |
+
os.environ["HF_HUB_CACHE"] = "/tmp/hf_cache"
|
| 11 |
+
os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/hf_cache"
|
| 12 |
+
os.environ["XDG_CACHE_HOME"] = "/tmp/hf_cache"
|
| 13 |
+
os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib_config"
|
| 14 |
+
|
| 15 |
import io
|
| 16 |
import uuid
|
| 17 |
import logging
|
|
|
|
| 183 |
try:
|
| 184 |
logger.info("🔄 Loading SDXL + ControlNet colorization models...")
|
| 185 |
|
| 186 |
+
# Use writable directory for model downloads
|
| 187 |
+
controlnet_dir = "/tmp/sdxl_light_caption_output"
|
|
|
|
|
|
|
| 188 |
try:
|
| 189 |
+
os.makedirs(controlnet_dir, exist_ok=True)
|
| 190 |
+
# Test write permissions
|
| 191 |
+
test_file = os.path.join(controlnet_dir, ".test_write")
|
| 192 |
+
with open(test_file, "w") as f:
|
| 193 |
+
f.write("test")
|
| 194 |
+
os.remove(test_file)
|
| 195 |
+
logger.info(f"Using directory: {controlnet_dir}")
|
| 196 |
+
except PermissionError as e:
|
| 197 |
+
logger.error(f"Permission denied for directory {controlnet_dir}: {e}")
|
| 198 |
+
raise
|
| 199 |
except Exception as e:
|
| 200 |
+
logger.error(f"Failed to create directory {controlnet_dir}: {e}")
|
| 201 |
+
raise
|
| 202 |
+
|
| 203 |
+
# Download controlnet model snapshot
|
| 204 |
+
controlnet_path = os.path.join(controlnet_dir, "checkpoint-30000", "controlnet")
|
| 205 |
+
if os.path.exists(controlnet_path):
|
| 206 |
+
logger.info(f"ControlNet model already exists at {controlnet_path}")
|
| 207 |
+
else:
|
| 208 |
+
try:
|
| 209 |
+
logger.info("Downloading ControlNet model...")
|
| 210 |
+
snapshot_download(
|
| 211 |
+
repo_id='nickpai/sdxl_light_caption_output',
|
| 212 |
+
local_dir=controlnet_dir
|
| 213 |
+
)
|
| 214 |
+
logger.info("ControlNet model downloaded successfully")
|
| 215 |
+
except Exception as e:
|
| 216 |
+
logger.error(f"Could not download controlnet snapshot: {e}")
|
| 217 |
+
if not os.path.exists(controlnet_path):
|
| 218 |
+
raise
|
| 219 |
|
| 220 |
# Device and precision setup
|
| 221 |
accelerator = Accelerator(mixed_precision="fp16")
|
|
|
|
| 227 |
# Pretrained paths
|
| 228 |
base_model_path = settings.BASE_MODEL_ID
|
| 229 |
safetensors_ckpt = settings.LIGHTNING_WEIGHTS
|
| 230 |
+
# controlnet_path already defined above
|
| 231 |
|
| 232 |
# Load diffusion components
|
| 233 |
logger.info("Loading VAE...")
|