VedikaP commited on
Commit
9fd74fd
Β·
verified Β·
1 Parent(s): 9251e17

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +280 -0
  2. custom_cnn.h5 +3 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import cv2
4
+ import json
5
+ import os
6
+ from tensorflow.keras.models import load_model
7
+
8
+ # ─── Load model ───────────────────────────────────────────────────────────────
9
+ model = load_model("custom_cnn.h5")
10
+ IMG_SIZE = 224
11
+ NUM_OUTPUTS = model.output_shape[-1] # auto-detects 3-class or 16-class
12
+
13
+ # ─── Class / cluster labels ───────────────────────────────────────────────────
14
+ # Priority 1: class_labels.json saved alongside the model (from the 16-class notebook)
15
+ # Priority 2: fallback cluster names for the 3-class K-Means model
16
+ if os.path.exists("class_labels.json"):
17
+ with open("class_labels.json") as f:
18
+ CLASS_NAMES = json.load(f)["classes"]
19
+ else:
20
+ # 3-class K-Means cluster model fallback
21
+ CLASS_NAMES = [f"Cluster {i}" for i in range(NUM_OUTPUTS)]
22
+
23
+ # ─── Which actual pathology classes are dominant in each cluster ──────────────
24
+ # These come from analysing your K-Means cluster assignments vs ground-truth labels.
25
+ # REPLACE these lists with the real counts from your own cluster analysis notebook.
26
+ CLUSTER_DOMINANT = {
27
+ "Cluster 0": [
28
+ ("Normal", 0.38),
29
+ ("Mild Ventriculomegaly", 0.22),
30
+ ("Arnold–Chiari Malformation",0.15),
31
+ ("Moderate Ventriculomegaly", 0.14),
32
+ ("Hydranencephaly", 0.11),
33
+ ],
34
+ "Cluster 1": [
35
+ ("Severe Ventriculomegaly", 0.35),
36
+ ("Dandy–Walker Malformation", 0.25),
37
+ ("Holoprosencephaly", 0.18),
38
+ ("Agenesis of Corpus Callosum",0.13),
39
+ ("Intracranial Tumors", 0.09),
40
+ ],
41
+ "Cluster 2": [
42
+ ("Intracranial Tumors", 0.30),
43
+ ("Intracranial Hemorrhages", 0.28),
44
+ ("Holoprosencephaly", 0.20),
45
+ ("Dandy–Walker Malformation", 0.12),
46
+ ("Agenesis of Corpus Callosum",0.10),
47
+ ],
48
+ }
49
+
50
+ # For the 16-class model, dominant "classes in cluster" = top-5 softmax outputs
51
+ USE_SOFTMAX_DOMINANT = (NUM_OUTPUTS > 3)
52
+
53
+ # ─── All 16 ground-truth class names for the dropdown ────────────────────────
54
+ ALL_GT_CLASSES = [
55
+ "Normal",
56
+ "Mild Ventriculomegaly",
57
+ "Moderate Ventriculomegaly",
58
+ "Severe Ventriculomegaly",
59
+ "Arnold–Chiari Malformation",
60
+ "Hydranencephaly",
61
+ "Agenesis of Corpus Callosum",
62
+ "Dandy–Walker Malformation",
63
+ "Intracranial Tumors",
64
+ "Intracranial Hemorrhages",
65
+ "Holoprosencephaly",
66
+ "Cerebellar Hypoplasia",
67
+ "Microcephaly",
68
+ "Macrocephaly",
69
+ "Lissencephaly",
70
+ "Unknown / Not provided",
71
+ ]
72
+
73
+ # ─── Preprocessing β€” mirrors the paper Β§3B pipeline ──────────────────────────
74
+ def preprocess(image: np.ndarray) -> np.ndarray:
75
+ """Gaussian blur β†’ median filter β†’ CLAHE β†’ normalize [0,1]."""
76
+ if image is None:
77
+ return None
78
+ img = image.astype(np.uint8)
79
+ # To grayscale
80
+ if img.ndim == 3 and img.shape[2] == 3:
81
+ gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
82
+ else:
83
+ gray = img if img.ndim == 2 else img[:, :, 0]
84
+ # Β§3B-2: Gaussian + median
85
+ blurred = cv2.GaussianBlur(gray, (5, 5), sigmaX=1.0)
86
+ median = cv2.medianBlur(blurred, 5)
87
+ # Β§3B-3: CLAHE
88
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
89
+ enhanced = clahe.apply(median)
90
+ # Back to RGB float32 [0,1]
91
+ rgb = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2RGB).astype(np.float32) / 255.0
92
+ return rgb
93
+
94
+ # ─── EMOJI badges for ranks ───────────────────────────────────────────────────
95
+ RANK_EMOJI = ["πŸ₯‡", "πŸ₯ˆ", "πŸ₯‰", "4️⃣", "5️⃣"]
96
+
97
+ # ─── Progress-bar helper ──────────────────────────────────────────────────────
98
+ def pct_bar(value: float, width: int = 28) -> str:
99
+ filled = round(value * width)
100
+ return "β–ˆ" * filled + "β–‘" * (width - filled)
101
+
102
+ # ─── Main prediction function ─────────────────────────────────────────────────
103
+ def predict(image, actual_class):
104
+ if image is None:
105
+ empty = "Upload an ultrasound image to begin."
106
+ return empty, empty, empty
107
+
108
+ # ── Preprocess & predict ──────────────────────────────────────────────────
109
+ proc = preprocess(image)
110
+ resized = cv2.resize(proc, (IMG_SIZE, IMG_SIZE))
111
+ inp = np.expand_dims(resized, axis=0)
112
+ probs = model.predict(inp, verbose=0)[0] # shape: (num_classes,)
113
+
114
+ top5_idx = np.argsort(probs)[::-1][:5]
115
+ pred_idx = top5_idx[0]
116
+ pred_label = CLASS_NAMES[pred_idx]
117
+ confidence = probs[pred_idx] * 100.0
118
+
119
+ # ── Panel 1: Prediction cluster ───────────────────────────────────────────
120
+ cluster_lines = [
121
+ "β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”",
122
+ f"β”‚ PREDICTED CLUSTER / CLASS β”‚",
123
+ "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€",
124
+ f"β”‚ {pred_label:<39} β”‚",
125
+ f"β”‚ Confidence : {confidence:>6.2f}% β”‚",
126
+ "β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜",
127
+ "",
128
+ "All cluster probabilities:",
129
+ "─" * 43,
130
+ ]
131
+ for i, (cname, p) in enumerate(zip(CLASS_NAMES, probs)):
132
+ marker = " β—€ PREDICTED" if i == pred_idx else ""
133
+ cluster_lines.append(
134
+ f" {cname:<35} {p*100:5.1f}%{marker}"
135
+ )
136
+ cluster_text = "\n".join(cluster_lines)
137
+
138
+ # ── Panel 2: Top-5 dominant classes ──────────────────────────────────────
139
+ if USE_SOFTMAX_DOMINANT:
140
+ # 16-class model β€” dominant = top-5 softmax outputs
141
+ dominant = [(CLASS_NAMES[i], float(probs[i])) for i in top5_idx]
142
+ source_note = f"(direct softmax outputs from {NUM_OUTPUTS}-class model)"
143
+ else:
144
+ # 3-class cluster model β€” look up pre-computed dominant pathologies
145
+ dominant = CLUSTER_DOMINANT.get(
146
+ pred_label,
147
+ [(f"Class {j}", 0.2) for j in range(5)]
148
+ )
149
+ source_note = f"(pathologies most common in {pred_label})"
150
+
151
+ top5_lines = [
152
+ f"TOP 5 DOMINANT PATHOLOGY CLASSES {source_note}",
153
+ "─" * 63,
154
+ "",
155
+ ]
156
+ for rank, (cname, score) in enumerate(dominant):
157
+ bar = pct_bar(score)
158
+ emoji = RANK_EMOJI[rank]
159
+ top5_lines.append(
160
+ f" {emoji} {cname:<40} {bar} {score*100:5.1f}%"
161
+ )
162
+ top5_text = "\n".join(top5_lines)
163
+
164
+ # ── Panel 3: Actual class comparison ─────────────────────────────────────
165
+ if not actual_class or actual_class == "Unknown / Not provided":
166
+ actual_lines = [
167
+ "ℹ️ No ground-truth label provided.",
168
+ "",
169
+ "Select the actual class from the dropdown",
170
+ "on the left to see a correctness check.",
171
+ ]
172
+ else:
173
+ # For cluster model: check if actual class appears in the top-5 dominant list
174
+ dominant_names = [d[0] for d in dominant]
175
+ in_top5 = actual_class in dominant_names
176
+
177
+ # For 16-class model: direct label match
178
+ if USE_SOFTMAX_DOMINANT:
179
+ correct = (actual_class == pred_label)
180
+ match_str = "βœ… CORRECT PREDICTION" if correct else f"❌ INCORRECT (model predicted '{pred_label}')"
181
+ else:
182
+ # Cluster model: soft match β€” is the actual class in the cluster's top-5?
183
+ if in_top5:
184
+ rank_pos = dominant_names.index(actual_class) + 1
185
+ match_str = f"βœ… CORRECT CLUSTER ('{actual_class}' is #{rank_pos} in {pred_label})"
186
+ else:
187
+ match_str = (
188
+ f"⚠️ PARTIAL MISS ('{actual_class}' not in top-5 of {pred_label})\n"
189
+ f" This may indicate a cluster assignment issue or borderline case."
190
+ )
191
+
192
+ actual_lines = [
193
+ "GROUND TRUTH vs PREDICTION",
194
+ "─" * 43,
195
+ "",
196
+ f" Actual class : {actual_class}",
197
+ f" Predicted : {pred_label} ({confidence:.1f}%)",
198
+ "",
199
+ f" {match_str}",
200
+ "",
201
+ "─" * 43,
202
+ "Top-5 dominant classes in predicted cluster:",
203
+ ]
204
+ for rank, (cname, score) in enumerate(dominant):
205
+ tick = " βœ“" if cname == actual_class else " "
206
+ actual_lines.append(f" {tick} {rank+1}. {cname:<38} {score*100:.1f}%")
207
+
208
+ actual_text = "\n".join(actual_lines)
209
+
210
+ return cluster_text, top5_text, actual_text
211
+
212
+
213
+ # ─── Gradio UI ────────────────────────────────────────────────────────────────
214
+ CSS = """
215
+ body, .gradio-container { background: #0d1117 !important; }
216
+ .gr-box, .gr-panel { background: #161b22 !important; border: 1px solid #30363d !important; }
217
+ .gr-button { background: #238636 !important; color: #fff !important; border: none !important; }
218
+ .gr-button:hover { background: #2ea043 !important; }
219
+ .output-text textarea { font-family: 'Courier New', monospace !important; font-size: 13px !important;
220
+ background: #0d1117 !important; color: #e6edf3 !important;
221
+ border: 1px solid #30363d !important; }
222
+ label span { color: #8b949e !important; }
223
+ h1, h2, h3 { color: #e6edf3 !important; }
224
+ """
225
+
226
+ with gr.Blocks(css=CSS, title="Fetal Brain MRI Classifier 🧠") as demo:
227
+ gr.Markdown("""
228
+ # 🧠 Fetal Brain MRI Classifier
229
+ #### Ultrasound anomaly detection β€” Standard CNN / Xception transfer learning
230
+ Upload a fetal ultrasound image, optionally select the known ground-truth class, then click **Submit**.
231
+ """)
232
+
233
+ with gr.Row():
234
+ # ── Left column: inputs ──────────────────────────────────────────────
235
+ with gr.Column(scale=1):
236
+ image_input = gr.Image(
237
+ type="numpy",
238
+ label="Ultrasound Image",
239
+ image_mode="RGB",
240
+ )
241
+ actual_input = gr.Dropdown(
242
+ choices=ALL_GT_CLASSES,
243
+ value="Unknown / Not provided",
244
+ label="Actual Ground-Truth Class (optional)",
245
+ )
246
+ with gr.Row():
247
+ clear_btn = gr.Button("Clear")
248
+ submit_btn = gr.Button("Submit", variant="primary")
249
+
250
+ # ── Right column: outputs ────────────────────────────────────────────
251
+ with gr.Column(scale=2):
252
+ cluster_out = gr.Textbox(
253
+ label="πŸ† Predicted Cluster / Class",
254
+ lines=14,
255
+ interactive=False,
256
+ )
257
+ top5_out = gr.Textbox(
258
+ label="πŸ“Š Top 5 Dominant Pathology Classes",
259
+ lines=10,
260
+ interactive=False,
261
+ )
262
+ actual_out = gr.Textbox(
263
+ label="βœ… Actual Class Comparison",
264
+ lines=12,
265
+ interactive=False,
266
+ )
267
+
268
+ # ── Wire up events ───────────────────────────────────────────────────────
269
+ submit_btn.click(
270
+ fn=predict,
271
+ inputs=[image_input, actual_input],
272
+ outputs=[cluster_out, top5_out, actual_out],
273
+ )
274
+ clear_btn.click(
275
+ fn=lambda: (None, "Unknown / Not provided", "", "", ""),
276
+ inputs=[],
277
+ outputs=[image_input, actual_input, cluster_out, top5_out, actual_out],
278
+ )
279
+
280
+ demo.launch()
custom_cnn.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2186540e651e7019bb211572387d72b848a65790579322f58726a4a6c3fe9b2a
3
+ size 134080104
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ tensorflow==2.20.0
2
+ gradio
3
+ numpy
4
+ opencv-python-headless