File size: 2,070 Bytes
941b996
72d3376
 
 
15a6715
72d3376
0b3046d
 
15a6715
 
5ab85aa
15a6715
72d3376
 
 
 
 
 
 
 
 
 
 
 
941b996
15a6715
72d3376
 
68e9e5e
72d3376
68e9e5e
 
941b996
 
 
 
 
68e9e5e
 
 
5ab85aa
 
68e9e5e
72d3376
68e9e5e
15a6715
941b996
72d3376
 
 
 
941b996
 
 
 
 
 
 
 
 
72d3376
 
941b996
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from transformers import pipeline
from imgutils.data import rgb_encode, load_image
from onnx_ import _open_onnx_model
from PIL import Image
import gradio as gr
import numpy as np
import os
import requests
import timm
import torch
import json

def _img_encode(image, size=(384,384), normalize=(0.5,0.5)):
    image = image.resize(size, Image.BILINEAR)
    data = rgb_encode(image, order_='CHW')

    if normalize is not None:
        mean_, std_ = normalize
        mean = np.asarray([mean_]).reshape((-1, 1, 1))
        std = np.asarray([std_]).reshape((-1, 1, 1))
        data = (data - mean) / std

    return data.astype(np.float32)
    
nsfw_tf = pipeline(model="carbon225/vit-base-patch16-224-hentai")

if not os.path.exists("timm.onnx"):
    open("timm.onnx", "wb").write(
        requests.get(
            "https://huggingface.co/deepghs/anime_rating/resolve/main/caformer_s36_plus/model.onnx"
        ).content
    )
    open("timmcfg.json", "wb").write(
        requests.get(
            "https://huggingface.co/deepghs/anime_rating/resolve/main/caformer_s36_plus/meta.json"
        ).content
    )
else:
    print("Model already exists, skipping redownload")

with open("timmcfg.json") as file:
    tm_cfg = json.load(file)

nsfw_tm = _open_onnx_model("timm.onnx")

def launch(img):
    weight = 0
    tm_image = load_image(img, mode='RGB')
    tm_input_ = _img_encode(tm_image, size=(256, 256))[None, ...]
    tm_output, = nsfw_tm.run(['output'], {'input': tm_input_})
    tm_output = zip(tm_cfg["labels"], map(lambda x: x.item(), output[0]))[0][0]
    
    match tm_output:
        case "safe":
            weight -= 2
        case "r15":
            weight += 1
        case "r18":
            weight += 2

    tf_img = Image.open(img).convert('RGB')
    tf_output = nsfw_tf(tf_img)[0]["label"]

    match tf_output:
        case "safe":
            weight -= 2
        case "suggestive":
            weight += 1
        case "r18":
            weight += 2

    return weight > 0

app = gr.Interface(fn=generate, inputs="image", outputs="text")
app.launch()