spuun commited on
Commit
941b996
1 Parent(s): 68e9e5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -14
app.py CHANGED
@@ -1,17 +1,9 @@
1
- from transformers import AutoFeatureExtractor, AutoModelForImageClassification, pipeline
2
  import gradio as gr
3
  import timm
4
  import torch
5
 
6
- nsfw_tf = pipeline(
7
- "image-classification",
8
- model=AutoModelForImageClassification.from_pretrained(
9
- "carbon225/vit-base-patch16-224-hentai"
10
- ),
11
- feature_extractor=AutoFeatureExtractor.from_pretrained(
12
- "carbon225/vit-base-patch16-224-hentai"
13
- ),
14
- )
15
 
16
  if not os.path.exists("timm.ckpt"):
17
  open("timm.ckpt", "wb").write(
@@ -19,6 +11,11 @@ if not os.path.exists("timm.ckpt"):
19
  "https://huggingface.co/deepghs/anime_rating/resolve/main/caformer_s36_plus/model.ckpt"
20
  ).content
21
  )
 
 
 
 
 
22
  else:
23
  print("Model already exists, skipping redownload")
24
 
@@ -26,11 +23,44 @@ else:
26
  nsfw_tm = timm.create_model(
27
  "caformer_s36.sail_in22k_ft_in1k_384",
28
  checkpoint_path="./timm.ckpt",
29
- pretrained=True,
 
30
  ).eval()
31
- tm_config = timm.data.resolve_model_data_config(model)
32
- tm_trans = timm.data.create_transform(**tm_config, is_training=False)
33
 
34
 
35
  def launch(img):
36
- tm_output = nsfw_tm(transforms(img).unsqueeze(0))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
  import gradio as gr
3
  import timm
4
  import torch
5
 
6
+ nsfw_tf = pipeline(model="carbon225/vit-base-patch16-224-hentai")
 
 
 
 
 
 
 
 
7
 
8
  if not os.path.exists("timm.ckpt"):
9
  open("timm.ckpt", "wb").write(
 
11
  "https://huggingface.co/deepghs/anime_rating/resolve/main/caformer_s36_plus/model.ckpt"
12
  ).content
13
  )
14
+ open("timmcfg.json", "wb").write(
15
+ requests.get(
16
+ "https://huggingface.co/deepghs/anime_rating/resolve/main/caformer_s36_plus/meta.json"
17
+ ).content
18
+ )
19
  else:
20
  print("Model already exists, skipping redownload")
21
 
 
23
  nsfw_tm = timm.create_model(
24
  "caformer_s36.sail_in22k_ft_in1k_384",
25
  checkpoint_path="./timm.ckpt",
26
+ pretrained_cfg="./timmcfg.json",
27
+ pretrained=True
28
  ).eval()
29
+ tm_config = timm.data.resolve_model_data_config(nsfw_tm.pretrained_cfg, model=nsfw_tm)
30
+ tm_trans = timm.data.create_transform(**tm_config)
31
 
32
 
33
  def launch(img):
34
+ weight = 0
35
+ img = Image.open(img).convert('RGB')
36
+ tm_output = model.pretrained_cfg['labels'][
37
+ torch.argmax(
38
+ torch.nn.functional.softmax(
39
+ nsfw_tm(transforms(img).unsqueeze(0))[0], dim=0
40
+ )
41
+ )
42
+ ]
43
+
44
+ match tm_output:
45
+ case "safe":
46
+ weight -= 2
47
+ case "r15":
48
+ weight += 1
49
+ case "r18":
50
+ weight += 2
51
+
52
+
53
+ tf_output = nsfw_tf(img)[0]["label"]
54
+
55
+ match tf_output:
56
+ case "safe":
57
+ weight -= 2
58
+ case "suggestive":
59
+ weight += 1
60
+ case "r18":
61
+ weight += 2
62
+
63
+ return weight > 0
64
+
65
+ app = gr.Interface(fn=generate, inputs="image", outputs="text")
66
+ app.launch()