mamathew commited on
Commit
c85ba32
·
verified ·
1 Parent(s): 83791c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -18
app.py CHANGED
@@ -15,9 +15,9 @@ import torch
15
  from transformers import CLIPModel, CLIPProcessor
16
 
17
  # ========= CONFIG (override in Space → Settings → Variables) =========
18
- TEXT_MODEL_REPO = os.environ.get("TEXT_MODEL_REPO", "<your-username>/text-ft-food-rag")
19
- CLIP_MODEL_REPO = os.environ.get("CLIP_MODEL_REPO", "<your-username>/clip-ft-food-rag")
20
- DATASET_REPO = os.environ.get("DATASET_REPO", "<your-username>/food-rag-index")
21
 
22
  # Inference API chat model (Gemma IT by default).
23
  LLM_ID = os.environ.get("LLM_ID", "google/gemma-2-2b-it")
@@ -57,7 +57,20 @@ except Exception:
57
  client = None
58
 
59
  # ---------------------- utils & dataclasses ----------------------
 
 
 
 
 
 
 
60
 
 
 
 
 
 
 
61
  def normalize_fa(s: str) -> str:
62
  if not s: return s
63
  return (s.replace("ي","ی").replace("ك","ک").replace("\u200c"," ").strip())
@@ -257,24 +270,30 @@ def call_llm(prompt: str) -> str:
257
 
258
  # ---------------------- gallery helpers ----------------------
259
 
260
- def display_gallery_pairs(pairs: List[Pair]) -> List[Tuple[str, str]]:
261
  items = []
262
  for p in pairs:
263
- if not p.image_path: continue
264
- local_path = os.path.join(DATA_DIR, p.image_path) if not os.path.isabs(p.image_path) else p.image_path
 
265
  if os.path.exists(local_path):
266
- caption = f"#{p.rank} — {p.title or ''}\nscore={(p.hscore if p.hscore is not None else p.score):.3f}"
267
- items.append((local_path, caption))
 
 
268
  return items
269
 
270
- def display_gallery_images(img_hits: List[ImgHit]) -> List[Tuple[str, str]]:
271
  items = []
272
  for h in img_hits:
273
- if not h.image_path: continue
274
- local_path = os.path.join(DATA_DIR, h.image_path) if not os.path.isabs(h.image_path) else h.image_path
 
275
  if os.path.exists(local_path):
276
- caption = f"#{h.rank} — {h.title or ''}\nscore={h.score:.3f}"
277
- items.append((local_path, caption))
 
 
278
  return items
279
 
280
  # ---------------------- main app logic ----------------------
