Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -23,6 +23,17 @@ import datetime
|
|
23 |
from typing import Optional, Tuple, List
|
24 |
import torch
|
25 |
from transformers import CLIPProcessor, CLIPModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
# -----------------------
|
28 |
# Configuration
|
@@ -37,14 +48,15 @@ MAX_IMAGES = 250 # Set to 250 as requested
|
|
37 |
JPEG_QUALITY = 85
|
38 |
TARGET_MAX_SIZE = (800, 800)
|
39 |
MAX_WORKERS = 6 # Reduced for stability
|
40 |
-
RETRY_COUNT =
|
41 |
BATCH_SIZE = 20
|
|
|
42 |
|
43 |
EMB_NPY = EMBED_DIR / "image_embeddings.npy"
|
44 |
EMB_INDEX_JSON = EMBED_DIR / "index.json"
|
45 |
# Removed HIST_BINS_PER_CHANNEL and HIST_RANGE as they are no longer used for embedding generation
|
46 |
|
47 |
-
CLIP_MODEL = "openai/clip-vit-
|
48 |
|
49 |
@st.cache_resource
|
50 |
def load_clip_model():
|
@@ -52,10 +64,20 @@ def load_clip_model():
|
|
52 |
print(f"Loading CLIP model: {CLIP_MODEL}...")
|
53 |
processor = CLIPProcessor.from_pretrained(CLIP_MODEL)
|
54 |
model = CLIPModel.from_pretrained(CLIP_MODEL)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
print("CLIP model loaded successfully.")
|
56 |
-
return processor, model
|
|
|
|
|
57 |
|
58 |
-
|
59 |
|
60 |
# Phase Constants
|
61 |
PHASE_IDLE = "idle"
|
@@ -126,12 +148,23 @@ progress_tracker = SafeProgressTracker()
|
|
126 |
# Utility Functions
|
127 |
# -----------------------
|
128 |
def ensure_dirs():
|
129 |
-
"""Create directories if they don't exist"""
|
130 |
try:
|
131 |
IMAGES_DIR.mkdir(parents=True, exist_ok=True)
|
132 |
EMBED_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
except Exception as e:
|
134 |
-
print(f"Directory
|
135 |
|
136 |
def seq_filename(i: int) -> str:
|
137 |
return f"{i:04d}.jpg"
|
@@ -217,6 +250,7 @@ def download_single_image(i: int, url: str) -> bool:
|
|
217 |
response = requests.get(url, stream=True, timeout=(30, 90))
|
218 |
if response.status_code != 200:
|
219 |
if attempt == RETRY_COUNT - 1:
|
|
|
220 |
return False
|
221 |
time.sleep(2 ** attempt) # Exponential backoff
|
222 |
continue
|
@@ -310,16 +344,70 @@ def create_safe_embedding(img_path: Path) -> np.ndarray:
|
|
310 |
return np.zeros(CLIP_MODEL_LOCAL.config.projection_dim, dtype=np.float32)
|
311 |
|
312 |
img = Image.open(img_path).convert("RGB")
|
|
|
313 |
inputs = CLIP_PROCESSOR(images=img, return_tensors="pt")
|
|
|
314 |
|
315 |
with torch.no_grad():
|
316 |
-
|
317 |
|
318 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
319 |
except Exception as e:
|
320 |
print(f"Embedding creation error for {img_path}: {e}")
|
321 |
return np.zeros(CLIP_MODEL_LOCAL.config.projection_dim, dtype=np.float32)
|
322 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
323 |
def process_embeddings_thread_safe() -> bool:
|
324 |
"""Create embeddings in background thread - NO Streamlit APIs"""
|
325 |
image_files = sorted([f for f in IMAGES_DIR.glob("*.jpg")
|
@@ -328,6 +416,14 @@ def process_embeddings_thread_safe() -> bool:
|
|
328 |
if not image_files:
|
329 |
progress_tracker.update(PHASE_ERROR, 0, 1, 1, "❌ No images found", "")
|
330 |
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
331 |
|
332 |
# Check if embeddings already exist and are current
|
333 |
try:
|
@@ -348,68 +444,97 @@ def process_embeddings_thread_safe() -> bool:
|
|
348 |
index = []
|
349 |
processed = 0
|
350 |
failed = 0
|
351 |
-
|
352 |
progress_tracker.update(PHASE_2_EMBEDDING, 0, total, 0,
|
353 |
f"🧠 Creating embeddings for {total} images...",
|
354 |
"Processing visual features")
|
355 |
-
|
356 |
try:
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
# Atomic save
|
378 |
-
temp_npy = EMB_NPY.with_suffix('.tmp')
|
379 |
-
temp_json = EMB_INDEX_JSON.with_suffix('.tmp')
|
380 |
-
|
381 |
-
np.save(temp_npy, embeddings_array)
|
382 |
-
with open(temp_json, 'w') as f:
|
383 |
-
json.dump(index, f, indent=2)
|
384 |
-
|
385 |
-
# Atomic move
|
386 |
-
temp_npy.replace(EMB_NPY)
|
387 |
-
temp_json.replace(EMB_INDEX_JSON)
|
388 |
else:
|
389 |
-
#
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
412 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
413 |
return True
|
414 |
|
415 |
except Exception as e:
|
@@ -667,12 +792,6 @@ def init_session_state():
|
|
667 |
|
668 |
def main():
|
669 |
"""Main application - All session state access here"""
|
670 |
-
st.set_page_config(
|
671 |
-
page_title="Visual Search System",
|
672 |
-
page_icon="🔍",
|
673 |
-
layout="wide",
|
674 |
-
initial_sidebar_state="collapsed"
|
675 |
-
)
|
676 |
|
677 |
apply_styling()
|
678 |
init_session_state() # Safe - main thread only
|
|
|
23 |
from typing import Optional, Tuple, List
|
24 |
import torch
|
25 |
from transformers import CLIPProcessor, CLIPModel
|
26 |
+
import PIL
|
27 |
+
|
28 |
+
# Reduce thread contention in tokenizers
|
29 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
30 |
+
|
31 |
+
st.set_page_config(
|
32 |
+
page_title="Visual Search System",
|
33 |
+
page_icon="🔍",
|
34 |
+
layout="wide",
|
35 |
+
initial_sidebar_state="collapsed"
|
36 |
+
)
|
37 |
|
38 |
# -----------------------
|
39 |
# Configuration
|
|
|
48 |
JPEG_QUALITY = 85
|
49 |
TARGET_MAX_SIZE = (800, 800)
|
50 |
MAX_WORKERS = 6 # Reduced for stability
|
51 |
+
RETRY_COUNT = 5
|
52 |
BATCH_SIZE = 20
|
53 |
+
EMBED_BATCH_SIZE = 8
|
54 |
|
55 |
EMB_NPY = EMBED_DIR / "image_embeddings.npy"
|
56 |
EMB_INDEX_JSON = EMBED_DIR / "index.json"
|
57 |
# Removed HIST_BINS_PER_CHANNEL and HIST_RANGE as they are no longer used for embedding generation
|
58 |
|
59 |
+
CLIP_MODEL = "openai/clip-vit-small-patch16" # Switched to smaller model
|
60 |
|
61 |
@st.cache_resource
|
62 |
def load_clip_model():
|
|
|
64 |
print(f"Loading CLIP model: {CLIP_MODEL}...")
|
65 |
processor = CLIPProcessor.from_pretrained(CLIP_MODEL)
|
66 |
model = CLIPModel.from_pretrained(CLIP_MODEL)
|
67 |
+
device = torch.device("cpu")
|
68 |
+
model.to(device)
|
69 |
+
model.eval()
|
70 |
+
# Limit CPU threads to avoid oversubscription on Spaces/limited CPUs
|
71 |
+
try:
|
72 |
+
torch.set_num_threads(max(1, min(4, os.cpu_count() or 2)))
|
73 |
+
except Exception:
|
74 |
+
pass
|
75 |
print("CLIP model loaded successfully.")
|
76 |
+
return processor, model, device
|
77 |
+
|
78 |
+
CLIP_PROCESSOR, CLIP_MODEL_LOCAL, CLIP_DEVICE = load_clip_model()
|
79 |
|
80 |
+
# Removed HF_TOKEN, API_URL, HEADERS as they are no longer used for image embedding
|
81 |
|
82 |
# Phase Constants
|
83 |
PHASE_IDLE = "idle"
|
|
|
148 |
# Utility Functions
|
149 |
# -----------------------
|
150 |
def ensure_dirs():
|
151 |
+
"""Create directories if they don't exist and clean up old progress/temp files"""
|
152 |
try:
|
153 |
IMAGES_DIR.mkdir(parents=True, exist_ok=True)
|
154 |
EMBED_DIR.mkdir(parents=True, exist_ok=True)
|
155 |
+
|
156 |
+
# Clean up old progress and temp embedding files for a fresh start
|
157 |
+
if PROGRESS_FILE.exists():
|
158 |
+
PROGRESS_FILE.unlink()
|
159 |
+
if SETUP_COMPLETE_FILE.exists():
|
160 |
+
SETUP_COMPLETE_FILE.unlink()
|
161 |
+
for f in EMBED_DIR.glob("*.tmp"): # Clean up any temp embedding files
|
162 |
+
f.unlink()
|
163 |
+
for f in IMAGES_DIR.glob("*.tmp"): # Clean up any temp image files
|
164 |
+
f.unlink()
|
165 |
+
|
166 |
except Exception as e:
|
167 |
+
print(f"Directory or cleanup error: {e}")
|
168 |
|
169 |
def seq_filename(i: int) -> str:
|
170 |
return f"{i:04d}.jpg"
|
|
|
250 |
response = requests.get(url, stream=True, timeout=(30, 90))
|
251 |
if response.status_code != 200:
|
252 |
if attempt == RETRY_COUNT - 1:
|
253 |
+
print(f"Final download attempt failed for {url}. Status: {response.status_code}")
|
254 |
return False
|
255 |
time.sleep(2 ** attempt) # Exponential backoff
|
256 |
continue
|
|
|
344 |
return np.zeros(CLIP_MODEL_LOCAL.config.projection_dim, dtype=np.float32)
|
345 |
|
346 |
img = Image.open(img_path).convert("RGB")
|
347 |
+
print(f"Embedding image: {img_path.name}, size={img.size}, mode={img.mode}")
|
348 |
inputs = CLIP_PROCESSOR(images=img, return_tensors="pt")
|
349 |
+
inputs = {k: v.to(CLIP_DEVICE) for k, v in inputs.items()}
|
350 |
|
351 |
with torch.no_grad():
|
352 |
+
image_features = CLIP_MODEL_LOCAL.get_image_features(**inputs)
|
353 |
|
354 |
+
if torch.isnan(image_features).any() or torch.isinf(image_features).any():
|
355 |
+
print(f"NaN/Inf detected in features for {img_path.name}")
|
356 |
+
return np.zeros(CLIP_MODEL_LOCAL.config.projection_dim, dtype=np.float32)
|
357 |
+
|
358 |
+
vec = image_features.squeeze().detach().cpu().numpy().astype(np.float32)
|
359 |
+
print(f"Feature vector shape: {vec.shape}, dtype: {vec.dtype}")
|
360 |
+
if vec.ndim != 1:
|
361 |
+
vec = vec.reshape(-1)
|
362 |
+
|
363 |
+
if vec.size != CLIP_MODEL_LOCAL.config.projection_dim:
|
364 |
+
print(f"Warning: feature dim {vec.size} != projection_dim {CLIP_MODEL_LOCAL.config.projection_dim}")
|
365 |
+
|
366 |
+
return vec
|
367 |
except Exception as e:
|
368 |
print(f"Embedding creation error for {img_path}: {e}")
|
369 |
return np.zeros(CLIP_MODEL_LOCAL.config.projection_dim, dtype=np.float32)
|
370 |
|
371 |
+
def create_embeddings_batch(image_paths: List[Path]) -> np.ndarray:
|
372 |
+
"""Create embeddings for a batch of images efficiently on CPU.
|
373 |
+
Returns array of shape (batch_size, projection_dim). Fills zeros on failures.
|
374 |
+
"""
|
375 |
+
images = []
|
376 |
+
fallback_indices = []
|
377 |
+
for idx, p in enumerate(image_paths):
|
378 |
+
try:
|
379 |
+
if not p.exists() or p.stat().st_size == 0:
|
380 |
+
fallback_indices.append(idx)
|
381 |
+
images.append(Image.new("RGB", (224, 224), color=(0, 0, 0)))
|
382 |
+
continue
|
383 |
+
img = Image.open(p).convert("RGB")
|
384 |
+
# Pre-resize to 224 to reduce CPU and memory
|
385 |
+
img = img.resize((224, 224), Image.Resampling.LANCZOS)
|
386 |
+
images.append(img)
|
387 |
+
except (PIL.UnidentifiedImageError, IOError) as image_err:
|
388 |
+
print(f"Image loading error for {p.name}: {image_err}. Using blank image.")
|
389 |
+
fallback_indices.append(idx)
|
390 |
+
images.append(Image.new("RGB", (224, 224), color=(0, 0, 0)))
|
391 |
+
except Exception as e:
|
392 |
+
print(f"Unexpected error loading image {p.name}: {e}. Using blank image.")
|
393 |
+
fallback_indices.append(idx)
|
394 |
+
images.append(Image.new("RGB", (224, 224), color=(0, 0, 0)))
|
395 |
+
|
396 |
+
try:
|
397 |
+
inputs = CLIP_PROCESSOR(images=images, return_tensors="pt")
|
398 |
+
inputs = {k: v.to(CLIP_DEVICE) for k, v in inputs.items()}
|
399 |
+
with torch.no_grad():
|
400 |
+
feats = CLIP_MODEL_LOCAL.get_image_features(**inputs)
|
401 |
+
feats = feats.detach().cpu().numpy().astype(np.float32)
|
402 |
+
# Replace fallback rows with zeros explicitly
|
403 |
+
for i in fallback_indices:
|
404 |
+
feats[i, :] = np.zeros_like(feats[i, :])
|
405 |
+
return feats
|
406 |
+
except Exception as e:
|
407 |
+
print(f"Batch embedding error for {len(image_paths)} images: {e}")
|
408 |
+
# Return None to signal caller to fallback to smaller batch
|
409 |
+
return None
|
410 |
+
|
411 |
def process_embeddings_thread_safe() -> bool:
|
412 |
"""Create embeddings in background thread - NO Streamlit APIs"""
|
413 |
image_files = sorted([f for f in IMAGES_DIR.glob("*.jpg")
|
|
|
416 |
if not image_files:
|
417 |
progress_tracker.update(PHASE_ERROR, 0, 1, 1, "❌ No images found", "")
|
418 |
return False
|
419 |
+
|
420 |
+
# Quick self-test on the first image to detect failures early
|
421 |
+
try:
|
422 |
+
test_vec = create_safe_embedding(image_files[0])
|
423 |
+
if not np.any(test_vec):
|
424 |
+
print(f"Self-test failed on first image: {image_files[0].name}")
|
425 |
+
except Exception as e:
|
426 |
+
print(f"Self-test exception: {e}")
|
427 |
|
428 |
# Check if embeddings already exist and are current
|
429 |
try:
|
|
|
444 |
index = []
|
445 |
processed = 0
|
446 |
failed = 0
|
447 |
+
|
448 |
progress_tracker.update(PHASE_2_EMBEDDING, 0, total, 0,
|
449 |
f"🧠 Creating embeddings for {total} images...",
|
450 |
"Processing visual features")
|
451 |
+
|
452 |
try:
|
453 |
+
current_batch_size = EMBED_BATCH_SIZE
|
454 |
+
for start in range(0, total, current_batch_size):
|
455 |
+
# adaptively chunk
|
456 |
+
end = min(start + current_batch_size, total)
|
457 |
+
batch_files = image_files[start:end]
|
458 |
+
|
459 |
+
# Try with current batch size; fallback by halving on failure
|
460 |
+
attempts = 0
|
461 |
+
feats = None
|
462 |
+
while attempts < 3:
|
463 |
+
feats = create_embeddings_batch(batch_files)
|
464 |
+
if feats is None:
|
465 |
+
attempts += 1
|
466 |
+
if current_batch_size > 4:
|
467 |
+
current_batch_size = max(4, current_batch_size // 2)
|
468 |
+
end = min(start + current_batch_size, total)
|
469 |
+
batch_files = image_files[start:end]
|
470 |
+
print(f"⚠️ Falling back to smaller batch size: {current_batch_size}")
|
471 |
+
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
472 |
else:
|
473 |
+
# Hard failure at smallest batch: compute per-image to maximize success
|
474 |
+
per_feats = []
|
475 |
+
for p in batch_files:
|
476 |
+
vec = create_safe_embedding(p)
|
477 |
+
per_feats.append(vec)
|
478 |
+
feats = np.vstack(per_feats).astype(np.float32)
|
479 |
+
break
|
480 |
+
break
|
481 |
+
|
482 |
+
# Count failures in this batch (rows that are all zeros)
|
483 |
+
if feats.ndim != 2 or feats.shape[0] != len(batch_files):
|
484 |
+
print(f"Unexpected batch feature shape: {feats.shape}, expected ({len(batch_files)}, D)")
|
485 |
+
batch_failed = int((np.linalg.norm(feats, axis=1) < 1e-12).sum()) if feats.size else 0
|
486 |
+
failed += batch_failed
|
487 |
+
embeddings.append(feats)
|
488 |
+
index.extend([p.name for p in batch_files])
|
489 |
+
processed = end
|
490 |
+
|
491 |
+
# Periodic save by batch for resilience
|
492 |
+
try:
|
493 |
+
if embeddings:
|
494 |
+
embeddings_array = np.vstack(embeddings).astype(np.float32)
|
495 |
+
temp_npy = EMB_NPY.with_suffix('.tmp')
|
496 |
+
temp_json = EMB_INDEX_JSON.with_suffix('.tmp')
|
497 |
+
np.save(temp_npy, embeddings_array)
|
498 |
+
with open(temp_json, 'w') as f:
|
499 |
+
json.dump(index, f, indent=2)
|
500 |
+
temp_npy.replace(EMB_NPY)
|
501 |
+
temp_json.replace(EMB_INDEX_JSON)
|
502 |
+
except Exception as e:
|
503 |
+
progress_tracker.update(PHASE_ERROR, processed, total, failed,
|
504 |
+
f"❌ Save failed: {e}", "")
|
505 |
+
return False
|
506 |
+
|
507 |
+
# Free memory after each batch
|
508 |
+
try:
|
509 |
+
import gc
|
510 |
+
del feats
|
511 |
+
gc.collect()
|
512 |
+
except Exception:
|
513 |
+
pass
|
514 |
+
|
515 |
+
success_rate = ((processed - failed) / processed * 100) if processed > 0 else 0
|
516 |
+
batch_success_count = len(batch_files) - batch_failed
|
517 |
+
print(f"Batch {start//current_batch_size + 1} completed: {batch_success_count} success, {batch_failed} failed.")
|
518 |
+
details = f"💾 Saved up to {processed} • 📊 failures {failed}"
|
519 |
+
message = f"🧠 Processed {processed}/{total} ({success_rate:.1f}%)"
|
520 |
+
progress_tracker.update(PHASE_2_EMBEDDING, processed, total, failed,
|
521 |
+
message, details)
|
522 |
+
|
523 |
+
# Final validation
|
524 |
+
embeddings_array = np.vstack(embeddings).astype(np.float32) if embeddings else np.zeros((0, CLIP_MODEL_LOCAL.config.projection_dim), dtype=np.float32)
|
525 |
+
if embeddings_array.shape[0] != len(index) or len(index) != total:
|
526 |
+
print(f"⚠️ Final size mismatch: emb_rows={embeddings_array.shape[0]}, index={len(index)}, total={total}")
|
527 |
+
print(f"Embedding processing completed. Total failed: {failed}/{total}")
|
528 |
|
529 |
+
# Ensure files saved
|
530 |
+
temp_npy = EMB_NPY.with_suffix('.tmp')
|
531 |
+
temp_json = EMB_INDEX_JSON.with_suffix('.tmp')
|
532 |
+
np.save(temp_npy, embeddings_array)
|
533 |
+
with open(temp_json, 'w') as f:
|
534 |
+
json.dump(index, f, indent=2)
|
535 |
+
temp_npy.replace(EMB_NPY)
|
536 |
+
temp_json.replace(EMB_INDEX_JSON)
|
537 |
+
|
538 |
return True
|
539 |
|
540 |
except Exception as e:
|
|
|
792 |
|
793 |
def main():
|
794 |
"""Main application - All session state access here"""
|
|
|
|
|
|
|
|
|
|
|
|
|
795 |
|
796 |
apply_styling()
|
797 |
init_session_state() # Safe - main thread only
|