Convert the model to torchscript or ONNX
Hi,
I would like to run the model, featureExtractor and tokenizer in c++.
So i am looking to convert it to torchscript , i load them with the parameter torchscript=true as below.
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning", torchscript=True)
feature_extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning", torchscript=True)
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning", torchscript=True)
But i cant seem to figure out what parameters to pass to the trace function.
traced_model = torch.jit.trace(model, WHAT_INPUT_TO_PASS))
I did try passing pixel_values generated by feature_extractor and a random tensor or shape (1, 16). But the model that is traced seems to be incorrect.
I tried converting it to ONNX but "Export a custom model for an unsupported architecture." seemed very confusing.
Any guidance will be deeply appreciated.
Regards,
Prabesh Khadka
First of all my suggestion would be go to this blog: https://ankur3107.github.io/blogs/the-illustrated-image-captioning-using-transformers/
It will make you understand more about training and inference of vision encoder decoder based models.
This is how you can do it.
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
pixel_values = feature_extractor(image, return_tensors="pt").pixel_values
labels = tokenizer(
"an image of two cats chilling on a couch",
return_tensors="pt",
).input_ids
traced_model = torch.jit.trace(model, [pixel_values, labels])
torch.jit.save(traced_model, "traced_vit-gpt2-image-captioning.pt")
# load model
loaded_model = torch.jit.load("traced_vit-gpt2-image-captioning.pt")
you may see running colab notebook: https://colab.research.google.com/drive/1a96pgxfpqfsJ6OCOwE6i0sGWmkhnjwv8?usp=sharing
Thank you so much. This works like a charm.
Hi what if I have a vit image classifier like nateraw/vit-age-classifier how would I convert it to torch script?