Spaces:
Running
Running
spuuntries
commited on
Commit
•
68e9e5e
1
Parent(s):
52ffbe2
download model
Browse files
app.py
CHANGED
@@ -3,15 +3,34 @@ import gradio as gr
|
|
3 |
import timm
|
4 |
import torch
|
5 |
|
6 |
-
nsfw_tf = pipeline(
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
tm_config = timm.data.resolve_model_data_config(model)
|
14 |
tm_trans = timm.data.create_transform(**tm_config, is_training=False)
|
15 |
|
|
|
16 |
def launch(img):
|
17 |
-
tm_output = nsfw_tm(transforms(img).unsqueeze(0))
|
|
|
3 |
import timm
|
4 |
import torch
|
5 |
|
6 |
+
nsfw_tf = pipeline(
|
7 |
+
"image-classification",
|
8 |
+
model=AutoModelForImageClassification.from_pretrained(
|
9 |
+
"carbon225/vit-base-patch16-224-hentai"
|
10 |
+
),
|
11 |
+
feature_extractor=AutoFeatureExtractor.from_pretrained(
|
12 |
+
"carbon225/vit-base-patch16-224-hentai"
|
13 |
+
),
|
14 |
+
)
|
15 |
|
16 |
+
if not os.path.exists("timm.ckpt"):
|
17 |
+
open("timm.ckpt", "wb").write(
|
18 |
+
requests.get(
|
19 |
+
"https://huggingface.co/deepghs/anime_rating/resolve/main/caformer_s36_plus/model.ckpt"
|
20 |
+
).content
|
21 |
+
)
|
22 |
+
else:
|
23 |
+
print("Model already exists, skipping redownload")
|
24 |
+
|
25 |
+
|
26 |
+
nsfw_tm = timm.create_model(
|
27 |
+
"caformer_s36.sail_in22k_ft_in1k_384",
|
28 |
+
checkpoint_path="./timm.ckpt",
|
29 |
+
pretrained=True,
|
30 |
+
).eval()
|
31 |
tm_config = timm.data.resolve_model_data_config(model)
|
32 |
tm_trans = timm.data.create_transform(**tm_config, is_training=False)
|
33 |
|
34 |
+
|
35 |
def launch(img):
|
36 |
+
tm_output = nsfw_tm(transforms(img).unsqueeze(0))
|