Spaces:
Sleeping
Sleeping
from transformers import pipeline | |
from imgutils.data import rgb_encode, load_image | |
from onnx_ import _open_onnx_model | |
from PIL import Image | |
import gradio as gr | |
import numpy as np | |
import os | |
import requests | |
import timm | |
import torch | |
import json | |
def _img_encode(image, size=(384,384), normalize=(0.5,0.5)): | |
image = image.resize(size, Image.BILINEAR) | |
data = rgb_encode(image, order_='CHW') | |
if normalize is not None: | |
mean_, std_ = normalize | |
mean = np.asarray([mean_]).reshape((-1, 1, 1)) | |
std = np.asarray([std_]).reshape((-1, 1, 1)) | |
data = (data - mean) / std | |
return data.astype(np.float32) | |
nsfw_tf = pipeline(model="carbon225/vit-base-patch16-224-hentai") | |
if not os.path.exists("timm.onnx"): | |
open("timm.onnx", "wb").write( | |
requests.get( | |
"https://huggingface.co/deepghs/anime_rating/resolve/main/caformer_s36_plus/model.onnx" | |
).content | |
) | |
open("timmcfg.json", "wb").write( | |
requests.get( | |
"https://huggingface.co/deepghs/anime_rating/resolve/main/caformer_s36_plus/meta.json" | |
).content | |
) | |
else: | |
print("Model already exists, skipping redownload") | |
with open("timmcfg.json") as file: | |
tm_cfg = json.load(file) | |
nsfw_tm = _open_onnx_model("timm.onnx") | |
def launch(img): | |
weight = 0 | |
tm_image = load_image(img, mode='RGB') | |
tm_input_ = _img_encode(tm_image, size=(256, 256))[None, ...] | |
tm_output, = nsfw_tm.run(['output'], {'input': tm_input_}) | |
tm_output = zip(tm_cfg["labels"], map(lambda x: x.item(), output[0]))[0][0] | |
match tm_output: | |
case "safe": | |
weight -= 2 | |
case "r15": | |
weight += 1 | |
case "r18": | |
weight += 2 | |
tf_img = Image.open(img).convert('RGB') | |
tf_output = nsfw_tf(tf_img)[0]["label"] | |
match tf_output: | |
case "safe": | |
weight -= 2 | |
case "suggestive": | |
weight += 1 | |
case "r18": | |
weight += 2 | |
return weight > 0 | |
app = gr.Interface(fn=generate, inputs="image", outputs="text") | |
app.launch() |