from transformers import pipeline from imgutils.data import rgb_encode, load_image from onnx_ import _open_onnx_model from PIL import Image import gradio as gr import numpy as np import os import requests import timm import torch import json def _img_encode(image, size=(384,384), normalize=(0.5,0.5)): image = image.resize(size, Image.BILINEAR) data = rgb_encode(image, order_='CHW') if normalize is not None: mean_, std_ = normalize mean = np.asarray([mean_]).reshape((-1, 1, 1)) std = np.asarray([std_]).reshape((-1, 1, 1)) data = (data - mean) / std return data.astype(np.float32) nsfw_tf = pipeline(model="carbon225/vit-base-patch16-224-hentai") if not os.path.exists("timm.onnx"): open("timm.onnx", "wb").write( requests.get( "https://huggingface.co/deepghs/anime_rating/resolve/main/caformer_s36_plus/model.onnx" ).content ) open("timmcfg.json", "wb").write( requests.get( "https://huggingface.co/deepghs/anime_rating/resolve/main/caformer_s36_plus/meta.json" ).content ) else: print("Model already exists, skipping redownload") with open("timmcfg.json") as file: tm_cfg = json.load(file) nsfw_tm = _open_onnx_model("timm.onnx") def launch(img): weight = 0 tm_image = load_image(img, mode='RGB') tm_input_ = _img_encode(tm_image, size=(256, 256))[None, ...] tm_output, = nsfw_tm.run(['output'], {'input': tm_input_}) tm_output = zip(tm_cfg["labels"], map(lambda x: x.item(), output[0]))[0][0] match tm_output: case "safe": weight -= 2 case "r15": weight += 1 case "r18": weight += 2 tf_img = Image.open(img).convert('RGB') tf_output = nsfw_tf(tf_img)[0]["label"] match tf_output: case "safe": weight -= 2 case "suggestive": weight += 1 case "r18": weight += 2 return weight > 0 app = gr.Interface(fn=generate, inputs="image", outputs="text") app.launch()