Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -15,9 +15,9 @@ import torch
|
|
| 15 |
from transformers import CLIPModel, CLIPProcessor
|
| 16 |
|
| 17 |
# ========= CONFIG (override in Space → Settings → Variables) =========
|
| 18 |
-
TEXT_MODEL_REPO = os.environ.get("TEXT_MODEL_REPO", "
|
| 19 |
-
CLIP_MODEL_REPO = os.environ.get("CLIP_MODEL_REPO", "
|
| 20 |
-
DATASET_REPO = os.environ.get("DATASET_REPO", "
|
| 21 |
|
| 22 |
# Inference API chat model (Gemma IT by default).
|
| 23 |
LLM_ID = os.environ.get("LLM_ID", "google/gemma-2-2b-it")
|
|
@@ -57,7 +57,20 @@ except Exception:
|
|
| 57 |
client = None
|
| 58 |
|
| 59 |
# ---------------------- utils & dataclasses ----------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
def normalize_fa(s: str) -> str:
|
| 62 |
if not s: return s
|
| 63 |
return (s.replace("ي","ی").replace("ك","ک").replace("\u200c"," ").strip())
|
|
@@ -257,24 +270,30 @@ def call_llm(prompt: str) -> str:
|
|
| 257 |
|
| 258 |
# ---------------------- gallery helpers ----------------------
|
| 259 |
|
| 260 |
-
def display_gallery_pairs(pairs
|
| 261 |
items = []
|
| 262 |
for p in pairs:
|
| 263 |
-
if not p.image_path:
|
| 264 |
-
|
|
|
|
| 265 |
if os.path.exists(local_path):
|
| 266 |
-
|
| 267 |
-
|
|
|
|
|
|
|
| 268 |
return items
|
| 269 |
|
| 270 |
-
def display_gallery_images(img_hits
|
| 271 |
items = []
|
| 272 |
for h in img_hits:
|
| 273 |
-
if not h.image_path:
|
| 274 |
-
|
|
|
|
| 275 |
if os.path.exists(local_path):
|
| 276 |
-
|
| 277 |
-
|
|
|
|
|
|
|
| 278 |
return items
|
| 279 |
|
| 280 |
# ---------------------- main app logic ----------------------
|
|
@@ -298,13 +317,13 @@ def answer(question: str, image: Optional[Image.Image], topk: int, k_ctx: int, u
|
|
| 298 |
img_hits = search_image_by_text(question, topk=min(8, topk))
|
| 299 |
gallery = display_gallery_images(img_hits)
|
| 300 |
|
| 301 |
-
|
| 302 |
|
| 303 |
# Table
|
| 304 |
def fmt(x): return "—" if x is None else f"{x:.3f}"
|
| 305 |
table = [[p.rank, p.title or "", fmt(p.tscore), fmt(p.iscore), fmt(p.hscore or p.score), p.doc_id] for p in top_pairs]
|
| 306 |
|
| 307 |
-
return gen, table, gallery,
|
| 308 |
|
| 309 |
# ---------------------- UI ----------------------
|
| 310 |
|
|
@@ -325,13 +344,18 @@ with gr.Blocks() as demo:
|
|
| 325 |
out_text = gr.Textbox(label="پاسخ (Answer)")
|
| 326 |
out_table = gr.Dataframe(headers=["Rank", "Title", "Text S", "Image S", "Hybrid S", "Doc ID"], label="Top-K retrieval")
|
| 327 |
out_gallery = gr.Gallery(label="تصاویر مرتبط (Image matches)", columns=5, height=240)
|
| 328 |
-
out_img_top = gr.Image(label="
|
| 329 |
|
| 330 |
btn.click(
|
| 331 |
answer,
|
| 332 |
inputs=[q, img, topk, kctx, use_img, alpha],
|
| 333 |
outputs=[out_text, out_table, out_gallery, out_img_top]
|
| 334 |
)
|
| 335 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
if __name__ == "__main__":
|
| 337 |
-
demo.launch(allowed_paths=[
|
|
|
|
| 15 |
from transformers import CLIPModel, CLIPProcessor
|
| 16 |
|
| 17 |
# ========= CONFIG (override in Space → Settings → Variables) =========
|
| 18 |
+
TEXT_MODEL_REPO = os.environ.get("TEXT_MODEL_REPO", "mamathew/text-ft-food-rag")
|
| 19 |
+
CLIP_MODEL_REPO = os.environ.get("CLIP_MODEL_REPO", "mamathew/clip-ft-food-rag")
|
| 20 |
+
DATASET_REPO = os.environ.get("DATASET_REPO", "mamathew/food-rag-index")
|
| 21 |
|
| 22 |
# Inference API chat model (Gemma IT by default).
|
| 23 |
LLM_ID = os.environ.get("LLM_ID", "google/gemma-2-2b-it")
|
|
|
|
| 57 |
client = None
|
| 58 |
|
| 59 |
# ---------------------- utils & dataclasses ----------------------
|
| 60 |
+
from PIL import Image
|
| 61 |
+
|
| 62 |
+
def _resolve_path(rel_or_abs: str) -> str:
|
| 63 |
+
# If relative, make it under the dataset snapshot root
|
| 64 |
+
p = rel_or_abs if os.path.isabs(rel_or_abs) else os.path.join(DATA_DIR, rel_or_abs)
|
| 65 |
+
# Resolve symlinks to a canonical path (helps in HF cache)
|
| 66 |
+
return os.path.realpath(p)
|
| 67 |
|
| 68 |
+
def _open_image_safe(path: str):
|
| 69 |
+
try:
|
| 70 |
+
return Image.open(path).convert("RGB")
|
| 71 |
+
except Exception:
|
| 72 |
+
return None
|
| 73 |
+
|
| 74 |
def normalize_fa(s: str) -> str:
|
| 75 |
if not s: return s
|
| 76 |
return (s.replace("ي","ی").replace("ك","ک").replace("\u200c"," ").strip())
|
|
|
|
| 270 |
|
| 271 |
# ---------------------- gallery helpers ----------------------
|
| 272 |
|
| 273 |
+
def display_gallery_pairs(pairs):
|
| 274 |
items = []
|
| 275 |
for p in pairs:
|
| 276 |
+
if not p.image_path:
|
| 277 |
+
continue
|
| 278 |
+
local_path = _resolve_path(p.image_path)
|
| 279 |
if os.path.exists(local_path):
|
| 280 |
+
img = _open_image_safe(local_path)
|
| 281 |
+
if img is not None:
|
| 282 |
+
caption = f"#{p.rank} — {p.title or ''}\nscore={(p.hscore if p.hscore is not None else p.score):.3f}"
|
| 283 |
+
items.append((img, caption)) # PIL image instead of path
|
| 284 |
return items
|
| 285 |
|
| 286 |
+
def display_gallery_images(img_hits):
|
| 287 |
items = []
|
| 288 |
for h in img_hits:
|
| 289 |
+
if not h.image_path:
|
| 290 |
+
continue
|
| 291 |
+
local_path = _resolve_path(h.image_path)
|
| 292 |
if os.path.exists(local_path):
|
| 293 |
+
img = _open_image_safe(local_path)
|
| 294 |
+
if img is not None:
|
| 295 |
+
caption = f"#{h.rank} — {h.title or ''}\nscore={h.score:.3f}"
|
| 296 |
+
items.append((img, caption)) # PIL image instead of path
|
| 297 |
return items
|
| 298 |
|
| 299 |
# ---------------------- main app logic ----------------------
|
|
|
|
| 317 |
img_hits = search_image_by_text(question, topk=min(8, topk))
|
| 318 |
gallery = display_gallery_images(img_hits)
|
| 319 |
|
| 320 |
+
top_image = gallery[0][0] if gallery else None
|
| 321 |
|
| 322 |
# Table
|
| 323 |
def fmt(x): return "—" if x is None else f"{x:.3f}"
|
| 324 |
table = [[p.rank, p.title or "", fmt(p.tscore), fmt(p.iscore), fmt(p.hscore or p.score), p.doc_id] for p in top_pairs]
|
| 325 |
|
| 326 |
+
return gen, table, gallery, top_image
|
| 327 |
|
| 328 |
# ---------------------- UI ----------------------
|
| 329 |
|
|
|
|
| 344 |
out_text = gr.Textbox(label="پاسخ (Answer)")
|
| 345 |
out_table = gr.Dataframe(headers=["Rank", "Title", "Text S", "Image S", "Hybrid S", "Doc ID"], label="Top-K retrieval")
|
| 346 |
out_gallery = gr.Gallery(label="تصاویر مرتبط (Image matches)", columns=5, height=240)
|
| 347 |
+
out_img_top = gr.Image(label="Top image match")")
|
| 348 |
|
| 349 |
btn.click(
|
| 350 |
answer,
|
| 351 |
inputs=[q, img, topk, kctx, use_img, alpha],
|
| 352 |
outputs=[out_text, out_table, out_gallery, out_img_top]
|
| 353 |
)
|
| 354 |
+
ALLOWED = [
|
| 355 |
+
DATA_DIR,
|
| 356 |
+
os.path.join(DATA_DIR, "data"),
|
| 357 |
+
os.path.join(DATA_DIR, "data", "interim"),
|
| 358 |
+
os.path.join(DATA_DIR, "data", "interim", "images_cache"),
|
| 359 |
+
]
|
| 360 |
if __name__ == "__main__":
|
| 361 |
+
demo.launch(allowed_paths=[os.path.realpath(p) for p in ALLOWED])
|