@@ -298,13 +317,13 @@ def answer(question: str, image: Optional[Image.Image], topk: int, k_ctx: int, u
298
  img_hits = search_image_by_text(question, topk=min(8, topk))
299
  gallery = display_gallery_images(img_hits)
300
 
301
- top_image_path = gallery[0][0] if gallery else None
302
 
303
  # Table
304
  def fmt(x): return "—" if x is None else f"{x:.3f}"
305
  table = [[p.rank, p.title or "", fmt(p.tscore), fmt(p.iscore), fmt(p.hscore or p.score), p.doc_id] for p in top_pairs]
306
 
307
- return gen, table, gallery, top_image_path
308
 
309
  # ---------------------- UI ----------------------
310
 
@@ -325,13 +344,18 @@ with gr.Blocks() as demo:
325
  out_text = gr.Textbox(label="پاسخ (Answer)")
326
  out_table = gr.Dataframe(headers=["Rank", "Title", "Text S", "Image S", "Hybrid S", "Doc ID"], label="Top-K retrieval")
327
  out_gallery = gr.Gallery(label="تصاویر مرتبط (Image matches)", columns=5, height=240)
328
- out_img_top = gr.Image(label="بهترین تصویر")
329
 
330
  btn.click(
331
  answer,
332
  inputs=[q, img, topk, kctx, use_img, alpha],
333
  outputs=[out_text, out_table, out_gallery, out_img_top]
334
  )
335
-
 
 
 
 
 
336
  if __name__ == "__main__":
337
- demo.launch(allowed_paths=[DATA_DIR])
 
15
  from transformers import CLIPModel, CLIPProcessor
16
 
17
  # ========= CONFIG (override in Space → Settings → Variables) =========
18
+ TEXT_MODEL_REPO = os.environ.get("TEXT_MODEL_REPO", "mamathew/text-ft-food-rag")
19
+ CLIP_MODEL_REPO = os.environ.get("CLIP_MODEL_REPO", "mamathew/clip-ft-food-rag")
20
+ DATASET_REPO = os.environ.get("DATASET_REPO", "mamathew/food-rag-index")
21
 
22
  # Inference API chat model (Gemma IT by default).
23
  LLM_ID = os.environ.get("LLM_ID", "google/gemma-2-2b-it")
 
57
  client = None
58
 
59
  # ---------------------- utils & dataclasses ----------------------
60
+ from PIL import Image
61
+
62
+ def _resolve_path(rel_or_abs: str) -> str:
63
+ # If relative, make it under the dataset snapshot root
64
+ p = rel_or_abs if os.path.isabs(rel_or_abs) else os.path.join(DATA_DIR, rel_or_abs)
65
+ # Resolve symlinks to a canonical path (helps in HF cache)
66
+ return os.path.realpath(p)
67
 
68
+ def _open_image_safe(path: str):
69
+ try:
70
+ return Image.open(path).convert("RGB")
71
+ except Exception:
72
+ return None
73
+
74
  def normalize_fa(s: str) -> str:
75
  if not s: return s
76
  return (s.replace("ي","ی").replace("ك","ک").replace("\u200c"," ").strip())
 
270
 
271
  # ---------------------- gallery helpers ----------------------
272
 
273
+ def display_gallery_pairs(pairs):
274
  items = []
275
  for p in pairs:
276
+ if not p.image_path:
277
+ continue
278
+ local_path = _resolve_path(p.image_path)
279
  if os.path.exists(local_path):
280
+ img = _open_image_safe(local_path)
281
+ if img is not None:
282
+ caption = f"#{p.rank} — {p.title or ''}\nscore={(p.hscore if p.hscore is not None else p.score):.3f}"
283
+ items.append((img, caption)) # PIL image instead of path
284
  return items
285
 
286
+ def display_gallery_images(img_hits):
287
  items = []
288
  for h in img_hits:
289
+ if not h.image_path:
290
+ continue
291
+ local_path = _resolve_path(h.image_path)
292
  if os.path.exists(local_path):
293
+ img = _open_image_safe(local_path)
294
+ if img is not None:
295
+ caption = f"#{h.rank} — {h.title or ''}\nscore={h.score:.3f}"
296
+ items.append((img, caption)) # PIL image instead of path
297
  return items
298
 
299
  # ---------------------- main app logic ----------------------
 
317
  img_hits = search_image_by_text(question, topk=min(8, topk))
318
  gallery = display_gallery_images(img_hits)
319
 
320
+ top_image = gallery[0][0] if gallery else None
321
 
322
  # Table
323
  def fmt(x): return "—" if x is None else f"{x:.3f}"
324
  table = [[p.rank, p.title or "", fmt(p.tscore), fmt(p.iscore), fmt(p.hscore or p.score), p.doc_id] for p in top_pairs]
325
 
326
+ return gen, table, gallery, top_image
327
 
328
  # ---------------------- UI ----------------------
329
 
 
344
  out_text = gr.Textbox(label="پاسخ (Answer)")
345
  out_table = gr.Dataframe(headers=["Rank", "Title", "Text S", "Image S", "Hybrid S", "Doc ID"], label="Top-K retrieval")
346
  out_gallery = gr.Gallery(label="تصاویر مرتبط (Image matches)", columns=5, height=240)
347
+ out_img_top = gr.Image(label="Top image match")")
348
 
349
  btn.click(
350
  answer,
351
  inputs=[q, img, topk, kctx, use_img, alpha],
352
  outputs=[out_text, out_table, out_gallery, out_img_top]
353
  )
354
+ ALLOWED = [
355
+ DATA_DIR,
356
+ os.path.join(DATA_DIR, "data"),
357
+ os.path.join(DATA_DIR, "data", "interim"),
358
+ os.path.join(DATA_DIR, "data", "interim", "images_cache"),
359
+ ]
360
  if __name__ == "__main__":
361
+ demo.launch(allowed_paths=[os.path.realpath(p) for p in ALLOWED])