Spaces:
Running
Running
from transformers import pipeline | |
import gradio as gr | |
import os | |
import requests | |
import timm | |
import torch | |
import json | |
nsfw_tf = pipeline(model="carbon225/vit-base-patch16-224-hentai") | |
if not os.path.exists("timm.ckpt"): | |
open("timm.ckpt", "wb").write( | |
requests.get( | |
"https://huggingface.co/deepghs/anime_rating/resolve/main/caformer_s36_plus/model.ckpt" | |
).content | |
) | |
open("timmcfg.json", "wb").write( | |
requests.get( | |
"https://huggingface.co/deepghs/anime_rating/resolve/main/caformer_s36_plus/meta.json" | |
).content | |
) | |
else: | |
print("Model already exists, skipping redownload") | |
with open("timmcfg.json") as file: | |
tm_cfg = json.load(file) | |
nsfw_tm = timm.create_model( | |
"caformer_s36.sail_in22k_ft_in1k_384", | |
checkpoint_path="./timm.ckpt", | |
model_config=tm_cfg, | |
pretrained=True | |
).eval() | |
tm_config = timm.data.resolve_model_data_config(nsfw_tm) | |
tm_trans = timm.data.create_transform(**tm_config) | |
def launch(img): | |
weight = 0 | |
img = Image.open(img).convert('RGB') | |
tm_output = model.pretrained_cfg['labels'][ | |
torch.argmax( | |
torch.nn.functional.softmax( | |
nsfw_tm(transforms(img).unsqueeze(0))[0], dim=0 | |
) | |
) | |
] | |
match tm_output: | |
case "safe": | |
weight -= 2 | |
case "r15": | |
weight += 1 | |
case "r18": | |
weight += 2 | |
tf_output = nsfw_tf(img)[0]["label"] | |
match tf_output: | |
case "safe": | |
weight -= 2 | |
case "suggestive": | |
weight += 1 | |
case "r18": | |
weight += 2 | |
return weight > 0 | |
app = gr.Interface(fn=generate, inputs="image", outputs="text") | |
app.launch() |