Spaces:
Running
on
Zero
Running
on
Zero
import huggingface_hub | |
from PIL import Image | |
from pathlib import Path | |
import csv | |
import spaces | |
import onnxruntime as rt | |
try: | |
e621_model_path = Path(huggingface_hub.snapshot_download('toynya/Z3D-E621-Convnext')) | |
e621_model_session = rt.InferenceSession(e621_model_path / 'model.onnx', providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) | |
with open(e621_model_path / 'tags-selected.csv', mode='r', encoding='utf-8') as file: | |
csv_reader = csv.DictReader(file) | |
e621_model_tags = [row['name'].strip() for row in csv_reader] | |
except Exception as e: | |
print(e) | |
def prepare_image_e621(image: Image.Image, target_size: int): | |
import numpy as np | |
# Pad image to square | |
image_shape = image.size | |
max_dim = max(image_shape) | |
pad_left = (max_dim - image_shape[0]) // 2 | |
pad_top = (max_dim - image_shape[1]) // 2 | |
padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) | |
padded_image.paste(image, (pad_left, pad_top)) | |
# Resize | |
if max_dim != target_size: | |
padded_image = padded_image.resize((target_size, target_size), Image.BICUBIC) | |
# Convert to numpy array | |
# Based on the ONNX graph, the model appears to expect inputs in the range of 0-255 | |
image_array = np.asarray(padded_image, dtype=np.float32) | |
# Convert PIL-native RGB to BGR | |
image_array = image_array[:, :, ::-1] | |
return np.expand_dims(image_array, axis=0) | |
def predict_e621(image: Image.Image, threshold: float = 0.3): | |
image_array = prepare_image_e621(image, 448) | |
image_array = prepare_image_e621(image, 448) | |
input_name = 'input_1:0' | |
output_name = 'predictions_sigmoid' | |
result = e621_model_session.run([output_name], {input_name: image_array}) | |
result = result[0][0] | |
scores = {e621_model_tags[i]: result[i] for i in range(len(result))} | |
predicted_tags = [tag for tag, score in scores.items() if score > threshold] | |
tag_string = ', '.join(predicted_tags).replace("_", " ") | |
return tag_string | |
def predict_tags_e621(image: Image.Image, input_tags: str, algo: list[str], threshold: float = 0.3): | |
def to_list(s): | |
return [x.strip() for x in s.split(",") if not s == ""] | |
def list_uniq(l): | |
return sorted(set(l), key=l.index) | |
if not "Use Z3D-E621-Convnext" in algo: | |
return input_tags | |
tag_list = list_uniq(to_list(input_tags) + to_list(predict_e621(image, threshold))) | |
return ", ".join(tag_list) | |