Spaces:
Running
Running
import json | |
import math | |
from dataclasses import dataclass, field | |
from os import PathLike, cpu_count | |
from pathlib import Path | |
from typing import Any, Optional, TypeAlias | |
import colorcet as cc | |
import cv2 | |
import numpy as np | |
import pandas as pd | |
import timm | |
import torch | |
from matplotlib.colors import LinearSegmentedColormap | |
from PIL import Image | |
from timm.data import create_transform, resolve_data_config | |
from timm.models import VisionTransformer | |
from torch import Tensor, nn | |
from torch.nn import functional as F | |
from torchvision import transforms as T | |
from .common import Heatmap, ImageLabels, LabelData, load_labels_hf, pil_ensure_rgb, pil_make_grid | |
# working dir, either file parent dir or cwd if interactive | |
work_dir = (Path(__file__).parent if "__file__" in locals() else Path.cwd()).resolve() | |
temp_dir = work_dir.joinpath("temp") | |
temp_dir.mkdir(exist_ok=True, parents=True) | |
# model cache | |
model_cache: dict[str, VisionTransformer] = {} | |
transform_cache: dict[str, T.Compose] = {} | |
# device to use | |
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
class RGBtoBGR(nn.Module): | |
def forward(self, x: Tensor) -> Tensor: | |
if x.ndim == 4: | |
return x[:, [2, 1, 0], :, :] | |
return x[[2, 1, 0], :, :] | |
def model_device(model: nn.Module) -> torch.device: | |
return next(model.parameters()).device | |
def load_model(repo_id: str) -> VisionTransformer: | |
global model_cache | |
if model_cache.get(repo_id, None) is None: | |
# save model to cache | |
model_cache[repo_id] = timm.create_model("hf-hub:" + repo_id, pretrained=True).eval().to(torch_device) | |
return model_cache[repo_id] | |
def load_model_and_transform(repo_id: str) -> tuple[VisionTransformer, T.Compose]: | |
global transform_cache | |
global model_cache | |
if model_cache.get(repo_id, None) is None: | |
# save model to cache | |
model_cache[repo_id] = timm.create_model("hf-hub:" + repo_id, pretrained=True).eval() | |
model = model_cache[repo_id] | |
if transform_cache.get(repo_id, None) is None: | |
transforms = create_transform(**resolve_data_config(model.pretrained_cfg, model=model)) | |
# hack in the RGBtoBGR transform, save to cache | |
transform_cache[repo_id] = T.Compose(transforms.transforms + [RGBtoBGR()]) | |
transform = transform_cache[repo_id] | |
return model, transform | |
def get_tags( | |
probs: Tensor, | |
labels: LabelData, | |
gen_threshold: float, | |
char_threshold: float, | |
): | |
# Convert indices+probs to labels | |
probs = list(zip(labels.names, probs.numpy())) | |
# First 4 labels are actually ratings | |
rating_labels = dict([probs[i] for i in labels.rating]) | |
# General labels, pick any where prediction confidence > threshold | |
gen_labels = [probs[i] for i in labels.general] | |
gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold]) | |
gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True)) | |
# Character labels, pick any where prediction confidence > threshold | |
char_labels = [probs[i] for i in labels.character] | |
char_labels = dict([x for x in char_labels if x[1] > char_threshold]) | |
char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True)) | |
# Combine general and character labels, sort by confidence | |
combined_names = [x for x in gen_labels] | |
combined_names.extend([x for x in char_labels]) | |
# Convert to a string suitable for use as a training caption | |
caption = ", ".join(combined_names).replace("(", "\(").replace(")", "\)") | |
booru = caption.replace("_", " ") | |
return caption, booru, rating_labels, char_labels, gen_labels | |
def render_heatmap( | |
image: Tensor, | |
gradients: Tensor, | |
image_feats: Tensor, | |
image_probs: Tensor, | |
image_labels: list[str], | |
cmap: LinearSegmentedColormap = cc.m_linear_bmy_10_95_c71, | |
pos_embed_dim: int = 784, | |
image_size: tuple[int, int] = (448, 448), | |
font_args: dict = { | |
"fontFace": cv2.FONT_HERSHEY_SIMPLEX, | |
"fontScale": 1, | |
"color": (255, 255, 255), | |
"thickness": 2, | |
"lineType": cv2.LINE_AA, | |
}, | |
partial_rows: bool = True, | |
) -> tuple[list[Heatmap], Image.Image]: | |
hmap_dim = int(math.sqrt(pos_embed_dim)) | |
image_hmaps = gradients.mean(2, keepdim=True).mul(image_feats.unsqueeze(0)).squeeze() | |
image_hmaps = image_hmaps.mean(-1).reshape(len(image_labels), hmap_dim, hmap_dim) | |
image_hmaps = image_hmaps.max(torch.zeros_like(image_hmaps)) | |
image_hmaps /= image_hmaps.reshape(image_hmaps.shape[0], -1).max(-1)[0].unsqueeze(-1).unsqueeze(-1) | |
# normalize to 0-1 | |
image_hmaps = torch.stack([(x - x.min()) / (x.max() - x.min()) for x in image_hmaps]).unsqueeze(1) | |
# interpolate to input image size | |
image_hmaps = F.interpolate(image_hmaps, size=image_size, mode="bilinear").squeeze(1) | |
hmap_imgs: list[Heatmap] = [] | |
for tag, hmap, score in zip(image_labels, image_hmaps, image_probs.cpu()): | |
image_pixels = image.add(1).mul(127.5).squeeze().permute(1, 2, 0).cpu().numpy().astype(np.uint8) | |
hmap_pixels = cmap(hmap.cpu().numpy(), bytes=True)[:, :, :3] | |
hmap_cv2 = cv2.cvtColor(hmap_pixels, cv2.COLOR_RGB2BGR) | |
hmap_image = cv2.addWeighted(image_pixels, 0.5, hmap_cv2, 0.5, 0) | |
if tag is not None: | |
cv2.putText(hmap_image, tag, (10, 30), **font_args) | |
cv2.putText(hmap_image, f"{score:.3f}", org=(10, 60), **font_args) | |
hmap_pil = Image.fromarray(cv2.cvtColor(hmap_image, cv2.COLOR_BGR2RGB)) | |
hmap_imgs.append(Heatmap(tag, score.item(), hmap_pil)) | |
hmap_imgs = sorted(hmap_imgs, key=lambda x: x.score, reverse=True) | |
hmap_grid = pil_make_grid([x.image for x in hmap_imgs], partial_rows=partial_rows) | |
return hmap_imgs, hmap_grid | |
def process_heatmap( | |
model: VisionTransformer, | |
image: Tensor, | |
labels: LabelData, | |
threshold: float = 0.5, | |
partial_rows: bool = True, | |
) -> tuple[list[tuple[float, str, Image.Image]], Image.Image, ImageLabels]: | |
torch_device = model_device(model) | |
with torch.set_grad_enabled(True): | |
features = model.forward_features(image.to(torch_device)) | |
probs = model.forward_head(features) | |
probs = F.sigmoid(probs).squeeze(0) | |
probs_mask = probs > threshold | |
heatmap_probs = probs[probs_mask] | |
label_indices = torch.nonzero(probs_mask, as_tuple=False).squeeze(1) | |
image_labels = [labels.names[label_indices[i]] for i in range(len(label_indices))] | |
eye = torch.eye(heatmap_probs.shape[0], device=torch_device) | |
grads = torch.autograd.grad( | |
outputs=heatmap_probs, | |
inputs=features, | |
grad_outputs=eye, | |
is_grads_batched=True, | |
retain_graph=True, | |
) | |
grads = grads[0].detach().requires_grad_(False)[:, 0, :, :].unsqueeze(1) | |
with torch.set_grad_enabled(False): | |
hmap_imgs, hmap_grid = render_heatmap( | |
image=image, | |
gradients=grads, | |
image_feats=features, | |
image_probs=heatmap_probs, | |
image_labels=image_labels, | |
partial_rows=partial_rows, | |
) | |
caption, booru, ratings, character, general = get_tags( | |
probs=probs.cpu(), | |
labels=labels, | |
gen_threshold=threshold, | |
char_threshold=threshold, | |
) | |
labels = ImageLabels(caption, booru, ratings, general, character) | |
return hmap_imgs, hmap_grid, labels | |