Spaces:
Sleeping
Sleeping
File size: 2,215 Bytes
941b996 72d3376 15a6715 72d3376 0b3046d 15a6715 5ab85aa 15a6715 72d3376 941b996 15a6715 72d3376 68e9e5e 72d3376 68e9e5e 941b996 68e9e5e 5ab85aa 68e9e5e 72d3376 68e9e5e 15a6715 941b996 cec0a27 72d3376 a533c0a 688fe94 941b996 465fe88 941b996 6e7e55f 941b996 465fe88 941b996 688fe94 941b996 c230ae8 941b996 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
from transformers import pipeline
from imgutils.data import rgb_encode, load_image
from onnx_ import _open_onnx_model
from PIL import Image
import gradio as gr
import numpy as np
import os
import requests
import torch
import json
def _img_encode(image, size=(384,384), normalize=(0.5,0.5)):
image = image.resize(size, Image.BILINEAR)
data = rgb_encode(image, order_='CHW')
if normalize is not None:
mean_, std_ = normalize
mean = np.asarray([mean_]).reshape((-1, 1, 1))
std = np.asarray([std_]).reshape((-1, 1, 1))
data = (data - mean) / std
return data.astype(np.float32)
nsfw_tf = pipeline(model="carbon225/vit-base-patch16-224-hentai")
if not os.path.exists("timm.onnx"):
open("timm.onnx", "wb").write(
requests.get(
"https://huggingface.co/deepghs/anime_rating/resolve/main/caformer_s36_plus/model.onnx"
).content
)
open("timmcfg.json", "wb").write(
requests.get(
"https://huggingface.co/deepghs/anime_rating/resolve/main/caformer_s36_plus/meta.json"
).content
)
else:
print("Model already exists, skipping redownload")
with open("timmcfg.json") as file:
tm_cfg = json.load(file)
nsfw_tm = _open_onnx_model("timm.onnx")
def launch(img):
weight = 0
img = img.convert('RGB')
tm_image = load_image(img, mode='RGB')
tm_input_ = _img_encode(tm_image, size=(256, 256))[None, ...]
tm_items, = nsfw_tm.run(['output'], {'input': tm_input_})
tm_output = sorted(list(zip(tm_cfg["labels"], map(lambda x: x.item(), tm_items[0]))), key=lambda x: x[1], reverse=True)[0][0]
match tm_output:
case "safe":
weight -= 1
case "r15":
weight += 1
case "r18":
weight += 2
tf_output = nsfw_tf(img)[0]["label"]
match tf_output:
case "safe":
weight -= 1
case "suggestive":
weight += 1
case "r18":
weight += 2
print(sorted(list(zip(tm_cfg["labels"], map(lambda x: x.item(), tm_items[0]))), key=lambda x: x[1], reverse=True), tf_output)
return weight > 0
app = gr.Interface(fn=launch, inputs="pil", outputs="text")
app.launch() |