File size: 2,215 Bytes
941b996
72d3376
 
 
15a6715
72d3376
0b3046d
 
15a6715
5ab85aa
15a6715
72d3376
 
 
 
 
 
 
 
 
 
 
 
941b996
15a6715
72d3376
 
68e9e5e
72d3376
68e9e5e
 
941b996
 
 
 
 
68e9e5e
 
 
5ab85aa
 
68e9e5e
72d3376
68e9e5e
15a6715
941b996
cec0a27
72d3376
 
a533c0a
688fe94
941b996
 
 
465fe88
941b996
19c4f43
941b996
 
 
6e7e55f
941b996
 
 
465fe88
941b996
 
 
 
 
688fe94
941b996
 
c230ae8
941b996
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from transformers import pipeline
from imgutils.data import rgb_encode, load_image
from onnx_ import _open_onnx_model
from PIL import Image
import gradio as gr
import numpy as np
import os
import requests
import torch
import json

def _img_encode(image, size=(384,384), normalize=(0.5,0.5)):
    image = image.resize(size, Image.BILINEAR)
    data = rgb_encode(image, order_='CHW')

    if normalize is not None:
        mean_, std_ = normalize
        mean = np.asarray([mean_]).reshape((-1, 1, 1))
        std = np.asarray([std_]).reshape((-1, 1, 1))
        data = (data - mean) / std

    return data.astype(np.float32)
    
nsfw_tf = pipeline(model="carbon225/vit-base-patch16-224-hentai")

if not os.path.exists("timm.onnx"):
    open("timm.onnx", "wb").write(
        requests.get(
            "https://huggingface.co/deepghs/anime_rating/resolve/main/caformer_s36_plus/model.onnx"
        ).content
    )
    open("timmcfg.json", "wb").write(
        requests.get(
            "https://huggingface.co/deepghs/anime_rating/resolve/main/caformer_s36_plus/meta.json"
        ).content
    )
else:
    print("Model already exists, skipping redownload")

with open("timmcfg.json") as file:
    tm_cfg = json.load(file)

nsfw_tm = _open_onnx_model("timm.onnx")

def launch(img):
    weight = 0
    img = img.convert('RGB')
    tm_image = load_image(img, mode='RGB')
    tm_input_ = _img_encode(tm_image, size=(256, 256))[None, ...]
    tm_items, = nsfw_tm.run(['output'], {'input': tm_input_})
    tm_output = sorted(list(zip(tm_cfg["labels"], map(lambda x: x.item(), tm_items[0]))), key=lambda x: x[1], reverse=True)[0][0]
    
    match tm_output:
        case "safe":
            weight -= 1
        case "r15":
            weight += 2
        case "r18":
            weight += 2

    tf_output = nsfw_tf(img)[0]["label"]

    match tf_output:
        case "safe":
            weight -= 1
        case "suggestive":
            weight += 1
        case "r18":
            weight += 2

    print(sorted(list(zip(tm_cfg["labels"], map(lambda x: x.item(), tm_items[0]))), key=lambda x: x[1], reverse=True), tf_output)
    return weight > 0

app = gr.Interface(fn=launch, inputs="pil", outputs="text")
app.launch()