NLP / sentiment_onnx_classify.py
ashish rai
added script for onnx sent clf
0d4914a
raw
history blame
1.69 kB
import onnxruntime as ort
import torch
from transformers import AutoTokenizer
import numpy as np
tokenizer=AutoTokenizer.from_pretrained("sentiment_classifier/")
#create onnx & onnx_int_8 sessions
session=ort.InferenceSession("sent_clf_onnx/sentiment_classifier_onnx.onnx")
session_int8=ort.InferenceSession("sent_clf_onnx/sentiment_classifier_onnx_int8.onnx")
def classify_sentiment_onnx(texts,_model=session,_tokenizer=tokenizer):
"""
user will pass texts separated by comma
"""
try:
texts=texts.split(',')
except:
pass
_inputs = _tokenizer(texts, padding=True, truncation=True,
return_tensors="np")
input_feed={
"input_ids":np.array(_inputs['input_ids']),
"attention_mask":np.array((_inputs['attention_mask']))
}
output = _model.run(input_feed=input_feed, output_names=['output_0'])[0]
output=np.argmax(output,axis=1)
output = ['Positive' if i == 1 else 'Negative' for i in output]
return output
def classify_sentiment_onnx_quant(texts, _model=session_int8, _tokenizer=tokenizer):
"""
user will pass texts separated by comma
"""
try:
texts=texts.split(',')
except:
pass
_inputs = _tokenizer(texts, padding=True, truncation=True,
return_tensors="np")
input_feed={
"input_ids":np.array(_inputs['input_ids']),
"attention_mask":np.array((_inputs['attention_mask']))
}
output = _model.run(input_feed=input_feed, output_names=['output_0'])[0]
output=np.argmax(output,axis=1)
output = ['Positive' if i == 1 else 'Negative' for i in output]
return output