ce-esci-MiniLM-L12-v2 / onnx_convert.py
shuttie's picture
use proper model output dim
c4a5571
raw
history blame
797 Bytes
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
import torch
max_seq_length=128
model = AutoModelForSequenceClassification.from_pretrained(".")
model.eval()
inputs = {"input_ids": torch.ones(1, max_seq_length, dtype=torch.int64),
"attention_mask": torch.ones(1, max_seq_length, dtype=torch.int64),
"token_type_ids": torch.ones(1, max_seq_length, dtype=torch.int64)}
symbolic_names = {0: 'batch_size', 1: 'max_seq_len'}
torch.onnx.export(model, args=tuple(inputs.values()), f='pytorch_model.onnx', export_params=True,
input_names=['input_ids', 'attention_mask', 'token_type_ids'], output_names=['last_hidden_state'],
dynamic_axes={'input_ids': symbolic_names, 'attention_mask': symbolic_names, 'token_type_ids': symbolic_names})