nsfw_image_detection / convert.py
DenisNovac's picture
TorchScript converted model with synset
d19449b verified
from transformers import AutoImageProcessor, AutoModelForImageClassification
import torch
from PIL import Image
from transformers import AutoTokenizer
model_name = "DenisNovac/nsfw_image_detection"
model = AutoModelForImageClassification.from_pretrained(model_name, torchscript=True, return_dict=False)
processor = AutoImageProcessor.from_pretrained(model_name)
image = Image.open("images/hentai.jpg")
image_inputs = processor(images=image, return_tensors="pt")
config = {'forward': [image_inputs['pixel_values']]}
converted = torch.jit.trace_module(model, config)
torch.jit.save(converted, "converted-to-torchscript.pt")