File size: 821 Bytes
3962c71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

max_seq_length=128

model = AutoModelForSequenceClassification.from_pretrained('cross-encoder/ms-marco-MiniLM-L-6-v2')
model.eval()

inputs = {"input_ids": torch.ones(1, max_seq_length, dtype=torch.int64),
        "attention_mask": torch.ones(1, max_seq_length, dtype=torch.int64),
        "token_type_ids": torch.ones(1, max_seq_length, dtype=torch.int64)}

symbolic_names = {0: 'batch_size', 1: 'max_seq_len'}

torch.onnx.export(model, args=tuple(inputs.values()), f='pytorch_model.onnx', export_params=True, 
    input_names=['input_ids', 'attention_mask', 'token_type_ids'], output_names=['last_hidden_state'],
    dynamic_axes={'input_ids': symbolic_names, 'attention_mask': symbolic_names, 'token_type_ids': symbolic_names})