vungocthach1112 commited on
Commit
9f89b03
·
1 Parent(s): 10fc621

Add logic UI gradio

Browse files
Files changed (5) hide show
  1. app.py +270 -1
  2. configs.py +21 -1
  3. dataset_utils.py +90 -1
  4. encoders.py +139 -1
  5. index_builder.py +168 -1
app.py CHANGED
@@ -1,3 +1,272 @@
1
  #!/usr/bin/env python3
2
  # -*- coding: utf-8 -*-
3
- ... (truncated for brevity, full content from assistant's previous message) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #!/usr/bin/env python3
2
  # -*- coding: utf-8 -*-
3
+
4
+ from __future__ import annotations
5
+
6
+ import os
7
+ import json
8
+ import time
9
+ import hashlib
10
+ from pathlib import Path
11
+ from typing import List, Tuple, Optional
12
+
13
+ import gradio as gr
14
+ from PIL import Image
15
+
16
+ from configs import (
17
+ DEFAULT_SPLIT,
18
+ DEFAULT_TOP_K,
19
+ DEFAULT_MAX_SAMPLES,
20
+ DEFAULT_IMAGE_COL,
21
+ DEFAULT_TEXT_COL,
22
+ DEFAULT_INDEX_DIR,
23
+ EXAMPLE_QUERIES,
24
+ DATASET_NAME,
25
+ )
26
+ from dataset_utils import load_fashion_dataset, SampleAccessor
27
+ from encoders import SiglipEncoder
28
+ from index_builder import (
29
+ IndexStatus,
30
+ ensure_index,
31
+ load_faiss_index,
32
+ search_faiss,
33
+ index_signature_from_env,
34
+ )
35
+
36
+
37
+ # --------------------------
38
+ # Globals initialized lazily
39
+ # --------------------------
40
+ _encoder: Optional[SiglipEncoder] = None
41
+ _accessor: Optional[SampleAccessor] = None
42
+ _index_ref = {"index": None, "sig": None, "dim": None}
43
+
44
+
45
+ def _get_encoder(log: callable) -> SiglipEncoder:
46
+ global _encoder
47
+ if _encoder is None:
48
+ ckpt = os.getenv("SIGLIP_CHECKPOINT_DIR", "siglip_checkpoint")
49
+ log(f"Loading SigLIP checkpoint from: {ckpt}")
50
+ _encoder = SiglipEncoder.from_checkpoint_dir(ckpt, log=log)
51
+ log(f"Device: {_encoder.device}, dtype: {_encoder.dtype}")
52
+ return _encoder
53
+
54
+
55
+ def _get_accessor(split: str, image_col: str, text_col: str, max_samples: int, log: callable) -> SampleAccessor:
56
+ global _accessor
57
+ # Always create a new accessor matching current UI state
58
+ _accessor = load_fashion_dataset(
59
+ dataset_name=DATASET_NAME,
60
+ split=split,
61
+ image_col=image_col,
62
+ text_col=text_col,
63
+ max_samples=max_samples,
64
+ log=log,
65
+ )
66
+ log(f"Dataset ready: {len(_accessor)} samples.")
67
+ return _accessor
68
+
69
+
70
+ def _maybe_load_index(sig: str, index_dir: str, log: callable):
71
+ """Try to load an existing FAISS index matching 'sig'."""
72
+ faiss_index, dim = load_faiss_index(index_dir=index_dir, signature=sig, log=log)
73
+ _index_ref["index"] = faiss_index
74
+ _index_ref["sig"] = sig
75
+ _index_ref["dim"] = dim
76
+ return faiss_index, dim
77
+
78
+
79
+ def _build_index(
80
+ split: str,
81
+ image_col: str,
82
+ text_col: str,
83
+ max_samples: int,
84
+ index_dir: str,
85
+ log: callable,
86
+ ) -> Tuple[IndexStatus, str]:
87
+ """Ensure an index exists and return its status + signature string."""
88
+ encoder = _get_encoder(log)
89
+ accessor = _get_accessor(split, image_col, text_col, max_samples, log)
90
+
91
+ sig = index_signature_from_env(
92
+ dataset_name=DATASET_NAME,
93
+ split=split,
94
+ max_samples=max_samples,
95
+ ckpt_dir=encoder.ckpt_dir,
96
+ image_col=image_col,
97
+ text_col=text_col,
98
+ )
99
+ status = ensure_index(
100
+ accessor=accessor,
101
+ encoder=encoder,
102
+ index_dir=index_dir,
103
+ signature=sig,
104
+ log=log,
105
+ )
106
+ return status, sig
107
+
108
+
109
+ def _log_to_console(msg: str):
110
+ print(msg, flush=True)
111
+
112
+
113
+ def ui_rebuild_index(split, image_col, text_col, max_samples, index_dir):
114
+ logs = []
115
+
116
+ def log(s: str):
117
+ logs.append(s)
118
+ _log_to_console(s)
119
+
120
+ status, sig = _build_index(split, image_col, text_col, max_samples, index_dir, log)
121
+ _maybe_load_index(sig, index_dir, log)
122
+
123
+ footer = f"Index status: {status.value} | signature: {sig}"
124
+ if _index_ref["index"] is not None:
125
+ footer += f" | dim={_index_ref['dim']}"
126
+
127
+ return "\n".join(logs), footer
128
+
129
+
130
+ def ui_search(
131
+ query_text: str,
132
+ split: str,
133
+ image_col: str,
134
+ text_col: str,
135
+ max_samples: int,
136
+ top_k: int,
137
+ index_dir: str,
138
+ ):
139
+ if not query_text or not query_text.strip():
140
+ return [], "Please enter a non-empty query."
141
+
142
+ # Prepare logger to capture build/search messages in the footer/status
143
+ logs = []
144
+
145
+ def log(s: str):
146
+ logs.append(s)
147
+ _log_to_console(s)
148
+
149
+ # Make sure encoder, dataset accessor, and index are aligned to UI state
150
+ encoder = _get_encoder(log)
151
+ accessor = _get_accessor(split, image_col, text_col, max_samples, log)
152
+
153
+ sig = index_signature_from_env(
154
+ dataset_name=DATASET_NAME,
155
+ split=split,
156
+ max_samples=max_samples,
157
+ ckpt_dir=encoder.ckpt_dir,
158
+ image_col=image_col,
159
+ text_col=text_col,
160
+ )
161
+ if _index_ref["index"] is None or _index_ref["sig"] != sig:
162
+ # Try load; if not present, build
163
+ idx, _ = _maybe_load_index(sig, index_dir, log)
164
+ if idx is None:
165
+ log("Index not found on disk. Building now...")
166
+ status, sig = _build_index(split, image_col, text_col, max_samples, index_dir, log)
167
+ log(f"Index status after build: {status.value}")
168
+ _maybe_load_index(sig, index_dir, log)
169
+
170
+ if _index_ref["index"] is None:
171
+ return [], "Index is unavailable. Check logs."
172
+
173
+ # Encode query
174
+ tic = time.time()
175
+ q_emb = encoder.encode_texts([query_text]) # (1, D), already L2-normalized
176
+ encode_ms = (time.time() - tic) * 1000.0
177
+
178
+ # Search
179
+ tic = time.time()
180
+ scores, ids = search_faiss(_index_ref["index"], q_emb, top_k=top_k)
181
+ search_ms = (time.time() - tic) * 1000.0
182
+
183
+ # Prepare gallery: [(image, caption)] with title as "score: ..."
184
+ results = []
185
+ for rank, (idx, score) in enumerate(zip(ids[0], scores[0]), start=1):
186
+ sample = accessor.get(idx)
187
+ img: Image.Image = sample.image
188
+ cap: str = sample.text
189
+ # Gradio Gallery expects [ (image, caption) ]
190
+ caption = f"#{rank} | score={score:.4f}\n{cap}"
191
+ results.append((img, caption))
192
+
193
+ footer = f"Encoded in {encode_ms:.1f} ms, searched in {search_ms:.1f} ms | idx sig: {sig}"
194
+ if logs:
195
+ footer += "\n" + "\n".join(logs)
196
+
197
+ return results, footer
198
+
199
+
200
+ def build_ui():
201
+ with gr.Blocks(css="footer {visibility: hidden}") as demo:
202
+ gr.Markdown(
203
+ """
204
+ # 🔎 Text → Image Retrieval (SigLIP + FAISS)
205
+ Dataset: `tomytjandra/h-and-m-fashion-caption` • Index is cached on disk • Works on CPU (default) and uses GPU FAISS if available.
206
+ """
207
+ )
208
+
209
+ with gr.Row():
210
+ with gr.Column(scale=2):
211
+ query = gr.Textbox(
212
+ label="Enter a text query",
213
+ placeholder="e.g., 'women's red floral dress with long sleeves'",
214
+ )
215
+ examples = gr.Examples(
216
+ examples=[[q] for q in EXAMPLE_QUERIES],
217
+ inputs=[query],
218
+ label="Examples",
219
+ )
220
+ top_k = gr.Slider(1, 50, value=DEFAULT_TOP_K, step=1, label="Top‑K")
221
+ search_btn = gr.Button("Search", variant="primary")
222
+
223
+ with gr.Column(scale=1):
224
+ split = gr.Dropdown(
225
+ choices=["train", "validation", "test"],
226
+ value=DEFAULT_SPLIT,
227
+ label="Dataset split",
228
+ )
229
+ image_col = gr.Textbox(value=DEFAULT_IMAGE_COL, label="IMAGE_COL")
230
+ text_col = gr.Textbox(value=DEFAULT_TEXT_COL, label="TEXT_COL")
231
+ max_samples = gr.Slider(
232
+ minimum=100, maximum=200_000, value=DEFAULT_MAX_SAMPLES, step=100,
233
+ label="MAX_SAMPLES (cap for demo)"
234
+ )
235
+ index_dir = gr.Textbox(value=DEFAULT_INDEX_DIR, label="INDEX_DIR")
236
+ rebuild_btn = gr.Button("(Re)Build Index")
237
+
238
+ with gr.Row():
239
+ gallery = gr.Gallery(
240
+ label="Results",
241
+ columns=5,
242
+ height=520,
243
+ preview=True,
244
+ show_label=True,
245
+ )
246
+ status = gr.Textbox(label="Status / Logs", interactive=False)
247
+
248
+ # Wire actions
249
+ search_btn.click(
250
+ ui_search,
251
+ inputs=[query, split, image_col, text_col, max_samples, top_k, index_dir],
252
+ outputs=[gallery, status],
253
+ )
254
+ rebuild_btn.click(
255
+ ui_rebuild_index,
256
+ inputs=[split, image_col, text_col, max_samples, index_dir],
257
+ outputs=[status, status],
258
+ )
259
+
260
+ gr.Markdown(
261
+ """
262
+ **Notes**
263
+ - Set `SIGLIP_CHECKPOINT_DIR` to your local SigLIP checkpoint folder (uploaded to this Space).
264
+ - First run will build an index and cache it under `INDEX_DIR`.
265
+ - Uses cosine similarity via L2-normalized embeddings on `IndexFlatIP`. If GPU FAISS is available, it will be used automatically.
266
+ """
267
+ )
268
+ return demo
269
+
270
+
271
+ if __name__ == "__main__":
272
+ build_ui().launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))
configs.py CHANGED
@@ -1,2 +1,22 @@
1
  from __future__ import annotations
