saad003 commited on
Commit
bc08e23
·
verified ·
1 Parent(s): 94c989b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +227 -0
app.py CHANGED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Tuple
3
+
4
+ import faiss
5
+ import numpy as np
6
+ import pandas as pd
7
+ from PIL import Image
8
+
9
+ import torch
10
+ import gradio as gr
11
+ from huggingface_hub import hf_hub_download
12
+ from transformers import (
13
+ CLIPModel,
14
+ CLIPProcessor,
15
+ AutoProcessor,
16
+ BlipForConditionalGeneration,
17
+ )
18
+
19
+ # =========================
20
+ # CONFIG
21
+ # =========================
22
+
23
+ DATASET_REPO = "saad003/Dataset_final" # where embeddings + faiss + metadata live
24
+ IMAGES_REPO = "saad003/images" # where the radiology images live
25
+
26
+ CLIP_MODEL_ID = "openai/clip-vit-base-patch32"
27
+ CAPTION_MODEL_ID = "WafaaFraih/blip-roco-radiology-captioning" # BLIP radiology
28
+
29
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
30
+
31
+ # =========================
32
+ # LOAD MODELS
33
+ # =========================
34
+
35
+ print("Loading CLIP model...")
36
+ clip_model = CLIPModel.from_pretrained(CLIP_MODEL_ID).to(DEVICE)
37
+ clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL_ID)
38
+ clip_model.eval()
39
+
40
+ print("Loading caption model...")
41
+ caption_processor = AutoProcessor.from_pretrained(CAPTION_MODEL_ID)
42
+ caption_model = BlipForConditionalGeneration.from_pretrained(
43
+ CAPTION_MODEL_ID
44
+ ).to(DEVICE)
45
+ caption_model.eval()
46
+
47
+ # =========================
48
+ # LOAD INDEX + METADATA
49
+ # =========================
50
+
51
+ print("Loading FAISS index + embeddings + metadata...")
52
+
53
+ embeddings_path = hf_hub_download(DATASET_REPO, "embeddings.npy")
54
+ index_path = hf_hub_download(DATASET_REPO, "image_index.faiss")
55
+
56
+ EMBEDDINGS = np.load(embeddings_path).astype("float32")
57
+ INDEX = faiss.read_index(index_path)
58
+
59
+ # metadata: parquet preferred, else csv
60
+ try:
61
+ meta_path = hf_hub_download(DATASET_REPO, "metadata.parquet")
62
+ METADATA = pd.read_parquet(meta_path)
63
+ print("Loaded metadata.parquet")
64
+ except Exception:
65
+ meta_path = hf_hub_download(DATASET_REPO, "metadata.csv")
66
+ METADATA = pd.read_csv(meta_path)
67
+ print("Loaded metadata.csv")
68
+
69
+ print("Metadata columns:", list(METADATA.columns))
70
+
71
+
72
+ def pick_column(candidates: List[str]) -> str:
73
+ """Pick first existing column name from candidates."""
74
+ for c in candidates:
75
+ if c in METADATA.columns:
76
+ return c
77
+ raise RuntimeError(
78
+ f"None of {candidates} found in metadata columns: {list(METADATA.columns)}"
79
+ )
80
+
81
+
82
+ # Adjust these if my guesses are wrong; check your metadata file on HF
83
+ IMAGE_COL = pick_column(
84
+ ["image_path", "img_path", "filepath", "image", "image_file", "path"]
85
+ )
86
+ CAPTION_COL = pick_column(["caption", "report", "text", "caption_text"])
87
+
88
+ print("Using IMAGE_COL =", IMAGE_COL)
89
+ print("Using CAPTION_COL =", CAPTION_COL)
90
+
91
+ # =========================
92
+ # HELPER FUNCTIONS
93
+ # =========================
94
+
95
+
96
+ def load_image_for_row(row: pd.Series) -> Image.Image:
97
+ """
98
+ Load one image given a metadata row.
99
+ Assumes metadata[IMAGE_COL] is a relative path inside saad003/images repo.
100
+ """
101
+ rel_path = str(row[IMAGE_COL])
102
+ local_path = hf_hub_download(IMAGES_REPO, rel_path)
103
+ img = Image.open(local_path).convert("RGB")
104
+ return img
105
+
106
+
107
+ @torch.no_grad()
108
+ def embed_query_image(image: Image.Image) -> np.ndarray:
109
+ """Embed query image with the same CLIP model used during indexing."""
110
+ inputs = clip_processor(images=image, return_tensors="pt").to(DEVICE)
111
+ features = clip_model.get_image_features(**inputs)
112
+ # normalize for cosine similarity
113
+ features = features / features.norm(dim=-1, keepdim=True)
114
+ return features.cpu().numpy().astype("float32")
115
+
116
+
117
+ def retrieve_similar(image: Image.Image, k: int = 5) -> pd.DataFrame:
118
+ """Return top-k similar rows from METADATA."""
119
+ query_emb = embed_query_image(image)
120
+ D, I = INDEX.search(query_emb, k)
121
+ rows = METADATA.iloc[I[0]].copy()
122
+ rows["distance"] = D[0]
123
+ return rows
124
+
125
+
126
+ @torch.no_grad()
127
+ def generate_caption(image: Image.Image, neighbors: pd.DataFrame) -> str:
128
+ """Generate caption for query image, using neighbors' captions as context."""
129
+ neighbor_captions = neighbors[CAPTION_COL].astype(str).tolist()
130
+ context = " | ".join(neighbor_captions[:3])
131
+
132
+ prompt = (
133
+ "Radiology image. Similar case descriptions: "
134
+ f"{context}. Generate a concise radiology-style caption for this new image."
135
+ )
136
+
137
+ inputs = caption_processor(
138
+ images=image,
139
+ text=prompt,
140
+ return_tensors="pt",
141
+ ).to(DEVICE)
142
+
143
+ out = caption_model.generate(
144
+ **inputs,
145
+ max_new_tokens=64,
146
+ num_beams=3,
147
+ do_sample=False,
148
+ )
149
+
150
+ caption = caption_processor.decode(out[0], skip_special_tokens=True).strip()
151
+ return caption
152
+
153
+
154
+ def detect_modality(text: str) -> str:
155
+ t = text.lower()
156
+ modalities = {
157
+ "CT": ["ct", "computed tomography"],
158
+ "X-ray": ["x-ray", "xray", "radiograph", "chest x-ray", "cxr"],
159
+ "MRI": ["mri", "magnetic resonance"],
160
+ "Ultrasound": ["ultrasound", "sonography", "usg"],
161
+ "PET": ["pet scan", "pet-ct", "positron emission tomography"],
162
+ "Mammography": ["mammogram", "mammography"],
163
+ }
164
+
165
+ for name, kws in modalities.items():
166
+ if any(kw in t for kw in kws):
167
+ return name
168
+ return "Unknown"
169
+
170
+
171
+ def run_pipeline(
172
+ query_image: Image.Image, k: int = 5
173
+ ) -> Tuple[List[Tuple[Image.Image, str]], str, str]:
174
+ """
175
+ Full pipeline:
176
+ - retrieve neighbors
177
+ - load their images
178
+ - generate caption for query
179
+ - detect modality
180
+ """
181
+ neighbors = retrieve_similar(query_image, k=k)
182
+
183
+ neighbor_images = [load_image_for_row(row) for _, row in neighbors.iterrows()]
184
+ neighbor_captions = neighbors[CAPTION_COL].astype(str).tolist()
185
+
186
+ gallery = [(img, cap) for img, cap in zip(neighbor_images, neighbor_captions)]
187
+
188
+ generated_caption = generate_caption(query_image, neighbors)
189
+
190
+ modality = detect_modality(
191
+ generated_caption + " " + " ".join(neighbor_captions)
192
+ )
193
+
194
+ return gallery, generated_caption, modality
195
+
196
+
197
+ # =========================
198
+ # GRADIO APP
199
+ # =========================
200
+
201
+
202
+ def gradio_infer(image, k):
203
+ if image is None:
204
+ return [], "No image provided", ""
205
+
206
+ k = int(k)
207
+ gallery, caption, modality = run_pipeline(image, k=k)
208
+ return gallery, caption, modality
209
+
210
+
211
+ demo = gr.Interface(
212
+ fn=gradio_infer,
213
+ inputs=[
214
+ gr.Image(type="pil", label="Query radiology image"),
215
+ gr.Slider(1, 12, value=5, step=1, label="Number of similar images"),
216
+ ],
217
+ outputs=[
218
+ gr.Gallery(label="Similar images (with captions)").style(preview=True),
219
+ gr.Textbox(label="Generated caption for query image"),
220
+ gr.Textbox(label="Detected modality"),
221
+ ],
222
+ title="Radiology Image Retrieval + Captioning",
223
+ description="Research demo. Not for clinical use.",
224
+ )
225
+
226
+ if __name__ == "__main__":
227
+ demo.launch()