The textual ONNX models seem to have issues with more than 16 tokens on input

#22
by ondrejnespor - opened

It is my understanding that the onnx texual model expects token IDs on input:

import onnxruntime as ort

ort_sess = ort.InferenceSession('text_model.onnx')
for i in ort_sess.get_inputs():
    print(i)

says

NodeArg(name='input_ids', type='tensor(int64)', shape=['batch_size', 'sequence_length'])

so the expected use would be something like:

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-clip-v1')

input = tokenizer(['hello world'])
output = ort_sess.run(None, {'input_ids': input['input_ids']})

The issue is that the model stops working with inputs of 17 tokens or longer.

For 16 tokens, it returns an embedding:

input = tokenizer(['hello world hello world hello world hello world hello world hello world hello world'])
print(len(input['input_ids'][0])) # 16
output = ort_sess.run(None, {'input_ids': input['input_ids']})

with 17 or more, it throws an error:

input = tokenizer(['hello world hello world hello world hello world hello world hello world hello world hello'])
print(len(input['input_ids'][0])) # 17
output = ort_sess.run(None, {'input_ids': input['input_ids']})

outputs

RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Sub node. Name:'/transformer/encoder/layers.0/mixer/inner_attn/Sub' Status Message: D:\a\_work\1\s\onnxruntime\core/providers/cpu/math/element_wise_ops.h:540 onnxruntime::BroadcastIterator::Init axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 16 by 17

It does it for all text_model, text_model_quantized and text_model_int8.

When I try to export an onnx model myself:

import torch
from transformers import AutoModel

model = AutoModel.from_pretrained('jinaai/jina-clip-v1', trust_remote_code=True)
tm = model.text_model
tm.eval()

torch.onnx.export(
    tm,
    torch.randint(1, 1024, (1, 1024)),
    "./manual.onnx",
    export_params=True,
    do_constant_folding=True,
    input_names = ['input_ids'],
    output_names = ['text_embeds'],
    dynamic_axes={
        'input_ids' : {0 : 'batch_size', 1: 'sequence_length'},
        'text_embeds' : {0 : 'batch_size'}
    }
)

I should end up with a model very similar to text_model.onnx from this repo. And indeed it seems to return the same embeddings but does support 17 tokens and more on input.

Am I missing something or do the textual ONNX exports have an issue?

Yes, I’ve encountered the same issue. I found that the cause is related to the need to update variables at runtime in the SelfAttention code, specifically in the mha file of the jina-bert-flash-implementation

if self.alibi_slopes is not None:
            if seqlen > self.linear_biases.shape[-1]:
                self.linear_biases = self._build_linear_biases(seqlen)
            cropped_biases = self.linear_biases[..., :seqlen, :seqlen]
            # print(self.linear_biases, self.linear_biases.shape)
            scores = scores - cropped_biases
Jina AI org

thanks, i'll take a look today!

Jina AI org

@Riddler2024 @ondrejnespor the issue with the text_model.onnx should be fixed now!

bwang0911 changed discussion status to closed

Any chance that the other onnx model weights can be updated with the fix for this as well? It only looks like the full and fp16 models have it.

Sign up or log in to comment