2
- ... (full content from assistant's previous message) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
+
3
+ DATASET_NAME = "tomytjandra/h-and-m-fashion-caption"
4
+
5
+ # Default dataset/view options (can be changed in UI)
6
+ DEFAULT_SPLIT = "validation" # "train" | "validation" | "test"
7
+ DEFAULT_IMAGE_COL = "image" # change if your dataset variant differs
8
+ DEFAULT_TEXT_COL = "caption" # change if your dataset variant differs
9
+
10
+ DEFAULT_MAX_SAMPLES = 5000 # cap for demo builds; adjustable in UI
11
+ DEFAULT_TOP_K = 12
12
+
13
+ # Index cache directory inside the Space persistent storage
14
+ DEFAULT_INDEX_DIR = "./index_cache"
15
+
16
+ EXAMPLE_QUERIES = [
17
+ "red floral summer dress",
18
+ "men's black leather jacket",
19
+ "white sneakers with chunky sole",
20
+ "blue denim jeans for women",
21
+ "kids' yellow raincoat",
22
+ ]
dataset_utils.py CHANGED
@@ -1,2 +1,91 @@
1
  from __future__ import annotations
2
- ... (full content from assistant's previous message) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Callable, Optional, List
5
+
6
+ from datasets import load_dataset, Dataset
7
+ from PIL import Image
8
+
9
+
10
+ @dataclass
11
+ class Sample:
12
+ idx: int
13
+ image: Image.Image
14
+ text: str
15
+
16
+
17
+ class SampleAccessor:
18
+ """A thin wrapper for random access into a loaded HF dataset with normalized columns."""
19
+
20
+ def __init__(self, hf_ds: Dataset, image_col: str, text_col: str):
21
+ self.ds = hf_ds
22
+ self.image_col = image_col
23
+ self.text_col = text_col
24
+
25
+ def __len__(self) -> int:
26
+ return len(self.ds)
27
+
28
+ def get(self, i: int) -> Sample:
29
+ row = self.ds[i]
30
+ img = row[self.image_col]
31
+ if not isinstance(img, Image.Image):
32
+ # datasets may store {'bytes':..., 'path':...} or numpy array
33
+ img = Image.fromarray(img)
34
+ if img.mode not in ("RGB", "RGBA"):
35
+ img = img.convert("RGB")
36
+ if img.mode == "RGBA":
37
+ img = img.convert("RGB") # drop alpha for encoders
38
+ text = str(row[self.text_col])
39
+ return Sample(idx=i, image=img, text=text)
40
+
41
+ def batched_images(self, start: int, end: int) -> List[Image.Image]:
42
+ images = []
43
+ rows = self.ds[start:end]
44
+ for row in rows:
45
+ img = row[self.image_col]
46
+ if not isinstance(img, Image.Image):
47
+ img = Image.fromarray(img)
48
+ if img.mode == "RGBA":
49
+ img = img.convert("RGB")
50
+ elif img.mode != "RGB":
51
+ img = img.convert("RGB")
52
+ images.append(img)
53
+ return images
54
+
55
+ def texts(self, start: int, end: int) -> List[str]:
56
+ rows = self.ds[start:end]
57
+ return [str(r[self.text_col]) for r in rows]
58
+
59
+
60
+ def load_fashion_dataset(
61
+ dataset_name: str,
62
+ split: str,
63
+ image_col: str,
64
+ text_col: str,
65
+ max_samples: int,
66
+ log: Optional[Callable[[str], None]] = None,
67
+ ) -> SampleAccessor:
68
+ """Load and normalize the H&M fashion caption dataset.
69
+
70
+ Some dataset versions may vary in column names, so we accept user-specified columns.
71
+ """
72
+ if log:
73
+ log(f"Loading dataset: {dataset_name} [{split}] (max_samples={max_samples})")
74
+ ds = load_dataset(dataset_name, split=split, streaming=False)
75
+ total = len(ds)
76
+ if log:
77
+ log(f"Dataset size (split={split}): {total}")
78
+
79
+ # Trim to max_samples for demo
80
+ if max_samples is not None and max_samples < total:
81
+ ds = ds.select(range(max_samples))
82
+
83
+ # Validate columns
84
+ for col in (image_col, text_col):
85
+ if col not in ds.column_names:
86
+ raise KeyError(
87
+ f"Column '{col}' not found in dataset. Available: {ds.column_names}. "
88
+ "Adjust IMAGE_COL/TEXT_COL in the UI."
89
+ )
90
+
91
+ return SampleAccessor(hf_ds=ds, image_col=image_col, text_col=text_col)
encoders.py CHANGED
@@ -1,2 +1,140 @@
1
  from __future__ import annotations
2
- ... (full content from assistant's previous message) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
+
3
+ import os
4
+ from dataclasses import dataclass
5
+ from typing import List, Tuple, Optional, Callable
6
+
7
+ import torch
8
+ import numpy as np
9
+ from PIL import Image
10
+ from transformers import AutoModel, AutoProcessor
11
+
12
+
13
+ def _pick_device() -> torch.device:
14
+ if torch.cuda.is_available():
15
+ return torch.device("cuda")
16
+ # Apple Silicon
17
+ if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
18
+ return torch.device("mps")
19
+ return torch.device("cpu")
20
+
21
+
22
+ def _pick_dtype(device: torch.device) -> torch.dtype:
23
+ if device.type == "cuda":
24
+ # Prefer bf16 if supported; else fp16
25
+ if torch.cuda.is_bf16_supported():
26
+ return torch.bfloat16
27
+ return torch.float16
28
+ if device.type == "mps":
29
+ # mps prefers float32 accuracy
30
+ return torch.float32
31
+ return torch.float32
32
+
33
+
34
+ @dataclass
35
+ class SiglipEncoder:
36
+ model: AutoModel
37
+ processor: AutoProcessor
38
+ device: torch.device
39
+ dtype: torch.dtype
40
+ ckpt_dir: str
41
+
42
+ @classmethod
43
+ def from_checkpoint_dir(cls, ckpt_dir: str, log: Optional[Callable[[str], None]] = None) -> "SiglipEncoder":
44
+ if not os.path.isdir(ckpt_dir):
45
+ raise FileNotFoundError(
46
+ f"SIGLIP_CHECKPOINT_DIR not found: {ckpt_dir}. "
47
+ "Upload your SigLIP checkpoint folder to the Space and set the env var."
48
+ )
49
+ device = _pick_device()
50
+ dtype = _pick_dtype(device)
51
+
52
+ if log:
53
+ log(f"Loading processor/model from {ckpt_dir} (device={device}, dtype={dtype})")
54
+
55
+ processor = AutoProcessor.from_pretrained(ckpt_dir, trust_remote_code=True)
56
+ model = AutoModel.from_pretrained(ckpt_dir, trust_remote_code=True)
57
+
58
+ model.to(device)
59
+ model.eval()
60
+ return cls(model=model, processor=processor, device=device, dtype=dtype, ckpt_dir=ckpt_dir)
61
+
62
+ # ---------- Embedding helpers ----------
63
+
64
+ @torch.no_grad()
65
+ def _maybe_autocast(self):
66
+ # cuda amp context
67
+ if self.device.type == "cuda" and self.dtype in (torch.float16, torch.bfloat16):
68
+ return torch.autocast(device_type="cuda", dtype=self.dtype)
69
+ # for mps/cpu, no autocast by default
70
+ class DummyCtx:
71
+ def __enter__(self): return None
72
+ def __exit__(self, *args): return False
73
+ return DummyCtx()
74
+
75
+ def _normalize(self, x: np.ndarray) -> np.ndarray:
76
+ norms = np.linalg.norm(x, axis=1, keepdims=True) + 1e-12
77
+ return x / norms
78
+
79
+ def _pool_mean(self, last_hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor]) -> torch.Tensor:
80
+ # mean pooling with attention mask
81
+ if attention_mask is None:
82
+ return last_hidden_state.mean(dim=1)
83
+ mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
84
+ summed = (last_hidden_state * mask).sum(dim=1)
85
+ counts = mask.sum(dim=1).clamp(min=1e-6)
86
+ return summed / counts
87
+
88
+ def _forward_image(self, pixel_values: torch.Tensor) -> torch.Tensor:
89
+ # Try common signatures: get_image_features or forward(...).image_embeds
90
+ # Fallback: mean pool last_hidden_state of vision tower.
91
+ if hasattr(self.model, "get_image_features"):
92
+ return self.model.get_image_features(pixel_values=pixel_values)
93
+ out = self.model(pixel_values=pixel_values)
94
+ if hasattr(out, "image_embeds") and out.image_embeds is not None:
95
+ return out.image_embeds
96
+ if hasattr(out, "last_hidden_state"):
97
+ return out.last_hidden_state.mean(dim=1)
98
+ raise RuntimeError("Unable to extract image embeddings from model outputs.")
99
+
100
+ def _forward_text(self, **text_inputs) -> torch.Tensor:
101
+ if hasattr(self.model, "get_text_features"):
102
+ return self.model.get_text_features(**text_inputs)
103
+ out = self.model(**text_inputs)
104
+ if hasattr(out, "text_embeds") and out.text_embeds is not None:
105
+ return out.text_embeds
106
+ if hasattr(out, "last_hidden_state"):
107
+ return self._pool_mean(out.last_hidden_state, text_inputs.get("attention_mask"))
108
+ raise RuntimeError("Unable to extract text embeddings from model outputs.")
109
+
110
+ @torch.no_grad()
111
+ def encode_images(self, images: List[Image.Image], batch_size: int = 64) -> np.ndarray:
112
+ """Encode a list of PIL images to L2-normalized embeddings."""
113
+ feats: List[np.ndarray] = []
114
+ with self._maybe_autocast():
115
+ for i in range(0, len(images), batch_size):
116
+ batch = images[i : i + batch_size]
117
+ # Ensure RGB
118
+ batch = [im.convert("RGB") if im.mode != "RGB" else im for im in batch]
119
+ inputs = self.processor(images=batch, return_tensors="pt")
120
+ pixel_values = inputs["pixel_values"].to(self.device, dtype=self.dtype if self.device.type == "cuda" else torch.float32)
121
+ embs = self._forward_image(pixel_values) # (B, D)
122
+ embs = embs.float().cpu().numpy()
123
+ feats.append(embs)
124
+ feats_np = np.concatenate(feats, axis=0)
125
+ return self._normalize(feats_np)
126
+
127
+ @torch.no_grad()
128
+ def encode_texts(self, texts: List[str], batch_size: int = 128) -> np.ndarray:
129
+ """Encode a list of texts to L2-normalized embeddings."""
130
+ feats: List[np.ndarray] = []
131
+ with self._maybe_autocast():
132
+ for i in range(0, len(texts), batch_size):
133
+ batch = texts[i : i + batch_size]
134
+ inputs = self.processor(text=batch, return_tensors="pt", padding=True, truncation=True)
135
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
136
+ embs = self._forward_text(**inputs) # (B, D)
137
+ embs = embs.float().cpu().numpy()
138
+ feats.append(embs)
139
+ feats_np = np.concatenate(feats, axis=0)
140
+ return self._normalize(feats_np)
index_builder.py CHANGED
@@ -1,2 +1,169 @@
1
  from __future__ import annotations
2
- ... (full content from assistant's previous message) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
+
3
+ import os
4
+ import io
5
+ import json
6
+ import time
7
+ import hashlib
8
+ from enum import Enum
9
+ from typing import Optional, Callable, Tuple
10
+
11
+ import numpy as np
12
+
13
+ try:
14
+ import faiss # type: ignore
15
+ except Exception as e:
16
+ raise RuntimeError(
17
+ "Failed to import faiss. Ensure 'faiss-cpu' is in requirements.txt."
18
+ ) from e
19
+
20
+ from dataset_utils import SampleAccessor
21
+ from encoders import SiglipEncoder
22
+
23
+
24
+ class IndexStatus(Enum):
25
+ CREATED = "CREATED"
26
+ LOADED = "LOADED"
27
+ SKIPPED_FOUND = "SKIPPED_FOUND"
28
+ UPDATED = "UPDATED"
29
+
30
+
31
+ def index_signature_from_env(
32
+ dataset_name: str,
33
+ split: str,
34
+ max_samples: int,
35
+ ckpt_dir: str,
36
+ image_col: str,
37
+ text_col: str,
38
+ ) -> str:
39
+ """Create a stable signature for the on-disk index cache."""
40
+ # include optional hash of checkpoint config.json if exists
41
+ cfg_path = os.path.join(ckpt_dir, "config.json")
42
+ cfg_hash = "nocfg"
43
+ if os.path.isfile(cfg_path):
44
+ try:
45
+ with open(cfg_path, "rb") as f:
46
+ cfg_hash = hashlib.md5(f.read()).hexdigest()[:10]
47
+ except Exception:
48
+ pass
49
+ base = json.dumps(
50
+ {
51
+ "dataset": dataset_name,
52
+ "split": split,
53
+ "max_samples": int(max_samples),
54
+ "ckpt": os.path.basename(os.path.abspath(ckpt_dir)),
55
+ "cfg": cfg_hash,
56
+ "image_col": image_col,
57
+ "text_col": text_col,
58
+ },
59
+ sort_keys=True,
60
+ )
61
+ return hashlib.sha1(base.encode("utf-8")).hexdigest()[:16]
62
+
63
+
64
+ def _index_paths(index_dir: str, signature: str):
65
+ os.makedirs(index_dir, exist_ok=True)
66
+ idx_path = os.path.join(index_dir, f"{signature}.faiss")
67
+ meta_path = os.path.join(index_dir, f"{signature}.meta.json")
68
+ return idx_path, meta_path
69
+
70
+
71
+ def _maybe_gpu(index):
72
+ """If FAISS GPU is available, move index to GPU; else return as-is."""
73
+ try:
74
+ import faiss # noqa
75
+ if faiss.get_num_gpus() > 0:
76
+ res = faiss.StandardGpuResources()
77
+ return faiss.index_cpu_to_gpu(res, 0, index)
78
+ except Exception:
79
+ pass
80
+ return index
81
+
82
+
83
+ def _normalize_rows(x: np.ndarray) -> np.ndarray:
84
+ norms = np.linalg.norm(x, axis=1, keepdims=True) + 1e-12
85
+ return x / norms
86
+
87
+
88
+ def ensure_index(
89
+ accessor: SampleAccessor,
90
+ encoder: SiglipEncoder,
91
+ index_dir: str,
92
+ signature: str,
93
+ log: Optional[Callable[[str], None]] = None,
94
+ ) -> IndexStatus:
95
+ """Create the FAISS index if not present; otherwise leave it."""
96
+ idx_path, meta_path = _index_paths(index_dir, signature)
97
+
98
+ if os.path.isfile(idx_path) and os.path.isfile(meta_path):
99
+ if log:
100
+ log(f"Index already exists at {idx_path}")
101
+ return IndexStatus.SKIPPED_FOUND
102
+
103
+ # Encode all images in batches
104
+ n = len(accessor)
105
+ if log:
106
+ log(f"Encoding {n} images to build index ...")
107
+
108
+ batch = 512
109
+ feats = []
110
+ t0 = time.time()
111
+ for start in range(0, n, batch):
112
+ end = min(n, start + batch)
113
+ imgs = accessor.batched_images(start, end)
114
+ emb = encoder.encode_images(imgs) # (B, D), L2 normalized
115
+ feats.append(emb)
116
+ if log:
117
+ pct = (end / n) * 100.0
118
+ log(f"Progress: {end}/{n} ({pct:.1f}%)")
119
+
120
+ feats_np = np.concatenate(feats, axis=0).astype("float32", copy=False)
121
+ dim = feats_np.shape[1]
122
+
123
+ # Build cosine via inner-product on normalized vectors
124
+ cpu_index = faiss.IndexFlatIP(dim)
125
+ cpu_index.add(feats_np)
126
+
127
+ # Save to disk (CPU index for compatibility)
128
+ faiss.write_index(cpu_index, idx_path)
129
+
130
+ # Save meta information (captions and mapping)
131
+ meta = {
132
+ "signature": signature,
133
+ "size": int(n),
134
+ "dim": int(dim),
135
+ "created_at": time.time(),
136
+ "index_path": os.path.basename(idx_path),
137
+ "notes": "Embeddings are L2-normalized; cosine == inner product.",
138
+ }
139
+ with open(meta_path, "w", encoding="utf-8") as f:
140
+ json.dump(meta, f, ensure_ascii=False, indent=2)
141
+
142
+ if log:
143
+ log(f"Index built in {(time.time() - t0):.2f}s. Saved to {idx_path}")
144
+ return IndexStatus.CREATED
145
+
146
+
147
+ def load_faiss_index(index_dir: str, signature: str, log: Optional[Callable[[str], None]] = None):
148
+ idx_path, meta_path = _index_paths(index_dir, signature)
149
+ if not (os.path.isfile(idx_path) and os.path.isfile(meta_path)):
150
+ return None, None
151
+
152
+ with open(meta_path, "r", encoding="utf-8") as f:
153
+ meta = json.load(f)
154
+ idx = faiss.read_index(idx_path)
155
+ dim = int(meta.get("dim", idx.d))
156
+ # Try moving to GPU
157
+ idx = _maybe_gpu(idx)
158
+ if log:
159
+ log(f"Loaded FAISS index: {idx_path} (dim={dim})")
160
+ return idx, dim
161
+
162
+
163
+ def search_faiss(index, query_embs: np.ndarray, top_k: int = 10):
164
+ """Search FAISS (inner product) with normalized query embeddings."""
165
+ assert query_embs.ndim == 2
166
+ # Ensure L2-normalized
167
+ q = _normalize_rows(query_embs.astype("float32", copy=False))
168
+ scores, ids = index.search(q, int(top_k))
169
+ return scores, ids