NLP / sentiment_onnx_classify.py
ashishraics's picture
bug fix
6886461
raw
history blame
3.79 kB
import onnxruntime as ort
import torch
from transformers import AutoTokenizer,AutoModelForSequenceClassification
import numpy as np
import transformers
from onnxruntime.quantization import quantize_dynamic,QuantType
import transformers.convert_graph_to_onnx as onnx_convert
from pathlib import Path
import os
# chkpt='distilbert-base-uncased-finetuned-sst-2-english'
# model= AutoModelForSequenceClassification.from_pretrained(chkpt)
# tokenizer= AutoTokenizer.from_pretrained(chkpt)
def create_onnx_model(_model, _tokenizer):
"""
Args:
_model: model checkpoint with AutoModelForSequenceClassification
_tokenizer: model checkpoint with AutoTokenizer
Returns:
Creates a simple ONNX model & int8 Quantized Model in the directory "sent_clf_onnx/" if directory not present
"""
if not os.path.exists('sent_clf_onnx'):
try:
os.mkdir('sent_clf_onnx')
except:
pass
"""
Making ONNX model object
"""
pipeline=transformers.pipeline("text-classification", model=_model, tokenizer=_tokenizer)
"""
convert pipeline to onnx object
"""
onnx_convert.convert_pytorch(pipeline,
opset=11,
output=Path("sent_clf_onnx/sentiment_classifier_onnx.onnx"),
use_external_format=False
)
"""
convert onnx object to another onnx object with int8 quantization
"""
quantize_dynamic("sent_clf_onnx/sentiment_classifier_onnx.onnx","sent_clf_onnx/sentiment_classifier_onnx_int8.onnx",
weight_type=QuantType.QUInt8)
else:
pass
# #create onnx & onnx_int_8 sessions
# session = ort.InferenceSession("sent_clf_onnx/sentiment_classifier_onnx.onnx")
# session_int8 = ort.InferenceSession("sent_clf_onnx/sentiment_classifier_onnx_int8.onnx")
# options=ort.SessionOptions()
# options.inter_op_num_threads=1
# options.intra_op_num_threads=1
def classify_sentiment_onnx(texts, _session, _tokenizer):
"""
Args:
texts: input texts from user
_session: pass ONNX runtime session
_tokenizer: Relevant Tokenizer e.g. AutoTokenizer.from_pretrained("same checkpoint as the model")
Returns:
list of Positve and Negative texts
"""
try:
texts=texts.split(',')
except:
pass
_inputs = _tokenizer(texts, padding=True, truncation=True,
return_tensors="np")
input_feed={
"input_ids":np.array(_inputs['input_ids']),
"attention_mask":np.array((_inputs['attention_mask']))
}
output = _session.run(input_feed=input_feed, output_names=['output_0'])[0]
output=np.argmax(output,axis=1)
output = ['Positive' if i == 1 else 'Negative' for i in output]
return output
def classify_sentiment_onnx_quant(texts, _session, _tokenizer):
"""
Args:
texts: input texts from user
_session: pass ONNX runtime session
_tokenizer: Relevant Tokenizer e.g. AutoTokenizer.from_pretrained("same checkpoint as the model")
Returns:
list of Positve and Negative texts
"""
try:
texts=texts.split(',')
except:
pass
_inputs = _tokenizer(texts, padding=True, truncation=True,
return_tensors="np")
input_feed={
"input_ids":np.array(_inputs['input_ids']),
"attention_mask":np.array((_inputs['attention_mask']))
}
output = _session.run(input_feed=input_feed, output_names=['output_0'])[0]
output=np.argmax(output,axis=1)
output = ['Positive' if i == 1 else 'Negative' for i in output]
return output