yilunzhang's picture
Initial Commit
9ae1ebe
raw
history blame
639 Bytes
from transformers import AutoModel
import torch
max_seq_length = 384
model = AutoModel.from_pretrained("sentence-transformers/all-mpnet-base-v2")
model.eval()
inputs = {
"input_ids": torch.ones(1, max_seq_length, dtype=torch.int64),
"attention_mask": torch.ones(1, max_seq_length, dtype=torch.int64),
}
symbolic_names = {0: 'batch_size', 1: 'max_seq_len'}
torch.onnx.export(
model,args=tuple(inputs.values()),
f="model.onnx",
export_params=True,
input_names=["input_ids", "attention_mask"], output_names=["last_hidden_state"],
dynamic_axes={"input_ids": symbolic_names, "attention_mask": symbolic_names}
)