Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,17 +1,33 @@
|
|
1 |
from transformers import pipeline
|
|
|
|
|
|
|
2 |
import gradio as gr
|
|
|
3 |
import os
|
4 |
import requests
|
5 |
import timm
|
6 |
import torch
|
7 |
import json
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
nsfw_tf = pipeline(model="carbon225/vit-base-patch16-224-hentai")
|
10 |
|
11 |
-
if not os.path.exists("timm.
|
12 |
-
open("timm.
|
13 |
requests.get(
|
14 |
-
"https://huggingface.co/deepghs/anime_rating/resolve/main/caformer_s36_plus/model.
|
15 |
).content
|
16 |
)
|
17 |
open("timmcfg.json", "wb").write(
|
@@ -25,26 +41,14 @@ else:
|
|
25 |
with open("timmcfg.json") as file:
|
26 |
tm_cfg = json.load(file)
|
27 |
|
28 |
-
nsfw_tm = timm.
|
29 |
-
"caformer_s36.sail_in22k_ft_in1k_384",
|
30 |
-
checkpoint_path="./timm.ckpt",
|
31 |
-
model_config=tm_cfg,
|
32 |
-
num_classes=3
|
33 |
-
).eval()
|
34 |
-
tm_config = timm.data.resolve_model_data_config(nsfw_tm)
|
35 |
-
tm_trans = timm.data.create_transform((256, 256), **tm_config, is_training=False)
|
36 |
-
|
37 |
|
38 |
def launch(img):
|
39 |
weight = 0
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
nsfw_tm(transforms(img).unsqueeze(0))[0], dim=0
|
45 |
-
)
|
46 |
-
)
|
47 |
-
]
|
48 |
|
49 |
match tm_output:
|
50 |
case "safe":
|
@@ -54,8 +58,8 @@ def launch(img):
|
|
54 |
case "r18":
|
55 |
weight += 2
|
56 |
|
57 |
-
|
58 |
-
tf_output = nsfw_tf(
|
59 |
|
60 |
match tf_output:
|
61 |
case "safe":
|
|
|
1 |
from transformers import pipeline
|
2 |
+
from imgutils.data import rgb_encode, load_image
|
3 |
+
from onnx_ import _open_onnx_model
|
4 |
+
from PIL import Image
|
5 |
import gradio as gr
|
6 |
+
import numpy as np
|
7 |
import os
|
8 |
import requests
|
9 |
import timm
|
10 |
import torch
|
11 |
import json
|
12 |
|
13 |
+
def _img_encode(image, size=(384,384), normalize=(0.5,0.5)):
|
14 |
+
image = image.resize(size, Image.BILINEAR)
|
15 |
+
data = rgb_encode(image, order_='CHW')
|
16 |
+
|
17 |
+
if normalize is not None:
|
18 |
+
mean_, std_ = normalize
|
19 |
+
mean = np.asarray([mean_]).reshape((-1, 1, 1))
|
20 |
+
std = np.asarray([std_]).reshape((-1, 1, 1))
|
21 |
+
data = (data - mean) / std
|
22 |
+
|
23 |
+
return data.astype(np.float32)
|
24 |
+
|
25 |
nsfw_tf = pipeline(model="carbon225/vit-base-patch16-224-hentai")
|
26 |
|
27 |
+
if not os.path.exists("timm.onnx"):
|
28 |
+
open("timm.onnx", "wb").write(
|
29 |
requests.get(
|
30 |
+
"https://huggingface.co/deepghs/anime_rating/resolve/main/caformer_s36_plus/model.onnx"
|
31 |
).content
|
32 |
)
|
33 |
open("timmcfg.json", "wb").write(
|
|
|
41 |
with open("timmcfg.json") as file:
|
42 |
tm_cfg = json.load(file)
|
43 |
|
44 |
+
nsfw_tm = _open_onnx_model("timm.onnx")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
def launch(img):
|
47 |
weight = 0
|
48 |
+
tm_image = load_image(img, mode='RGB')
|
49 |
+
tm_input_ = _img_encode(tm_image, size=(256, 256))[None, ...]
|
50 |
+
tm_output, = nsfw_tm.run(['output'], {'input': tm_input_})
|
51 |
+
tm_output = zip(tm_cfg["labels"], map(lambda x: x.item(), output[0]))[0][0]
|
|
|
|
|
|
|
|
|
52 |
|
53 |
match tm_output:
|
54 |
case "safe":
|
|
|
58 |
case "r18":
|
59 |
weight += 2
|
60 |
|
61 |
+
tf_img = Image.open(img).convert('RGB')
|
62 |
+
tf_output = nsfw_tf(tf_img)[0]["label"]
|
63 |
|
64 |
match tf_output:
|
65 |
case "safe":
|