How to use ONNX model in Triton efficiently?
I'm using the phi3 ONNX model in a triton model server.
It's running significantly slower than the pytorch model, probably because I'm making some obvious mistakes in setting this up. Any help would be welcome.
I want to set up the triton server to take in a user prompt as text and return the generated text, to this end I'm using a python backend that does the following:
- tokenize the input text, create
input ids
tensor using hf transformers - create the
attention mask
tensor - create zero-initialized
past_key_values
tensors
then, repeat until EOS token is detected:
- send over
input_ids
,attention_mask
andpast_key_values
to the Phi3 onnx model - take the argmax of the returned
logits
and add toinput ids
- copy the returned
present_key_values
into newpast_key_values
tensors
In a diagram:
Some questions I have:
I'm currently using the Phi3 model in Triton by simply putting the
model.onnx
and respectivemodel.onnx.data
in the Triton model repository. I can't seem to figure out how to use all the other configs and python files included in the ONNX repository in Triton; am I missing something here?I'm "manually" copying the
present_key_values
into newpast_key_values
tensors. This seems very inefficient. I'm not even sure the Phi3 model is using thepast_key_values
for anything, since on the first call I'm sending over zeros, and that works fine, which seems strange to me.
Any pointers would be helpful, this is my first time setting this up, so I'm probably doing things wrong.
Closing this since it has already been posted as an issue on the ORT GenAI repo and is being tracked there