Spaces:
Runtime error
Runtime error
ashish rai
commited on
Commit
•
19da0ee
1
Parent(s):
13315ca
onnx object for sentiment
Browse files- sentiment_onnx.py +44 -0
sentiment_onnx.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from transformers import AutoTokenizer,AutoModelForSequenceClassification
|
3 |
+
import transformers.convert_graph_to_onnx as onnx_convert
|
4 |
+
from pathlib import Path
|
5 |
+
import transformers
|
6 |
+
from onnxruntime.quantization import quantize_dynamic,QuantType
|
7 |
+
import onnx
|
8 |
+
import torch
|
9 |
+
import onnxruntime as ort
|
10 |
+
import streamlit as st
|
11 |
+
|
12 |
+
"""
|
13 |
+
type in cmd to create onnx model of hugging face chkpt
|
14 |
+
python3 -m transformers.onnx --model= distilbert-base-uncased-finetuned-sst-2-english sentiment_onnx/
|
15 |
+
"""
|
16 |
+
|
17 |
+
model= AutoModelForSequenceClassification.from_pretrained('sentiment_classifier/')
|
18 |
+
tokenizer= AutoTokenizer.from_pretrained('sentiment_classifier/')
|
19 |
+
|
20 |
+
"""
|
21 |
+
or download the model directly from hub --
|
22 |
+
chkpt='distilbert-base-uncased-finetuned-sst-2-english'
|
23 |
+
model= AutoModelForSequenceClassification.from_pretrained(chkpt)
|
24 |
+
tokenizer= AutoTokenizer.from_pretrained(chkpt)
|
25 |
+
"""
|
26 |
+
|
27 |
+
|
28 |
+
pipeline=transformers.pipeline("text-classification",model=model,tokenizer=tokenizer)
|
29 |
+
|
30 |
+
""" convert pipeline to onnx object"""
|
31 |
+
onnx_convert.convert_pytorch(pipeline,
|
32 |
+
opset=11,
|
33 |
+
output=Path("sent_clf_onnx/sentiment_classifier_onnx.onnx"),
|
34 |
+
use_external_format=False
|
35 |
+
)
|
36 |
+
|
37 |
+
""" convert onnx object to another onnx object with int8 quantization """
|
38 |
+
quantize_dynamic("sent_clf_onnx/sentiment_classifier_onnx.onnx","sent_clf_onnx/sentiment_classifier_onnx_int8.onnx",
|
39 |
+
weight_type=QuantType.QUInt8)
|
40 |
+
|
41 |
+
print(ort.__version__)
|
42 |
+
print(onnx.__version__)
|
43 |
+
|
44 |
+
|