nsfw_image_detection / convert.py
DenisNovac's picture
TorchScript converted model with synset
d19449b verified
raw history blame
No virus
629 Bytes
from transformers import AutoImageProcessor, AutoModelForImageClassification
import torch
from PIL import Image
from transformers import AutoTokenizer
model_name = "DenisNovac/nsfw_image_detection"
model = AutoModelForImageClassification.from_pretrained(model_name, torchscript=True, return_dict=False)
processor = AutoImageProcessor.from_pretrained(model_name)
image = Image.open("images/hentai.jpg")
image_inputs = processor(images=image, return_tensors="pt")
config = {'forward': [image_inputs['pixel_values']]}
converted = torch.jit.trace_module(model, config)
torch.jit.save(converted, "converted-to-torchscript.pt")