JuanHernandez-uc commited on
Commit
d28b98f
·
1 Parent(s): abc4d20

add app.py

Browse files
Files changed (6) hide show
  1. Dockerfile +35 -0
  2. app.py +648 -0
  3. dino_chestmnist_head.pt +3 -0
  4. rad_dino_chestmnist_head.pt +3 -0
  5. requirements.txt +14 -0
  6. test.png +0 -0
Dockerfile ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ ENV PYTHONDONTWRITEBYTECODE=1 \
4
+ PYTHONUNBUFFERED=1 \
5
+ PIP_NO_CACHE_DIR=1 \
6
+ PORT=7860 \
7
+ HF_HOME=/home/user/.cache/huggingface \
8
+ TORCH_HOME=/home/user/.cache/torch \
9
+ OMP_NUM_THREADS=1 \
10
+ MKL_NUM_THREADS=1
11
+
12
+ RUN apt-get update && apt-get install -y --no-install-recommends \
13
+ ca-certificates \
14
+ libgomp1 \
15
+ git \
16
+ && rm -rf /var/lib/apt/lists/*
17
+
18
+ RUN useradd -m -u 1000 user
19
+ USER user
20
+
21
+ ENV HOME=/home/user \
22
+ PATH=/home/user/.local/bin:$PATH
23
+
24
+ WORKDIR $HOME/app
25
+
26
+ COPY --chown=user:user requirements.txt .
27
+ RUN pip install --upgrade pip && pip install -r requirements.txt
28
+
29
+ COPY --chown=user:user . .
30
+
31
+ RUN git clone --depth 1 https://github.com/facebookresearch/dinov2.git
32
+
33
+ EXPOSE 7860
34
+
35
+ CMD ["sh", "-c", "uvicorn app:app --host 0.0.0.0 --port ${PORT:-7860} --workers 1"]
app.py ADDED
@@ -0,0 +1,648 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import io
3
+ import os
4
+ import uuid
5
+ import threading
6
+ import hashlib
7
+ from contextvars import ContextVar
8
+ from typing import Optional, Dict, Any, List
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ from PIL import Image
14
+ from huggingface_hub import hf_hub_download
15
+ from safetensors.torch import load_file
16
+
17
+ from fastapi import FastAPI, UploadFile, File, Query, HTTPException
18
+ from fastapi.middleware.cors import CORSMiddleware
19
+ from fastapi.responses import JSONResponse
20
+
21
+ # ============================================================
22
+ # Config
23
+ # ============================================================
24
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
+
26
+ MODEL_IMG_SIZE = 518
27
+ ALLOW_ORIGINS = os.environ.get("ALLOW_ORIGINS", "*").split(",")
28
+
29
+ # RAD-DINO checkpoint en HF
30
+ RAD_BACKBONE_REPO_ID = "microsoft/rad-dino"
31
+ RAD_BACKBONE_FILENAME = "backbone_compatible.safetensors"
32
+
33
+ # Heads
34
+ RAD_HEAD_CKPT_PATH = os.environ.get("RAD_HEAD_CKPT_PATH", "rad_dino_chestmnist_head.pt")
35
+ DINO_HEAD_CKPT_PATH = os.environ.get("DINO_HEAD_CKPT_PATH", "dino_chestmnist_head.pt")
36
+
37
+ # Normalización
38
+ RAD_MEAN = torch.tensor([0.5307, 0.5307, 0.5307], dtype=torch.float32).view(3, 1, 1)
39
+ RAD_STD = torch.tensor([0.2583, 0.2583, 0.2583], dtype=torch.float32).view(3, 1, 1)
40
+
41
+ # DINOv2 usual / ImageNet normalization
42
+ DINO_MEAN = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32).view(3, 1, 1)
43
+ DINO_STD = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32).view(3, 1, 1)
44
+
45
+ DEFAULT_LABEL_NAMES = [
46
+ "atelectasis", "cardiomegaly", "effusion", "infiltration",
47
+ "mass", "nodule", "pneumonia", "pneumothorax",
48
+ "consolidation", "edema", "emphysema", "fibrosis",
49
+ "pleural", "hernia"
50
+ ]
51
+
52
+ MODEL_CONFIGS = {
53
+ "rad-dino": {
54
+ "backbone_type": "rad-dino",
55
+ "head_ckpt_path": RAD_HEAD_CKPT_PATH,
56
+ "model_name": "rad-dino-chestmnist",
57
+ "mean": RAD_MEAN,
58
+ "std": RAD_STD,
59
+ },
60
+ "dino": {
61
+ "backbone_type": "dino",
62
+ "head_ckpt_path": DINO_HEAD_CKPT_PATH,
63
+ "model_name": "dino-chestmnist",
64
+ "mean": DINO_MEAN,
65
+ "std": DINO_STD,
66
+ },
67
+ }
68
+
69
+
70
+ # ============================================================
71
+ # Model definitions
72
+ # ============================================================
73
+ class MedicalHead(nn.Module):
74
+ def __init__(self, in_dim: int = 768, num_classes: int = 14, dropout: float = 0.1):
75
+ super().__init__()
76
+ self.drop = nn.Dropout(dropout)
77
+ self.fc = nn.Linear(in_dim, num_classes)
78
+
79
+ def forward(self, cls_token: torch.Tensor) -> torch.Tensor:
80
+ return self.fc(self.drop(cls_token))
81
+
82
+
83
+ def round_tensor(t: torch.Tensor, decimals: int = 4) -> torch.Tensor:
84
+ s = 10 ** decimals
85
+ return torch.round(t * s) / s
86
+
87
+
88
+ def preprocess_pil(pil_img: Image.Image, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
89
+ img = pil_img.convert("RGB").resize((MODEL_IMG_SIZE, MODEL_IMG_SIZE), Image.BICUBIC)
90
+ arr = np.array(img).astype("float32") / 255.0
91
+ x = torch.from_numpy(arr).permute(2, 0, 1) # [3,H,W]
92
+ x = (x - mean) / std
93
+ return x.unsqueeze(0) # [1,3,H,W]
94
+
95
+
96
+ # ============================================================
97
+ # Build backbones
98
+ # ============================================================
99
+ def ensure_local_dinov2_repo():
100
+ if not os.path.exists("./dinov2"):
101
+ raise FileNotFoundError(
102
+ "No encontré ./dinov2. Clona el repo primero con:\n"
103
+ "git clone https://github.com/facebookresearch/dinov2.git"
104
+ )
105
+
106
+
107
+ def disable_fused_attn(model: nn.Module):
108
+ for blk in model.blocks:
109
+ if hasattr(blk.attn, "fused_attn"):
110
+ blk.attn.fused_attn = False
111
+
112
+
113
+ def build_dinov2_backbone() -> nn.Module:
114
+ ensure_local_dinov2_repo()
115
+ model = torch.hub.load("./dinov2", "dinov2_vitb14", source="local")
116
+ model.eval().to(DEVICE)
117
+ disable_fused_attn(model)
118
+ return model
119
+
120
+
121
+ def build_rad_dino_backbone() -> nn.Module:
122
+ model = build_dinov2_backbone()
123
+
124
+ backbone_path = hf_hub_download(
125
+ repo_id=RAD_BACKBONE_REPO_ID,
126
+ filename=RAD_BACKBONE_FILENAME
127
+ )
128
+ state = load_file(backbone_path)
129
+ model.load_state_dict(state, strict=True)
130
+ model.eval().to(DEVICE)
131
+ disable_fused_attn(model)
132
+ return model
133
+
134
+
135
+ def build_head(head_ckpt_path: str) -> tuple[nn.Module, Dict[str, Any], List[str]]:
136
+ ckpt = torch.load(head_ckpt_path, map_location=DEVICE)
137
+ label_names = ckpt.get("label_names", DEFAULT_LABEL_NAMES)
138
+ num_classes = len(label_names)
139
+
140
+ head = MedicalHead(in_dim=768, num_classes=num_classes, dropout=0.1).to(DEVICE)
141
+ head.load_state_dict(ckpt["head_state_dict"])
142
+ head.eval()
143
+
144
+ return head, ckpt, label_names
145
+
146
+
147
+ def build_model_bundle(model_key: str, cfg: Dict[str, Any]) -> Dict[str, Any]:
148
+ if cfg["backbone_type"] == "rad-dino":
149
+ backbone = build_rad_dino_backbone()
150
+ elif cfg["backbone_type"] == "dino":
151
+ backbone = build_dinov2_backbone()
152
+ else:
153
+ raise ValueError(f"backbone_type desconocido: {cfg['backbone_type']}")
154
+
155
+ head, ckpt, label_names = build_head(cfg["head_ckpt_path"])
156
+
157
+ bundle = {
158
+ "key": model_key,
159
+ "model_name": cfg["model_name"],
160
+ "backbone_type": cfg["backbone_type"],
161
+ "backbone": backbone,
162
+ "head": head,
163
+ "head_ckpt": ckpt,
164
+ "label_names": label_names,
165
+ "mean": cfg["mean"],
166
+ "std": cfg["std"],
167
+ "num_layers": len(backbone.blocks),
168
+ "num_heads": getattr(backbone.blocks[0].attn, "num_heads", None),
169
+ "current": {
170
+ "hash": None,
171
+ "attention_cls_full": None,
172
+ "logit_lens_full": None,
173
+ },
174
+ "results": {},
175
+ "lock": threading.Lock(),
176
+ }
177
+ return bundle
178
+
179
+
180
+ # ============================================================
181
+ # Hook registration per model
182
+ # ============================================================
183
+ def register_hooks(bundle: Dict[str, Any]):
184
+ _attn_in_var: ContextVar[Optional[list]] = ContextVar(
185
+ f"_attn_in_var_{bundle['key']}", default=None
186
+ )
187
+ _tok_var: ContextVar[Optional[list]] = ContextVar(
188
+ f"_tok_var_{bundle['key']}", default=None
189
+ )
190
+
191
+ def _save_attn_input(module, inp):
192
+ lst = _attn_in_var.get()
193
+ if lst is None:
194
+ return
195
+ if len(inp) == 0 or not torch.is_tensor(inp[0]):
196
+ return
197
+ # input to attn: [B, N, D]
198
+ lst.append(inp[0].detach())
199
+
200
+ def _save_block_out(module, inp, out):
201
+ lst = _tok_var.get()
202
+ if lst is None:
203
+ return
204
+ if torch.is_tensor(out):
205
+ # block output: [B, N, D]
206
+ lst.append(out.detach())
207
+
208
+ attn_hooks = []
209
+ tok_hooks = []
210
+
211
+ for blk in bundle["backbone"].blocks:
212
+ if not hasattr(blk, "attn"):
213
+ raise RuntimeError(f"No encontré blk.attn en backbone {bundle['key']}")
214
+ attn_hooks.append(blk.attn.register_forward_pre_hook(_save_attn_input))
215
+ tok_hooks.append(blk.register_forward_hook(_save_block_out))
216
+
217
+ bundle["_attn_in_var"] = _attn_in_var
218
+ bundle["_tok_var"] = _tok_var
219
+ bundle["_attn_hooks"] = attn_hooks
220
+ bundle["_tok_hooks"] = tok_hooks
221
+
222
+
223
+ # ============================================================
224
+ # Build all models
225
+ # ============================================================
226
+ MODELS: Dict[str, Dict[str, Any]] = {
227
+ key: build_model_bundle(key, cfg)
228
+ for key, cfg in MODEL_CONFIGS.items()
229
+ }
230
+
231
+ for _bundle in MODELS.values():
232
+ register_hooks(_bundle)
233
+
234
+ for key, bundle in MODELS.items():
235
+ print(f"[server] model_key={key}")
236
+ print(f"[server] model_name={bundle['model_name']}")
237
+ print(f"[server] backbone_type={bundle['backbone_type']} device={DEVICE}")
238
+ print(f"[server] head_ckpt={MODEL_CONFIGS[key]['head_ckpt_path']}")
239
+ print(f"[server] num_layers={bundle['num_layers']} num_heads={bundle['num_heads']}")
240
+ print(f"[server] num_classes={len(bundle['label_names'])}")
241
+ if "best_val_auc" in bundle["head_ckpt"]:
242
+ print(f"[server] checkpoint best_val_auc={bundle['head_ckpt']['best_val_auc']:.4f}")
243
+
244
+
245
+ # ============================================================
246
+ # Inference helpers
247
+ # ============================================================
248
+ @torch.no_grad()
249
+ def extract_cls(backbone: nn.Module, images: torch.Tensor) -> torch.Tensor:
250
+ feats = backbone.forward_features(images)
251
+ if "x_norm_clstoken" not in feats:
252
+ raise RuntimeError("forward_features no devolvió 'x_norm_clstoken'.")
253
+ return feats["x_norm_clstoken"]
254
+
255
+
256
+ @torch.no_grad()
257
+ def compute_logit_lens_from_tokens(tokens_per_layer: List[torch.Tensor], head: nn.Module):
258
+ logits_list = []
259
+ probs_list = []
260
+
261
+ for x_l in tokens_per_layer:
262
+ # x_l: [B, N, D]
263
+ cls_l = x_l[:, 0] # [B, D]
264
+ logits_l = head(cls_l) # [B, C]
265
+ probs_l = torch.sigmoid(logits_l)
266
+
267
+ logits_list.append(logits_l.detach().cpu())
268
+ probs_list.append(probs_l.detach().cpu())
269
+
270
+ logits_per_layer = torch.stack(logits_list, dim=0) # [L, B, C]
271
+ probs_per_layer = torch.stack(probs_list, dim=0) # [L, B, C]
272
+ return logits_per_layer, probs_per_layer
273
+
274
+
275
+ @torch.no_grad()
276
+ def compute_cls_attention_from_inputs(backbone: nn.Module, attn_inputs: List[torch.Tensor]):
277
+ """
278
+ Reconstruct CLS->tokens attention per layer from the input to attention.
279
+ Returns list of [B, H, N], one per layer.
280
+ """
281
+ cls_attn_per_layer = []
282
+
283
+ for blk, x in zip(backbone.blocks, attn_inputs):
284
+ x = x.to(DEVICE) # [B, N, D]
285
+
286
+ B, N, C = x.shape
287
+ num_heads = blk.attn.num_heads
288
+ head_dim = C // num_heads
289
+
290
+ qkv = blk.attn.qkv(x) # [B, N, 3*C]
291
+ qkv = qkv.reshape(B, N, 3, num_heads, head_dim).permute(2, 0, 3, 1, 4)
292
+ q, k, v = qkv[0], qkv[1], qkv[2] # [B, H, N, Hd]
293
+
294
+ attn = (q @ k.transpose(-2, -1)) * blk.attn.scale
295
+ attn = attn.softmax(dim=-1)
296
+
297
+ cls_attn = attn[:, :, 0, :].detach().cpu() # [B, H, N]
298
+ cls_attn_per_layer.append(cls_attn)
299
+
300
+ return cls_attn_per_layer
301
+
302
+
303
+ def analyze_image(bundle: Dict[str, Any], pil_img: Image.Image) -> Dict[str, Any]:
304
+ x = preprocess_pil(pil_img, bundle["mean"], bundle["std"]).to(DEVICE)
305
+
306
+ attn_inputs = []
307
+ layer_tokens = []
308
+
309
+ tok_token = bundle["_tok_var"].set(layer_tokens)
310
+ attn_token = bundle["_attn_in_var"].set(attn_inputs)
311
+
312
+ try:
313
+ with torch.no_grad():
314
+ with bundle["lock"]:
315
+ cls_final = extract_cls(bundle["backbone"], x) # [1, 768]
316
+ logits_final = bundle["head"](cls_final) # [1, C]
317
+
318
+ probs_final = torch.sigmoid(logits_final)[0].detach().cpu()
319
+ probs_final = round_tensor(probs_final, 6)
320
+
321
+ if len(layer_tokens) == 0:
322
+ raise RuntimeError("No se capturaron tokens por capa.")
323
+ if len(attn_inputs) == 0:
324
+ raise RuntimeError("No se capturaron entradas a atención por capa.")
325
+
326
+ logits_by_layer, probs_by_layer = compute_logit_lens_from_tokens(
327
+ layer_tokens, bundle["head"]
328
+ )
329
+ attn_maps = compute_cls_attention_from_inputs(bundle["backbone"], attn_inputs)
330
+
331
+ # ----------------------------------------------------
332
+ # attention_cls_full
333
+ # ----------------------------------------------------
334
+ attn_maps2 = [a.squeeze(0) for a in attn_maps] # list of [H, N]
335
+
336
+ attn_serializable_all = []
337
+ attn_serializable_patches = []
338
+
339
+ for layer in attn_maps2:
340
+ layer_all = []
341
+ layer_patches = []
342
+ for head in layer:
343
+ head = round_tensor(head, 4) # [N]
344
+ layer_all.append(head.tolist())
345
+ layer_patches.append(head[1:].tolist()) # remove CLS->CLS
346
+ attn_serializable_all.append(layer_all)
347
+ attn_serializable_patches.append(layer_patches)
348
+
349
+ num_tokens_all = len(attn_serializable_all[0][0])
350
+ num_patch_tokens = len(attn_serializable_patches[0][0])
351
+
352
+ export_attn = {
353
+ "model": bundle["model_name"],
354
+ "attention_type": "cls_only",
355
+ "num_layers": len(attn_serializable_all),
356
+ "num_heads": len(attn_serializable_all[0]),
357
+ "num_tokens_all": num_tokens_all,
358
+ "num_patch_tokens": num_patch_tokens,
359
+ "cls_index": 0,
360
+ "attention_cls_to_all_tokens": attn_serializable_all,
361
+ "attention_cls_to_patches": attn_serializable_patches,
362
+ }
363
+
364
+ # ----------------------------------------------------
365
+ # logit_lens_full
366
+ # ----------------------------------------------------
367
+ export_logit = {
368
+ "model": bundle["model_name"],
369
+ "num_layers": int(logits_by_layer.shape[0]),
370
+ "num_classes": int(logits_by_layer.shape[-1]),
371
+ "class_names": bundle["label_names"],
372
+ "checkpoint_best_val_auc": bundle["head_ckpt"].get("best_val_auc", None),
373
+ "final_probs": probs_final.tolist(),
374
+ "logits": [],
375
+ "probs_by_layer": [],
376
+ }
377
+
378
+ for l in range(logits_by_layer.shape[0]):
379
+ v_logits = round_tensor(logits_by_layer[l, 0], 4)
380
+ v_probs = round_tensor(probs_by_layer[l, 0], 6)
381
+ export_logit["logits"].append(v_logits.tolist())
382
+ export_logit["probs_by_layer"].append(v_probs.tolist())
383
+
384
+ return {
385
+ "attention_cls_full": export_attn,
386
+ "logit_lens_full": export_logit,
387
+ }
388
+
389
+ finally:
390
+ bundle["_tok_var"].reset(tok_token)
391
+ bundle["_attn_in_var"].reset(attn_token)
392
+ layer_tokens.clear()
393
+ attn_inputs.clear()
394
+
395
+
396
+ # ============================================================
397
+ # FastAPI app
398
+ # ============================================================
399
+ app = FastAPI(title="ChestMNIST Explainer API (RAD-DINO + DINO)", version="2.0")
400
+
401
+ app.add_middleware(
402
+ CORSMiddleware,
403
+ allow_origins=ALLOW_ORIGINS if ALLOW_ORIGINS != ["*"] else ["*"],
404
+ allow_credentials=True,
405
+ allow_methods=["*"],
406
+ allow_headers=["*"],
407
+ )
408
+
409
+
410
+ def _no_store(resp: JSONResponse) -> JSONResponse:
411
+ resp.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, max-age=0"
412
+ resp.headers["Pragma"] = "no-cache"
413
+ return resp
414
+
415
+
416
+ def get_model_bundle(model_key: str) -> Dict[str, Any]:
417
+ if model_key not in MODELS:
418
+ raise HTTPException(
419
+ status_code=404,
420
+ detail=f"Unknown model_key '{model_key}'. Available: {list(MODELS.keys())}"
421
+ )
422
+ return MODELS[model_key]
423
+
424
+
425
+ # ============================================================
426
+ # Root / health
427
+ # ============================================================
428
+ @app.get("/")
429
+ def root():
430
+ return {
431
+ "status": "ok",
432
+ "device": DEVICE,
433
+ "available_models": list(MODELS.keys()),
434
+ "image_size": MODEL_IMG_SIZE,
435
+ }
436
+
437
+
438
+ @app.get("/health")
439
+ def health():
440
+ return {
441
+ "status": "ok",
442
+ "device": DEVICE,
443
+ "available_models": list(MODELS.keys()),
444
+ "models": {
445
+ key: {
446
+ "model": bundle["model_name"],
447
+ "num_layers": bundle["num_layers"],
448
+ "num_heads": bundle["num_heads"],
449
+ "num_classes": len(bundle["label_names"]),
450
+ "class_names": bundle["label_names"],
451
+ "checkpoint_best_val_auc": bundle["head_ckpt"].get("best_val_auc", None),
452
+ "has_current": bundle["current"]["attention_cls_full"] is not None,
453
+ }
454
+ for key, bundle in MODELS.items()
455
+ }
456
+ }
457
+
458
+
459
+ @app.get("/health/{model_key}")
460
+ def health_model(model_key: str):
461
+ bundle = get_model_bundle(model_key)
462
+ return {
463
+ "status": "ok",
464
+ "device": DEVICE,
465
+ "model_key": model_key,
466
+ "model": bundle["model_name"],
467
+ "image_size": MODEL_IMG_SIZE,
468
+ "num_layers": bundle["num_layers"],
469
+ "num_heads": bundle["num_heads"],
470
+ "num_classes": len(bundle["label_names"]),
471
+ "class_names": bundle["label_names"],
472
+ "checkpoint_best_val_auc": bundle["head_ckpt"].get("best_val_auc", None),
473
+ "has_current": bundle["current"]["attention_cls_full"] is not None,
474
+ }
475
+
476
+
477
+ # ============================================================
478
+ # Legacy analyze with stored jobs
479
+ # ============================================================
480
+ @app.post("/analyze/{model_key}")
481
+ async def analyze(
482
+ model_key: str,
483
+ file: UploadFile = File(...),
484
+ store: int = Query(0, description="1 => guarda resultados y entrega endpoints /results/{model_key}/{id}/..."),
485
+ ):
486
+ bundle = get_model_bundle(model_key)
487
+
488
+ if not file.content_type or not file.content_type.startswith("image/"):
489
+ raise HTTPException(status_code=400, detail="Please upload an image file.")
490
+
491
+ raw = await file.read()
492
+ try:
493
+ img = Image.open(io.BytesIO(raw)).convert("RGB")
494
+ except Exception:
495
+ raise HTTPException(status_code=400, detail="Could not decode image.")
496
+
497
+ try:
498
+ out = analyze_image(bundle, img)
499
+ except Exception as e:
500
+ raise HTTPException(status_code=500, detail=f"Model inference failed: {e}")
501
+
502
+ if store == 1:
503
+ job_id = str(uuid.uuid4())
504
+ bundle["results"][job_id] = out
505
+ return {
506
+ "model_key": model_key,
507
+ "job_id": job_id,
508
+ "endpoints": {
509
+ "attention_cls_full": f"/results/{model_key}/{job_id}/attention_cls_full.json",
510
+ "logit_lens_full": f"/results/{model_key}/{job_id}/logit_lens_full.json",
511
+ }
512
+ }
513
+
514
+ return out
515
+
516
+
517
+ @app.get("/results/{model_key}/{job_id}/attention_cls_full.json")
518
+ def get_attention(model_key: str, job_id: str):
519
+ bundle = get_model_bundle(model_key)
520
+ if job_id not in bundle["results"]:
521
+ raise HTTPException(status_code=404, detail="job_id not found")
522
+ return _no_store(JSONResponse(bundle["results"][job_id]["attention_cls_full"]))
523
+
524
+
525
+ @app.get("/results/{model_key}/{job_id}/logit_lens_full.json")
526
+ def get_logit(model_key: str, job_id: str):
527
+ bundle = get_model_bundle(model_key)
528
+ if job_id not in bundle["results"]:
529
+ raise HTTPException(status_code=404, detail="job_id not found")
530
+ return _no_store(JSONResponse(bundle["results"][job_id]["logit_lens_full"]))
531
+
532
+
533
+ # ============================================================
534
+ # Preferred: current endpoints per model
535
+ # ============================================================
536
+ @app.post("/analyze_current/{model_key}")
537
+ async def analyze_current(model_key: str, file: UploadFile = File(...)):
538
+ bundle = get_model_bundle(model_key)
539
+
540
+ if not file.content_type or not file.content_type.startswith("image/"):
541
+ raise HTTPException(status_code=400, detail="Please upload an image file.")
542
+
543
+ raw = await file.read()
544
+ img_hash = hashlib.sha256(raw).hexdigest()
545
+
546
+ if bundle["current"]["hash"] == img_hash and bundle["current"]["attention_cls_full"] is not None:
547
+ return {"status": "unchanged", "hash": img_hash, "model_key": model_key}
548
+
549
+ try:
550
+ img = Image.open(io.BytesIO(raw)).convert("RGB")
551
+ except Exception:
552
+ raise HTTPException(status_code=400, detail="Could not decode image.")
553
+
554
+ try:
555
+ out = analyze_image(bundle, img)
556
+ except Exception as e:
557
+ raise HTTPException(status_code=500, detail=f"Model inference failed: {e}")
558
+
559
+ bundle["current"]["hash"] = img_hash
560
+ bundle["current"]["attention_cls_full"] = out["attention_cls_full"]
561
+ bundle["current"]["logit_lens_full"] = out["logit_lens_full"]
562
+
563
+ return {"status": "ok", "hash": img_hash, "model_key": model_key}
564
+
565
+
566
+ @app.get("/{model_key}/attention_cls_full.json")
567
+ def current_attention(model_key: str):
568
+ bundle = get_model_bundle(model_key)
569
+
570
+ if bundle["current"]["attention_cls_full"] is None:
571
+ raise HTTPException(
572
+ status_code=404,
573
+ detail=f"No current attention file for '{model_key}'. POST /analyze_current/{model_key} first."
574
+ )
575
+ return _no_store(JSONResponse(bundle["current"]["attention_cls_full"]))
576
+
577
+
578
+ @app.get("/{model_key}/logit_lens_full.json")
579
+ def current_logit(model_key: str):
580
+ bundle = get_model_bundle(model_key)
581
+
582
+ if bundle["current"]["logit_lens_full"] is None:
583
+ raise HTTPException(
584
+ status_code=404,
585
+ detail=f"No current logit file for '{model_key}'. POST /analyze_current/{model_key} first."
586
+ )
587
+ return _no_store(JSONResponse(bundle["current"]["logit_lens_full"]))
588
+
589
+
590
+ # ============================================================
591
+ # Optional backward-compatible aliases for RAD-DINO
592
+ # ============================================================
593
+ @app.post("/analyze_current")
594
+ async def analyze_current_rad_default(file: UploadFile = File(...)):
595
+ return await analyze_current("rad-dino", file)
596
+
597
+
598
+ @app.get("/attention_cls_full.json")
599
+ def current_attention_rad_default():
600
+ return current_attention("rad-dino")
601
+
602
+
603
+ @app.get("/logit_lens_full.json")
604
+ def current_logit_rad_default():
605
+ return current_logit("rad-dino")
606
+
607
+
608
+ @app.post("/analyze")
609
+ async def analyze_rad_default(
610
+ file: UploadFile = File(...),
611
+ store: int = Query(0, description="1 => guarda resultados"),
612
+ ):
613
+ return await analyze("rad-dino", file, store)
614
+
615
+
616
+ # ============================================================
617
+ # Smoke test
618
+ # ============================================================
619
+ def smoke_test_local_image(image_path: str, model_key: str = "rad-dino"):
620
+ if not os.path.exists(image_path):
621
+ raise FileNotFoundError(f"No existe la imagen: {image_path}")
622
+
623
+ bundle = get_model_bundle(model_key)
624
+ img = Image.open(image_path).convert("RGB")
625
+ out = analyze_image(bundle, img)
626
+
627
+ print(f"\n[smoke test] model_key={model_key} OK")
628
+ print("[smoke test] capas:", out["attention_cls_full"]["num_layers"])
629
+ print("[smoke test] heads:", out["attention_cls_full"]["num_heads"])
630
+ print("[smoke test] patch tokens:", out["attention_cls_full"]["num_patch_tokens"])
631
+
632
+ final_probs = out["logit_lens_full"]["final_probs"]
633
+ pairs = sorted(zip(bundle["label_names"], final_probs), key=lambda t: t[1], reverse=True)
634
+
635
+ print("\nTop-5 predicciones:")
636
+ for name, p in pairs[:5]:
637
+ print(f" {name:<15} {p:.4f}")
638
+
639
+
640
+ if __name__ == "__main__":
641
+ test_path = os.environ.get("TEST_IMAGE_PATH", "").strip()
642
+ test_model = os.environ.get("TEST_MODEL_KEY", "rad-dino").strip()
643
+
644
+ if test_path:
645
+ smoke_test_local_image(test_path, test_model)
646
+
647
+ import uvicorn
648
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)
dino_chestmnist_head.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12e5409624f3333ac9cc00402d0d84d2a58bfd62aac0462dc3fe33bbda14c1fa
3
+ size 133769
rad_dino_chestmnist_head.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b570912198f3cdf00fe2b01b7cae88cdc13539245173e7a13f665f4a241f5873
3
+ size 45613
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cpu
2
+
3
+ fastapi==0.115.6
4
+ uvicorn[standard]==0.30.6
5
+ python-multipart==0.0.9
6
+
7
+ torch==2.5.1
8
+ torchvision==0.20.1
9
+ timm==0.9.16
10
+
11
+ pillow==10.4.0
12
+ numpy==2.1.3
13
+ huggingface-hub==0.24.7
14
+ safetensors==0.4.5
test.png ADDED