content_moderation_demo / src /models /eva_headpreserving.py
onuruls
fix
70e1e32
raw
history blame
4.18 kB
from __future__ import annotations
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from timm.data import resolve_model_data_config, create_transform
from contextlib import nullcontext
from .utils import load_tag_names
class EVAHeadPreserving:
"""
Head-preserving inference for EVA-02 backbones (Animetimm / WD-EVA02).
Interface: encode / logits / prob / tags_prob / top_tags
"""
def __init__(self,
repo_id: str,
head_path: str,
categories: List[str],
tag_csv: str = "selected_tags.csv"):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
self.categories = list(categories)
self.tag_csv = tag_csv
self.backbone = timm.create_model(f"hf-hub:{repo_id}", pretrained=True)
self.backbone = self.backbone.to(self.device).eval().requires_grad_(False)
cfg = resolve_model_data_config(self.backbone)
self.preprocess = create_transform(**cfg)
with torch.no_grad():
in_size = cfg.get("input_size", (3, 448, 448))
h, w = int(in_size[-2]), int(in_size[-1])
dummy = torch.zeros(1, 3, h, w, device=self.device)
fx = self.backbone.forward_features(dummy)
pre = self.backbone.forward_head(fx, pre_logits=True)
tags_log = self.backbone.forward_head(fx, pre_logits=False)
D, T = int(pre.shape[-1]), int(tags_log.shape[-1])
self.custom_head = nn.Linear(D, len(self.categories)).to(self.device).eval().requires_grad_(False)
ckpt = torch.load(head_path, map_location=self.device, weights_only=True)
state = ckpt.get("state_dict", ckpt)
w = state["head.weight"].to(self.device).float()
b = state["head.bias"].to(self.device).float()
if w.shape != self.custom_head.weight.shape and w.t().shape == self.custom_head.weight.shape:
w = w.t()
with torch.no_grad():
self.custom_head.weight.copy_(w)
self.custom_head.bias.copy_(b)
self.use_amp = True
self.tag_names = load_tag_names(T, self.tag_csv)
self.use_amp = False
if self.device == "cuda":
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
@torch.inference_mode()
def encode(self, pil_list: List) -> Tuple[torch.Tensor, torch.Tensor]:
x = torch.stack([self.preprocess(im.convert("RGB")) for im in pil_list], 0)
x = x.to(self.device, non_blocking=True, memory_format=torch.channels_last)
ctx = torch.amp.autocast("cuda", dtype=self.torch_dtype) if self.use_amp else nullcontext()
with ctx:
fx = self.backbone.forward_features(x)
pre = self.backbone.forward_head(fx, pre_logits=True)
feat = F.normalize(pre, dim=1)
tags_log = self.backbone.forward_head(fx, pre_logits=False)
return feat.float(), tags_log.float()
@torch.inference_mode()
def logits(self, pil_list: List) -> torch.Tensor:
feat_norm, _ = self.encode(pil_list)
return self.custom_head(feat_norm)
@torch.inference_mode()
def prob(self, pil_list: List) -> torch.Tensor:
z = torch.clamp(self.logits(pil_list), -20, 20)
return torch.sigmoid(z)
@torch.inference_mode()
def tags_prob(self, pil_list: List) -> torch.Tensor:
_, tags_log = self.encode(pil_list)
z = torch.clamp(tags_log, -20, 20)
return torch.sigmoid(z)
@torch.inference_mode()
def top_tags(self, pil_image, top_k: int = 50):
p = self.tags_prob([pil_image])[0].tolist()
k = max(0, min(top_k, len(p)))
idx = sorted(range(len(p)), key=lambda i: -p[i])[:k]
names = self.tag_names
return [(names[i] if i < len(names) else f"tag_{i:04d}", float(p[i])) for i in idx]