Spaces:
Running
Running
# https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags | |
# https://github.com/pythongosssss/ComfyUI-WD14-Tagger/blob/main/wd14tagger.py | |
# { | |
# "wd-v1-4-moat-tagger-v2": "https://huggingface.co/SmilingWolf/wd-v1-4-moat-tagger-v2", | |
# "wd-v1-4-convnextv2-tagger-v2": "https://huggingface.co/SmilingWolf/wd-v1-4-convnextv2-tagger-v2", | |
# "wd-v1-4-convnext-tagger-v2": "https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2", | |
# "wd-v1-4-convnext-tagger": "https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger", | |
# "wd-v1-4-vit-tagger-v2": "https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger-v2" | |
# } | |
import numpy as np | |
import csv | |
import onnxruntime as ort | |
from PIL import Image | |
from onnxruntime import InferenceSession | |
from modules.config import path_clip_vision | |
from modules.model_loader import load_file_from_url | |
global_model = None | |
global_csv = None | |
def default_interrogator(image_rgb, threshold=0.35, character_threshold=0.85, exclude_tags=""): | |
global global_model, global_csv | |
model_name = "wd-v1-4-moat-tagger-v2" | |
model_onnx_filename = load_file_from_url( | |
url=f'https://huggingface.co/lllyasviel/misc/resolve/main/{model_name}.onnx', | |
model_dir=path_clip_vision, | |
file_name=f'{model_name}.onnx', | |
) | |
model_csv_filename = load_file_from_url( | |
url=f'https://huggingface.co/lllyasviel/misc/resolve/main/{model_name}.csv', | |
model_dir=path_clip_vision, | |
file_name=f'{model_name}.csv', | |
) | |
if global_model is not None: | |
model = global_model | |
else: | |
model = InferenceSession(model_onnx_filename, providers=ort.get_available_providers()) | |
global_model = model | |
input = model.get_inputs()[0] | |
height = input.shape[1] | |
image = Image.fromarray(image_rgb) # RGB | |
ratio = float(height)/max(image.size) | |
new_size = tuple([int(x*ratio) for x in image.size]) | |
image = image.resize(new_size, Image.LANCZOS) | |
square = Image.new("RGB", (height, height), (255, 255, 255)) | |
square.paste(image, ((height-new_size[0])//2, (height-new_size[1])//2)) | |
image = np.array(square).astype(np.float32) | |
image = image[:, :, ::-1] # RGB -> BGR | |
image = np.expand_dims(image, 0) | |
if global_csv is not None: | |
csv_lines = global_csv | |
else: | |
csv_lines = [] | |
with open(model_csv_filename) as f: | |
reader = csv.reader(f) | |
next(reader) | |
for row in reader: | |
csv_lines.append(row) | |
global_csv = csv_lines | |
tags = [] | |
general_index = None | |
character_index = None | |
for line_num, row in enumerate(csv_lines): | |
if general_index is None and row[2] == "0": | |
general_index = line_num | |
elif character_index is None and row[2] == "4": | |
character_index = line_num | |
tags.append(row[1]) | |
label_name = model.get_outputs()[0].name | |
probs = model.run([label_name], {input.name: image})[0] | |
result = list(zip(tags, probs[0])) | |
general = [item for item in result[general_index:character_index] if item[1] > threshold] | |
character = [item for item in result[character_index:] if item[1] > character_threshold] | |
all = character + general | |
remove = [s.strip() for s in exclude_tags.lower().split(",")] | |
all = [tag for tag in all if tag[0] not in remove] | |
res = ", ".join((item[0].replace("(", "\\(").replace(")", "\\)") for item in all)).replace('_', ' ') | |
return res | |