karthikeya1212 commited on
Commit
17379cc
Β·
verified Β·
1 Parent(s): 32e7681

Update core/image_generator.py

Browse files
Files changed (1) hide show
  1. core/image_generator.py +41 -59
core/image_generator.py CHANGED
@@ -9,20 +9,27 @@ from typing import Dict, Any
9
  from PIL import Image
10
  from io import BytesIO
11
  import base64
 
12
 
13
  # --------------------------------------------------------------
14
- # βœ… PERSISTENT STORAGE SETUP (for Hugging Face Spaces)
15
  # --------------------------------------------------------------
16
  HF_CACHE_DIR = Path("/tmp/hf_cache")
17
- MODEL_DIR = Path("/tmp/models/realvisxl_v4")
18
- SEED_DIR = Path("/tmp/seed_images")
19
- TMP_DIR = Path("/tmp/generated_images")
20
 
21
- for d in [HF_CACHE_DIR, MODEL_DIR, SEED_DIR, TMP_DIR]:
22
- d.mkdir(parents=True, exist_ok=True)
 
 
23
 
 
 
 
 
24
 
25
- # Ensure all relevant environment variables are correctly set
 
 
26
  os.environ.update({
27
  "HF_HOME": str(HF_CACHE_DIR),
28
  "HF_HUB_CACHE": str(HF_CACHE_DIR),
@@ -35,23 +42,19 @@ os.environ.update({
35
  "CACHE_DIR": str(HF_CACHE_DIR),
36
  })
37
 
38
- # Force Python's tempfile to use /tmp/hf_cache too
39
- import tempfile
40
  tempfile.tempdir = str(HF_CACHE_DIR)
41
 
42
- # πŸ”’ Extra layer: patch os.path.expanduser to block β€œ/.cache”
43
- import os.path
44
- def safe_expanduser(path):
45
- if path.startswith("~") or path.startswith("/.cache"):
46
- return str(HF_CACHE_DIR)
47
- return os.path.expanduser_original(path)
48
-
49
- os.path.expanduser_original = os.path.expanduser
50
- os.path.expanduser = safe_expanduser
51
-
52
  print("[DEBUG] βœ… Hugging Face and Diffusers cache fully redirected to:", HF_CACHE_DIR)
53
 
 
 
 
 
 
 
54
 
 
 
55
 
56
  print("[DEBUG] βœ… Using persistent Hugging Face cache at:", HF_CACHE_DIR)
57
  print("[DEBUG] βœ… Model directory:", MODEL_DIR)
@@ -88,7 +91,6 @@ def download_model() -> Path:
88
  print("[ImageGen] βœ… Model already exists at:", model_path)
89
  return model_path
90
 
91
-
92
  # --------------------------------------------------------------
93
  # MEMORY-SAFE PIPELINE MANAGER
94
  # --------------------------------------------------------------
@@ -116,10 +118,27 @@ def unload_pipelines(target="all"):
116
  torch.cuda.empty_cache()
117
  print("[ImageGen] βœ… Memory cleared.")
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  def load_pipeline():
121
  global pipe
122
- unload_pipelines(target="pipe") # only clear old txt2img
123
  model_path = download_model()
124
  print("[ImageGen] Loading main (txt2img) pipeline...")
125
  pipe = safe_load_pipeline(StableDiffusionXLPipeline, model_path)
@@ -130,10 +149,9 @@ def load_pipeline():
130
  print("[ImageGen] βœ… Text-to-image pipeline ready.")
131
  return pipe
132
 
133
-
134
  def load_img2img_pipeline():
135
  global img2img_pipe
136
- unload_pipelines(target="img2img_pipe") # only clear old img2img
137
  model_path = download_model()
138
  print("[ImageGen] Loading img2img pipeline...")
139
  img2img_pipe = safe_load_pipeline(StableDiffusionXLImg2ImgPipeline, model_path)
@@ -144,41 +162,6 @@ def load_img2img_pipeline():
144
  print("[ImageGen] βœ… Img2Img pipeline ready.")
145
  return img2img_pipe
146
 
147
-
148
-
149
- def load_img2img_pipeline():
150
- """Load img2img pipeline into RAM."""
151
- global img2img_pipe
152
- unload_pipelines() # Ensure txt2img is removed first
153
- model_path = download_model()
154
- print("[ImageGen] Loading img2img pipeline...")
155
-
156
- img2img_pipe = safe_load_pipeline(StableDiffusionXLImg2ImgPipeline, model_path)
157
- device = "cuda" if torch.cuda.is_available() else "cpu"
158
- img2img_pipe.to(device)
159
- img2img_pipe.safety_checker = None
160
- img2img_pipe.enable_attention_slicing()
161
- print("[ImageGen] βœ… Img2Img pipeline ready.")
162
- return img2img_pipe
163
-
164
- def safe_load_pipeline(pipeline_class, model_path):
165
- """Safely load a pipeline with retry logic and memory handling."""
166
- try:
167
- print(f"[ImageGen] πŸ”„ Loading {pipeline_class.__name__} from {model_path} ...")
168
- pipe = pipeline_class.from_single_file(
169
- model_path,
170
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
171
- )
172
- print(f"[ImageGen] βœ… Successfully loaded {pipeline_class.__name__}.")
173
- return pipe
174
- except Exception as e:
175
- print(f"[ImageGen] ❌ Failed to load {pipeline_class.__name__}: {e}")
176
- unload_pipelines()
177
- gc.collect()
178
- if torch.cuda.is_available():
179
- torch.cuda.empty_cache()
180
- raise e
181
-
182
  # --------------------------------------------------------------
183
  # UTILITY: PIL β†’ BASE64
184
  # --------------------------------------------------------------
@@ -187,7 +170,6 @@ def pil_to_base64(img: Image.Image) -> str:
187
  img.save(buffered, format="PNG")
188
  return f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode()}"
189
 
190
-
191
  # --------------------------------------------------------------
192
  # UNIFIED IMAGE GENERATION FUNCTION
193
  # --------------------------------------------------------------
 
9
  from PIL import Image
10
  from io import BytesIO
11
  import base64
12
+ import tempfile
13
 
14
  # --------------------------------------------------------------
15
+ # 🚨 ABSOLUTE FIX FOR PermissionError('/.cache')
16
  # --------------------------------------------------------------
17
  HF_CACHE_DIR = Path("/tmp/hf_cache")
18
+ HF_CACHE_DIR.mkdir(parents=True, exist_ok=True)
 
 
19
 
20
+ # Patch expanduser BEFORE any library imports that might touch ~/.cache
21
+ import os.path
22
+ if not hasattr(os.path, "expanduser_original"):
23
+ os.path.expanduser_original = os.path.expanduser
24
 
25
+ def safe_expanduser(path):
26
+ if path.startswith("~") or path.startswith("/.cache"):
27
+ return str(HF_CACHE_DIR)
28
+ return os.path.expanduser_original(path)
29
 
30
+ os.path.expanduser = safe_expanduser
31
+
32
+ # Set environment variables AFTER patching
33
  os.environ.update({
34
  "HF_HOME": str(HF_CACHE_DIR),
35
  "HF_HUB_CACHE": str(HF_CACHE_DIR),
 
42
  "CACHE_DIR": str(HF_CACHE_DIR),
43
  })
44
 
 
 
45
  tempfile.tempdir = str(HF_CACHE_DIR)
46
 
 
 
 
 
 
 
 
 
 
 
47
  print("[DEBUG] βœ… Hugging Face and Diffusers cache fully redirected to:", HF_CACHE_DIR)
48
 
49
+ # --------------------------------------------------------------
50
+ # βœ… PERSISTENT STORAGE SETUP (for Hugging Face Spaces)
51
+ # --------------------------------------------------------------
52
+ MODEL_DIR = Path("/tmp/models/realvisxl_v4")
53
+ SEED_DIR = Path("/tmp/seed_images")
54
+ TMP_DIR = Path("/tmp/generated_images")
55
 
56
+ for d in [MODEL_DIR, SEED_DIR, TMP_DIR]:
57
+ d.mkdir(parents=True, exist_ok=True)
58
 
59
  print("[DEBUG] βœ… Using persistent Hugging Face cache at:", HF_CACHE_DIR)
60
  print("[DEBUG] βœ… Model directory:", MODEL_DIR)
 
91
  print("[ImageGen] βœ… Model already exists at:", model_path)
92
  return model_path
93
 
 
94
  # --------------------------------------------------------------
95
  # MEMORY-SAFE PIPELINE MANAGER
96
  # --------------------------------------------------------------
 
118
  torch.cuda.empty_cache()
119
  print("[ImageGen] βœ… Memory cleared.")
120
 
121
+ def safe_load_pipeline(pipeline_class, model_path):
122
+ """Safely load a pipeline with retry logic and memory handling."""
123
+ try:
124
+ print(f"[ImageGen] πŸ”„ Loading {pipeline_class.__name__} from {model_path} ...")
125
+ pipe = pipeline_class.from_single_file(
126
+ model_path,
127
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
128
+ )
129
+ print(f"[ImageGen] βœ… Successfully loaded {pipeline_class.__name__}.")
130
+ return pipe
131
+ except Exception as e:
132
+ print(f"[ImageGen] ❌ Failed to load {pipeline_class.__name__}: {e}")
133
+ unload_pipelines()
134
+ gc.collect()
135
+ if torch.cuda.is_available():
136
+ torch.cuda.empty_cache()
137
+ raise e
138
 
139
  def load_pipeline():
140
  global pipe
141
+ unload_pipelines(target="pipe")
142
  model_path = download_model()
143
  print("[ImageGen] Loading main (txt2img) pipeline...")
144
  pipe = safe_load_pipeline(StableDiffusionXLPipeline, model_path)
 
149
  print("[ImageGen] βœ… Text-to-image pipeline ready.")
150
  return pipe
151
 
 
152
  def load_img2img_pipeline():
153
  global img2img_pipe
154
+ unload_pipelines(target="img2img_pipe")
155
  model_path = download_model()
156
  print("[ImageGen] Loading img2img pipeline...")
157
  img2img_pipe = safe_load_pipeline(StableDiffusionXLImg2ImgPipeline, model_path)
 
162
  print("[ImageGen] βœ… Img2Img pipeline ready.")
163
  return img2img_pipe
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  # --------------------------------------------------------------
166
  # UTILITY: PIL β†’ BASE64
167
  # --------------------------------------------------------------
 
170
  img.save(buffered, format="PNG")
171
  return f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode()}"
172
 
 
173
  # --------------------------------------------------------------
174
  # UNIFIED IMAGE GENERATION FUNCTION
175
  # --------------------------------------------------------------