Doron Adler
* Updated model card
6e1c9c6
raw
history blame
1.05 kB
#!/usr/bin/python
# -*- coding: utf-8 -*-
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel, AutoConfig
from transformers.onnx import FeaturesManager, convert, export
from pathlib import Path
import os
model_id = "./distilgpt2-base-pretrained-he"
export_folder = "tmp/onnx/"
file_name = "model.onnx"
print('Loading tokenizer...')
tokenizer = AutoTokenizer.from_pretrained(model_id)
print('Saving tokenizer to ', export_folder)
tokenizer.save_pretrained(export_folder)
print('Loading model...')
model = AutoModelForCausalLM.from_pretrained(model_id)
feature= "causal-lm"
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=feature)
onnx_config = model_onnx_config(model.config)
print("model_kind = {0}\nonx_config = {1}\n".format(model_kind, onnx_config))
onnx_path = Path(export_folder+file_name)
print('Exporting model to ', onnx_path)
onnx_inputs, onnx_outputs = export(tokenizer, model, onnx_config, onnx_config.default_onnx_opset, onnx_path)
print('Done')