pi-tagger / tagger /common.py
neggles's picture
clean up
680a411
raw
history blame
No virus
3.4 kB
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Optional
import numpy as np
import pandas as pd
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import HfHubHTTPError
from PIL import Image
@dataclass
class LabelData:
names: list[str]
rating: list[np.int64]
general: list[np.int64]
character: list[np.int64]
@dataclass
class ImageLabels:
caption: str
booru: str
rating: dict[str, float]
general: dict[str, float]
character: dict[str, float]
@lru_cache(maxsize=5)
def load_labels_hf(
repo_id: str,
revision: Optional[str] = None,
token: Optional[str] = None,
) -> LabelData:
try:
csv_path = hf_hub_download(
repo_id=repo_id, filename="selected_tags.csv", revision=revision, token=token
)
csv_path = Path(csv_path).resolve()
except HfHubHTTPError as e:
raise FileNotFoundError(f"selected_tags.csv failed to download from {repo_id}") from e
df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
tag_data = LabelData(
names=df["name"].tolist(),
rating=list(np.where(df["category"] == 9)[0]),
general=list(np.where(df["category"] == 0)[0]),
character=list(np.where(df["category"] == 4)[0]),
)
return tag_data
def pil_ensure_rgb(image: Image.Image) -> Image.Image:
# convert to RGB/RGBA if not already (deals with palette images etc.)
if image.mode not in ["RGB", "RGBA"]:
image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
# convert RGBA to RGB with white background
if image.mode == "RGBA":
canvas = Image.new("RGBA", image.size, (255, 255, 255))
canvas.alpha_composite(image)
image = canvas.convert("RGB")
return image
def pil_pad_square(
image: Image.Image,
fill: tuple[int, int, int] = (255, 255, 255),
) -> Image.Image:
w, h = image.size
# get the largest dimension so we can pad to a square
px = max(image.size)
# pad to square with white background
canvas = Image.new("RGB", (px, px), fill)
canvas.paste(image, ((px - w) // 2, (px - h) // 2))
return canvas
def preprocess_image(
image: Image.Image,
size_px: int | tuple[int, int],
upscale: bool = True,
) -> Image.Image:
"""
Preprocess an image to be square and centered on a white background.
"""
if isinstance(size_px, int):
size_px = (size_px, size_px)
# ensure RGB and pad to square
image = pil_ensure_rgb(image)
image = pil_pad_square(image)
# resize to target size
if image.size[0] < size_px[0] or image.size[1] < size_px[1]:
if upscale is False:
raise ValueError("Image is smaller than target size, and upscaling is disabled")
image = image.resize(size_px, Image.LANCZOS)
if image.size[0] > size_px[0] or image.size[1] > size_px[1]:
image.thumbnail(size_px, Image.BICUBIC)
return image
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
kaomojis = [
"0_0",
"(o)_(o)",
"+_+",
"+_-",
"._.",
"<o>_<o>",
"<|>_<|>",
"=_=",
">_<",
"3_3",
"6_9",
">_o",
"@_@",
"^_^",
"o_o",
"u_u",
"x_x",
"|_|",
"||_||",
]