from transformers import pipeline import gradio as gr import os import requests import timm import torch nsfw_tf = pipeline(model="carbon225/vit-base-patch16-224-hentai") if not os.path.exists("timm.ckpt"): open("timm.ckpt", "wb").write( requests.get( "https://huggingface.co/deepghs/anime_rating/resolve/main/caformer_s36_plus/model.ckpt" ).content ) open("timmcfg.json", "wb").write( requests.get( "https://huggingface.co/deepghs/anime_rating/resolve/main/caformer_s36_plus/meta.json" ).content ) else: print("Model already exists, skipping redownload") nsfw_tm = timm.create_model( "caformer_s36.sail_in22k_ft_in1k_384", checkpoint_path="./timm.ckpt", pretrained_cfg="./timmcfg.json", pretrained=True ).eval() tm_config = timm.data.resolve_model_data_config(nsfw_tm.pretrained_cfg, model=nsfw_tm) tm_trans = timm.data.create_transform(**tm_config) def launch(img): weight = 0 img = Image.open(img).convert('RGB') tm_output = model.pretrained_cfg['labels'][ torch.argmax( torch.nn.functional.softmax( nsfw_tm(transforms(img).unsqueeze(0))[0], dim=0 ) ) ] match tm_output: case "safe": weight -= 2 case "r15": weight += 1 case "r18": weight += 2 tf_output = nsfw_tf(img)[0]["label"] match tf_output: case "safe": weight -= 2 case "suggestive": weight += 1 case "r18": weight += 2 return weight > 0 app = gr.Interface(fn=generate, inputs="image", outputs="text") app.launch()