karthikeya1212 commited on
Commit
1791ffb
·
verified ·
1 Parent(s): c7a41ed

Update core/image_generator.py

Browse files
Files changed (1) hide show
  1. core/image_generator.py +35 -30
core/image_generator.py CHANGED
@@ -118,24 +118,20 @@
118
 
119
  # core/image_generator.py
120
  # ---------------- CACHE & MODEL DIRECTORIES (FIXED) ----------------
 
121
  import os
122
  from pathlib import Path
123
- import torch
124
- from diffusers import StableDiffusionXLPipeline
125
- from huggingface_hub import hf_hub_download
126
- from typing import List
127
- from io import BytesIO
128
- import base64
129
- from PIL import Image
130
 
131
- # Force all Hugging Face caches to /tmp/hf_cache
 
 
132
  HF_CACHE_DIR = Path("/tmp/hf_cache")
133
- HF_CACHE_DIR.mkdir(parents=True, exist_ok=True)
134
-
135
  MODEL_DIR = Path("/tmp/models/realvisxl_v4")
136
- MODEL_DIR.mkdir(parents=True, exist_ok=True)
137
 
138
- # MUST be set before importing diffusers/transformers
 
 
 
139
  os.environ["HF_HOME"] = str(HF_CACHE_DIR)
140
  os.environ["HF_HUB_CACHE"] = str(HF_CACHE_DIR)
141
  os.environ["DIFFUSERS_CACHE"] = str(HF_CACHE_DIR)
@@ -143,22 +139,33 @@ os.environ["TRANSFORMERS_CACHE"] = str(HF_CACHE_DIR)
143
  os.environ["XDG_CACHE_HOME"] = str(HF_CACHE_DIR)
144
  os.environ["HF_DATASETS_CACHE"] = str(HF_CACHE_DIR)
145
  os.environ["HF_MODULES_CACHE"] = str(HF_CACHE_DIR)
 
 
146
 
147
  print("[DEBUG] Hugging Face cache directory set to:", HF_CACHE_DIR)
148
  print("[DEBUG] Model directory set to:", MODEL_DIR)
149
- # ---------------- MODEL CONFIG ----------------
150
- MODEL_REPO = "SG161222/RealVisXL_V4.0"
151
- MODEL_FILENAME = "RealVisXL_V4.0.safetensors"
152
-
153
 
 
 
 
 
 
 
 
 
 
 
154
 
 
 
 
 
 
155
 
156
- # ---------------- MODEL DOWNLOAD ----------------
 
 
157
  def download_model() -> Path:
158
- """
159
- Downloads RealVisXL V4.0 model if not present.
160
- Returns local path.
161
- """
162
  model_path = MODEL_DIR / MODEL_FILENAME
163
  if not model_path.exists():
164
  print("[ImageGen] Downloading RealVisXL V4.0 model...")
@@ -168,7 +175,7 @@ def download_model() -> Path:
168
  filename=MODEL_FILENAME,
169
  cache_dir=str(HF_CACHE_DIR),
170
  force_download=False,
171
- resume_download=True, # safer if download interrupted
172
  )
173
  )
174
  print(f"[ImageGen] Model downloaded to: {model_path}")
@@ -176,17 +183,17 @@ def download_model() -> Path:
176
  print("[ImageGen] Model already exists. Skipping download.")
177
  return model_path
178
 
179
- # ---------------- PIPELINE LOAD ----------------
 
 
180
  def load_pipeline() -> StableDiffusionXLPipeline:
181
- """
182
- Loads the RealVisXL V4.0 model for image generation.
183
- """
184
  model_path = download_model()
185
  print("[ImageGen] Loading model into pipeline...")
186
 
187
  pipe = StableDiffusionXLPipeline.from_single_file(
188
  str(model_path),
189
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
 
190
  )
191
 
192
  if torch.cuda.is_available():
@@ -194,10 +201,8 @@ def load_pipeline() -> StableDiffusionXLPipeline:
194
  else:
195
  pipe.to("cpu")
196
 
197
- # Optional: skip safety checker to save memory/performance
198
- pipe.safety_checker = None
199
- # Enable attention slicing for memory-efficient CPU usage
200
- pipe.enable_attention_slicing()
201
 
