manisharma494 commited on
Commit
18a001a
·
verified ·
1 Parent(s): f65d508

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +189 -70
app.py CHANGED
@@ -23,6 +23,17 @@ import datetime
23
  from typing import Optional, Tuple, List
24
  import torch
25
  from transformers import CLIPProcessor, CLIPModel
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # -----------------------
28
  # Configuration
@@ -37,14 +48,15 @@ MAX_IMAGES = 250 # Set to 250 as requested
37
  JPEG_QUALITY = 85
38
  TARGET_MAX_SIZE = (800, 800)
39
  MAX_WORKERS = 6 # Reduced for stability
40
- RETRY_COUNT = 3
41
  BATCH_SIZE = 20
 
42
 
43
  EMB_NPY = EMBED_DIR / "image_embeddings.npy"
44
  EMB_INDEX_JSON = EMBED_DIR / "index.json"
45
  # Removed HIST_BINS_PER_CHANNEL and HIST_RANGE as they are no longer used for embedding generation
46
 
47
- CLIP_MODEL = "openai/clip-vit-base-patch32"
48
 
49
  @st.cache_resource
50
  def load_clip_model():
@@ -52,10 +64,20 @@ def load_clip_model():
52
  print(f"Loading CLIP model: {CLIP_MODEL}...")
53
  processor = CLIPProcessor.from_pretrained(CLIP_MODEL)
54
  model = CLIPModel.from_pretrained(CLIP_MODEL)
 
 
 
 
 
 
 
 
55
  print("CLIP model loaded successfully.")
56
- return processor, model
 
 
57
 
58
- CLIP_PROCESSOR, CLIP_MODEL_LOCAL = load_clip_model()
59
 
60
  # Phase Constants
61
  PHASE_IDLE = "idle"
@@ -126,12 +148,23 @@ progress_tracker = SafeProgressTracker()
126
  # Utility Functions
127
  # -----------------------
128
  def ensure_dirs():
129
- """Create directories if they don't exist"""
130
  try:
131
  IMAGES_DIR.mkdir(parents=True, exist_ok=True)
