neggles's picture
make the thing work
c24a176
raw
history blame
4.99 kB
import math
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Optional
import numpy as np
import pandas as pd
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import HfHubHTTPError
from PIL import Image
from torch import Tensor, nn
@dataclass
class Heatmap:
label: str
score: float
image: Image.Image
@dataclass
class LabelData:
names: list[str]
rating: list[np.int64]
general: list[np.int64]
character: list[np.int64]
@dataclass
class ImageLabels:
caption: str
booru: str
rating: dict[str, float]
general: dict[str, float]
character: dict[str, float]
@lru_cache(maxsize=5)
def load_labels_hf(
repo_id: str,
revision: Optional[str] = None,
token: Optional[str] = None,
) -> LabelData:
try:
csv_path = hf_hub_download(
repo_id=repo_id, filename="selected_tags.csv", revision=revision, token=token
)
csv_path = Path(csv_path).resolve()
except HfHubHTTPError as e:
raise FileNotFoundError(f"selected_tags.csv failed to download from {repo_id}") from e
df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
tag_data = LabelData(
names=df["name"].tolist(),
rating=list(np.where(df["category"] == 9)[0]),
general=list(np.where(df["category"] == 0)[0]),
character=list(np.where(df["category"] == 4)[0]),
)
return tag_data
def mcut_threshold(probs: np.ndarray) -> float:
"""
Maximum Cut Thresholding (MCut)
Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy
for Multi-label Classification. In 11th International Symposium, IDA 2012
(pp. 172-183).
"""
probs = probs[probs.argsort()[::-1]]
diffs = probs[:-1] - probs[1:]
idx = diffs.argmax()
thresh = (probs[idx] + probs[idx + 1]) / 2
return float(thresh)
def pil_ensure_rgb(image: Image.Image) -> Image.Image:
# convert to RGB/RGBA if not already (deals with palette images etc.)
if image.mode not in ["RGB", "RGBA"]:
image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
# convert RGBA to RGB with white background
if image.mode == "RGBA":
canvas = Image.new("RGBA", image.size, (255, 255, 255))
canvas.alpha_composite(image)
image = canvas.convert("RGB")
return image
def pil_pad_square(
image: Image.Image,
fill: tuple[int, int, int] = (255, 255, 255),
) -> Image.Image:
w, h = image.size
# get the largest dimension so we can pad to a square
px = max(image.size)
# pad to square with white background
canvas = Image.new("RGB", (px, px), fill)
canvas.paste(image, ((px - w) // 2, (px - h) // 2))
return canvas
def preprocess_image(
image: Image.Image,
size_px: int | tuple[int, int],
upscale: bool = True,
) -> Image.Image:
"""
Preprocess an image to be square and centered on a white background.
"""
if isinstance(size_px, int):
size_px = (size_px, size_px)
# ensure RGB and pad to square
image = pil_ensure_rgb(image)
image = pil_pad_square(image)
# resize to target size
if image.size[0] < size_px[0] or image.size[1] < size_px[1]:
if upscale is False:
raise ValueError("Image is smaller than target size, and upscaling is disabled")
image = image.resize(size_px, Image.LANCZOS)
if image.size[0] > size_px[0] or image.size[1] > size_px[1]:
image.thumbnail(size_px, Image.BICUBIC)
return image
def pil_make_grid(
images: list[Image.Image],
max_cols: int = 8,
padding: int = 4,
bg_color: tuple[int, int, int] = (40, 42, 54), # dracula background color
partial_rows: bool = True,
) -> Image.Image:
n_cols = min(math.floor(math.sqrt(len(images))), max_cols)
n_rows = math.ceil(len(images) / n_cols)
# if the final row is not full and partial_rows is False, remove a row
if n_cols * n_rows > len(images) and not partial_rows:
n_rows -= 1
# assumes all images are same size
image_width, image_height = images[0].size
canvas_width = ((image_width + padding) * n_cols) + padding
canvas_height = ((image_height + padding) * n_rows) + padding
canvas = Image.new("RGB", (canvas_width, canvas_height), bg_color)
for i, img in enumerate(images):
x = (i % n_cols) * (image_width + padding) + padding
y = (i // n_cols) * (image_height + padding) + padding
canvas.paste(img, (x, y))
return canvas
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
kaomojis = [
"0_0",
"(o)_(o)",
"+_+",
"+_-",
"._.",
"<o>_<o>",
"<|>_<|>",
"=_=",
">_<",
"3_3",
"6_9",
">_o",
"@_@",
"^_^",
"o_o",
"u_u",
"x_x",
"|_|",
"||_||",
]