202
  print("[ImageGen] Model ready.")
203
  return pipe
 
118
 
119
  # core/image_generator.py
120
  # ---------------- CACHE & MODEL DIRECTORIES (FIXED) ----------------
121
+ # core/image_generator.py
122
  import os
123
  from pathlib import Path
 
 
 
 
 
 
 
124
 
125
+ # --------------------------------------------------------------
126
+ # ✅ FORCE ALL HUGGINGFACE + DIFFUSERS CACHES TO /tmp/hf_cache
127
+ # --------------------------------------------------------------
128
  HF_CACHE_DIR = Path("/tmp/hf_cache")
 
 
129
  MODEL_DIR = Path("/tmp/models/realvisxl_v4")
 
130
 
131
+ for d in [HF_CACHE_DIR, MODEL_DIR]:
132
+ d.mkdir(parents=True, exist_ok=True)
133
+
134
+ # Set all relevant environment variables *before* any imports
135
  os.environ["HF_HOME"] = str(HF_CACHE_DIR)
136
  os.environ["HF_HUB_CACHE"] = str(HF_CACHE_DIR)
137
  os.environ["DIFFUSERS_CACHE"] = str(HF_CACHE_DIR)
 
139
  os.environ["XDG_CACHE_HOME"] = str(HF_CACHE_DIR)
140
  os.environ["HF_DATASETS_CACHE"] = str(HF_CACHE_DIR)
141
  os.environ["HF_MODULES_CACHE"] = str(HF_CACHE_DIR)
142
+ os.environ["TMPDIR"] = str(HF_CACHE_DIR)
143
+ os.environ["CACHE_DIR"] = str(HF_CACHE_DIR)
144
 
145
  print("[DEBUG] Hugging Face cache directory set to:", HF_CACHE_DIR)
146
  print("[DEBUG] Model directory set to:", MODEL_DIR)
 
 
 
 
147
 
148
+ # --------------------------------------------------------------
149
+ # IMPORTS (must be done after env vars are set)
150
+ # --------------------------------------------------------------
151
+ import torch
152
+ from diffusers import StableDiffusionXLPipeline
153
+ from huggingface_hub import hf_hub_download
154
+ from typing import List
155
+ from io import BytesIO
156
+ import base64
157
+ from PIL import Image
158
 
159
+ # --------------------------------------------------------------
160
+ # MODEL CONFIG
161
+ # --------------------------------------------------------------
162
+ MODEL_REPO = "SG161222/RealVisXL_V4.0"
163
+ MODEL_FILENAME = "RealVisXL_V4.0.safetensors"
164
 
165
+ # --------------------------------------------------------------
166
+ # MODEL DOWNLOAD
167
+ # --------------------------------------------------------------
168
  def download_model() -> Path:
 
 
 
 
169
  model_path = MODEL_DIR / MODEL_FILENAME
170
  if not model_path.exists():
171
  print("[ImageGen] Downloading RealVisXL V4.0 model...")
 
175
  filename=MODEL_FILENAME,
176
  cache_dir=str(HF_CACHE_DIR),
177
  force_download=False,
178
+ resume_download=True,
179
  )
180
  )
181
  print(f"[ImageGen] Model downloaded to: {model_path}")
 
183
  print("[ImageGen] Model already exists. Skipping download.")
184
  return model_path
185
 
186
+ # --------------------------------------------------------------
187
+ # PIPELINE LOAD
188
+ # --------------------------------------------------------------
189
  def load_pipeline() -> StableDiffusionXLPipeline:
 
 
 
190
  model_path = download_model()
191
  print("[ImageGen] Loading model into pipeline...")
192
 
193
  pipe = StableDiffusionXLPipeline.from_single_file(
194
  str(model_path),
195
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
196
+ local_files_only=True, # ✅ ensures it never writes to /.cache
197
  )
198
 
199
  if torch.cuda.is_available():
 
201
  else:
202
  pipe.to("cpu")
203
 
204
+ pipe.safety_checker = None
205
+ pipe.enable_attention_slicing()
 
 
206
 
207
  print("[ImageGen] Model ready.")
208
  return pipe