132
  EMBED_DIR.mkdir(parents=True, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
133
  except Exception as e:
134
- print(f"Directory creation error: {e}")
135
 
136
  def seq_filename(i: int) -> str:
137
  return f"{i:04d}.jpg"
@@ -217,6 +250,7 @@ def download_single_image(i: int, url: str) -> bool:
217
  response = requests.get(url, stream=True, timeout=(30, 90))
218
  if response.status_code != 200:
219
  if attempt == RETRY_COUNT - 1:
 
220
  return False
221
  time.sleep(2 ** attempt) # Exponential backoff
222
  continue
@@ -310,16 +344,70 @@ def create_safe_embedding(img_path: Path) -> np.ndarray:
310
  return np.zeros(CLIP_MODEL_LOCAL.config.projection_dim, dtype=np.float32)
311
 
312
  img = Image.open(img_path).convert("RGB")
 
313
  inputs = CLIP_PROCESSOR(images=img, return_tensors="pt")
 
314
 
315
  with torch.no_grad():
316
- embeddings = CLIP_MODEL_LOCAL.get_image_features(**inputs)
317
 
318
- return embeddings.squeeze().cpu().numpy().astype(np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
319
  except Exception as e:
320
  print(f"Embedding creation error for {img_path}: {e}")
321
  return np.zeros(CLIP_MODEL_LOCAL.config.projection_dim, dtype=np.float32)
322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  def process_embeddings_thread_safe() -> bool:
324
  """Create embeddings in background thread - NO Streamlit APIs"""
325
  image_files = sorted([f for f in IMAGES_DIR.glob("*.jpg")
@@ -328,6 +416,14 @@ def process_embeddings_thread_safe() -> bool:
328
  if not image_files:
329
  progress_tracker.update(PHASE_ERROR, 0, 1, 1, "❌ No images found", "")
330
  return False
 
 
 
 
 
 
 
 
331
 
332
  # Check if embeddings already exist and are current
333
  try:
@@ -348,68 +444,97 @@ def process_embeddings_thread_safe() -> bool:
348
  index = []
349
  processed = 0
350
  failed = 0
351
-
352
  progress_tracker.update(PHASE_2_EMBEDDING, 0, total, 0,
353
  f"🧠 Creating embeddings for {total} images...",
354
  "Processing visual features")
355
-
356
  try:
357
- for img_file in image_files:
358
- embedding = create_safe_embedding(img_file)
359
-
360
- if np.any(embedding): # Check if embedding is not all zeros
361
- embeddings.append(embedding)
362
- index.append(img_file.name)
363
- else:
364
- failed += 1
365
- # Still add to maintain indexing
366
- embeddings.append(embedding)
367
- index.append(img_file.name)
368
-
369
- processed += 1
370
-
371
- # Save in batches for resilience
372
- if processed % BATCH_SIZE == 0 or processed == total:
373
- try:
374
- if embeddings:
375
- embeddings_array = np.vstack(embeddings).astype(np.float32)
376
-
377
- # Atomic save
378
- temp_npy = EMB_NPY.with_suffix('.tmp')
379
- temp_json = EMB_INDEX_JSON.with_suffix('.tmp')
380
-
381
- np.save(temp_npy, embeddings_array)
382
- with open(temp_json, 'w') as f:
383
- json.dump(index, f, indent=2)
384
-
385
- # Atomic move
386
- temp_npy.replace(EMB_NPY)
387
- temp_json.replace(EMB_INDEX_JSON)
388
  else:
389
- # If no valid embeddings were created, ensure files are empty or removed.
390
- # This prevents partial/corrupted files from being considered complete.
391
- if EMB_NPY.exists():
392
- EMB_NPY.unlink()
393
- if EMB_INDEX_JSON.exists():
394
- EMB_INDEX_JSON.unlink()
395
- print("No valid embeddings to save, clearing existing embedding files.")
396
-
397
- details = f"💾 Batch saved • 📊 {len(embeddings)} embeddings"
398
- if failed > 0:
399
- details += f" ⚠️ {failed} errors"
400
-
401
- message = f"🧠 Processed {processed}/{total}"
402
- if processed == total:
403
- message = "✅ All embeddings created!"
404
-
405
- progress_tracker.update(PHASE_2_EMBEDDING, processed, total, failed,
406
- message, details)
407
-
408
- except Exception as e:
409
- progress_tracker.update(PHASE_ERROR, processed, total, failed,
410
- f"❌ Save failed: {e}", "")
411
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
 
 
 
 
 
 
 
 
 
 
413
  return True
414
 
415
  except Exception as e:
@@ -667,12 +792,6 @@ def init_session_state():
667
 
668
  def main():
669
  """Main application - All session state access here"""
670
- st.set_page_config(
671
- page_title="Visual Search System",
672
- page_icon="🔍",
673
- layout="wide",
674
- initial_sidebar_state="collapsed"
675
- )
676
 
677
  apply_styling()
678
  init_session_state() # Safe - main thread only
 
23
  from typing import Optional, Tuple, List
24
  import torch
25
  from transformers import CLIPProcessor, CLIPModel
26
+ import PIL
27
+
28
+ # Reduce thread contention in tokenizers
29
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
30
+
31
+ st.set_page_config(
32
+ page_title="Visual Search System",
33
+ page_icon="🔍",
34
+ layout="wide",
35
+ initial_sidebar_state="collapsed"
36
+ )
37
 
38
  # -----------------------
39
  # Configuration
 
48
  JPEG_QUALITY = 85
49
  TARGET_MAX_SIZE = (800, 800)
50
  MAX_WORKERS = 6 # Reduced for stability
51
+ RETRY_COUNT = 5
52
  BATCH_SIZE = 20
53
+ EMBED_BATCH_SIZE = 8
54
 
55
  EMB_NPY = EMBED_DIR / "image_embeddings.npy"
56
  EMB_INDEX_JSON = EMBED_DIR / "index.json"
57
  # Removed HIST_BINS_PER_CHANNEL and HIST_RANGE as they are no longer used for embedding generation
58
 
59
+ CLIP_MODEL = "openai/clip-vit-small-patch16" # Switched to smaller model
60
 
61
  @st.cache_resource
62
  def load_clip_model():
 
64
  print(f"Loading CLIP model: {CLIP_MODEL}...")
65
  processor = CLIPProcessor.from_pretrained(CLIP_MODEL)
66
  model = CLIPModel.from_pretrained(CLIP_MODEL)
67
+ device = torch.device("cpu")
68
+ model.to(device)
69
+ model.eval()
70
+ # Limit CPU threads to avoid oversubscription on Spaces/limited CPUs
71
+ try:
72
+ torch.set_num_threads(max(1, min(4, os.cpu_count() or 2)))
73
+ except Exception:
74
+ pass
75
  print("CLIP model loaded successfully.")
76
+ return processor, model, device
77
+
78
+ CLIP_PROCESSOR, CLIP_MODEL_LOCAL, CLIP_DEVICE = load_clip_model()
79
 
80
+ # Removed HF_TOKEN, API_URL, HEADERS as they are no longer used for image embedding
81
 
82
  # Phase Constants
83
  PHASE_IDLE = "idle"
 
148
  # Utility Functions
149
  # -----------------------
150
  def ensure_dirs():
151
+ """Create directories if they don't exist and clean up old progress/temp files"""
152
  try:
153
  IMAGES_DIR.mkdir(parents=True, exist_ok=True)
154
  EMBED_DIR.mkdir(parents=True, exist_ok=True)
155
+
156
+ # Clean up old progress and temp embedding files for a fresh start
157
+ if PROGRESS_FILE.exists():
158
+ PROGRESS_FILE.unlink()
159
+ if SETUP_COMPLETE_FILE.exists():
160
+ SETUP_COMPLETE_FILE.unlink()
161
+ for f in EMBED_DIR.glob("*.tmp"): # Clean up any temp embedding files
162
+ f.unlink()
163
+ for f in IMAGES_DIR.glob("*.tmp"): # Clean up any temp image files
164
+ f.unlink()
165
+
166
  except Exception as e:
167
+ print(f"Directory or cleanup error: {e}")
168
 
169
  def seq_filename(i: int) -> str:
170
  return f"{i:04d}.jpg"
 
250
  response = requests.get(url, stream=True, timeout=(30, 90))
251
  if response.status_code != 200:
252
  if attempt == RETRY_COUNT - 1:
253
+ print(f"Final download attempt failed for {url}. Status: {response.status_code}")
254
  return False
255
  time.sleep(2 ** attempt) # Exponential backoff
256
  continue
 
344
  return np.zeros(CLIP_MODEL_LOCAL.config.projection_dim, dtype=np.float32)
345
 
346
  img = Image.open(img_path).convert("RGB")
347
+ print(f"Embedding image: {img_path.name}, size={img.size}, mode={img.mode}")
348
  inputs = CLIP_PROCESSOR(images=img, return_tensors="pt")
349
+ inputs = {k: v.to(CLIP_DEVICE) for k, v in inputs.items()}
350
 
351
  with torch.no_grad():
352
+ image_features = CLIP_MODEL_LOCAL.get_image_features(**inputs)
353
 
354
+ if torch.isnan(image_features).any() or torch.isinf(image_features).any():
355
+ print(f"NaN/Inf detected in features for {img_path.name}")
356
+ return np.zeros(CLIP_MODEL_LOCAL.config.projection_dim, dtype=np.float32)
357
+
358
+ vec = image_features.squeeze().detach().cpu().numpy().astype(np.float32)
359
+ print(f"Feature vector shape: {vec.shape}, dtype: {vec.dtype}")
360
+ if vec.ndim != 1:
361
+ vec = vec.reshape(-1)
362
+
363
+ if vec.size != CLIP_MODEL_LOCAL.config.projection_dim:
364
+ print(f"Warning: feature dim {vec.size} != projection_dim {CLIP_MODEL_LOCAL.config.projection_dim}")
365
+
366
+ return vec
367
  except Exception as e:
368
  print(f"Embedding creation error for {img_path}: {e}")
369
  return np.zeros(CLIP_MODEL_LOCAL.config.projection_dim, dtype=np.float32)
370
 
371
+ def create_embeddings_batch(image_paths: List[Path]) -> np.ndarray:
372
+ """Create embeddings for a batch of images efficiently on CPU.
373
+ Returns array of shape (batch_size, projection_dim). Fills zeros on failures.
374
+ """
375
+ images = []
376
+ fallback_indices = []
377
+ for idx, p in enumerate(image_paths):
378
+ try:
379
+ if not p.exists() or p.stat().st_size == 0:
380
+ fallback_indices.append(idx)
381
+ images.append(Image.new("RGB", (224, 224), color=(0, 0, 0)))
382
+ continue
383
+ img = Image.open(p).convert("RGB")
384
+ # Pre-resize to 224 to reduce CPU and memory
385
+ img = img.resize((224, 224), Image.Resampling.LANCZOS)
386
+ images.append(img)
387
+ except (PIL.UnidentifiedImageError, IOError) as image_err:
388
+ print(f"Image loading error for {p.name}: {image_err}. Using blank image.")
389
+ fallback_indices.append(idx)
390
+ images.append(Image.new("RGB", (224, 224), color=(0, 0, 0)))
391
+ except Exception as e:
392
+ print(f"Unexpected error loading image {p.name}: {e}. Using blank image.")
393
+ fallback_indices.append(idx)
394
+ images.append(Image.new("RGB", (224, 224), color=(0, 0, 0)))
395
+
396
+ try:
397
+ inputs = CLIP_PROCESSOR(images=images, return_tensors="pt")
398
+ inputs = {k: v.to(CLIP_DEVICE) for k, v in inputs.items()}
399
+ with torch.no_grad():
400
+ feats = CLIP_MODEL_LOCAL.get_image_features(**inputs)
401
+ feats = feats.detach().cpu().numpy().astype(np.float32)
402
+ # Replace fallback rows with zeros explicitly
403
+ for i in fallback_indices:
404
+ feats[i, :] = np.zeros_like(feats[i, :])
405
+ return feats
406
+ except Exception as e:
407
+ print(f"Batch embedding error for {len(image_paths)} images: {e}")
408
+ # Return None to signal caller to fallback to smaller batch
409
+ return None
410
+
411
  def process_embeddings_thread_safe() -> bool:
412
  """Create embeddings in background thread - NO Streamlit APIs"""
413
  image_files = sorted([f for f in IMAGES_DIR.glob("*.jpg")
 
416
  if not image_files:
417
  progress_tracker.update(PHASE_ERROR, 0, 1, 1, "❌ No images found", "")
418
  return False
419
+
420
+ # Quick self-test on the first image to detect failures early
421
+ try:
422
+ test_vec = create_safe_embedding(image_files[0])
423
+ if not np.any(test_vec):
424
+ print(f"Self-test failed on first image: {image_files[0].name}")
425
+ except Exception as e:
426
+ print(f"Self-test exception: {e}")
427
 
428
  # Check if embeddings already exist and are current
429
  try:
 
444
  index = []
445
  processed = 0
446
  failed = 0
447
+
448
  progress_tracker.update(PHASE_2_EMBEDDING, 0, total, 0,
449
  f"🧠 Creating embeddings for {total} images...",
450
  "Processing visual features")
451
+
452
  try:
453
+ current_batch_size = EMBED_BATCH_SIZE
454
+ for start in range(0, total, current_batch_size):
455
+ # adaptively chunk
456
+ end = min(start + current_batch_size, total)
457
+ batch_files = image_files[start:end]
458
+
459
+ # Try with current batch size; fallback by halving on failure
460
+ attempts = 0
461
+ feats = None
462
+ while attempts < 3:
463
+ feats = create_embeddings_batch(batch_files)
464
+ if feats is None:
465
+ attempts += 1
466
+ if current_batch_size > 4:
467
+ current_batch_size = max(4, current_batch_size // 2)
468
+ end = min(start + current_batch_size, total)
469
+ batch_files = image_files[start:end]
470
+ print(f"⚠️ Falling back to smaller batch size: {current_batch_size}")
471
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
472
  else:
473
+ # Hard failure at smallest batch: compute per-image to maximize success
474
+ per_feats = []
475
+ for p in batch_files:
476
+ vec = create_safe_embedding(p)
477
+ per_feats.append(vec)
478
+ feats = np.vstack(per_feats).astype(np.float32)
479
+ break
480
+ break
481
+
482
+ # Count failures in this batch (rows that are all zeros)
483
+ if feats.ndim != 2 or feats.shape[0] != len(batch_files):
484
+ print(f"Unexpected batch feature shape: {feats.shape}, expected ({len(batch_files)}, D)")
485
+ batch_failed = int((np.linalg.norm(feats, axis=1) < 1e-12).sum()) if feats.size else 0
486
+ failed += batch_failed
487
+ embeddings.append(feats)
488
+ index.extend([p.name for p in batch_files])
489
+ processed = end
490
+
491
+ # Periodic save by batch for resilience
492
+ try:
493
+ if embeddings:
494
+ embeddings_array = np.vstack(embeddings).astype(np.float32)
495
+ temp_npy = EMB_NPY.with_suffix('.tmp')
496
+ temp_json = EMB_INDEX_JSON.with_suffix('.tmp')
497
+ np.save(temp_npy, embeddings_array)
498
+ with open(temp_json, 'w') as f:
499
+ json.dump(index, f, indent=2)
500
+ temp_npy.replace(EMB_NPY)
501
+ temp_json.replace(EMB_INDEX_JSON)
502
+ except Exception as e:
503
+ progress_tracker.update(PHASE_ERROR, processed, total, failed,
504
+ f"❌ Save failed: {e}", "")
505
+ return False
506
+
507
+ # Free memory after each batch
508
+ try:
509
+ import gc
510
+ del feats
511
+ gc.collect()
512
+ except Exception:
513
+ pass
514
+
515
+ success_rate = ((processed - failed) / processed * 100) if processed > 0 else 0
516
+ batch_success_count = len(batch_files) - batch_failed
517
+ print(f"Batch {start//current_batch_size + 1} completed: {batch_success_count} success, {batch_failed} failed.")
518
+ details = f"💾 Saved up to {processed} • 📊 failures {failed}"
519
+ message = f"🧠 Processed {processed}/{total} ({success_rate:.1f}%)"
520
+ progress_tracker.update(PHASE_2_EMBEDDING, processed, total, failed,
521
+ message, details)
522
+
523
+ # Final validation
524
+ embeddings_array = np.vstack(embeddings).astype(np.float32) if embeddings else np.zeros((0, CLIP_MODEL_LOCAL.config.projection_dim), dtype=np.float32)
525
+ if embeddings_array.shape[0] != len(index) or len(index) != total:
526
+ print(f"⚠️ Final size mismatch: emb_rows={embeddings_array.shape[0]}, index={len(index)}, total={total}")
527
+ print(f"Embedding processing completed. Total failed: {failed}/{total}")
528
 
529
+ # Ensure files saved
530
+ temp_npy = EMB_NPY.with_suffix('.tmp')
531
+ temp_json = EMB_INDEX_JSON.with_suffix('.tmp')
532
+ np.save(temp_npy, embeddings_array)
533
+ with open(temp_json, 'w') as f:
534
+ json.dump(index, f, indent=2)
535
+ temp_npy.replace(EMB_NPY)
536
+ temp_json.replace(EMB_INDEX_JSON)
537
+
538
  return True
539
 
540
  except Exception as e:
 
792
 
793
  def main():
794
  """Main application - All session state access here"""
 
 
 
 
 
 
795
 
796
  apply_styling()
797
  init_session_state() # Safe - main thread only