spuuntries commited on
Commit
b75d01d
2 Parent(s): 65fb5ea 941b996

Merge pull branch

Browse files
Files changed (1) hide show
  1. app.py +66 -0
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+ import gradio as gr
3
+ import timm
4
+ import torch
5
+
6
+ nsfw_tf = pipeline(model="carbon225/vit-base-patch16-224-hentai")
7
+
8
+ if not os.path.exists("timm.ckpt"):
9
+ open("timm.ckpt", "wb").write(
10
+ requests.get(
11
+ "https://huggingface.co/deepghs/anime_rating/resolve/main/caformer_s36_plus/model.ckpt"
12
+ ).content
13
+ )
14
+ open("timmcfg.json", "wb").write(
15
+ requests.get(
16
+ "https://huggingface.co/deepghs/anime_rating/resolve/main/caformer_s36_plus/meta.json"
17
+ ).content
18
+ )
19
+ else:
20
+ print("Model already exists, skipping redownload")
21
+
22
+
23
+ nsfw_tm = timm.create_model(
24
+ "caformer_s36.sail_in22k_ft_in1k_384",
25
+ checkpoint_path="./timm.ckpt",
26
+ pretrained_cfg="./timmcfg.json",
27
+ pretrained=True
28
+ ).eval()
29
+ tm_config = timm.data.resolve_model_data_config(nsfw_tm.pretrained_cfg, model=nsfw_tm)
30
+ tm_trans = timm.data.create_transform(**tm_config)
31
+
32
+
33
+ def launch(img):
34
+ weight = 0
35
+ img = Image.open(img).convert('RGB')
36
+ tm_output = model.pretrained_cfg['labels'][
37
+ torch.argmax(
38
+ torch.nn.functional.softmax(
39
+ nsfw_tm(transforms(img).unsqueeze(0))[0], dim=0
40
+ )
41
+ )
42
+ ]
43
+
44
+ match tm_output:
45
+ case "safe":
46
+ weight -= 2
47
+ case "r15":
48
+ weight += 1
49
+ case "r18":
50
+ weight += 2
51
+
52
+
53
+ tf_output = nsfw_tf(img)[0]["label"]
54
+
55
+ match tf_output:
56
+ case "safe":
57
+ weight -= 2
58
+ case "suggestive":
59
+ weight += 1
60
+ case "r18":
61
+ weight += 2
62
+
63
+ return weight > 0
64
+
65
+ app = gr.Interface(fn=generate, inputs="image", outputs="text")
66
+ app.launch()