ViT-Explainer / app.py
JuanHernandez-uc
fix app.py
6ed0dab
# app.py
import io
import uuid
import json
import threading
import hashlib
from contextvars import ContextVar
from typing import Optional, Dict, Any
import torch
import torch.nn.functional as F
import timm
from PIL import Image
from fastapi import FastAPI, UploadFile, File, Query, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from timm.layers.pos_embed import resample_abs_pos_embed
try:
from timm.layers.patch_embed import resample_patch_embed
except Exception:
resample_patch_embed = None
# -----------------------
# Config
# -----------------------
MODEL_NAME = "flexivit_large.300ep_in1k"
TARGET_IMG = 96
TARGET_PATCH = 32
NEW_GRID = (TARGET_IMG // TARGET_PATCH, TARGET_IMG // TARGET_PATCH) # (3,3)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ImageNet normalization
IMNET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
IMNET_STD = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
# -----------------------
# Load labels (local file recommended)
# -----------------------
def load_imagenet_labels(path="imagenet_classes.txt"):
try:
with open(path, "r", encoding="utf-8") as f:
return [line.strip() for line in f.readlines() if line.strip()]
except FileNotFoundError:
# If missing, still works but without names.
return None
IMAGENET_LABELS = load_imagenet_labels()
# -----------------------
# Build & adapt model once
# -----------------------
def adapt_flexivit_to_3x3(model: torch.nn.Module):
# --- Resize patch embedding conv weight ---
with torch.no_grad():
proj = model.patch_embed.proj
w = proj.weight.detach().cpu() # [embed_dim, in_chans, old_ps, old_ps]
b = proj.bias.detach().cpu() if proj.bias is not None else None
old_ps = w.shape[-1]
if old_ps != TARGET_PATCH:
if resample_patch_embed is not None:
w2 = resample_patch_embed(w, (TARGET_PATCH, TARGET_PATCH))
else:
ed, ic, _, _ = w.shape
w_ = w.reshape(ed * ic, 1, old_ps, old_ps)
w_ = F.interpolate(w_, size=(TARGET_PATCH, TARGET_PATCH), mode="bicubic", align_corners=False)
w2 = w_.reshape(ed, ic, TARGET_PATCH, TARGET_PATCH)
else:
w2 = w
embed_dim, in_chans, _, _ = w2.shape
new_proj = torch.nn.Conv2d(
in_channels=in_chans,
out_channels=embed_dim,
kernel_size=TARGET_PATCH,
stride=TARGET_PATCH,
padding=0,
bias=(b is not None),
)
new_proj.weight.copy_(w2)
if b is not None:
new_proj.bias.copy_(b)
model.patch_embed.proj = new_proj.to(DEVICE)
# Update patch embed metadata if present
if hasattr(model.patch_embed, "patch_size"):
model.patch_embed.patch_size = (TARGET_PATCH, TARGET_PATCH)
if hasattr(model.patch_embed, "img_size"):
model.patch_embed.img_size = (TARGET_IMG, TARGET_IMG)
if hasattr(model.patch_embed, "grid_size"):
model.patch_embed.grid_size = NEW_GRID
if hasattr(model.patch_embed, "num_patches"):
model.patch_embed.num_patches = NEW_GRID[0] * NEW_GRID[1]
# --- Resize absolute positional embeddings to 3x3 ---
if hasattr(model, "pos_embed") and model.pos_embed is not None:
with torch.no_grad():
pe = model.pos_embed.detach()
# infer prefix tokens (cls, dist, etc.)
prefix = int(getattr(model, "num_prefix_tokens", 0))
if prefix == 0 and hasattr(model, "cls_token") and model.cls_token is not None:
prefix = 1
# infer old grid
old_grid = None
if hasattr(model, "patch_embed") and hasattr(model.patch_embed, "grid_size"):
old_grid = tuple(model.patch_embed.grid_size)
if old_grid is not None:
grid_tokens = old_grid[0] * old_grid[1]
if pe.shape[1] == grid_tokens:
prefix = 0
elif pe.shape[1] == grid_tokens + prefix:
pass
else:
prefix = 0
old_grid = None
if old_grid is None:
n = pe.shape[1] - prefix
g = int(n ** 0.5)
old_grid = (g, g)
pe2 = resample_abs_pos_embed(
pe,
new_size=NEW_GRID,
old_size=old_grid,
num_prefix_tokens=prefix,
interpolation="bicubic",
antialias=True,
)
model.pos_embed = torch.nn.Parameter(pe2)
return model
def build_model():
model = timm.create_model(MODEL_NAME, pretrained=True).to(DEVICE).eval()
# (Recommended) disable fused attention if present (helps hooks)
for blk in model.blocks:
if hasattr(blk.attn, "fused_attn"):
blk.attn.fused_attn = False
model = adapt_flexivit_to_3x3(model)
return model
MODEL = build_model()
print(f"[server] model={MODEL_NAME} device={DEVICE} grid={NEW_GRID}")
# -----------------------
# Hooks using ContextVar (safe-ish for concurrent requests)
# -----------------------
_attn_var: ContextVar[Optional[list]] = ContextVar("_attn_var", default=None)
_tok_var: ContextVar[Optional[list]] = ContextVar("_tok_var", default=None)
def _save_attn_drop_input(module, inp, out):
lst = _attn_var.get()
if lst is not None and len(inp) > 0 and torch.is_tensor(inp[0]):
# inp[0]: [B, H, N, N]
lst.append(inp[0].detach().cpu())
def _save_block_out(module, inp, out):
lst = _tok_var.get()
if lst is not None and torch.is_tensor(out):
# out: [B, N, D]
lst.append(out.detach())
# Register hooks once
ATTN_HOOKS = []
TOK_HOOKS = []
for blk in MODEL.blocks:
ATTN_HOOKS.append(blk.attn.attn_drop.register_forward_hook(_save_attn_drop_input))
TOK_HOOKS.append(blk.register_forward_hook(_save_block_out))
# -----------------------
# Preprocess
# -----------------------
def preprocess(pil_img: Image.Image) -> torch.Tensor:
img = pil_img.convert("RGB")
w, h = img.size
s = min(w, h)
left = (w - s) // 2
top = (h - s) // 2
img = img.crop((left, top, left + s, top + s)).resize((TARGET_IMG, TARGET_IMG), Image.BICUBIC)
x = torch.from_numpy(
(torch.ByteTensor(torch.ByteStorage.from_buffer(img.tobytes()))
.view(TARGET_IMG, TARGET_IMG, 3).numpy()).astype("float32") / 255.0
)
x = x.permute(2, 0, 1) # CHW
x = (x - IMNET_MEAN) / IMNET_STD
return x.unsqueeze(0) # [1,3,H,W]
# -----------------------
# Compute logit lens + attention export
# -----------------------
def compute_logit_lens_from_tokens(tokens_per_layer, model):
logits_list = []
probs_list = []
with torch.no_grad():
for x_l in tokens_per_layer:
x_ln = model.norm(x_l) if hasattr(model, "norm") and model.norm is not None else x_l
cls_l = x_ln[:, 0] # CLS token
logits_l = model.head(cls_l)
logits_list.append(logits_l.detach().cpu())
probs_list.append(torch.softmax(logits_l, dim=-1).detach().cpu())
logits_per_layer = torch.stack(logits_list, dim=0) # [L,B,C]
probs_per_layer = torch.stack(probs_list, dim=0)
return logits_per_layer, probs_per_layer
def round_tensor(t: torch.Tensor, decimals: int):
s = 10 ** decimals
return torch.round(t * s) / s
MODEL_LOCK = threading.Lock()
def analyze_image(pil_img: Image.Image) -> Dict[str, Any]:
x = preprocess(pil_img).to(DEVICE)
# Per-request storage
attn_maps = []
layer_tokens = []
tok_token = _tok_var.set(layer_tokens)
attn_token = _attn_var.set(attn_maps)
try:
with torch.no_grad():
# Lock recommended if you run multiple workers/threads with GPU,
# and because we use shared model + hooks
with MODEL_LOCK:
logits_final = MODEL(x)
# Final probs
probs_final = torch.softmax(logits_final, dim=-1)[0].detach().cpu()
probs_final = round_tensor(probs_final, 6)
# Logit lens
logits_by_layer, probs_by_layer = compute_logit_lens_from_tokens(layer_tokens, MODEL)
# Export logit lens json
export_logit = {
"model": MODEL_NAME,
"grid": [NEW_GRID[0], NEW_GRID[1]],
"num_layers": int(logits_by_layer.shape[0]),
"num_classes": int(logits_by_layer.shape[-1]),
"class_names": IMAGENET_LABELS,
"logits": [],
"final_probs": probs_final.tolist()
}
for l in range(logits_by_layer.shape[0]):
v = logits_by_layer[l, 0] # [C]
v = round_tensor(v, 3)
export_logit["logits"].append(v.tolist())
# Attention json
# attn_maps is list length L, each: [B,H,N,N] CPU
attn_maps2 = [a.squeeze(0) for a in attn_maps] # -> [H,N,N]
if len(attn_maps2) == 0:
raise RuntimeError("No attention captured. (Hook may not match this timm model/config)")
attn_serializable = []
for layer in attn_maps2:
layer_data = []
for head in layer:
head = round_tensor(head, 4)
layer_data.append(head.tolist())
attn_serializable.append(layer_data)
export_attn = {
"num_layers": len(attn_serializable),
"num_heads": len(attn_serializable[0]),
"num_tokens": len(attn_serializable[0][0]),
"grid": [NEW_GRID[0], NEW_GRID[1]],
"attention": attn_serializable
}
return {
"logit_lens_full": export_logit,
"attention_full": export_attn
}
finally:
_tok_var.reset(tok_token)
_attn_var.reset(attn_token)
layer_tokens.clear()
attn_maps.clear()
# -----------------------
# FastAPI app
# -----------------------
app = FastAPI(title="ViT Explainer API", version="1.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # tighten in prod
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# In-memory store for "file-like endpoints" (job-based)
RESULTS: Dict[str, Dict[str, Any]] = {}
# In-memory store for "current files" (no-regenerate on GET)
CURRENT: Dict[str, Any] = {
"hash": None,
"attention_full": None,
"logit_lens_full": None,
}
def _no_store(resp: JSONResponse) -> JSONResponse:
resp.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, max-age=0"
resp.headers["Pragma"] = "no-cache"
return resp
@app.get("/health")
def health():
return {
"status": "ok",
"model": MODEL_NAME,
"device": DEVICE,
"grid": list(NEW_GRID),
"has_current": CURRENT["attention_full"] is not None,
}
# -----------------------
# Legacy: returns JSON directly OR job endpoints
# -----------------------
@app.post("/analyze")
async def analyze(
file: UploadFile = File(...),
store: int = Query(0, description="1 => guarda resultados y entrega endpoints /results/{id}/..."),
):
if not file.content_type or not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="Please upload an image file.")
raw = await file.read()
try:
img = Image.open(io.BytesIO(raw)).convert("RGB")
except Exception:
raise HTTPException(status_code=400, detail="Could not decode image.")
try:
out = analyze_image(img)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Model inference failed: {e}")
if store == 1:
job_id = str(uuid.uuid4())
RESULTS[job_id] = out
return {
"job_id": job_id,
"endpoints": {
"attention_full": f"/results/{job_id}/attention_full.json",
"logit_lens_full": f"/results/{job_id}/logit_lens_full.json",
}
}
return out
@app.get("/results/{job_id}/attention_full.json")
def get_attention(job_id: str):
if job_id not in RESULTS:
raise HTTPException(status_code=404, detail="job_id not found")
return _no_store(JSONResponse(RESULTS[job_id]["attention_full"]))
@app.get("/results/{job_id}/logit_lens_full.json")
def get_logit(job_id: str):
if job_id not in RESULTS:
raise HTTPException(status_code=404, detail="job_id not found")
return _no_store(JSONResponse(RESULTS[job_id]["logit_lens_full"]))
# -----------------------
# Preferred: "current files" endpoints (keep frontend fetch paths stable)
# - POST /analyze_current only when image changes
# - GET /attention_full.json and /logit_lens_full.json are just readers
# -----------------------
@app.post("/analyze_current")
async def analyze_current(file: UploadFile = File(...)):
if not file.content_type or not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="Please upload an image file.")
raw = await file.read()
img_hash = hashlib.sha256(raw).hexdigest()
# ✅ no regenerate if same image already processed
if CURRENT["hash"] == img_hash and CURRENT["attention_full"] is not None:
return {"status": "unchanged", "hash": img_hash}
try:
img = Image.open(io.BytesIO(raw)).convert("RGB")
except Exception:
raise HTTPException(status_code=400, detail="Could not decode image.")
try:
out = analyze_image(img)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Model inference failed: {e}")
CURRENT["hash"] = img_hash
CURRENT["attention_full"] = out["attention_full"]
CURRENT["logit_lens_full"] = out["logit_lens_full"]
return {"status": "updated", "hash": img_hash}
@app.get("/attention_full.json")
def attention_full_current():
if CURRENT["attention_full"] is None:
raise HTTPException(status_code=404, detail="No attention computed yet. Call POST /analyze_current first.")
return _no_store(JSONResponse(CURRENT["attention_full"]))
@app.get("/logit_lens_full.json")
def logit_lens_current():
if CURRENT["logit_lens_full"] is None:
raise HTTPException(status_code=404, detail="No logit lens computed yet. Call POST /analyze_current first.")
return _no_store(JSONResponse(CURRENT["logit_lens_full"]))