nsfw-det / app.py
spuun's picture
Update app.py
19c4f43
from transformers import pipeline
from imgutils.data import rgb_encode, load_image
from onnx_ import _open_onnx_model
from PIL import Image
import gradio as gr
import numpy as np
import os
import requests
import torch
import json
def _img_encode(image, size=(384,384), normalize=(0.5,0.5)):
image = image.resize(size, Image.BILINEAR)
data = rgb_encode(image, order_='CHW')
if normalize is not None:
mean_, std_ = normalize
mean = np.asarray([mean_]).reshape((-1, 1, 1))
std = np.asarray([std_]).reshape((-1, 1, 1))
data = (data - mean) / std
return data.astype(np.float32)
nsfw_tf = pipeline(model="carbon225/vit-base-patch16-224-hentai")
if not os.path.exists("timm.onnx"):
open("timm.onnx", "wb").write(
requests.get(
"https://huggingface.co/deepghs/anime_rating/resolve/main/caformer_s36_plus/model.onnx"
).content
)
open("timmcfg.json", "wb").write(
requests.get(
"https://huggingface.co/deepghs/anime_rating/resolve/main/caformer_s36_plus/meta.json"
).content
)
else:
print("Model already exists, skipping redownload")
with open("timmcfg.json") as file:
tm_cfg = json.load(file)
nsfw_tm = _open_onnx_model("timm.onnx")
def launch(img):
weight = 0
img = img.convert('RGB')
tm_image = load_image(img, mode='RGB')
tm_input_ = _img_encode(tm_image, size=(256, 256))[None, ...]
tm_items, = nsfw_tm.run(['output'], {'input': tm_input_})
tm_output = sorted(list(zip(tm_cfg["labels"], map(lambda x: x.item(), tm_items[0]))), key=lambda x: x[1], reverse=True)[0][0]
match tm_output:
case "safe":
weight -= 1
case "r15":
weight += 2
case "r18":
weight += 2
tf_output = nsfw_tf(img)[0]["label"]
match tf_output:
case "safe":
weight -= 1
case "suggestive":
weight += 1
case "r18":
weight += 2
print(sorted(list(zip(tm_cfg["labels"], map(lambda x: x.item(), tm_items[0]))), key=lambda x: x[1], reverse=True), tf_output)
return weight > 0
app = gr.Interface(fn=launch, inputs="pil", outputs="text")
app.launch()