Space / app.py
Jaywalker061707's picture
Update app.py
0bd8b05 verified
import gradio as gr
from datasets import load_dataset
from itertools import islice
import numpy as np
from PIL import Image
import torch
from transformers import CLIPModel, CLIPProcessor
import torch.nn.functional as F
import os, json, time
# ---------- utils ----------
def flux_to_gray(flux_array):
a = np.array(flux_array, dtype=np.float32)
a = np.squeeze(a)
if a.ndim == 3:
axis = int(np.argmin(a.shape))
a = np.nanmean(a, axis=axis)
a = np.nan_to_num(a, nan=0.0, posinf=0.0, neginf=0.0)
lo = np.nanpercentile(a, 1)
hi = np.nanpercentile(a, 99)
if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo:
lo, hi = float(np.nanmin(a)), float(np.nanmax(a))
norm = np.clip((a - lo) / (hi - lo + 1e-9), 0, 1)
arr = (norm * 255).astype(np.uint8)
return Image.fromarray(arr, mode="L")
# ---------- model ----------
model_id = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(model_id)
processor = CLIPProcessor.from_pretrained(model_id)
model.eval()
# ---------- in-memory index ----------
INDEX = {
"feats": None, # torch.Tensor [N, 512]
"ids": [], # list[str]
"thumbs": [], # list[PIL.Image]
"bands": [] # list[str]
}
def build_index(n=200):
ds = load_dataset("MultimodalUniverse/jwst", split="train", streaming=True)
feats, ids, thumbs, bands = [], [], [], []
for rec in islice(ds, int(n)):
pil = flux_to_gray(rec["image"]["flux"]).convert("RGB")
t = pil.copy(); t.thumbnail((128, 128))
with torch.no_grad():
inp = processor(images=pil, return_tensors="pt")
f = model.get_image_features(**inp) # [1, 512]
f = F.normalize(f, p=2, dim=-1)[0] # [512]
feats.append(f)
ids.append(str(rec.get("object_id")))
bands.append(str(rec["image"].get("band")))
thumbs.append(t)
if not feats:
return "No records indexed."
INDEX["feats"] = torch.stack(feats) # [N, 512]
INDEX["ids"] = ids
INDEX["thumbs"] = thumbs
INDEX["bands"] = bands
return f"Index built: {len(ids)} images."
def search(text_query, image_query, k=5):
if INDEX["feats"] is None:
return [], "Build the index first."
with torch.no_grad():
if text_query and str(text_query).strip():
inputs = processor(text=[str(text_query).strip()], return_tensors="pt")
q = model.get_text_features(**inputs) # [1, 512]
elif image_query is not None:
pil = image_query.convert("RGB")
inputs = processor(images=pil, return_tensors="pt")
q = model.get_image_features(**inputs) # [1, 512]
else:
return [], "Enter text or upload an image."
q = F.normalize(q, p=2, dim=-1)[0] # [512]
sims = (INDEX["feats"] @ q).cpu() # [N]
k = min(int(k), sims.shape[0])
topk = torch.topk(sims, k=k)
items = []
for idx in topk.indices.tolist():
cap = f"id: {INDEX['ids'][idx]} score: {float(sims[idx]):.3f} band: {INDEX['bands'][idx]}"
items.append((INDEX["thumbs"][idx], cap))
return items, f"Returned {k} results."
# ---------- evaluation helpers ----------
def _search_topk_for_eval(text_query, image_query, k=5):
if INDEX["feats"] is None:
return [], [], "Build the index first."
with torch.no_grad():
if text_query and str(text_query).strip():
inputs = processor(text=[str(text_query).strip()], return_tensors="pt")
q = model.get_text_features(**inputs)
elif image_query is not None:
pil = image_query.convert("RGB")
inputs = processor(images=pil, return_tensors="pt")
q = model.get_image_features(**inputs)
else:
return [], [], "Enter text or upload an image."
q = F.normalize(q, p=2, dim=-1)[0]
sims = (INDEX["feats"] @ q).cpu()
k = min(int(k), sims.shape[0])
topk = torch.topk(sims, k=k)
idxs = topk.indices.tolist()
# reuse thumbs and captions like your main search
items = []
for idx in idxs:
cap = f"id: {INDEX['ids'][idx]} score: {float(sims[idx]):.3f} band: {INDEX['bands'][idx]}"
items.append((INDEX["thumbs"][idx], cap))
return items, idxs, f"Eval preview: top {k} ready."
def _format_eval_summary(query, k, hits, p_at_k):
lines = []
lines.append(f"Query: {query or '[image query]'}")
lines.append(f"K: {k}")
lines.append(f"Relevant marked: {hits} of {k}")
lines.append(f"Precision@{k}: {p_at_k:.2f}")
lines.append("Saved to eval_runs.jsonl")
return "\n".join(lines)
def _save_eval_run(record):
try:
with open("eval_runs.jsonl", "a", encoding="utf-8") as f:
f.write(json.dumps(record) + "\n")
except Exception:
pass
def _compute_avg_from_file():
try:
total = 0.0
n = 0
with open("eval_runs.jsonl", "r", encoding="utf-8") as f:
for line in f:
rec = json.loads(line)
if "precision_at_k" in rec:
total += float(rec["precision_at_k"])
n += 1
if n == 0:
return "No runs recorded yet."
return f"Macro average Precision@K across {n} runs: {total/n:.2f}"
except FileNotFoundError:
return "No eval_runs.jsonl yet. Run at least one evaluation."
# ---------- UI ----------
with gr.Blocks() as demo:
gr.Markdown("JWST multimodal search build the index")
# Build
n = gr.Slider(50, 1000, value=200, step=10, label="How many images to index")
build_btn = gr.Button("Build index")
status = gr.Textbox(label="Status", lines=2)
build_btn.click(build_index, inputs=n, outputs=status)
# Search
gr.Markdown("Search the index with text or an example image")
q_text = gr.Textbox(label="Text query", placeholder="e.g., spiral galaxy")
q_img = gr.Image(label="Image query", type="pil")
k = gr.Slider(1, 12, value=6, step=1, label="Top K")
search_btn = gr.Button("Search")
gallery = gr.Gallery(label="Results", columns=6, height=300)
info2 = gr.Textbox(label="Search status", lines=1)
search_btn.click(search, inputs=[q_text, q_img, k], outputs=[gallery, info2])
# ---------- Evaluation (guided) ----------
with gr.Accordion("Evaluation", open=False):
gr.Markdown(
"### What this does\n"
"We evaluate text to image retrieval using Precision at K.\n"
"Steps: pick a preset or type a query, click **Run and label**, "
"tick the results that match the rule shown, then click **Compute metrics**. "
"Each run is saved so you can average later."
)
# Preset prompts with plain English relevance rules
PRESETS = {
"star with spikes": "Relevant = bright point source with clear 4 to 6 diffraction spikes. Minimal extended glow.",
"edge-on galaxy": "Relevant = thin elongated streak. Looks like a narrow line. No round diffuse blob.",
"spiral galaxy": "Relevant = visible spiral arms or a spiral outline. Arms can be faint.",
"diffuse nebula": "Relevant = fuzzy cloud like structure. No sharp round core.",
"ring or annulus": "Relevant = ring or donut shape is the main feature.",
"two merging objects": "Relevant = two bright blobs touching or overlapping."
}
with gr.Row():
preset = gr.Dropdown(choices=list(PRESETS.keys()), label="Preset query (optional)")
eval_k = gr.Slider(1, 12, value=6, step=1, label="K for evaluation")
eval_query = gr.Textbox(label="Evaluation query (you can edit or type your own)")
eval_img = gr.Image(label="Evaluation image (optional)", type="pil")
rules_md = gr.Markdown()
run_and_label = gr.Button("Run and label this query")
eval_gallery = gr.Gallery(label="Eval top K results", columns=6, height=300)
relevant_picker = gr.CheckboxGroup(label="Select indices of relevant results (1..K)")
eval_md = gr.Markdown()
# state bag for this panel
eval_state = gr.State({"result_indices": [], "k": 5, "query": ""})
def _on_preset_change(name):
if name in PRESETS:
return gr.update(value=name), PRESETS[name]
return gr.update(), ""
preset.change(fn=_on_preset_change, inputs=preset, outputs=[eval_query, rules_md])
# uses helper _search_topk_for_eval defined above
def _run_eval_query(q_txt, q_img_in, k_in, state):
items, idxs, _ = _search_topk_for_eval(q_txt, q_img_in, k_in)
state["result_indices"] = idxs
state["k"] = int(k_in)
state["query"] = q_txt if (q_txt and q_txt.strip()) else "[image query]"
choice_labels = [str(i+1) for i in range(len(idxs))]
help_text = PRESETS.get((q_txt or "").strip().lower(), "Mark results that match the concept you typed.")
return (items,
gr.update(choices=choice_labels, value=[]),
f"**Relevance rule:** {help_text}\n\nThen click **Compute metrics**.",
state)
run_and_label.click(
fn=_run_eval_query,
inputs=[eval_query, eval_img, eval_k, eval_state],
outputs=[eval_gallery, relevant_picker, eval_md, eval_state]
)
compute_btn = gr.Button("Compute metrics")
# uses helpers _save_eval_run and _format_eval_summary defined above
def _compute_pk(selected_indices, state):
k_val = int(state.get("k", 5))
query = state.get("query", "")
hits = len(selected_indices)
p_at_k = hits / max(k_val, 1)
record = {
"ts": int(time.time()),
"query": query,
"k": k_val,
"relevant_indices": sorted([int(s) for s in selected_indices]),
"precision_at_k": p_at_k
}
_save_eval_run(record)
return _format_eval_summary(query, k_val, hits, p_at_k)
compute_btn.click(fn=_compute_pk, inputs=[relevant_picker, eval_state], outputs=eval_md)
avg_btn = gr.Button("Compute average across saved runs")
avg_md = gr.Markdown()
avg_btn.click(fn=_compute_avg_from_file, outputs=avg_md)
demo.launch()