Stefano01 commited on
Commit
a084821
·
verified ·
1 Parent(s): 6344100

Create gradio_app.py

Browse files
Files changed (1) hide show
  1. app/gradio_app.py +711 -0
app/gradio_app.py ADDED
@@ -0,0 +1,711 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime as dt
2
+ import random
3
+ from pathlib import Path
4
+ import os
5
+ import hashlib
6
+ import requests
7
+ import json
8
+ import tempfile
9
+
10
+ import numpy as np
11
+ import gradio as gr
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import torchvision.models as tvm
16
+ import torchvision.transforms as T
17
+ from PIL import Image
18
+ from torchcam.methods import GradCAM, GradCAMpp
19
+ from torchcam.utils import overlay_mask
20
+ from torchvision.datasets import CIFAR10, MNIST, FashionMNIST
21
+
22
+ # Global state for model and configuration
23
+ app_state = {
24
+ "model": None,
25
+ "classes": None,
26
+ "meta": None,
27
+ "transform": None,
28
+ "target_layer": None,
29
+ "dataset": None,
30
+ "dataset_classes": None
31
+ }
32
+
33
+ custom_theme = gr.themes.Soft(
34
+ primary_hue="green", # main brand color
35
+ secondary_hue="green", # accent color
36
+ neutral_hue="slate" # backgrounds/borders/text neutrals
37
+ )
38
+
39
+ def download_release_asset(url: str, dest_dir: str = "saved_checkpoints") -> str:
40
+ """Download a remote checkpoint to dest_dir and return its local path."""
41
+ Path(dest_dir).mkdir(parents=True, exist_ok=True)
42
+ url_hash = hashlib.sha256(url.encode("utf-8")).hexdigest()[:16]
43
+ fname = Path(url).name or f"asset_{url_hash}.ckpt"
44
+ if not fname.endswith(".ckpt"):
45
+ fname = f"{fname}.ckpt"
46
+ local_path = Path(dest_dir) / f"{url_hash}_{fname}"
47
+
48
+ if local_path.exists() and local_path.stat().st_size > 0:
49
+ return str(local_path)
50
+
51
+ with requests.get(url, stream=True, timeout=120) as r:
52
+ r.raise_for_status()
53
+ with open(local_path, "wb") as f:
54
+ for chunk in r.iter_content(chunk_size=1024 * 1024):
55
+ if chunk:
56
+ f.write(chunk)
57
+ return str(local_path)
58
+
59
+
60
+ def load_release_presets() -> dict:
61
+ """Load release preset URLs from multiple sources."""
62
+ # Try environment variable containing JSON mapping
63
+ env_json = os.environ.get("RELEASE_CKPTS_JSON", "").strip()
64
+ if env_json:
65
+ try:
66
+ data = json.loads(env_json)
67
+ if isinstance(data, dict):
68
+ return dict(data)
69
+ except Exception:
70
+ pass
71
+
72
+ # Try local JSON files for dev
73
+ for rel in (".streamlit/presets.json", "presets.json"):
74
+ p = Path(rel)
75
+ if p.exists():
76
+ try:
77
+ with open(p, "r", encoding="utf-8") as f:
78
+ data = json.load(f)
79
+ if isinstance(data, dict) and data:
80
+ if "release_checkpoints" in data and isinstance(data["release_checkpoints"], dict):
81
+ return dict(data["release_checkpoints"])
82
+ return dict(data)
83
+ except Exception:
84
+ pass
85
+
86
+ return {}
87
+
88
+
89
+ def get_device(choice="auto"):
90
+ if choice == "cpu":
91
+ return "cpu"
92
+ if choice == "cuda":
93
+ return "cuda"
94
+ return "cuda" if torch.cuda.is_available() else "cpu"
95
+
96
+
97
+ def denorm_to_pil(x, mean, std):
98
+ """Convert normalized tensor to PIL Image."""
99
+ x = x.detach().cpu().clone()
100
+ if len(mean) == 1:
101
+ # grayscale
102
+ m, s = float(mean[0]), float(std[0])
103
+ x = x * s + m
104
+ x = x.clamp(0, 1)
105
+ pil = T.ToPILImage()(x)
106
+ pil = pil.convert("RGB")
107
+ return pil
108
+ else:
109
+ mean = torch.tensor(mean)[:, None, None]
110
+ std = torch.tensor(std)[:, None, None]
111
+ x = x * std + mean
112
+ x = x.clamp(0, 1)
113
+ return T.ToPILImage()(x)
114
+
115
+
116
+ DATASET_CLASSES = {
117
+ "fashion-mnist": [
118
+ "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
119
+ "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot",
120
+ ],
121
+ "mnist": [str(i) for i in range(10)],
122
+ "cifar10": [
123
+ "airplane", "automobile", "bird", "cat", "deer",
124
+ "dog", "frog", "horse", "ship", "truck",
125
+ ],
126
+ }
127
+
128
+
129
+ def load_raw_dataset(name: str, root="data"):
130
+ """Load the test split with ToTensor() only (for preview)."""
131
+ tt = T.ToTensor()
132
+ if name == "fashion-mnist":
133
+ ds = FashionMNIST(root=root, train=False, download=True, transform=tt)
134
+ elif name == "mnist":
135
+ ds = MNIST(root=root, train=False, download=True, transform=tt)
136
+ elif name == "cifar10":
137
+ ds = CIFAR10(root=root, train=False, download=True, transform=tt)
138
+ else:
139
+ raise ValueError(f"Unknown dataset: {name}")
140
+ classes = getattr(ds, "classes", None) or [str(i) for i in range(10)]
141
+ return ds, classes
142
+
143
+
144
+ def pil_from_tensor(img_tensor, grayscale_to_rgb=True):
145
+ pil = T.ToPILImage()(img_tensor)
146
+ if grayscale_to_rgb and img_tensor.ndim == 3 and img_tensor.shape[0] == 1:
147
+ pil = pil.convert("RGB")
148
+ return pil
149
+
150
+
151
+ class SmallCNN(nn.Module):
152
+ def __init__(self, num_classes=10):
153
+ super().__init__()
154
+ self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
155
+ self.pool1 = nn.MaxPool2d(2, 2)
156
+ self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
157
+ self.pool2 = nn.MaxPool2d(2, 2)
158
+ self.fc = nn.Linear(64 * 7 * 7, num_classes)
159
+
160
+ def forward(self, x):
161
+ x = F.relu(self.conv1(x))
162
+ x = self.pool1(x)
163
+ x = F.relu(self.conv2(x))
164
+ x = self.pool2(x)
165
+ x = torch.flatten(x, 1)
166
+ return self.fc(x)
167
+
168
+
169
+ def load_model_from_ckpt(ckpt_path: Path, device: str):
170
+ ckpt = torch.load(str(ckpt_path), map_location=device)
171
+ classes = ckpt.get("classes", None)
172
+ meta = ckpt.get("meta", {})
173
+ num_classes = len(classes) if classes else 10
174
+ model_name = meta.get("model_name", "smallcnn")
175
+
176
+ if model_name == "smallcnn":
177
+ model = SmallCNN(num_classes=num_classes).to(device)
178
+ default_target_layer = "conv2"
179
+ elif model_name == "resnet18_cifar":
180
+ m = tvm.resnet18(weights=None)
181
+ m.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
182
+ m.maxpool = nn.Identity()
183
+ m.fc = nn.Linear(m.fc.in_features, num_classes)
184
+ model = m.to(device)
185
+ default_target_layer = "layer4"
186
+ elif model_name == "resnet18_imagenet":
187
+ try:
188
+ w = tvm.ResNet18_Weights.IMAGENET1K_V1
189
+ except Exception:
190
+ w = None
191
+ m = tvm.resnet18(weights=w)
192
+ m.fc = nn.Linear(m.fc.in_features, num_classes)
193
+ model = m.to(device)
194
+ default_target_layer = "layer4"
195
+ else:
196
+ raise ValueError(f"Unknown model_name in ckpt: {model_name}")
197
+
198
+ model.load_state_dict(ckpt["model_state"])
199
+ model.eval()
200
+ meta.setdefault("default_target_layer", default_target_layer)
201
+ return model, classes, meta
202
+
203
+
204
+ def build_transform_from_meta(meta):
205
+ img_size = int(meta.get("img_size", 28))
206
+ mean = meta.get("mean", [0.2860])
207
+ std = meta.get("std", [0.3530])
208
+ if len(mean) == 1:
209
+ return T.Compose([
210
+ T.Grayscale(num_output_channels=1),
211
+ T.Resize((img_size, img_size)),
212
+ T.ToTensor(),
213
+ T.Normalize(mean, std),
214
+ ])
215
+ else:
216
+ return T.Compose([
217
+ T.Resize((img_size, img_size)),
218
+ T.ToTensor(),
219
+ T.Normalize(mean, std),
220
+ ])
221
+
222
+
223
+ def predict_and_cam(model, x, device, target_layer, topk=3, method="Grad-CAM"):
224
+ """Predict and generate CAM for top-k classes."""
225
+ cam_cls = GradCAM if method == "Grad-CAM" else GradCAMpp
226
+ cam_extractor = cam_cls(model, target_layer=target_layer)
227
+
228
+ logits = model(x.to(device))
229
+ probs = torch.softmax(logits, dim=1)[0].detach().cpu()
230
+ top_vals, top_idxs = probs.topk(topk)
231
+
232
+ results = []
233
+ for rank, (p, idx) in enumerate(zip(top_vals.tolist(), top_idxs.tolist())):
234
+ retain = rank < topk - 1
235
+ cams = cam_extractor(idx, logits, retain_graph=retain)
236
+ cam = cams[0].detach().cpu()
237
+ results.append({
238
+ "rank": rank + 1,
239
+ "class_index": int(idx),
240
+ "prob": float(p),
241
+ "cam": cam
242
+ })
243
+ return results, probs
244
+
245
+
246
+ def overlay_pil(base_pil_rgb: Image.Image, cam_tensor, alpha=0.5):
247
+ """Create overlay of CAM on base image."""
248
+ cam = cam_tensor.clone()
249
+ cam -= cam.min()
250
+ cam = cam / (cam.max() + 1e-8)
251
+ heat = T.ToPILImage()(cam)
252
+ return overlay_mask(base_pil_rgb, heat, alpha=alpha)
253
+
254
+
255
+ # Gradio interface functions
256
+ def load_checkpoint_from_url(url, preset_name):
257
+ """Load checkpoint from URL or preset."""
258
+ presets = load_release_presets()
259
+
260
+ if preset_name and preset_name != "None":
261
+ url = presets.get(preset_name, "")
262
+
263
+ if not url:
264
+ return "❌ No URL provided", "", ""
265
+
266
+ try:
267
+ ckpt_path = download_release_asset(url)
268
+ device = get_device("cpu")
269
+ model, classes, meta = load_model_from_ckpt(Path(ckpt_path), device)
270
+
271
+ # Update global state
272
+ app_state["model"] = model
273
+ app_state["classes"] = classes
274
+ app_state["meta"] = meta
275
+ app_state["transform"] = build_transform_from_meta(meta)
276
+ app_state["target_layer"] = meta.get("default_target_layer", "conv2")
277
+
278
+ # Load dataset for samples
279
+ ds_name = meta.get("dataset", "fashion-mnist")
280
+ try:
281
+ dataset, dataset_classes = load_raw_dataset(ds_name)
282
+ app_state["dataset"] = dataset
283
+ app_state["dataset_classes"] = dataset_classes
284
+ except:
285
+ app_state["dataset"] = None
286
+ app_state["dataset_classes"] = None
287
+
288
+ meta_info = {
289
+ "dataset": meta.get("dataset"),
290
+ "model_name": meta.get("model_name"),
291
+ "img_size": meta.get("img_size"),
292
+ "target_layer": app_state["target_layer"],
293
+ "mean": meta.get("mean"),
294
+ "std": meta.get("std"),
295
+ "classes": len(classes) if classes else "N/A"
296
+ }
297
+
298
+ # Create class choices for filter
299
+ class_choices = ["(any)"] + (dataset_classes if app_state["dataset"] else [])
300
+ max_samples = len(dataset) - 1 if app_state["dataset"] else 0
301
+
302
+ return (f"✅ Loaded: {ckpt_path}", json.dumps(meta_info, indent=2),
303
+ gr.update(visible=True), gr.update(choices=class_choices, value="(any)", visible=True),
304
+ gr.update(visible=True, maximum=max_samples, value=0), gr.update(visible=True, value=""))
305
+
306
+ except Exception as e:
307
+ return f"❌ Failed: {str(e)}", "", gr.update(visible=False), gr.update(choices=["(any)"], value="(any)"), gr.update(visible=False), gr.update(choices=["(any)"], value="(any)"), gr.update(visible=False)
308
+
309
+
310
+ def load_checkpoint_from_file(file):
311
+ """Load checkpoint from uploaded file."""
312
+ if file is None:
313
+ return "❌ No file uploaded", "", ""
314
+
315
+ try:
316
+ # Save uploaded file temporarily
317
+ Path("saved_checkpoints").mkdir(parents=True, exist_ok=True)
318
+ with open(file.name, "rb") as f:
319
+ content = f.read()
320
+
321
+ content_hash = hashlib.sha256(content).hexdigest()[:16]
322
+ base_name = Path(file.name).name
323
+ if not base_name.endswith(".ckpt"):
324
+ base_name = f"{base_name}.ckpt"
325
+ local_path = Path("saved_checkpoints") / f"{content_hash}_{base_name}"
326
+
327
+ with open(local_path, "wb") as f:
328
+ f.write(content)
329
+
330
+ device = get_device("cpu")
331
+ model, classes, meta = load_model_from_ckpt(local_path, device)
332
+
333
+ # Update global state
334
+ app_state["model"] = model
335
+ app_state["classes"] = classes
336
+ app_state["meta"] = meta
337
+ app_state["transform"] = build_transform_from_meta(meta)
338
+ app_state["target_layer"] = meta.get("default_target_layer", "conv2")
339
+
340
+ # Load dataset for samples
341
+ ds_name = meta.get("dataset", "fashion-mnist")
342
+ try:
343
+ dataset, dataset_classes = load_raw_dataset(ds_name)
344
+ app_state["dataset"] = dataset
345
+ app_state["dataset_classes"] = dataset_classes
346
+ except:
347
+ app_state["dataset"] = None
348
+ app_state["dataset_classes"] = None
349
+
350
+ meta_info = {
351
+ "dataset": meta.get("dataset"),
352
+ "model_name": meta.get("model_name"),
353
+ "img_size": meta.get("img_size"),
354
+ "target_layer": app_state["target_layer"],
355
+ "mean": meta.get("mean"),
356
+ "std": meta.get("std"),
357
+ "classes": len(classes) if classes else "N/A"
358
+ }
359
+
360
+ # Create class choices for filter
361
+ class_choices = ["(any)"] + (dataset_classes if app_state["dataset"] else [])
362
+ max_samples = len(dataset) - 1 if app_state["dataset"] else 0
363
+
364
+ return (f"✅ Loaded: {local_path}", json.dumps(meta_info, indent=2),
365
+ gr.update(visible=True), gr.update(choices=class_choices, value="(any)", visible=True),
366
+ gr.update(visible=True, maximum=max_samples, value=0), gr.update(visible=True, value=""))
367
+
368
+ except Exception as e:
369
+ return f"❌ Failed: {str(e)}", "", gr.update(visible=False)
370
+
371
+
372
+ def get_random_sample(class_filter="(any)"):
373
+ """Get a random sample from the (optionally filtered) dataset."""
374
+ if app_state["dataset"] is None:
375
+ return None, "No dataset loaded", gr.update(visible=False)
376
+
377
+ dataset = app_state["dataset"]
378
+ dataset_classes = app_state["dataset_classes"]
379
+
380
+ # Build candidate indices according to filter
381
+ if class_filter != "(any)":
382
+ targets = np.array([dataset[i][1] for i in range(len(dataset))])
383
+ class_id = dataset_classes.index(class_filter)
384
+ filtered_indices = np.where(targets == class_id)[0]
385
+ if len(filtered_indices) == 0:
386
+ return None, f"No samples found for class: {class_filter}", gr.update(visible=True, maximum=0, value=0)
387
+ actual_idx = int(random.choice(filtered_indices))
388
+ # slider index is relative to the filtered list length
389
+ slider_max = len(filtered_indices) - 1
390
+ slider_value = int(np.where(filtered_indices == actual_idx)[0][0])
391
+ else:
392
+ actual_idx = random.randint(0, len(dataset) - 1)
393
+ slider_max = len(dataset) - 1
394
+ slider_value = actual_idx
395
+
396
+ img_tensor, label = dataset[actual_idx]
397
+ sample_img = pil_from_tensor(img_tensor, grayscale_to_rgb=True)
398
+ sample_img = double_height(sample_img)
399
+ class_name = dataset_classes[label] if dataset_classes else str(label)
400
+ caption = f"Sample {actual_idx} from {app_state['meta'].get('dataset', 'dataset')} • class: {class_name}"
401
+
402
+ # Update slider to the picked index inside the current filter's range
403
+ return sample_img, caption, gr.update(visible=True, maximum=slider_max, value=slider_value)
404
+
405
+
406
+ def get_sample_by_index(idx, class_filter):
407
+ """Get a specific sample by index with optional class filtering."""
408
+ if app_state["dataset"] is None:
409
+ return None, "No dataset loaded"
410
+
411
+ dataset = app_state["dataset"]
412
+ dataset_classes = app_state["dataset_classes"]
413
+
414
+ # Apply class filter
415
+ if class_filter != "(any)":
416
+ targets = np.array([dataset[i][1] for i in range(len(dataset))])
417
+ class_id = dataset_classes.index(class_filter)
418
+ filtered_indices = np.where(targets == class_id)[0]
419
+
420
+ if len(filtered_indices) == 0:
421
+ return None, f"No samples found for class: {class_filter}"
422
+
423
+ # Clamp index to filtered range
424
+ idx = max(0, min(idx, len(filtered_indices) - 1))
425
+ actual_idx = filtered_indices[idx]
426
+ else:
427
+ # Clamp index to dataset range
428
+ idx = max(0, min(idx, len(dataset) - 1))
429
+ actual_idx = idx
430
+
431
+ img_tensor, label = dataset[actual_idx]
432
+ sample_img = pil_from_tensor(img_tensor, grayscale_to_rgb=True)
433
+ sample_img = double_height(sample_img)
434
+ class_name = dataset_classes[label] if dataset_classes else str(label)
435
+ caption = f"Sample {actual_idx} from {app_state['meta'].get('dataset', 'dataset')} • class: {class_name}"
436
+
437
+ return sample_img, caption
438
+
439
+
440
+ def update_class_filter(class_filter):
441
+ """Update the slider range when class filter changes."""
442
+ if app_state["dataset"] is None:
443
+ return gr.update(visible=False, maximum=0, value=0)
444
+
445
+ dataset = app_state["dataset"]
446
+ dataset_classes = app_state["dataset_classes"]
447
+
448
+ if class_filter == "(any)":
449
+ max_idx = len(dataset) - 1
450
+ else:
451
+ targets = np.array([dataset[i][1] for i in range(len(dataset))])
452
+ class_id = dataset_classes.index(class_filter)
453
+ filtered_indices = np.where(targets == class_id)[0]
454
+ max_idx = len(filtered_indices) - 1 if len(filtered_indices) > 0 else 0
455
+
456
+ return gr.update(visible=True, maximum=max_idx, value=0)
457
+
458
+
459
+ def double_height(img: Image.Image) -> Image.Image:
460
+ """Return a copy of the image with doubled height."""
461
+ w, h = img.size
462
+ return img.resize((w * 10, h * 10), Image.Resampling.NEAREST)
463
+
464
+
465
+ def process_image(image, method, topk, alpha):
466
+ """Process image and generate Grad-CAM visualizations."""
467
+ if app_state["model"] is None:
468
+ return "❌ No model loaded", [], []
469
+
470
+ if image is None:
471
+ return "❌ No image provided", [], []
472
+
473
+ try:
474
+ # Convert to PIL if needed
475
+ if isinstance(image, np.ndarray):
476
+ image = Image.fromarray(image)
477
+
478
+ # Prepare image
479
+ pil = image.convert("RGB")
480
+ x = app_state["transform"](pil)
481
+ x_batched = x.unsqueeze(0)
482
+
483
+ # Generate base image for overlay
484
+ base_pil = denorm_to_pil(
485
+ x,
486
+ app_state["meta"].get("mean", [0.2860]),
487
+ app_state["meta"].get("std", [0.3530])
488
+ )
489
+
490
+ # Run prediction and CAM
491
+ device = get_device("cpu")
492
+ cam_results, probs = predict_and_cam(
493
+ app_state["model"], x_batched, device,
494
+ app_state["target_layer"], topk=topk, method=method
495
+ )
496
+
497
+ # Create predictions table
498
+ predictions = []
499
+ for r in cam_results:
500
+ class_name = app_state["classes"][r["class_index"]] if app_state["classes"] else str(r["class_index"])
501
+ predictions.append([
502
+ r["rank"],
503
+ class_name,
504
+ r["class_index"],
505
+ f"{r['prob']:.4f}"
506
+ ])
507
+
508
+ # Create overlay images
509
+ overlays = []
510
+ for r in cam_results:
511
+ class_name = app_state["classes"][r["class_index"]] if app_state["classes"] else str(r["class_index"])
512
+ overlay_img = overlay_pil(base_pil, r["cam"], alpha=alpha)
513
+ overlays.append((overlay_img, f"Top{r['rank']}: {class_name} ({r['prob']:.3f})"))
514
+
515
+ return "✅ Processing complete", predictions, overlays
516
+
517
+ except Exception as e:
518
+ return f"❌ Processing failed: {str(e)}", [], []
519
+
520
+
521
+ # Create Gradio interface
522
+ def create_interface():
523
+ presets = load_release_presets()
524
+ preset_choices = ["None"] + list(presets.keys()) if presets else ["None"]
525
+
526
+ with gr.Blocks(css="""
527
+ .alert {
528
+ padding: 10px 15px;
529
+ background-color: #FFF3CD;
530
+ color: #856404;
531
+ border: 1px solid #FFEEBA;
532
+ border-radius: 6px;
533
+ position: relative;
534
+ text-color: #856404;
535
+ }
536
+ """, theme=custom_theme) as demo:
537
+ gr.Markdown("# 🔍 Grad-CAM Demo — Upload an image, get top-k predictions + heatmaps")
538
+
539
+ with gr.Row():
540
+ with gr.Column(scale=1):
541
+ gr.Markdown("## Settings")
542
+
543
+ # Checkpoint loading
544
+ gr.Markdown("### Load Checkpoint")
545
+ with gr.Group():
546
+ preset_dropdown = gr.Dropdown(
547
+ choices=preset_choices,
548
+ value="None",
549
+ label="Preset (GitHub Releases)"
550
+ )
551
+ url_input = gr.Textbox(
552
+ label="Or paste asset URL",
553
+ placeholder="https://github.com/user/repo/releases/download/..."
554
+ )
555
+ url_button = gr.Button("Download from URL", variant="primary")
556
+
557
+ with gr.Group():
558
+ file_input = gr.File(
559
+ label="Upload checkpoint (.ckpt)",
560
+ file_types=[".ckpt"]
561
+ )
562
+ file_button = gr.Button("Load uploaded file", variant="primary")
563
+
564
+ status_text = gr.Textbox(
565
+ label="Status",
566
+ interactive=False,
567
+ value="No checkpoint loaded"
568
+ )
569
+
570
+ meta_display = gr.Code(
571
+ label="Model Metadata",
572
+ language="json",
573
+ interactive=False
574
+ )
575
+
576
+ # Processing options
577
+ gr.Markdown("### Processing Options")
578
+ method_radio = gr.Radio(
579
+ choices=["Grad-CAM", "Grad-CAM++"],
580
+ value="Grad-CAM",
581
+ label="CAM Method"
582
+ )
583
+ topk_slider = gr.Slider(
584
+ minimum=1, maximum=10, value=3, step=1,
585
+ label="Top-k classes"
586
+ )
587
+ alpha_slider = gr.Slider(
588
+ minimum=0.1, maximum=0.9, value=0.5, step=0.05,
589
+ label="Overlay alpha"
590
+ )
591
+
592
+ with gr.Column(scale=2):
593
+ gr.Markdown("## Image Input")
594
+
595
+ size_alert = gr.Markdown(
596
+ value="""
597
+ <div class="alert">
598
+ ⚠️ Image was resized for better visualization — not equal to the dataset’s original size.
599
+ </div>
600
+ """,
601
+ elem_id="size-alert"
602
+ )
603
+
604
+ with gr.Group():
605
+
606
+ image_input = gr.Image(
607
+ label="Upload Image",
608
+ type="pil",
609
+ height=400,
610
+ )
611
+
612
+ with gr.Row():
613
+ sample_button = gr.Button("Random Sample", visible=False)
614
+
615
+ with gr.Group():
616
+ gr.Markdown("**Dataset Sample Browser**")
617
+ class_filter = gr.Dropdown(
618
+ label="Filter by class",
619
+ choices=["(any)"],
620
+ value="(any)",
621
+ visible=False
622
+ )
623
+ sample_slider = gr.Slider(
624
+ label="Sample index",
625
+ minimum=0,
626
+ maximum=0,
627
+ value=0,
628
+ step=1,
629
+ visible=False,
630
+ interactive=True
631
+ )
632
+ sample_info = gr.Textbox(
633
+ label="Sample Info",
634
+ interactive=False,
635
+ visible=False
636
+ )
637
+
638
+ process_button = gr.Button("🔍 Process Image", variant="primary", size="lg")
639
+ process_status = gr.Textbox(
640
+ label="Processing Status",
641
+ interactive=False
642
+ )
643
+
644
+ gr.Markdown("## Results")
645
+
646
+ with gr.Group():
647
+ gr.Markdown("### Top-k Predictions")
648
+ predictions_table = gr.Dataframe(
649
+ headers=["Rank", "Class", "Index", "Probability"],
650
+ datatype=["number", "str", "number", "str"],
651
+ interactive=False
652
+ )
653
+
654
+ with gr.Group():
655
+ gr.Markdown("### Grad-CAM Overlays")
656
+ overlay_gallery = gr.Gallery(
657
+ label="CAM Overlays",
658
+ show_label=False,
659
+ elem_id="gallery",
660
+ columns=3,
661
+ object_fit="contain",
662
+ height="auto"
663
+ )
664
+
665
+ # Event handlers
666
+ url_button.click(
667
+ fn=load_checkpoint_from_url,
668
+ inputs=[url_input, preset_dropdown],
669
+ outputs=[status_text, meta_display, sample_button, class_filter, sample_slider, sample_info]
670
+ )
671
+
672
+ file_button.click(
673
+ fn=load_checkpoint_from_file,
674
+ inputs=[file_input],
675
+ outputs=[status_text, meta_display, sample_button, class_filter, sample_slider, sample_info]
676
+ )
677
+
678
+ sample_button.click(
679
+ fn=get_random_sample,
680
+ inputs=[class_filter],
681
+ outputs=[image_input, sample_info, sample_slider]
682
+ )
683
+
684
+ class_filter.change(
685
+ fn=update_class_filter,
686
+ inputs=[class_filter],
687
+ outputs=[sample_slider]
688
+ )
689
+
690
+ sample_slider.change(
691
+ fn=get_sample_by_index,
692
+ inputs=[sample_slider, class_filter],
693
+ outputs=[image_input, sample_info]
694
+ )
695
+
696
+ process_button.click(
697
+ fn=process_image,
698
+ inputs=[image_input, method_radio, topk_slider, alpha_slider],
699
+ outputs=[process_status, predictions_table, overlay_gallery]
700
+ )
701
+
702
+ return demo
703
+
704
+
705
+ if __name__ == "__main__":
706
+ demo = create_interface()
707
+ demo.launch(
708
+ share=True,
709
+ server_name="0.0.0.0",
710
+ server_port=7860
711
+ )