LogicGoInfotechSpaces commited on
Commit
da9ea7d
·
1 Parent(s): 779884f

Fix permission errors and environment variables for model loading - Use /tmp directory for writable model storage - Set OMP_NUM_THREADS and cache directories at startup - Add better error handling and permission checks

Browse files
Files changed (1) hide show
  1. app/main_sdxl.py +41 -10
app/main_sdxl.py CHANGED
@@ -3,6 +3,15 @@ FastAPI application for Text-Guided Image Colorization using SDXL + ControlNet
3
  Based on fffiloni/text-guided-image-colorization
4
  """
5
  import os
 
 
 
 
 
 
 
 
 
6
  import io
7
  import uuid
8
  import logging
@@ -174,17 +183,39 @@ async def startup_event():
174
  try:
175
  logger.info("🔄 Loading SDXL + ControlNet colorization models...")
176
 
177
- # Ensure required directories exist
178
- os.makedirs("sdxl_light_caption_output", exist_ok=True)
179
-
180
- # Download controlnet model snapshot
181
  try:
182
- snapshot_download(
183
- repo_id='nickpai/sdxl_light_caption_output',
184
- local_dir='sdxl_light_caption_output'
185
- )
 
 
 
 
 
 
186
  except Exception as e:
187
- logger.warning(f"Could not download controlnet snapshot: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
  # Device and precision setup
190
  accelerator = Accelerator(mixed_precision="fp16")
@@ -196,7 +227,7 @@ async def startup_event():
196
  # Pretrained paths
197
  base_model_path = settings.BASE_MODEL_ID
198
  safetensors_ckpt = settings.LIGHTNING_WEIGHTS
199
- controlnet_path = "sdxl_light_caption_output/checkpoint-30000/controlnet"
200
 
201
  # Load diffusion components
202
  logger.info("Loading VAE...")
 
3
  Based on fffiloni/text-guided-image-colorization
4
  """
5
  import os
6
+ # Set environment variables BEFORE any imports
7
+ os.environ["OMP_NUM_THREADS"] = "1"
8
+ os.environ["HF_HOME"] = "/tmp/hf_cache"
9
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
10
+ os.environ["HF_HUB_CACHE"] = "/tmp/hf_cache"
11
+ os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/hf_cache"
12
+ os.environ["XDG_CACHE_HOME"] = "/tmp/hf_cache"
13
+ os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib_config"
14
+
15
  import io
16
  import uuid
17
  import logging
 
183
  try:
184
  logger.info("🔄 Loading SDXL + ControlNet colorization models...")
185
 
186
+ # Use writable directory for model downloads
187
+ controlnet_dir = "/tmp/sdxl_light_caption_output"
 
 
188
  try:
189
+ os.makedirs(controlnet_dir, exist_ok=True)
190
+ # Test write permissions
191
+ test_file = os.path.join(controlnet_dir, ".test_write")
192
+ with open(test_file, "w") as f:
193
+ f.write("test")
194
+ os.remove(test_file)
195
+ logger.info(f"Using directory: {controlnet_dir}")
196
+ except PermissionError as e:
197
+ logger.error(f"Permission denied for directory {controlnet_dir}: {e}")
198
+ raise
199
  except Exception as e:
200
+ logger.error(f"Failed to create directory {controlnet_dir}: {e}")
201
+ raise
202
+
203
+ # Download controlnet model snapshot
204
+ controlnet_path = os.path.join(controlnet_dir, "checkpoint-30000", "controlnet")
205
+ if os.path.exists(controlnet_path):
206
+ logger.info(f"ControlNet model already exists at {controlnet_path}")
207
+ else:
208
+ try:
209
+ logger.info("Downloading ControlNet model...")
210
+ snapshot_download(
211
+ repo_id='nickpai/sdxl_light_caption_output',
212
+ local_dir=controlnet_dir
213
+ )
214
+ logger.info("ControlNet model downloaded successfully")
215
+ except Exception as e:
216
+ logger.error(f"Could not download controlnet snapshot: {e}")
217
+ if not os.path.exists(controlnet_path):
218
+ raise
219
 
220
  # Device and precision setup
221
  accelerator = Accelerator(mixed_precision="fp16")
 
227
  # Pretrained paths
228
  base_model_path = settings.BASE_MODEL_ID
229
  safetensors_ckpt = settings.LIGHTNING_WEIGHTS
230
+ # controlnet_path already defined above
231
 
232
  # Load diffusion components
233
  logger.info("Loading VAE...")