Spaces:
Build error
Build error
Commit
·
9f89b03
1
Parent(s):
10fc621
Add logic UI gradio
Browse files- app.py +270 -1
- configs.py +21 -1
- dataset_utils.py +90 -1
- encoders.py +139 -1
- index_builder.py +168 -1
app.py
CHANGED
|
@@ -1,3 +1,272 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
# -*- coding: utf-8 -*-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
import time
|
| 9 |
+
import hashlib
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import List, Tuple, Optional
|
| 12 |
+
|
| 13 |
+
import gradio as gr
|
| 14 |
+
from PIL import Image
|
| 15 |
+
|
| 16 |
+
from configs import (
|
| 17 |
+
DEFAULT_SPLIT,
|
| 18 |
+
DEFAULT_TOP_K,
|
| 19 |
+
DEFAULT_MAX_SAMPLES,
|
| 20 |
+
DEFAULT_IMAGE_COL,
|
| 21 |
+
DEFAULT_TEXT_COL,
|
| 22 |
+
DEFAULT_INDEX_DIR,
|
| 23 |
+
EXAMPLE_QUERIES,
|
| 24 |
+
DATASET_NAME,
|
| 25 |
+
)
|
| 26 |
+
from dataset_utils import load_fashion_dataset, SampleAccessor
|
| 27 |
+
from encoders import SiglipEncoder
|
| 28 |
+
from index_builder import (
|
| 29 |
+
IndexStatus,
|
| 30 |
+
ensure_index,
|
| 31 |
+
load_faiss_index,
|
| 32 |
+
search_faiss,
|
| 33 |
+
index_signature_from_env,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# --------------------------
|
| 38 |
+
# Globals initialized lazily
|
| 39 |
+
# --------------------------
|
| 40 |
+
_encoder: Optional[SiglipEncoder] = None
|
| 41 |
+
_accessor: Optional[SampleAccessor] = None
|
| 42 |
+
_index_ref = {"index": None, "sig": None, "dim": None}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _get_encoder(log: callable) -> SiglipEncoder:
|
| 46 |
+
global _encoder
|
| 47 |
+
if _encoder is None:
|
| 48 |
+
ckpt = os.getenv("SIGLIP_CHECKPOINT_DIR", "siglip_checkpoint")
|
| 49 |
+
log(f"Loading SigLIP checkpoint from: {ckpt}")
|
| 50 |
+
_encoder = SiglipEncoder.from_checkpoint_dir(ckpt, log=log)
|
| 51 |
+
log(f"Device: {_encoder.device}, dtype: {_encoder.dtype}")
|
| 52 |
+
return _encoder
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _get_accessor(split: str, image_col: str, text_col: str, max_samples: int, log: callable) -> SampleAccessor:
|
| 56 |
+
global _accessor
|
| 57 |
+
# Always create a new accessor matching current UI state
|
| 58 |
+
_accessor = load_fashion_dataset(
|
| 59 |
+
dataset_name=DATASET_NAME,
|
| 60 |
+
split=split,
|
| 61 |
+
image_col=image_col,
|
| 62 |
+
text_col=text_col,
|
| 63 |
+
max_samples=max_samples,
|
| 64 |
+
log=log,
|
| 65 |
+
)
|
| 66 |
+
log(f"Dataset ready: {len(_accessor)} samples.")
|
| 67 |
+
return _accessor
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _maybe_load_index(sig: str, index_dir: str, log: callable):
|
| 71 |
+
"""Try to load an existing FAISS index matching 'sig'."""
|
| 72 |
+
faiss_index, dim = load_faiss_index(index_dir=index_dir, signature=sig, log=log)
|
| 73 |
+
_index_ref["index"] = faiss_index
|
| 74 |
+
_index_ref["sig"] = sig
|
| 75 |
+
_index_ref["dim"] = dim
|
| 76 |
+
return faiss_index, dim
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _build_index(
|
| 80 |
+
split: str,
|
| 81 |
+
image_col: str,
|
| 82 |
+
text_col: str,
|
| 83 |
+
max_samples: int,
|
| 84 |
+
index_dir: str,
|
| 85 |
+
log: callable,
|
| 86 |
+
) -> Tuple[IndexStatus, str]:
|
| 87 |
+
"""Ensure an index exists and return its status + signature string."""
|
| 88 |
+
encoder = _get_encoder(log)
|
| 89 |
+
accessor = _get_accessor(split, image_col, text_col, max_samples, log)
|
| 90 |
+
|
| 91 |
+
sig = index_signature_from_env(
|
| 92 |
+
dataset_name=DATASET_NAME,
|
| 93 |
+
split=split,
|
| 94 |
+
max_samples=max_samples,
|
| 95 |
+
ckpt_dir=encoder.ckpt_dir,
|
| 96 |
+
image_col=image_col,
|
| 97 |
+
text_col=text_col,
|
| 98 |
+
)
|
| 99 |
+
status = ensure_index(
|
| 100 |
+
accessor=accessor,
|
| 101 |
+
encoder=encoder,
|
| 102 |
+
index_dir=index_dir,
|
| 103 |
+
signature=sig,
|
| 104 |
+
log=log,
|
| 105 |
+
)
|
| 106 |
+
return status, sig
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def _log_to_console(msg: str):
|
| 110 |
+
print(msg, flush=True)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def ui_rebuild_index(split, image_col, text_col, max_samples, index_dir):
|
| 114 |
+
logs = []
|
| 115 |
+
|
| 116 |
+
def log(s: str):
|
| 117 |
+
logs.append(s)
|
| 118 |
+
_log_to_console(s)
|
| 119 |
+
|
| 120 |
+
status, sig = _build_index(split, image_col, text_col, max_samples, index_dir, log)
|
| 121 |
+
_maybe_load_index(sig, index_dir, log)
|
| 122 |
+
|
| 123 |
+
footer = f"Index status: {status.value} | signature: {sig}"
|
| 124 |
+
if _index_ref["index"] is not None:
|
| 125 |
+
footer += f" | dim={_index_ref['dim']}"
|
| 126 |
+
|
| 127 |
+
return "\n".join(logs), footer
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def ui_search(
|
| 131 |
+
query_text: str,
|
| 132 |
+
split: str,
|
| 133 |
+
image_col: str,
|
| 134 |
+
text_col: str,
|
| 135 |
+
max_samples: int,
|
| 136 |
+
top_k: int,
|
| 137 |
+
index_dir: str,
|
| 138 |
+
):
|
| 139 |
+
if not query_text or not query_text.strip():
|
| 140 |
+
return [], "Please enter a non-empty query."
|
| 141 |
+
|
| 142 |
+
# Prepare logger to capture build/search messages in the footer/status
|
| 143 |
+
logs = []
|
| 144 |
+
|
| 145 |
+
def log(s: str):
|
| 146 |
+
logs.append(s)
|
| 147 |
+
_log_to_console(s)
|
| 148 |
+
|
| 149 |
+
# Make sure encoder, dataset accessor, and index are aligned to UI state
|
| 150 |
+
encoder = _get_encoder(log)
|
| 151 |
+
accessor = _get_accessor(split, image_col, text_col, max_samples, log)
|
| 152 |
+
|
| 153 |
+
sig = index_signature_from_env(
|
| 154 |
+
dataset_name=DATASET_NAME,
|
| 155 |
+
split=split,
|
| 156 |
+
max_samples=max_samples,
|
| 157 |
+
ckpt_dir=encoder.ckpt_dir,
|
| 158 |
+
image_col=image_col,
|
| 159 |
+
text_col=text_col,
|
| 160 |
+
)
|
| 161 |
+
if _index_ref["index"] is None or _index_ref["sig"] != sig:
|
| 162 |
+
# Try load; if not present, build
|
| 163 |
+
idx, _ = _maybe_load_index(sig, index_dir, log)
|
| 164 |
+
if idx is None:
|
| 165 |
+
log("Index not found on disk. Building now...")
|
| 166 |
+
status, sig = _build_index(split, image_col, text_col, max_samples, index_dir, log)
|
| 167 |
+
log(f"Index status after build: {status.value}")
|
| 168 |
+
_maybe_load_index(sig, index_dir, log)
|
| 169 |
+
|
| 170 |
+
if _index_ref["index"] is None:
|
| 171 |
+
return [], "Index is unavailable. Check logs."
|
| 172 |
+
|
| 173 |
+
# Encode query
|
| 174 |
+
tic = time.time()
|
| 175 |
+
q_emb = encoder.encode_texts([query_text]) # (1, D), already L2-normalized
|
| 176 |
+
encode_ms = (time.time() - tic) * 1000.0
|
| 177 |
+
|
| 178 |
+
# Search
|
| 179 |
+
tic = time.time()
|
| 180 |
+
scores, ids = search_faiss(_index_ref["index"], q_emb, top_k=top_k)
|
| 181 |
+
search_ms = (time.time() - tic) * 1000.0
|
| 182 |
+
|
| 183 |
+
# Prepare gallery: [(image, caption)] with title as "score: ..."
|
| 184 |
+
results = []
|
| 185 |
+
for rank, (idx, score) in enumerate(zip(ids[0], scores[0]), start=1):
|
| 186 |
+
sample = accessor.get(idx)
|
| 187 |
+
img: Image.Image = sample.image
|
| 188 |
+
cap: str = sample.text
|
| 189 |
+
# Gradio Gallery expects [ (image, caption) ]
|
| 190 |
+
caption = f"#{rank} | score={score:.4f}\n{cap}"
|
| 191 |
+
results.append((img, caption))
|
| 192 |
+
|
| 193 |
+
footer = f"Encoded in {encode_ms:.1f} ms, searched in {search_ms:.1f} ms | idx sig: {sig}"
|
| 194 |
+
if logs:
|
| 195 |
+
footer += "\n" + "\n".join(logs)
|
| 196 |
+
|
| 197 |
+
return results, footer
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def build_ui():
|
| 201 |
+
with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
| 202 |
+
gr.Markdown(
|
| 203 |
+
"""
|
| 204 |
+
# 🔎 Text → Image Retrieval (SigLIP + FAISS)
|
| 205 |
+
Dataset: `tomytjandra/h-and-m-fashion-caption` • Index is cached on disk • Works on CPU (default) and uses GPU FAISS if available.
|
| 206 |
+
"""
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
with gr.Row():
|
| 210 |
+
with gr.Column(scale=2):
|
| 211 |
+
query = gr.Textbox(
|
| 212 |
+
label="Enter a text query",
|
| 213 |
+
placeholder="e.g., 'women's red floral dress with long sleeves'",
|
| 214 |
+
)
|
| 215 |
+
examples = gr.Examples(
|
| 216 |
+
examples=[[q] for q in EXAMPLE_QUERIES],
|
| 217 |
+
inputs=[query],
|
| 218 |
+
label="Examples",
|
| 219 |
+
)
|
| 220 |
+
top_k = gr.Slider(1, 50, value=DEFAULT_TOP_K, step=1, label="Top‑K")
|
| 221 |
+
search_btn = gr.Button("Search", variant="primary")
|
| 222 |
+
|
| 223 |
+
with gr.Column(scale=1):
|
| 224 |
+
split = gr.Dropdown(
|
| 225 |
+
choices=["train", "validation", "test"],
|
| 226 |
+
value=DEFAULT_SPLIT,
|
| 227 |
+
label="Dataset split",
|
| 228 |
+
)
|
| 229 |
+
image_col = gr.Textbox(value=DEFAULT_IMAGE_COL, label="IMAGE_COL")
|
| 230 |
+
text_col = gr.Textbox(value=DEFAULT_TEXT_COL, label="TEXT_COL")
|
| 231 |
+
max_samples = gr.Slider(
|
| 232 |
+
minimum=100, maximum=200_000, value=DEFAULT_MAX_SAMPLES, step=100,
|
| 233 |
+
label="MAX_SAMPLES (cap for demo)"
|
| 234 |
+
)
|
| 235 |
+
index_dir = gr.Textbox(value=DEFAULT_INDEX_DIR, label="INDEX_DIR")
|
| 236 |
+
rebuild_btn = gr.Button("(Re)Build Index")
|
| 237 |
+
|
| 238 |
+
with gr.Row():
|
| 239 |
+
gallery = gr.Gallery(
|
| 240 |
+
label="Results",
|
| 241 |
+
columns=5,
|
| 242 |
+
height=520,
|
| 243 |
+
preview=True,
|
| 244 |
+
show_label=True,
|
| 245 |
+
)
|
| 246 |
+
status = gr.Textbox(label="Status / Logs", interactive=False)
|
| 247 |
+
|
| 248 |
+
# Wire actions
|
| 249 |
+
search_btn.click(
|
| 250 |
+
ui_search,
|
| 251 |
+
inputs=[query, split, image_col, text_col, max_samples, top_k, index_dir],
|
| 252 |
+
outputs=[gallery, status],
|
| 253 |
+
)
|
| 254 |
+
rebuild_btn.click(
|
| 255 |
+
ui_rebuild_index,
|
| 256 |
+
inputs=[split, image_col, text_col, max_samples, index_dir],
|
| 257 |
+
outputs=[status, status],
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
gr.Markdown(
|
| 261 |
+
"""
|
| 262 |
+
**Notes**
|
| 263 |
+
- Set `SIGLIP_CHECKPOINT_DIR` to your local SigLIP checkpoint folder (uploaded to this Space).
|
| 264 |
+
- First run will build an index and cache it under `INDEX_DIR`.
|
| 265 |
+
- Uses cosine similarity via L2-normalized embeddings on `IndexFlatIP`. If GPU FAISS is available, it will be used automatically.
|
| 266 |
+
"""
|
| 267 |
+
)
|
| 268 |
+
return demo
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
if __name__ == "__main__":
|
| 272 |
+
build_ui().launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))
|
configs.py
CHANGED
|
@@ -1,2 +1,22 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
DATASET_NAME = "tomytjandra/h-and-m-fashion-caption"
|
| 4 |
+
|
| 5 |
+
# Default dataset/view options (can be changed in UI)
|
| 6 |
+
DEFAULT_SPLIT = "validation" # "train" | "validation" | "test"
|
| 7 |
+
DEFAULT_IMAGE_COL = "image" # change if your dataset variant differs
|
| 8 |
+
DEFAULT_TEXT_COL = "caption" # change if your dataset variant differs
|
| 9 |
+
|
| 10 |
+
DEFAULT_MAX_SAMPLES = 5000 # cap for demo builds; adjustable in UI
|
| 11 |
+
DEFAULT_TOP_K = 12
|
| 12 |
+
|
| 13 |
+
# Index cache directory inside the Space persistent storage
|
| 14 |
+
DEFAULT_INDEX_DIR = "./index_cache"
|
| 15 |
+
|
| 16 |
+
EXAMPLE_QUERIES = [
|
| 17 |
+
"red floral summer dress",
|
| 18 |
+
"men's black leather jacket",
|
| 19 |
+
"white sneakers with chunky sole",
|
| 20 |
+
"blue denim jeans for women",
|
| 21 |
+
"kids' yellow raincoat",
|
| 22 |
+
]
|
dataset_utils.py
CHANGED
|
@@ -1,2 +1,91 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Callable, Optional, List
|
| 5 |
+
|
| 6 |
+
from datasets import load_dataset, Dataset
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class Sample:
|
| 12 |
+
idx: int
|
| 13 |
+
image: Image.Image
|
| 14 |
+
text: str
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class SampleAccessor:
|
| 18 |
+
"""A thin wrapper for random access into a loaded HF dataset with normalized columns."""
|
| 19 |
+
|
| 20 |
+
def __init__(self, hf_ds: Dataset, image_col: str, text_col: str):
|
| 21 |
+
self.ds = hf_ds
|
| 22 |
+
self.image_col = image_col
|
| 23 |
+
self.text_col = text_col
|
| 24 |
+
|
| 25 |
+
def __len__(self) -> int:
|
| 26 |
+
return len(self.ds)
|
| 27 |
+
|
| 28 |
+
def get(self, i: int) -> Sample:
|
| 29 |
+
row = self.ds[i]
|
| 30 |
+
img = row[self.image_col]
|
| 31 |
+
if not isinstance(img, Image.Image):
|
| 32 |
+
# datasets may store {'bytes':..., 'path':...} or numpy array
|
| 33 |
+
img = Image.fromarray(img)
|
| 34 |
+
if img.mode not in ("RGB", "RGBA"):
|
| 35 |
+
img = img.convert("RGB")
|
| 36 |
+
if img.mode == "RGBA":
|
| 37 |
+
img = img.convert("RGB") # drop alpha for encoders
|
| 38 |
+
text = str(row[self.text_col])
|
| 39 |
+
return Sample(idx=i, image=img, text=text)
|
| 40 |
+
|
| 41 |
+
def batched_images(self, start: int, end: int) -> List[Image.Image]:
|
| 42 |
+
images = []
|
| 43 |
+
rows = self.ds[start:end]
|
| 44 |
+
for row in rows:
|
| 45 |
+
img = row[self.image_col]
|
| 46 |
+
if not isinstance(img, Image.Image):
|
| 47 |
+
img = Image.fromarray(img)
|
| 48 |
+
if img.mode == "RGBA":
|
| 49 |
+
img = img.convert("RGB")
|
| 50 |
+
elif img.mode != "RGB":
|
| 51 |
+
img = img.convert("RGB")
|
| 52 |
+
images.append(img)
|
| 53 |
+
return images
|
| 54 |
+
|
| 55 |
+
def texts(self, start: int, end: int) -> List[str]:
|
| 56 |
+
rows = self.ds[start:end]
|
| 57 |
+
return [str(r[self.text_col]) for r in rows]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def load_fashion_dataset(
|
| 61 |
+
dataset_name: str,
|
| 62 |
+
split: str,
|
| 63 |
+
image_col: str,
|
| 64 |
+
text_col: str,
|
| 65 |
+
max_samples: int,
|
| 66 |
+
log: Optional[Callable[[str], None]] = None,
|
| 67 |
+
) -> SampleAccessor:
|
| 68 |
+
"""Load and normalize the H&M fashion caption dataset.
|
| 69 |
+
|
| 70 |
+
Some dataset versions may vary in column names, so we accept user-specified columns.
|
| 71 |
+
"""
|
| 72 |
+
if log:
|
| 73 |
+
log(f"Loading dataset: {dataset_name} [{split}] (max_samples={max_samples})")
|
| 74 |
+
ds = load_dataset(dataset_name, split=split, streaming=False)
|
| 75 |
+
total = len(ds)
|
| 76 |
+
if log:
|
| 77 |
+
log(f"Dataset size (split={split}): {total}")
|
| 78 |
+
|
| 79 |
+
# Trim to max_samples for demo
|
| 80 |
+
if max_samples is not None and max_samples < total:
|
| 81 |
+
ds = ds.select(range(max_samples))
|
| 82 |
+
|
| 83 |
+
# Validate columns
|
| 84 |
+
for col in (image_col, text_col):
|
| 85 |
+
if col not in ds.column_names:
|
| 86 |
+
raise KeyError(
|
| 87 |
+
f"Column '{col}' not found in dataset. Available: {ds.column_names}. "
|
| 88 |
+
"Adjust IMAGE_COL/TEXT_COL in the UI."
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
return SampleAccessor(hf_ds=ds, image_col=image_col, text_col=text_col)
|
encoders.py
CHANGED
|
@@ -1,2 +1,140 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import List, Tuple, Optional, Callable
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from transformers import AutoModel, AutoProcessor
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _pick_device() -> torch.device:
|
| 14 |
+
if torch.cuda.is_available():
|
| 15 |
+
return torch.device("cuda")
|
| 16 |
+
# Apple Silicon
|
| 17 |
+
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
|
| 18 |
+
return torch.device("mps")
|
| 19 |
+
return torch.device("cpu")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _pick_dtype(device: torch.device) -> torch.dtype:
|
| 23 |
+
if device.type == "cuda":
|
| 24 |
+
# Prefer bf16 if supported; else fp16
|
| 25 |
+
if torch.cuda.is_bf16_supported():
|
| 26 |
+
return torch.bfloat16
|
| 27 |
+
return torch.float16
|
| 28 |
+
if device.type == "mps":
|
| 29 |
+
# mps prefers float32 accuracy
|
| 30 |
+
return torch.float32
|
| 31 |
+
return torch.float32
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class SiglipEncoder:
|
| 36 |
+
model: AutoModel
|
| 37 |
+
processor: AutoProcessor
|
| 38 |
+
device: torch.device
|
| 39 |
+
dtype: torch.dtype
|
| 40 |
+
ckpt_dir: str
|
| 41 |
+
|
| 42 |
+
@classmethod
|
| 43 |
+
def from_checkpoint_dir(cls, ckpt_dir: str, log: Optional[Callable[[str], None]] = None) -> "SiglipEncoder":
|
| 44 |
+
if not os.path.isdir(ckpt_dir):
|
| 45 |
+
raise FileNotFoundError(
|
| 46 |
+
f"SIGLIP_CHECKPOINT_DIR not found: {ckpt_dir}. "
|
| 47 |
+
"Upload your SigLIP checkpoint folder to the Space and set the env var."
|
| 48 |
+
)
|
| 49 |
+
device = _pick_device()
|
| 50 |
+
dtype = _pick_dtype(device)
|
| 51 |
+
|
| 52 |
+
if log:
|
| 53 |
+
log(f"Loading processor/model from {ckpt_dir} (device={device}, dtype={dtype})")
|
| 54 |
+
|
| 55 |
+
processor = AutoProcessor.from_pretrained(ckpt_dir, trust_remote_code=True)
|
| 56 |
+
model = AutoModel.from_pretrained(ckpt_dir, trust_remote_code=True)
|
| 57 |
+
|
| 58 |
+
model.to(device)
|
| 59 |
+
model.eval()
|
| 60 |
+
return cls(model=model, processor=processor, device=device, dtype=dtype, ckpt_dir=ckpt_dir)
|
| 61 |
+
|
| 62 |
+
# ---------- Embedding helpers ----------
|
| 63 |
+
|
| 64 |
+
@torch.no_grad()
|
| 65 |
+
def _maybe_autocast(self):
|
| 66 |
+
# cuda amp context
|
| 67 |
+
if self.device.type == "cuda" and self.dtype in (torch.float16, torch.bfloat16):
|
| 68 |
+
return torch.autocast(device_type="cuda", dtype=self.dtype)
|
| 69 |
+
# for mps/cpu, no autocast by default
|
| 70 |
+
class DummyCtx:
|
| 71 |
+
def __enter__(self): return None
|
| 72 |
+
def __exit__(self, *args): return False
|
| 73 |
+
return DummyCtx()
|
| 74 |
+
|
| 75 |
+
def _normalize(self, x: np.ndarray) -> np.ndarray:
|
| 76 |
+
norms = np.linalg.norm(x, axis=1, keepdims=True) + 1e-12
|
| 77 |
+
return x / norms
|
| 78 |
+
|
| 79 |
+
def _pool_mean(self, last_hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor]) -> torch.Tensor:
|
| 80 |
+
# mean pooling with attention mask
|
| 81 |
+
if attention_mask is None:
|
| 82 |
+
return last_hidden_state.mean(dim=1)
|
| 83 |
+
mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
|
| 84 |
+
summed = (last_hidden_state * mask).sum(dim=1)
|
| 85 |
+
counts = mask.sum(dim=1).clamp(min=1e-6)
|
| 86 |
+
return summed / counts
|
| 87 |
+
|
| 88 |
+
def _forward_image(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
| 89 |
+
# Try common signatures: get_image_features or forward(...).image_embeds
|
| 90 |
+
# Fallback: mean pool last_hidden_state of vision tower.
|
| 91 |
+
if hasattr(self.model, "get_image_features"):
|
| 92 |
+
return self.model.get_image_features(pixel_values=pixel_values)
|
| 93 |
+
out = self.model(pixel_values=pixel_values)
|
| 94 |
+
if hasattr(out, "image_embeds") and out.image_embeds is not None:
|
| 95 |
+
return out.image_embeds
|
| 96 |
+
if hasattr(out, "last_hidden_state"):
|
| 97 |
+
return out.last_hidden_state.mean(dim=1)
|
| 98 |
+
raise RuntimeError("Unable to extract image embeddings from model outputs.")
|
| 99 |
+
|
| 100 |
+
def _forward_text(self, **text_inputs) -> torch.Tensor:
|
| 101 |
+
if hasattr(self.model, "get_text_features"):
|
| 102 |
+
return self.model.get_text_features(**text_inputs)
|
| 103 |
+
out = self.model(**text_inputs)
|
| 104 |
+
if hasattr(out, "text_embeds") and out.text_embeds is not None:
|
| 105 |
+
return out.text_embeds
|
| 106 |
+
if hasattr(out, "last_hidden_state"):
|
| 107 |
+
return self._pool_mean(out.last_hidden_state, text_inputs.get("attention_mask"))
|
| 108 |
+
raise RuntimeError("Unable to extract text embeddings from model outputs.")
|
| 109 |
+
|
| 110 |
+
@torch.no_grad()
|
| 111 |
+
def encode_images(self, images: List[Image.Image], batch_size: int = 64) -> np.ndarray:
|
| 112 |
+
"""Encode a list of PIL images to L2-normalized embeddings."""
|
| 113 |
+
feats: List[np.ndarray] = []
|
| 114 |
+
with self._maybe_autocast():
|
| 115 |
+
for i in range(0, len(images), batch_size):
|
| 116 |
+
batch = images[i : i + batch_size]
|
| 117 |
+
# Ensure RGB
|
| 118 |
+
batch = [im.convert("RGB") if im.mode != "RGB" else im for im in batch]
|
| 119 |
+
inputs = self.processor(images=batch, return_tensors="pt")
|
| 120 |
+
pixel_values = inputs["pixel_values"].to(self.device, dtype=self.dtype if self.device.type == "cuda" else torch.float32)
|
| 121 |
+
embs = self._forward_image(pixel_values) # (B, D)
|
| 122 |
+
embs = embs.float().cpu().numpy()
|
| 123 |
+
feats.append(embs)
|
| 124 |
+
feats_np = np.concatenate(feats, axis=0)
|
| 125 |
+
return self._normalize(feats_np)
|
| 126 |
+
|
| 127 |
+
@torch.no_grad()
|
| 128 |
+
def encode_texts(self, texts: List[str], batch_size: int = 128) -> np.ndarray:
|
| 129 |
+
"""Encode a list of texts to L2-normalized embeddings."""
|
| 130 |
+
feats: List[np.ndarray] = []
|
| 131 |
+
with self._maybe_autocast():
|
| 132 |
+
for i in range(0, len(texts), batch_size):
|
| 133 |
+
batch = texts[i : i + batch_size]
|
| 134 |
+
inputs = self.processor(text=batch, return_tensors="pt", padding=True, truncation=True)
|
| 135 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 136 |
+
embs = self._forward_text(**inputs) # (B, D)
|
| 137 |
+
embs = embs.float().cpu().numpy()
|
| 138 |
+
feats.append(embs)
|
| 139 |
+
feats_np = np.concatenate(feats, axis=0)
|
| 140 |
+
return self._normalize(feats_np)
|
index_builder.py
CHANGED
|
@@ -1,2 +1,169 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import io
|
| 5 |
+
import json
|
| 6 |
+
import time
|
| 7 |
+
import hashlib
|
| 8 |
+
from enum import Enum
|
| 9 |
+
from typing import Optional, Callable, Tuple
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
import faiss # type: ignore
|
| 15 |
+
except Exception as e:
|
| 16 |
+
raise RuntimeError(
|
| 17 |
+
"Failed to import faiss. Ensure 'faiss-cpu' is in requirements.txt."
|
| 18 |
+
) from e
|
| 19 |
+
|
| 20 |
+
from dataset_utils import SampleAccessor
|
| 21 |
+
from encoders import SiglipEncoder
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class IndexStatus(Enum):
|
| 25 |
+
CREATED = "CREATED"
|
| 26 |
+
LOADED = "LOADED"
|
| 27 |
+
SKIPPED_FOUND = "SKIPPED_FOUND"
|
| 28 |
+
UPDATED = "UPDATED"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def index_signature_from_env(
|
| 32 |
+
dataset_name: str,
|
| 33 |
+
split: str,
|
| 34 |
+
max_samples: int,
|
| 35 |
+
ckpt_dir: str,
|
| 36 |
+
image_col: str,
|
| 37 |
+
text_col: str,
|
| 38 |
+
) -> str:
|
| 39 |
+
"""Create a stable signature for the on-disk index cache."""
|
| 40 |
+
# include optional hash of checkpoint config.json if exists
|
| 41 |
+
cfg_path = os.path.join(ckpt_dir, "config.json")
|
| 42 |
+
cfg_hash = "nocfg"
|
| 43 |
+
if os.path.isfile(cfg_path):
|
| 44 |
+
try:
|
| 45 |
+
with open(cfg_path, "rb") as f:
|
| 46 |
+
cfg_hash = hashlib.md5(f.read()).hexdigest()[:10]
|
| 47 |
+
except Exception:
|
| 48 |
+
pass
|
| 49 |
+
base = json.dumps(
|
| 50 |
+
{
|
| 51 |
+
"dataset": dataset_name,
|
| 52 |
+
"split": split,
|
| 53 |
+
"max_samples": int(max_samples),
|
| 54 |
+
"ckpt": os.path.basename(os.path.abspath(ckpt_dir)),
|
| 55 |
+
"cfg": cfg_hash,
|
| 56 |
+
"image_col": image_col,
|
| 57 |
+
"text_col": text_col,
|
| 58 |
+
},
|
| 59 |
+
sort_keys=True,
|
| 60 |
+
)
|
| 61 |
+
return hashlib.sha1(base.encode("utf-8")).hexdigest()[:16]
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _index_paths(index_dir: str, signature: str):
|
| 65 |
+
os.makedirs(index_dir, exist_ok=True)
|
| 66 |
+
idx_path = os.path.join(index_dir, f"{signature}.faiss")
|
| 67 |
+
meta_path = os.path.join(index_dir, f"{signature}.meta.json")
|
| 68 |
+
return idx_path, meta_path
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _maybe_gpu(index):
|
| 72 |
+
"""If FAISS GPU is available, move index to GPU; else return as-is."""
|
| 73 |
+
try:
|
| 74 |
+
import faiss # noqa
|
| 75 |
+
if faiss.get_num_gpus() > 0:
|
| 76 |
+
res = faiss.StandardGpuResources()
|
| 77 |
+
return faiss.index_cpu_to_gpu(res, 0, index)
|
| 78 |
+
except Exception:
|
| 79 |
+
pass
|
| 80 |
+
return index
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _normalize_rows(x: np.ndarray) -> np.ndarray:
|
| 84 |
+
norms = np.linalg.norm(x, axis=1, keepdims=True) + 1e-12
|
| 85 |
+
return x / norms
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def ensure_index(
|
| 89 |
+
accessor: SampleAccessor,
|
| 90 |
+
encoder: SiglipEncoder,
|
| 91 |
+
index_dir: str,
|
| 92 |
+
signature: str,
|
| 93 |
+
log: Optional[Callable[[str], None]] = None,
|
| 94 |
+
) -> IndexStatus:
|
| 95 |
+
"""Create the FAISS index if not present; otherwise leave it."""
|
| 96 |
+
idx_path, meta_path = _index_paths(index_dir, signature)
|
| 97 |
+
|
| 98 |
+
if os.path.isfile(idx_path) and os.path.isfile(meta_path):
|
| 99 |
+
if log:
|
| 100 |
+
log(f"Index already exists at {idx_path}")
|
| 101 |
+
return IndexStatus.SKIPPED_FOUND
|
| 102 |
+
|
| 103 |
+
# Encode all images in batches
|
| 104 |
+
n = len(accessor)
|
| 105 |
+
if log:
|
| 106 |
+
log(f"Encoding {n} images to build index ...")
|
| 107 |
+
|
| 108 |
+
batch = 512
|
| 109 |
+
feats = []
|
| 110 |
+
t0 = time.time()
|
| 111 |
+
for start in range(0, n, batch):
|
| 112 |
+
end = min(n, start + batch)
|
| 113 |
+
imgs = accessor.batched_images(start, end)
|
| 114 |
+
emb = encoder.encode_images(imgs) # (B, D), L2 normalized
|
| 115 |
+
feats.append(emb)
|
| 116 |
+
if log:
|
| 117 |
+
pct = (end / n) * 100.0
|
| 118 |
+
log(f"Progress: {end}/{n} ({pct:.1f}%)")
|
| 119 |
+
|
| 120 |
+
feats_np = np.concatenate(feats, axis=0).astype("float32", copy=False)
|
| 121 |
+
dim = feats_np.shape[1]
|
| 122 |
+
|
| 123 |
+
# Build cosine via inner-product on normalized vectors
|
| 124 |
+
cpu_index = faiss.IndexFlatIP(dim)
|
| 125 |
+
cpu_index.add(feats_np)
|
| 126 |
+
|
| 127 |
+
# Save to disk (CPU index for compatibility)
|
| 128 |
+
faiss.write_index(cpu_index, idx_path)
|
| 129 |
+
|
| 130 |
+
# Save meta information (captions and mapping)
|
| 131 |
+
meta = {
|
| 132 |
+
"signature": signature,
|
| 133 |
+
"size": int(n),
|
| 134 |
+
"dim": int(dim),
|
| 135 |
+
"created_at": time.time(),
|
| 136 |
+
"index_path": os.path.basename(idx_path),
|
| 137 |
+
"notes": "Embeddings are L2-normalized; cosine == inner product.",
|
| 138 |
+
}
|
| 139 |
+
with open(meta_path, "w", encoding="utf-8") as f:
|
| 140 |
+
json.dump(meta, f, ensure_ascii=False, indent=2)
|
| 141 |
+
|
| 142 |
+
if log:
|
| 143 |
+
log(f"Index built in {(time.time() - t0):.2f}s. Saved to {idx_path}")
|
| 144 |
+
return IndexStatus.CREATED
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def load_faiss_index(index_dir: str, signature: str, log: Optional[Callable[[str], None]] = None):
|
| 148 |
+
idx_path, meta_path = _index_paths(index_dir, signature)
|
| 149 |
+
if not (os.path.isfile(idx_path) and os.path.isfile(meta_path)):
|
| 150 |
+
return None, None
|
| 151 |
+
|
| 152 |
+
with open(meta_path, "r", encoding="utf-8") as f:
|
| 153 |
+
meta = json.load(f)
|
| 154 |
+
idx = faiss.read_index(idx_path)
|
| 155 |
+
dim = int(meta.get("dim", idx.d))
|
| 156 |
+
# Try moving to GPU
|
| 157 |
+
idx = _maybe_gpu(idx)
|
| 158 |
+
if log:
|
| 159 |
+
log(f"Loaded FAISS index: {idx_path} (dim={dim})")
|
| 160 |
+
return idx, dim
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def search_faiss(index, query_embs: np.ndarray, top_k: int = 10):
|
| 164 |
+
"""Search FAISS (inner product) with normalized query embeddings."""
|
| 165 |
+
assert query_embs.ndim == 2
|
| 166 |
+
# Ensure L2-normalized
|
| 167 |
+
q = _normalize_rows(query_embs.astype("float32", copy=False))
|
| 168 |
+
scores, ids = index.search(q, int(top_k))
|
| 169 |
+
return scores, ids
|