PaintsUndo / wd14tagger.py
MohamedRashad's picture
Upload code
6dd488f
raw
history blame
No virus
3.33 kB
# https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags
import os
import csv
import numpy as np
import onnxruntime as ort
from PIL import Image
from onnxruntime import InferenceSession
from torch.hub import download_url_to_file
global_model = None
global_csv = None
def download_model(url, local_path):
if os.path.exists(local_path):
return local_path
temp_path = local_path + '.tmp'
download_url_to_file(url=url, dst=temp_path)
os.rename(temp_path, local_path)
return local_path
def default_interrogator(image, threshold=0.35, character_threshold=0.85, exclude_tags=""):
global global_model, global_csv
model_name = "wd-v1-4-moat-tagger-v2"
model_onnx_filename = download_model(
url=f'https://huggingface.co/lllyasviel/misc/resolve/main/{model_name}.onnx',
local_path=f'./{model_name}.onnx',
)
model_csv_filename = download_model(
url=f'https://huggingface.co/lllyasviel/misc/resolve/main/{model_name}.csv',
local_path=f'./{model_name}.csv',
)
if global_model is not None:
model = global_model
else:
# assert 'CUDAExecutionProvider' in ort.get_available_providers(), 'CUDA Install Failed!'
# model = InferenceSession(model_onnx_filename, providers=['CUDAExecutionProvider'])
model = InferenceSession(model_onnx_filename, providers=['CPUExecutionProvider'])
global_model = model
input = model.get_inputs()[0]
height = input.shape[1]
if isinstance(image, str):
image = Image.open(image) # RGB
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)
else:
image = image
ratio = float(height) / max(image.size)
new_size = tuple([int(x*ratio) for x in image.size])
image = image.resize(new_size, Image.LANCZOS)
square = Image.new("RGB", (height, height), (255, 255, 255))
square.paste(image, ((height-new_size[0])//2, (height-new_size[1])//2))
image = np.array(square).astype(np.float32)
image = image[:, :, ::-1] # RGB -> BGR
image = np.expand_dims(image, 0)
if global_csv is not None:
csv_lines = global_csv
else:
csv_lines = []
with open(model_csv_filename) as f:
reader = csv.reader(f)
next(reader)
for row in reader:
csv_lines.append(row)
global_csv = csv_lines
tags = []
general_index = None
character_index = None
for line_num, row in enumerate(csv_lines):
if general_index is None and row[2] == "0":
general_index = line_num
elif character_index is None and row[2] == "4":
character_index = line_num
tags.append(row[1])
label_name = model.get_outputs()[0].name
probs = model.run([label_name], {input.name: image})[0]
result = list(zip(tags, probs[0]))
general = [item for item in result[general_index:character_index] if item[1] > threshold]
character = [item for item in result[character_index:] if item[1] > character_threshold]
all = character + general
remove = [s.strip() for s in exclude_tags.lower().split(",")]
all = [tag for tag in all if tag[0] not in remove]
res = ", ".join((item[0].replace("(", "\\(").replace(")", "\\)") for item in all)).replace('_', ' ')
return res