Spaces:
Running
Running
import os, re, cv2 | |
from typing import Mapping, Tuple, Dict | |
import gradio as gr | |
import numpy as np | |
import io | |
import pandas as pd | |
from PIL import Image | |
from huggingface_hub import hf_hub_download | |
from onnxruntime import InferenceSession | |
# noinspection PyUnresolvedReferences | |
def make_square(img, target_size): | |
old_size = img.shape[:2] | |
desired_size = max(old_size) | |
desired_size = max(desired_size, target_size) | |
delta_w = desired_size - old_size[1] | |
delta_h = desired_size - old_size[0] | |
top, bottom = delta_h // 2, delta_h - (delta_h // 2) | |
left, right = delta_w // 2, delta_w - (delta_w // 2) | |
color = [255, 255, 255] | |
return cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) | |
# noinspection PyUnresolvedReferences | |
def smart_resize(img, size): | |
# Assumes the image has already gone through make_square | |
if img.shape[0] > size: | |
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA) | |
elif img.shape[0] < size: | |
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC) | |
else: # just do nothing | |
pass | |
return img | |
class WaifuDiffusionInterrogator: | |
def __init__( | |
self, | |
repo='SmilingWolf/wd-v1-4-vit-tagger', | |
model_path='model.onnx', | |
tags_path='selected_tags.csv', | |
mode: str = "auto" | |
) -> None: | |
self.__repo = repo | |
self.__model_path = model_path | |
self.__tags_path = tags_path | |
self._provider_mode = mode | |
self.__initialized = False | |
self._model, self._tags = None, None | |
def _init(self) -> None: | |
if self.__initialized: | |
return | |
model_path = hf_hub_download(self.__repo, filename=self.__model_path) | |
tags_path = hf_hub_download(self.__repo, filename=self.__tags_path) | |
self._model = InferenceSession(str(model_path)) | |
self._tags = pd.read_csv(tags_path) | |
self.__initialized = True | |
def _calculation(self, image: Image.Image) -> pd.DataFrame: | |
# print(image) todo: figure out what to do if URL | |
self._init() | |
# code for converting the image and running the model is taken from the link below | |
# thanks, SmilingWolf! | |
# https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py | |
# convert an image to fit the model | |
_, height, _, _ = self._model.get_inputs()[0].shape | |
# alpha to white | |
print(image) | |
image = image.convert('RGBA') | |
new_image = Image.new('RGBA', image.size, 'WHITE') | |
new_image.paste(image, mask=image) | |
image = new_image.convert('RGB') | |
image = np.asarray(image) | |
# PIL RGB to OpenCV BGR | |
image = image[:, :, ::-1] | |
image = make_square(image, height) | |
image = smart_resize(image, height) | |
image = image.astype(np.float32) | |
image = np.expand_dims(image, 0) | |
# evaluate model | |
input_name = self._model.get_inputs()[0].name | |
label_name = self._model.get_outputs()[0].name | |
confidence = self._model.run([label_name], {input_name: image})[0] | |
full_tags = self._tags[['name', 'category']].copy() | |
full_tags['confidence'] = confidence[0] | |
return full_tags | |
def interrogate(self, image: Image) -> Tuple[Dict[str, float], Dict[str, float]]: | |
full_tags = self._calculation(image) | |
# first 4 items are for rating (general, sensitive, questionable, explicit) | |
ratings = dict(full_tags[full_tags['category'] == 9][['name', 'confidence']].values) | |
# rest are regular tags | |
tags = dict(full_tags[full_tags['category'] != 9][['name', 'confidence']].values) | |
return ratings, tags | |
WAIFU_MODELS: Mapping[str, WaifuDiffusionInterrogator] = { | |
'chen-vit': WaifuDiffusionInterrogator(), | |
'chen-convnext': WaifuDiffusionInterrogator( | |
repo='SmilingWolf/wd-v1-4-convnext-tagger' | |
), | |
'chen-convnext2': WaifuDiffusionInterrogator( | |
repo="SmilingWolf/wd-v1-4-convnextv2-tagger-v2" | |
), | |
'chen-swinv2': WaifuDiffusionInterrogator( | |
repo='SmilingWolf/wd-v1-4-swinv2-tagger-v2' | |
), | |
'chen-moat2': WaifuDiffusionInterrogator( | |
repo='SmilingWolf/wd-v1-4-moat-tagger-v2' | |
), | |
'chen-convnext3': WaifuDiffusionInterrogator( | |
repo='SmilingWolf/wd-convnext-tagger-v3' | |
), | |
'chen-vit3': WaifuDiffusionInterrogator( | |
repo='SmilingWolf/wd-vit-tagger-v3' | |
), | |
'chen-swinv3': WaifuDiffusionInterrogator( | |
repo='SmilingWolf/wd-swinv2-tagger-v3' | |
), | |
} | |
RE_SPECIAL = re.compile(r'([\\()])') | |
def image_to_wd14_tags(image: Image.Image, model_name: str, threshold: float, | |
use_spaces: bool, use_escape: bool, include_ranks=False, score_descend=True) \ | |
-> Tuple[Mapping[str, float], str, Mapping[str, float]]: | |
model = WAIFU_MODELS[model_name] | |
ratings, tags = model.interrogate(image) | |
filtered_tags = { | |
tag: score for tag, score in tags.items() | |
if score >= threshold | |
} | |
text_items = [] | |
tags_pairs = filtered_tags.items() | |
if score_descend: | |
tags_pairs = sorted(tags_pairs, key=lambda x: (-x[1], x[0])) | |
for tag, score in tags_pairs: | |
tag_outformat = tag | |
if use_spaces: | |
tag_outformat = tag_outformat.replace('_', '-') | |
else: | |
tag_outformat = tag_outformat.replace(' ', ', ') | |
tag_outformat = tag_outformat.replace('_', ' ') | |
if use_escape: | |
tag_outformat = re.sub(RE_SPECIAL, r'\\\1', tag_outformat) | |
if include_ranks: | |
tag_outformat = f"({tag_outformat}:{score:.3f})" | |
text_items.append(tag_outformat) | |
if use_spaces: | |
output_text = ' '.join(text_items) | |
else: | |
output_text = ', '.join(text_items) | |
return ratings, output_text, filtered_tags | |
if __name__ == '__main__': | |
with gr.Blocks(analytics_enabled=False, theme="NoCrypt/miku") as demo: | |
with gr.Row(): | |
with gr.Column(): | |
gr_input_image = gr.Image(type='pil', label='Chen Chen', sources=['upload', 'clipboard']) | |
with gr.Row(): | |
gr_model = gr.Radio(list(WAIFU_MODELS.keys()), value='chen-moat2', label='Chen') | |
gr_threshold = gr.Slider(0.0, 1.0, 0.5, label='Chen Chen Chen Chen Chen') | |
with gr.Row(): | |
gr_space = gr.Checkbox(value=True, label='Use DashSpace') | |
gr_escape = gr.Checkbox(value=True, label='Chen Text Escape') | |
gr_btn_submit = gr.Button(value='橙', variant='primary') | |
with gr.Column(): | |
gr_ratings = gr.Label(label='橙 橙') | |
with gr.Tabs(): | |
with gr.Tab("Chens"): | |
gr_tags = gr.Label(label='Chens') | |
with gr.Tab("Chen Text"): | |
gr_output_text = gr.TextArea(label='Chen Text') | |
gr_btn_submit.click( | |
image_to_wd14_tags, | |
inputs=[gr_input_image, gr_model, gr_threshold, gr_space, gr_escape], | |
outputs=[gr_ratings, gr_output_text, gr_tags], | |
api_name="classify" | |
) | |
demo.queue(os.cpu_count()).launch() |