spuuntries commited on
Commit
68e9e5e
1 Parent(s): 52ffbe2

download model

Browse files
Files changed (1) hide show
  1. app.py +26 -7
app.py CHANGED
@@ -3,15 +3,34 @@ import gradio as gr
3
  import timm
4
  import torch
5
 
6
- nsfw_tf = pipeline("image-classification",
7
- model=AutoModelForImageClassification.from_pretrained(
8
- "carbon225/vit-base-patch16-224-hentai"),
9
- feature_extractor=AutoFeatureExtractor.from_pretrained(
10
- "carbon225/vit-base-patch16-224-hentai"))
 
 
 
 
11
 
12
- nsfw_tm = timm.create_model('deepghs/anime_rating', pretrained=True).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  tm_config = timm.data.resolve_model_data_config(model)
14
  tm_trans = timm.data.create_transform(**tm_config, is_training=False)
15
 
 
16
  def launch(img):
17
- tm_output = nsfw_tm(transforms(img).unsqueeze(0))
 
3
  import timm
4
  import torch
5
 
6
+ nsfw_tf = pipeline(
7
+ "image-classification",
8
+ model=AutoModelForImageClassification.from_pretrained(
9
+ "carbon225/vit-base-patch16-224-hentai"
10
+ ),
11
+ feature_extractor=AutoFeatureExtractor.from_pretrained(
12
+ "carbon225/vit-base-patch16-224-hentai"
13
+ ),
14
+ )
15
 
16
+ if not os.path.exists("timm.ckpt"):
17
+ open("timm.ckpt", "wb").write(
18
+ requests.get(
19
+ "https://huggingface.co/deepghs/anime_rating/resolve/main/caformer_s36_plus/model.ckpt"
20
+ ).content
21
+ )
22
+ else:
23
+ print("Model already exists, skipping redownload")
24
+
25
+
26
+ nsfw_tm = timm.create_model(
27
+ "caformer_s36.sail_in22k_ft_in1k_384",
28
+ checkpoint_path="./timm.ckpt",
29
+ pretrained=True,
30
+ ).eval()
31
  tm_config = timm.data.resolve_model_data_config(model)
32
  tm_trans = timm.data.create_transform(**tm_config, is_training=False)
33
 
34
+
35
  def launch(img):
36
+ tm_output = nsfw_tm(transforms(img).unsqueeze(0))