Spaces:
Running
Running
| # app.py | |
| import io | |
| import uuid | |
| import json | |
| import threading | |
| import hashlib | |
| from contextvars import ContextVar | |
| from typing import Optional, Dict, Any | |
| import torch | |
| import torch.nn.functional as F | |
| import timm | |
| from PIL import Image | |
| from fastapi import FastAPI, UploadFile, File, Query, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from timm.layers.pos_embed import resample_abs_pos_embed | |
| try: | |
| from timm.layers.patch_embed import resample_patch_embed | |
| except Exception: | |
| resample_patch_embed = None | |
| # ----------------------- | |
| # Config | |
| # ----------------------- | |
| MODEL_NAME = "flexivit_large.300ep_in1k" | |
| TARGET_IMG = 96 | |
| TARGET_PATCH = 32 | |
| NEW_GRID = (TARGET_IMG // TARGET_PATCH, TARGET_IMG // TARGET_PATCH) # (3,3) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ImageNet normalization | |
| IMNET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) | |
| IMNET_STD = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) | |
| # ----------------------- | |
| # Load labels (local file recommended) | |
| # ----------------------- | |
| def load_imagenet_labels(path="imagenet_classes.txt"): | |
| try: | |
| with open(path, "r", encoding="utf-8") as f: | |
| return [line.strip() for line in f.readlines() if line.strip()] | |
| except FileNotFoundError: | |
| # If missing, still works but without names. | |
| return None | |
| IMAGENET_LABELS = load_imagenet_labels() | |
| # ----------------------- | |
| # Build & adapt model once | |
| # ----------------------- | |
| def adapt_flexivit_to_3x3(model: torch.nn.Module): | |
| # --- Resize patch embedding conv weight --- | |
| with torch.no_grad(): | |
| proj = model.patch_embed.proj | |
| w = proj.weight.detach().cpu() # [embed_dim, in_chans, old_ps, old_ps] | |
| b = proj.bias.detach().cpu() if proj.bias is not None else None | |
| old_ps = w.shape[-1] | |
| if old_ps != TARGET_PATCH: | |
| if resample_patch_embed is not None: | |
| w2 = resample_patch_embed(w, (TARGET_PATCH, TARGET_PATCH)) | |
| else: | |
| ed, ic, _, _ = w.shape | |
| w_ = w.reshape(ed * ic, 1, old_ps, old_ps) | |
| w_ = F.interpolate(w_, size=(TARGET_PATCH, TARGET_PATCH), mode="bicubic", align_corners=False) | |
| w2 = w_.reshape(ed, ic, TARGET_PATCH, TARGET_PATCH) | |
| else: | |
| w2 = w | |
| embed_dim, in_chans, _, _ = w2.shape | |
| new_proj = torch.nn.Conv2d( | |
| in_channels=in_chans, | |
| out_channels=embed_dim, | |
| kernel_size=TARGET_PATCH, | |
| stride=TARGET_PATCH, | |
| padding=0, | |
| bias=(b is not None), | |
| ) | |
| new_proj.weight.copy_(w2) | |
| if b is not None: | |
| new_proj.bias.copy_(b) | |
| model.patch_embed.proj = new_proj.to(DEVICE) | |
| # Update patch embed metadata if present | |
| if hasattr(model.patch_embed, "patch_size"): | |
| model.patch_embed.patch_size = (TARGET_PATCH, TARGET_PATCH) | |
| if hasattr(model.patch_embed, "img_size"): | |
| model.patch_embed.img_size = (TARGET_IMG, TARGET_IMG) | |
| if hasattr(model.patch_embed, "grid_size"): | |
| model.patch_embed.grid_size = NEW_GRID | |
| if hasattr(model.patch_embed, "num_patches"): | |
| model.patch_embed.num_patches = NEW_GRID[0] * NEW_GRID[1] | |
| # --- Resize absolute positional embeddings to 3x3 --- | |
| if hasattr(model, "pos_embed") and model.pos_embed is not None: | |
| with torch.no_grad(): | |
| pe = model.pos_embed.detach() | |
| # infer prefix tokens (cls, dist, etc.) | |
| prefix = int(getattr(model, "num_prefix_tokens", 0)) | |
| if prefix == 0 and hasattr(model, "cls_token") and model.cls_token is not None: | |
| prefix = 1 | |
| # infer old grid | |
| old_grid = None | |
| if hasattr(model, "patch_embed") and hasattr(model.patch_embed, "grid_size"): | |
| old_grid = tuple(model.patch_embed.grid_size) | |
| if old_grid is not None: | |
| grid_tokens = old_grid[0] * old_grid[1] | |
| if pe.shape[1] == grid_tokens: | |
| prefix = 0 | |
| elif pe.shape[1] == grid_tokens + prefix: | |
| pass | |
| else: | |
| prefix = 0 | |
| old_grid = None | |
| if old_grid is None: | |
| n = pe.shape[1] - prefix | |
| g = int(n ** 0.5) | |
| old_grid = (g, g) | |
| pe2 = resample_abs_pos_embed( | |
| pe, | |
| new_size=NEW_GRID, | |
| old_size=old_grid, | |
| num_prefix_tokens=prefix, | |
| interpolation="bicubic", | |
| antialias=True, | |
| ) | |
| model.pos_embed = torch.nn.Parameter(pe2) | |
| return model | |
| def build_model(): | |
| model = timm.create_model(MODEL_NAME, pretrained=True).to(DEVICE).eval() | |
| # (Recommended) disable fused attention if present (helps hooks) | |
| for blk in model.blocks: | |
| if hasattr(blk.attn, "fused_attn"): | |
| blk.attn.fused_attn = False | |
| model = adapt_flexivit_to_3x3(model) | |
| return model | |
| MODEL = build_model() | |
| print(f"[server] model={MODEL_NAME} device={DEVICE} grid={NEW_GRID}") | |
| # ----------------------- | |
| # Hooks using ContextVar (safe-ish for concurrent requests) | |
| # ----------------------- | |
| _attn_var: ContextVar[Optional[list]] = ContextVar("_attn_var", default=None) | |
| _tok_var: ContextVar[Optional[list]] = ContextVar("_tok_var", default=None) | |
| def _save_attn_drop_input(module, inp, out): | |
| lst = _attn_var.get() | |
| if lst is not None and len(inp) > 0 and torch.is_tensor(inp[0]): | |
| # inp[0]: [B, H, N, N] | |
| lst.append(inp[0].detach().cpu()) | |
| def _save_block_out(module, inp, out): | |
| lst = _tok_var.get() | |
| if lst is not None and torch.is_tensor(out): | |
| # out: [B, N, D] | |
| lst.append(out.detach()) | |
| # Register hooks once | |
| ATTN_HOOKS = [] | |
| TOK_HOOKS = [] | |
| for blk in MODEL.blocks: | |
| ATTN_HOOKS.append(blk.attn.attn_drop.register_forward_hook(_save_attn_drop_input)) | |
| TOK_HOOKS.append(blk.register_forward_hook(_save_block_out)) | |
| # ----------------------- | |
| # Preprocess | |
| # ----------------------- | |
| def preprocess(pil_img: Image.Image) -> torch.Tensor: | |
| img = pil_img.convert("RGB") | |
| w, h = img.size | |
| s = min(w, h) | |
| left = (w - s) // 2 | |
| top = (h - s) // 2 | |
| img = img.crop((left, top, left + s, top + s)).resize((TARGET_IMG, TARGET_IMG), Image.BICUBIC) | |
| x = torch.from_numpy( | |
| (torch.ByteTensor(torch.ByteStorage.from_buffer(img.tobytes())) | |
| .view(TARGET_IMG, TARGET_IMG, 3).numpy()).astype("float32") / 255.0 | |
| ) | |
| x = x.permute(2, 0, 1) # CHW | |
| x = (x - IMNET_MEAN) / IMNET_STD | |
| return x.unsqueeze(0) # [1,3,H,W] | |
| # ----------------------- | |
| # Compute logit lens + attention export | |
| # ----------------------- | |
| def compute_logit_lens_from_tokens(tokens_per_layer, model): | |
| logits_list = [] | |
| probs_list = [] | |
| with torch.no_grad(): | |
| for x_l in tokens_per_layer: | |
| x_ln = model.norm(x_l) if hasattr(model, "norm") and model.norm is not None else x_l | |
| cls_l = x_ln[:, 0] # CLS token | |
| logits_l = model.head(cls_l) | |
| logits_list.append(logits_l.detach().cpu()) | |
| probs_list.append(torch.softmax(logits_l, dim=-1).detach().cpu()) | |
| logits_per_layer = torch.stack(logits_list, dim=0) # [L,B,C] | |
| probs_per_layer = torch.stack(probs_list, dim=0) | |
| return logits_per_layer, probs_per_layer | |
| def round_tensor(t: torch.Tensor, decimals: int): | |
| s = 10 ** decimals | |
| return torch.round(t * s) / s | |
| MODEL_LOCK = threading.Lock() | |
| def analyze_image(pil_img: Image.Image) -> Dict[str, Any]: | |
| x = preprocess(pil_img).to(DEVICE) | |
| # Per-request storage | |
| attn_maps = [] | |
| layer_tokens = [] | |
| tok_token = _tok_var.set(layer_tokens) | |
| attn_token = _attn_var.set(attn_maps) | |
| try: | |
| with torch.no_grad(): | |
| # Lock recommended if you run multiple workers/threads with GPU, | |
| # and because we use shared model + hooks | |
| with MODEL_LOCK: | |
| logits_final = MODEL(x) | |
| # Final probs | |
| probs_final = torch.softmax(logits_final, dim=-1)[0].detach().cpu() | |
| probs_final = round_tensor(probs_final, 6) | |
| # Logit lens | |
| logits_by_layer, probs_by_layer = compute_logit_lens_from_tokens(layer_tokens, MODEL) | |
| # Export logit lens json | |
| export_logit = { | |
| "model": MODEL_NAME, | |
| "grid": [NEW_GRID[0], NEW_GRID[1]], | |
| "num_layers": int(logits_by_layer.shape[0]), | |
| "num_classes": int(logits_by_layer.shape[-1]), | |
| "class_names": IMAGENET_LABELS, | |
| "logits": [], | |
| "final_probs": probs_final.tolist() | |
| } | |
| for l in range(logits_by_layer.shape[0]): | |
| v = logits_by_layer[l, 0] # [C] | |
| v = round_tensor(v, 3) | |
| export_logit["logits"].append(v.tolist()) | |
| # Attention json | |
| # attn_maps is list length L, each: [B,H,N,N] CPU | |
| attn_maps2 = [a.squeeze(0) for a in attn_maps] # -> [H,N,N] | |
| if len(attn_maps2) == 0: | |
| raise RuntimeError("No attention captured. (Hook may not match this timm model/config)") | |
| attn_serializable = [] | |
| for layer in attn_maps2: | |
| layer_data = [] | |
| for head in layer: | |
| head = round_tensor(head, 4) | |
| layer_data.append(head.tolist()) | |
| attn_serializable.append(layer_data) | |
| export_attn = { | |
| "num_layers": len(attn_serializable), | |
| "num_heads": len(attn_serializable[0]), | |
| "num_tokens": len(attn_serializable[0][0]), | |
| "grid": [NEW_GRID[0], NEW_GRID[1]], | |
| "attention": attn_serializable | |
| } | |
| return { | |
| "logit_lens_full": export_logit, | |
| "attention_full": export_attn | |
| } | |
| finally: | |
| _tok_var.reset(tok_token) | |
| _attn_var.reset(attn_token) | |
| layer_tokens.clear() | |
| attn_maps.clear() | |
| # ----------------------- | |
| # FastAPI app | |
| # ----------------------- | |
| app = FastAPI(title="ViT Explainer API", version="1.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # tighten in prod | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # In-memory store for "file-like endpoints" (job-based) | |
| RESULTS: Dict[str, Dict[str, Any]] = {} | |
| # In-memory store for "current files" (no-regenerate on GET) | |
| CURRENT: Dict[str, Any] = { | |
| "hash": None, | |
| "attention_full": None, | |
| "logit_lens_full": None, | |
| } | |
| def _no_store(resp: JSONResponse) -> JSONResponse: | |
| resp.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, max-age=0" | |
| resp.headers["Pragma"] = "no-cache" | |
| return resp | |
| def health(): | |
| return { | |
| "status": "ok", | |
| "model": MODEL_NAME, | |
| "device": DEVICE, | |
| "grid": list(NEW_GRID), | |
| "has_current": CURRENT["attention_full"] is not None, | |
| } | |
| # ----------------------- | |
| # Legacy: returns JSON directly OR job endpoints | |
| # ----------------------- | |
| async def analyze( | |
| file: UploadFile = File(...), | |
| store: int = Query(0, description="1 => guarda resultados y entrega endpoints /results/{id}/..."), | |
| ): | |
| if not file.content_type or not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="Please upload an image file.") | |
| raw = await file.read() | |
| try: | |
| img = Image.open(io.BytesIO(raw)).convert("RGB") | |
| except Exception: | |
| raise HTTPException(status_code=400, detail="Could not decode image.") | |
| try: | |
| out = analyze_image(img) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Model inference failed: {e}") | |
| if store == 1: | |
| job_id = str(uuid.uuid4()) | |
| RESULTS[job_id] = out | |
| return { | |
| "job_id": job_id, | |
| "endpoints": { | |
| "attention_full": f"/results/{job_id}/attention_full.json", | |
| "logit_lens_full": f"/results/{job_id}/logit_lens_full.json", | |
| } | |
| } | |
| return out | |
| def get_attention(job_id: str): | |
| if job_id not in RESULTS: | |
| raise HTTPException(status_code=404, detail="job_id not found") | |
| return _no_store(JSONResponse(RESULTS[job_id]["attention_full"])) | |
| def get_logit(job_id: str): | |
| if job_id not in RESULTS: | |
| raise HTTPException(status_code=404, detail="job_id not found") | |
| return _no_store(JSONResponse(RESULTS[job_id]["logit_lens_full"])) | |
| # ----------------------- | |
| # Preferred: "current files" endpoints (keep frontend fetch paths stable) | |
| # - POST /analyze_current only when image changes | |
| # - GET /attention_full.json and /logit_lens_full.json are just readers | |
| # ----------------------- | |
| async def analyze_current(file: UploadFile = File(...)): | |
| if not file.content_type or not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="Please upload an image file.") | |
| raw = await file.read() | |
| img_hash = hashlib.sha256(raw).hexdigest() | |
| # ✅ no regenerate if same image already processed | |
| if CURRENT["hash"] == img_hash and CURRENT["attention_full"] is not None: | |
| return {"status": "unchanged", "hash": img_hash} | |
| try: | |
| img = Image.open(io.BytesIO(raw)).convert("RGB") | |
| except Exception: | |
| raise HTTPException(status_code=400, detail="Could not decode image.") | |
| try: | |
| out = analyze_image(img) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Model inference failed: {e}") | |
| CURRENT["hash"] = img_hash | |
| CURRENT["attention_full"] = out["attention_full"] | |
| CURRENT["logit_lens_full"] = out["logit_lens_full"] | |
| return {"status": "updated", "hash": img_hash} | |
| def attention_full_current(): | |
| if CURRENT["attention_full"] is None: | |
| raise HTTPException(status_code=404, detail="No attention computed yet. Call POST /analyze_current first.") | |
| return _no_store(JSONResponse(CURRENT["attention_full"])) | |
| def logit_lens_current(): | |
| if CURRENT["logit_lens_full"] is None: | |
| raise HTTPException(status_code=404, detail="No logit lens computed yet. Call POST /analyze_current first.") | |
| return _no_store(JSONResponse(CURRENT["logit_lens_full"])) |