spuun commited on
Commit
72d3376
1 Parent(s): 255fd29

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -22
app.py CHANGED
@@ -1,17 +1,33 @@
1
  from transformers import pipeline
 
 
 
2
  import gradio as gr
 
3
  import os
4
  import requests
5
  import timm
6
  import torch
7
  import json
8
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  nsfw_tf = pipeline(model="carbon225/vit-base-patch16-224-hentai")
10
 
11
- if not os.path.exists("timm.ckpt"):
12
- open("timm.ckpt", "wb").write(
13
  requests.get(
14
- "https://huggingface.co/deepghs/anime_rating/resolve/main/caformer_s36_plus/model.ckpt"
15
  ).content
16
  )
17
  open("timmcfg.json", "wb").write(
@@ -25,26 +41,14 @@ else:
25
  with open("timmcfg.json") as file:
26
  tm_cfg = json.load(file)
27
 
28
- nsfw_tm = timm.create_model(
29
- "caformer_s36.sail_in22k_ft_in1k_384",
30
- checkpoint_path="./timm.ckpt",
31
- model_config=tm_cfg,
32
- num_classes=3
33
- ).eval()
34
- tm_config = timm.data.resolve_model_data_config(nsfw_tm)
35
- tm_trans = timm.data.create_transform((256, 256), **tm_config, is_training=False)
36
-
37
 
38
  def launch(img):
39
  weight = 0
40
- img = Image.open(img).convert('RGB')
41
- tm_output = model.pretrained_cfg['labels'][
42
- torch.argmax(
43
- torch.nn.functional.softmax(
44
- nsfw_tm(transforms(img).unsqueeze(0))[0], dim=0
45
- )
46
- )
47
- ]
48
 
49
  match tm_output:
50
  case "safe":
@@ -54,8 +58,8 @@ def launch(img):
54
  case "r18":
55
  weight += 2
56
 
57
-
58
- tf_output = nsfw_tf(img)[0]["label"]
59
 
60
  match tf_output:
61
  case "safe":
 
1
  from transformers import pipeline
2
+ from imgutils.data import rgb_encode, load_image
3
+ from onnx_ import _open_onnx_model
4
+ from PIL import Image
5
  import gradio as gr
6
+ import numpy as np
7
  import os
8
  import requests
9
  import timm
10
  import torch
11
  import json
12
 
13
+ def _img_encode(image, size=(384,384), normalize=(0.5,0.5)):
14
+ image = image.resize(size, Image.BILINEAR)
15
+ data = rgb_encode(image, order_='CHW')
16
+
17
+ if normalize is not None:
18
+ mean_, std_ = normalize
19
+ mean = np.asarray([mean_]).reshape((-1, 1, 1))
20
+ std = np.asarray([std_]).reshape((-1, 1, 1))
21
+ data = (data - mean) / std
22
+
23
+ return data.astype(np.float32)
24
+
25
  nsfw_tf = pipeline(model="carbon225/vit-base-patch16-224-hentai")
26
 
27
+ if not os.path.exists("timm.onnx"):
28
+ open("timm.onnx", "wb").write(
29
  requests.get(
30
+ "https://huggingface.co/deepghs/anime_rating/resolve/main/caformer_s36_plus/model.onnx"
31
  ).content
32
  )
33
  open("timmcfg.json", "wb").write(
 
41
  with open("timmcfg.json") as file:
42
  tm_cfg = json.load(file)
43
 
44
+ nsfw_tm = _open_onnx_model("timm.onnx")
 
 
 
 
 
 
 
 
45
 
46
  def launch(img):
47
  weight = 0
48
+ tm_image = load_image(img, mode='RGB')
49
+ tm_input_ = _img_encode(tm_image, size=(256, 256))[None, ...]
50
+ tm_output, = nsfw_tm.run(['output'], {'input': tm_input_})
51
+ tm_output = zip(tm_cfg["labels"], map(lambda x: x.item(), output[0]))[0][0]
 
 
 
 
52
 
53
  match tm_output:
54
  case "safe":
 
58
  case "r18":
59
  weight += 2
60
 
61
+ tf_img = Image.open(img).convert('RGB')
62
+ tf_output = nsfw_tf(tf_img)[0]["label"]
63
 
64
  match tf_output:
65
  case "safe":