Spaces:
Running
Running
| from __future__ import annotations | |
| from typing import List, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import timm | |
| from timm.data import resolve_model_data_config, create_transform | |
| from contextlib import nullcontext | |
| from .utils import load_tag_names | |
| class EVAHeadPreserving: | |
| """ | |
| Head-preserving inference for EVA-02 backbones (Animetimm / WD-EVA02). | |
| Interface: encode / logits / prob / tags_prob / top_tags | |
| """ | |
| def __init__(self, | |
| repo_id: str, | |
| head_path: str, | |
| categories: List[str], | |
| tag_csv: str = "selected_tags.csv"): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.torch_dtype = torch.float16 if self.device == "cuda" else torch.float32 | |
| self.categories = list(categories) | |
| self.tag_csv = tag_csv | |
| self.backbone = timm.create_model(f"hf-hub:{repo_id}", pretrained=True) | |
| self.backbone = self.backbone.to(self.device).eval().requires_grad_(False) | |
| cfg = resolve_model_data_config(self.backbone) | |
| self.preprocess = create_transform(**cfg) | |
| with torch.no_grad(): | |
| in_size = cfg.get("input_size", (3, 448, 448)) | |
| h, w = int(in_size[-2]), int(in_size[-1]) | |
| dummy = torch.zeros(1, 3, h, w, device=self.device) | |
| fx = self.backbone.forward_features(dummy) | |
| pre = self.backbone.forward_head(fx, pre_logits=True) | |
| tags_log = self.backbone.forward_head(fx, pre_logits=False) | |
| D, T = int(pre.shape[-1]), int(tags_log.shape[-1]) | |
| self.custom_head = nn.Linear(D, len(self.categories)).to(self.device).eval().requires_grad_(False) | |
| ckpt = torch.load(head_path, map_location=self.device, weights_only=True) | |
| state = ckpt.get("state_dict", ckpt) | |
| w = state["head.weight"].to(self.device).float() | |
| b = state["head.bias"].to(self.device).float() | |
| if w.shape != self.custom_head.weight.shape and w.t().shape == self.custom_head.weight.shape: | |
| w = w.t() | |
| with torch.no_grad(): | |
| self.custom_head.weight.copy_(w) | |
| self.custom_head.bias.copy_(b) | |
| self.use_amp = True | |
| self.tag_names = load_tag_names(T, self.tag_csv) | |
| self.use_amp = False | |
| if self.device == "cuda": | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| torch.backends.cudnn.benchmark = True | |
| def encode(self, pil_list: List) -> Tuple[torch.Tensor, torch.Tensor]: | |
| x = torch.stack([self.preprocess(im.convert("RGB")) for im in pil_list], 0) | |
| x = x.to(self.device, non_blocking=True, memory_format=torch.channels_last) | |
| ctx = torch.amp.autocast("cuda", dtype=self.torch_dtype) if self.use_amp else nullcontext() | |
| with ctx: | |
| fx = self.backbone.forward_features(x) | |
| pre = self.backbone.forward_head(fx, pre_logits=True) | |
| feat = F.normalize(pre, dim=1) | |
| tags_log = self.backbone.forward_head(fx, pre_logits=False) | |
| return feat.float(), tags_log.float() | |
| def logits(self, pil_list: List) -> torch.Tensor: | |
| feat_norm, _ = self.encode(pil_list) | |
| return self.custom_head(feat_norm) | |
| def prob(self, pil_list: List) -> torch.Tensor: | |
| z = torch.clamp(self.logits(pil_list), -20, 20) | |
| return torch.sigmoid(z) | |
| def tags_prob(self, pil_list: List) -> torch.Tensor: | |
| _, tags_log = self.encode(pil_list) | |
| z = torch.clamp(tags_log, -20, 20) | |
| return torch.sigmoid(z) | |
| def top_tags(self, pil_image, top_k: int = 50): | |
| p = self.tags_prob([pil_image])[0].tolist() | |
| k = max(0, min(top_k, len(p))) | |
| idx = sorted(range(len(p)), key=lambda i: -p[i])[:k] | |
| names = self.tag_names | |
| return [(names[i] if i < len(names) else f"tag_{i:04d}", float(p[i])) for i in idx] | |