|
|
|
|
|
import comfy.utils |
|
import asyncio |
|
import aiohttp |
|
import numpy as np |
|
import csv |
|
import os |
|
import sys |
|
import onnxruntime as ort |
|
from onnxruntime import InferenceSession |
|
from PIL import Image |
|
from server import PromptServer |
|
from aiohttp import web |
|
from .pysssss import get_ext_dir, get_comfy_dir, download_to_file, update_node_status, wait_for_async, get_extension_config |
|
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) |
|
|
|
|
|
config = get_extension_config() |
|
|
|
defaults = { |
|
"model": "wd-v1-4-moat-tagger-v2", |
|
"threshold": 0.35, |
|
"character_threshold": 0.85, |
|
"exclude_tags": "" |
|
} |
|
defaults.update(config.get("settings", {})) |
|
|
|
models_dir = get_ext_dir("models", mkdir=True) |
|
all_models = ("wd-v1-4-moat-tagger-v2", |
|
"wd-v1-4-convnext-tagger-v2", "wd-v1-4-convnext-tagger", |
|
"wd-v1-4-convnextv2-tagger-v2", "wd-v1-4-vit-tagger-v2") |
|
|
|
|
|
def get_installed_models(): |
|
return filter(lambda x: x.endswith(".onnx"), os.listdir(models_dir)) |
|
|
|
|
|
async def tag(image, model_name, threshold=0.35, character_threshold=0.85, exclude_tags="", client_id=None, node=None): |
|
if model_name.endswith(".onnx"): |
|
model_name = model_name[0:-5] |
|
installed = list(get_installed_models()) |
|
if not any(model_name + ".onnx" in s for s in installed): |
|
await download_model(model_name, client_id, node) |
|
|
|
name = os.path.join(models_dir, model_name + ".onnx") |
|
model = InferenceSession(name, providers=ort.get_available_providers()) |
|
|
|
input = model.get_inputs()[0] |
|
height = input.shape[1] |
|
|
|
|
|
ratio = float(height)/max(image.size) |
|
new_size = tuple([int(x*ratio) for x in image.size]) |
|
image = image.resize(new_size, Image.LANCZOS) |
|
square = Image.new("RGB", (height, height), (255, 255, 255)) |
|
square.paste(image, ((height-new_size[0])//2, (height-new_size[1])//2)) |
|
|
|
image = np.array(square).astype(np.float32) |
|
image = image[:, :, ::-1] |
|
image = np.expand_dims(image, 0) |
|
|
|
|
|
tags = [] |
|
general_index = None |
|
character_index = None |
|
with open(os.path.join(models_dir, model_name + ".csv")) as f: |
|
reader = csv.reader(f) |
|
next(reader) |
|
for row in reader: |
|
if general_index is None and row[2] == "0": |
|
general_index = reader.line_num - 2 |
|
elif character_index is None and row[2] == "4": |
|
character_index = reader.line_num - 2 |
|
tags.append(row[1]) |
|
|
|
label_name = model.get_outputs()[0].name |
|
probs = model.run([label_name], {input.name: image})[0] |
|
|
|
result = list(zip(tags, probs[0])) |
|
|
|
|
|
general = [item for item in result[general_index:character_index] if item[1] > threshold] |
|
character = [item for item in result[character_index:] if item[1] > character_threshold] |
|
|
|
all = character + general |
|
remove = [s.strip() for s in exclude_tags.lower().split(",")] |
|
all = [tag for tag in all if tag[0] not in remove] |
|
|
|
res = ", ".join((item[0].replace("(", "\\(").replace(")", "\\)") for item in all)) |
|
|
|
print(res) |
|
return res |
|
|
|
|
|
async def download_model(model, client_id, node): |
|
url = f"https://huggingface.co/SmilingWolf/{model}/resolve/main/" |
|
async with aiohttp.ClientSession(loop=asyncio.get_event_loop()) as session: |
|
async def update_callback(perc): |
|
nonlocal client_id |
|
message = "" |
|
if perc < 100: |
|
message = f"Downloading {model}" |
|
update_node_status(client_id, node, message, perc) |
|
|
|
await download_to_file( |
|
f"{url}model.onnx", os.path.join("models",f"{model}.onnx"), update_callback, session=session) |
|
await download_to_file( |
|
f"{url}selected_tags.csv", os.path.join("models",f"{model}.csv"), update_callback, session=session) |
|
|
|
update_node_status(client_id, node, None) |
|
|
|
return web.Response(status=200) |
|
|
|
|
|
@PromptServer.instance.routes.get("/pysssss/wd14tagger/tag") |
|
async def get_tags(request): |
|
if "filename" not in request.rel_url.query: |
|
return web.Response(status=404) |
|
|
|
type = request.query.get("type", "output") |
|
if type not in ["output", "input", "temp"]: |
|
return web.Response(status=400) |
|
|
|
target_dir = get_comfy_dir(type) |
|
image_path = os.path.abspath(os.path.join( |
|
target_dir, request.query.get("subfolder", ""), request.query["filename"])) |
|
c = os.path.commonpath((image_path, target_dir)) |
|
if os.path.commonpath((image_path, target_dir)) != target_dir: |
|
return web.Response(status=403) |
|
|
|
if not os.path.isfile(image_path): |
|
return web.Response(status=404) |
|
|
|
image = Image.open(image_path) |
|
|
|
models = get_installed_models() |
|
model = next(models, defaults["model"]) |
|
|
|
return web.json_response(await tag(image, model, client_id=request.rel_url.query.get("clientId", ""), node=request.rel_url.query.get("node", ""))) |
|
|
|
|
|
class WD14Tagger: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"image": ("IMAGE", ), |
|
"model": (all_models, ), |
|
"threshold": ("FLOAT", {"default": defaults["threshold"], "min": 0.0, "max": 1, "step": 0.05}), |
|
"character_threshold": ("FLOAT", {"default": defaults["character_threshold"], "min": 0.0, "max": 1, "step": 0.05}), |
|
"exclude_tags": ("STRING", {"default": defaults["exclude_tags"]}), |
|
}} |
|
|
|
RETURN_TYPES = ("STRING",) |
|
OUTPUT_IS_LIST = (True,) |
|
FUNCTION = "tag" |
|
OUTPUT_NODE = True |
|
|
|
CATEGORY = "image" |
|
|
|
def tag(self, image, model, threshold, character_threshold, exclude_tags=""): |
|
tensor = image*255 |
|
tensor = np.array(tensor, dtype=np.uint8) |
|
|
|
pbar = comfy.utils.ProgressBar(tensor.shape[0]) |
|
tags = [] |
|
for i in range(tensor.shape[0]): |
|
image = Image.fromarray(tensor[i]) |
|
tags.append(wait_for_async(lambda: tag(image, model, threshold, character_threshold, exclude_tags))) |
|
pbar.update(1) |
|
return {"ui": {"tags": tags}, "result": (tags,)} |
|
|
|
|
|
NODE_CLASS_MAPPINGS = { |
|
"WD14Tagger|pysssss": WD14Tagger, |
|
} |
|
NODE_DISPLAY_NAME_MAPPINGS = { |
|
"WD14Tagger|pysssss": "WD14 Tagger ๐", |
|
} |
|
|