chainyo commited on
Commit
4c5289e
1 Parent(s): 87e8a82
Files changed (1) hide show
  1. main.py +185 -31
main.py CHANGED
@@ -1,21 +1,96 @@
1
  """⭐ Text Classification with Optimum and ONNXRuntime
2
 
 
 
3
  Author:
4
  - @ChainYo - https://github.com/ChainYo
5
  """
6
 
 
 
 
7
  import streamlit as st
8
 
9
- from transformers import AutoTokenizer, AutoModel, pipeline
10
- from optimum.onnxruntime import ORTModelForSequenceClassification
11
- from optimum.pipelines import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
- MODEL_PATH = "cardiffnlp/twitter-roberta-base-sentiment-latest"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  st.set_page_config(page_title="Optimum Text Classification", page_icon="⭐")
17
- st.title("🤗 Optimum Text Classification")
18
- st.subheader("Sentiment analysis with 🤗 Optimum and ONNXRuntime")
19
  st.markdown("""
20
  [![GitHub](https://img.shields.io/badge/-%23121011.svg?style=for-the-badge&logo=github&logoColor=white)](https://github.com/ChainYo)
21
  [![HuggingFace](https://img.shields.io/badge/-yellow.svg?style=for-the-badge&logo=data:image/svg+xml;base64,PD94bWwgdmVyc2lvbj0iMS4wIiBzdGFuZGFsb25lPSJubyI/Pgo8IURPQ1RZUEUgc3ZnIFBVQkxJQyAiLS8vVzNDLy9EVEQgU1ZHIDIwMDEwOTA0Ly9FTiIKICJodHRwOi8vd3d3LnczLm9yZy9UUi8yMDAxL1JFQy1TVkctMjAwMTA5MDQvRFREL3N2ZzEwLmR0ZCI+CjxzdmcgdmVyc2lvbj0iMS4wIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciCiB3aWR0aD0iMTc1LjAwMDAwMHB0IiBoZWlnaHQ9IjE3NS4wMDAwMDBwdCIgdmlld0JveD0iMCAwIDE3NS4wMDAwMDAgMTc1LjAwMDAwMCIKIHByZXNlcnZlQXNwZWN0UmF0aW89InhNaWRZTWlkIG1lZXQiPgoKPGcgdHJhbnNmb3JtPSJ0cmFuc2xhdGUoMC4wMDAwMDAsMTc1LjAwMDAwMCkgc2NhbGUoMC4xMDAwMDAsLTAuMTAwMDAwKSIKZmlsbD0iIzAwMDAwMCIgc3Ryb2tlPSJub25lIj4KPHBhdGggZD0iTTU2MyAxMjM2IGMtMjkgLTEzIC02MyAtNTkgLTYzIC04NiAwIC0yNiAzMyAtODAgNTIgLTg2IDE1IC00IDI2IDEKNDMgMjEgMjAgMjYgMjQgMjcgNTMgMTcgMjggLTkgMzMgLTggNDIgOCAxNyAzMiAxMSA2OSAtMTcgOTkgLTM0IDM3IC02OCA0NQotMTEwIDI3eiIvPgo8cGF0aCBkPSJNMTA2NCAxMjQwIGMtNTAgLTIwIC03NyAtODYgLTU0IC0xMzAgOSAtMTYgMTQgLTE3IDQyIC04IDI5IDEwIDMzIDkKNTUgLTE3IDIxIC0yNCAyNyAtMjYgNDggLTE3IDMxIDE0IDUxIDc2IDM2IDExNCAtMTcgNDYgLTg0IDc2IC0xMjcgNTh6Ii8+CjxwYXRoIGQ9Ik02MDAgODg4IGMwIC00OSAxNiAtOTggNTAgLTE1MSA4NSAtMTM0IDMyNSAtMTM0IDQxMCAwIDUxIDgwIDY5IDE4MwozMSAxODMgLTEwIDAgLTUwIC0xNSAtODcgLTMyIC02MCAtMjkgLTc5IC0zMyAtMTQ5IC0zMyAtNzAgMCAtODkgNCAtMTQ5IDMzCi0zNyAxNyAtNzcgMzIgLTg3IDMyIC0xNSAwIC0xOSAtNyAtMTkgLTMyeiIvPgo8L2c+Cjwvc3ZnPgo=)](https://huggingface.co/ChainYo)
@@ -23,36 +98,115 @@ st.markdown("""
23
  [![Discord](https://img.shields.io/badge/Chainyo%233610-%237289DA.svg?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/)
24
  """)
25
 
26
- if "tokenizer" not in st.session_state:
27
- tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
28
- st.session_state["tokenizer"] = tokenizer
 
 
29
 
30
- if "ort_model" not in st.session_state:
31
- ort_model = ORTModelForSequenceClassification.from_pretrained(MODEL_PATH, from_transformers=True)
32
- st.session_state["ort_model"] = ort_model
 
 
33
 
34
- if "pt_model" not in st.session_state:
35
- pt_model = AutoModel.from_pretrained(MODEL_PATH)
36
- st.session_state["pt_model"] = pt_model
 
 
 
 
 
 
37
 
38
- if "ort_pipeline" not in st.session_state:
39
- ort_pipeline = pipeline(
40
- "text-classification", tokenizer=st.session_state["tokenizer"], model=st.session_state["ort_model"]
41
- )
42
- st.session_state["ort_pipeline"] = ort_pipeline
43
 
44
- if "pt_pipeline" not in st.session_state:
45
- pt_pipeline = pipeline(
46
- "text-classification", tokenizer=st.session_state["tokenizer"], model=st.session_state["pt_model"]
47
- )
48
- st.session_state["pt_pipeline"] = pt_pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
 
 
 
 
50
 
51
- model_format = st.radio("Choose the model format", ("PyTorch", "ONNXRuntime"))
52
- optimized = st.checkbox("Optimize the model for inference", value=False)
53
- quantized = st.checkbox("Quantize the model", value=False)
54
 
55
- if model_format == "PyTorch":
56
- optimized.disabled = True
57
- quantized.disabled = True
 
 
 
 
58
 
 
 
 
 
 
 
1
  """⭐ Text Classification with Optimum and ONNXRuntime
2
 
3
+ Streamlit application to classify text using multiple models.
4
+
5
  Author:
6
  - @ChainYo - https://github.com/ChainYo
7
  """
8
 
9
+ import plotly
10
+ import numpy as np
11
+ import pandas as pd
12
  import streamlit as st
13
 
14
+ from pathlib import Path
15
+ from time import sleep
16
+ from typing import Dict, List, Union
17
+
18
+ from optimum.onnxruntime import ORTModelForSequenceClassification, ORTOptimizer, ORTQuantizer
19
+ from optimum.onnxruntime.configuration import OptimizationConfig, AutoQuantizationConfig
20
+ from optimum.onnxruntime.model import ORTModel
21
+ from optimum.pipelines import pipeline as ort_pipeline
22
+ from transformers import BertTokenizer, BertForSequenceClassification, pipeline
23
+
24
+ from utils import calculate_inference_time
25
+
26
+
27
+ HUB_MODEL_PATH = "yiyanghkust/finbert-tone"
28
+ BASE_PATH = Path("models")
29
+ ONNX_MODEL_PATH = BASE_PATH.joinpath("model.onnx")
30
+ OPTIMIZED_BASE_PATH = BASE_PATH.joinpath("optimized")
31
+ OPTIMIZED_MODEL_PATH = OPTIMIZED_BASE_PATH.joinpath("model-optimized.onnx")
32
+ QUANTIZED_BASE_PATH = BASE_PATH.joinpath("quantized")
33
+ QUANTIZED_MODEL_PATH = QUANTIZED_BASE_PATH.joinpath("model-quantized.onnx")
34
+ VAR2LABEL = {
35
+ "pt_pipeline": "PyTorch",
36
+ "ort_pipeline": "ONNXRuntime",
37
+ "ort_optimized_pipeline": "ONNXRuntime (Optimized)",
38
+ "ort_quantized_pipeline": "ONNXRuntime (Quantized)",
39
+ }
40
 
41
 
42
+ def get_timers(
43
+ samples: Union[List[str], str], exp_number: int, only_mean: bool = False
44
+ ) -> Dict[str, float]:
45
+ """
46
+ Calculate inference time for each model for a given sample or list of samples.
47
+
48
+ Parameters
49
+ ----------
50
+ samples : Union[List[str], str]
51
+ Sample or list of samples to calculate inference time for.
52
+ exp_number : int
53
+ Number of experiments to run.
54
+
55
+ Returns
56
+ -------
57
+ Dict[str, float]
58
+ Dictionary of inference times for each model for the given samples.
59
+ """
60
+ if isinstance(samples, str):
61
+ samples = [samples]
62
+
63
+ timers: Dict[str, float] = {}
64
+ for model in VAR2LABEL.keys():
65
+ time_buffer = []
66
+ for _ in range(exp_number):
67
+ with calculate_inference_time(time_buffer):
68
+ st.session_state[model](samples)
69
+ timers[VAR2LABEL[model]] = np.mean(time_buffer) if only_mean else time_buffer
70
+ return timers
71
+
72
+
73
+ def get_plot(timers: Dict[str, Union[float, List[float]]]) -> plotly.graph_objs._figure.Figure:
74
+ """
75
+ Plot the inference time for each model.
76
+
77
+ Parameters
78
+ ----------
79
+ timers : Dict[str, Union[float, List[float]]]
80
+ Dictionary of inference times for each model.
81
+ """
82
+ data = pd.DataFrame.from_dict(timers, orient="columns")
83
+ colors = ["#140f0d", "#2b2c4f", "#615aa2", "#a991fa"]
84
+ fig = plotly.figure_factory.create_distplot(
85
+ [data[col] for col in data.columns], data.columns, bin_size=0.2, colors=colors
86
+ )
87
+ fig.update_layout(title_text="Inference Time", xaxis_title="Inference Time (s)", yaxis_title="Number of Samples")
88
+ return fig
89
+
90
 
91
  st.set_page_config(page_title="Optimum Text Classification", page_icon="⭐")
92
+ st.title(" Optimum Text Classification")
93
+ st.subheader("Classify financial news tone with 🤗 Optimum and ONNXRuntime")
94
  st.markdown("""
95
  [![GitHub](https://img.shields.io/badge/-%23121011.svg?style=for-the-badge&logo=github&logoColor=white)](https://github.com/ChainYo)
96
  [![HuggingFace](https://img.shields.io/badge/-yellow.svg?style=for-the-badge&logo=data:image/svg+xml;base64,PD94bWwgdmVyc2lvbj0iMS4wIiBzdGFuZGFsb25lPSJubyI/Pgo8IURPQ1RZUEUgc3ZnIFBVQkxJQyAiLS8vVzNDLy9EVEQgU1ZHIDIwMDEwOTA0Ly9FTiIKICJodHRwOi8vd3d3LnczLm9yZy9UUi8yMDAxL1JFQy1TVkctMjAwMTA5MDQvRFREL3N2ZzEwLmR0ZCI+CjxzdmcgdmVyc2lvbj0iMS4wIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciCiB3aWR0aD0iMTc1LjAwMDAwMHB0IiBoZWlnaHQ9IjE3NS4wMDAwMDBwdCIgdmlld0JveD0iMCAwIDE3NS4wMDAwMDAgMTc1LjAwMDAwMCIKIHByZXNlcnZlQXNwZWN0UmF0aW89InhNaWRZTWlkIG1lZXQiPgoKPGcgdHJhbnNmb3JtPSJ0cmFuc2xhdGUoMC4wMDAwMDAsMTc1LjAwMDAwMCkgc2NhbGUoMC4xMDAwMDAsLTAuMTAwMDAwKSIKZmlsbD0iIzAwMDAwMCIgc3Ryb2tlPSJub25lIj4KPHBhdGggZD0iTTU2MyAxMjM2IGMtMjkgLTEzIC02MyAtNTkgLTYzIC04NiAwIC0yNiAzMyAtODAgNTIgLTg2IDE1IC00IDI2IDEKNDMgMjEgMjAgMjYgMjQgMjcgNTMgMTcgMjggLTkgMzMgLTggNDIgOCAxNyAzMiAxMSA2OSAtMTcgOTkgLTM0IDM3IC02OCA0NQotMTEwIDI3eiIvPgo8cGF0aCBkPSJNMTA2NCAxMjQwIGMtNTAgLTIwIC03NyAtODYgLTU0IC0xMzAgOSAtMTYgMTQgLTE3IDQyIC04IDI5IDEwIDMzIDkKNTUgLTE3IDIxIC0yNCAyNyAtMjYgNDggLTE3IDMxIDE0IDUxIDc2IDM2IDExNCAtMTcgNDYgLTg0IDc2IC0xMjcgNTh6Ii8+CjxwYXRoIGQ9Ik02MDAgODg4IGMwIC00OSAxNiAtOTggNTAgLTE1MSA4NSAtMTM0IDMyNSAtMTM0IDQxMCAwIDUxIDgwIDY5IDE4MwozMSAxODMgLTEwIDAgLTUwIC0xNSAtODcgLTMyIC02MCAtMjkgLTc5IC0zMyAtMTQ5IC0zMyAtNzAgMCAtODkgNCAtMTQ5IDMzCi0zNyAxNyAtNzcgMzIgLTg3IDMyIC0xNSAwIC0xOSAtNyAtMTkgLTMyeiIvPgo8L2c+Cjwvc3ZnPgo=)](https://huggingface.co/ChainYo)
 
98
  [![Discord](https://img.shields.io/badge/Chainyo%233610-%237289DA.svg?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/)
99
  """)
100
 
101
+ with st.expander("⭐ Details", expanded=True):
102
+ st.markdown(
103
+ """
104
+ This app is a **demo** of the [🤗 Optimum Text Classification](https://huggingface.co/docs/optimum/onnxruntime/modeling_ort#optimum-inference-with-onnx-runtime) pipeline.
105
+ We aim to compare the original pipeline with the ONNXRuntime pipeline.
106
 
107
+ We use the [Finbert-Tone](https://huggingface.co/yiyanghkust/finbert-tone) model to classify financial news tone for the demo.
108
+
109
+ You can enter multiple sentences to classify them by separating them with a `; (semicolon)`.
110
+ """
111
+ )
112
 
113
+ if "init_models" not in st.session_state:
114
+ st.session_state["init_models"] = True
115
+ if st.session_state["init_models"]:
116
+ with st.spinner(text="Loading files and models..."):
117
+ loading_logs = st.empty()
118
+ with loading_logs.container():
119
+ BASE_PATH.mkdir(exist_ok=True)
120
+ QUANTIZED_BASE_PATH.mkdir(exist_ok=True)
121
+ OPTIMIZED_BASE_PATH.mkdir(exist_ok=True)
122
 
123
+ if "tokenizer" not in st.session_state:
124
+ tokenizer = BertTokenizer.from_pretrained(HUB_MODEL_PATH)
125
+ st.session_state["tokenizer"] = tokenizer
126
+ st.text("✅ Tokenizer loaded.")
 
127
 
128
+ if "pt_model" not in st.session_state:
129
+ pt_model = BertForSequenceClassification.from_pretrained(HUB_MODEL_PATH, num_labels=3)
130
+ st.session_state["pt_model"] = pt_model
131
+ st.text("✅ PyTorch model loaded.")
132
+
133
+ if "ort_model" not in st.session_state:
134
+ ort_model = ORTModelForSequenceClassification.from_pretrained(HUB_MODEL_PATH, from_transformers=True)
135
+ # if not ONNX_MODEL_PATH.exists():
136
+ # ort_model.save_pretrained(ONNX_MODEL_PATH)
137
+ st.session_state["ort_model"] = ort_model
138
+ st.text("✅ ONNX Model loaded.")
139
+
140
+ if "optimized_model" not in st.session_state:
141
+ optimization_config = OptimizationConfig(optimization_level=99)
142
+ optimizer = ORTOptimizer.from_pretrained(HUB_MODEL_PATH, feature="sequence-classification")
143
+ if not OPTIMIZED_MODEL_PATH.exists():
144
+ optimizer.export(ONNX_MODEL_PATH, OPTIMIZED_MODEL_PATH, optimization_config=optimization_config)
145
+ optimizer.model.config.save_pretrained(OPTIMIZED_BASE_PATH)
146
+ optimized_model = ORTModelForSequenceClassification.from_pretrained(
147
+ OPTIMIZED_BASE_PATH, file_name=OPTIMIZED_MODEL_PATH.name
148
+ )
149
+ st.session_state["optimized_model"] = optimized_model
150
+ st.text("✅ Optimized ONNX model loaded.")
151
+
152
+ if "quantized_model" not in st.session_state:
153
+ quantization_config = AutoQuantizationConfig.arm64(is_static=False, per_channel=False)
154
+ quantizer = ORTQuantizer.from_pretrained(HUB_MODEL_PATH, feature="sequence-classification")
155
+ if not QUANTIZED_MODEL_PATH.exists():
156
+ quantizer.export(ONNX_MODEL_PATH, QUANTIZED_MODEL_PATH, quantization_config=quantization_config)
157
+ quantizer.model.config.save_pretrained(QUANTIZED_BASE_PATH)
158
+ quantized_model = ORTModelForSequenceClassification.from_pretrained(
159
+ QUANTIZED_BASE_PATH, file_name=QUANTIZED_MODEL_PATH.name
160
+ )
161
+ st.session_state["quantized_model"] = quantized_model
162
+ st.text("✅ Quantized ONNX model loaded.")
163
+
164
+ if "pt_pipeline" not in st.session_state:
165
+ pt_pipeline = pipeline(
166
+ "sentiment-analysis", tokenizer=st.session_state["tokenizer"], model=st.session_state["pt_model"]
167
+ )
168
+ st.session_state["pt_pipeline"] = pt_pipeline
169
+
170
+ if "ort_pipeline" not in st.session_state:
171
+ ort_pipeline = ort_pipeline(
172
+ "text-classification", tokenizer=st.session_state["tokenizer"], model=st.session_state["ort_model"]
173
+ )
174
+ st.session_state["ort_pipeline"] = ort_pipeline
175
+
176
+ if "ort_optimized_pipeline" not in st.session_state:
177
+ ort_optimized_pipeline = pipeline(
178
+ "text-classification",
179
+ tokenizer=st.session_state["tokenizer"],
180
+ model=st.session_state["optimized_model"],
181
+ )
182
+ st.session_state["ort_optimized_pipeline"] = ort_optimized_pipeline
183
+
184
+ if "ort_quantized_pipeline" not in st.session_state:
185
+ ort_quantized_pipeline = pipeline(
186
+ "text-classification",
187
+ tokenizer=st.session_state["tokenizer"],
188
+ model=st.session_state["quantized_model"],
189
+ )
190
+ st.session_state["ort_quantized_pipeline"] = ort_quantized_pipeline
191
 
192
+ st.text("✅ All pipelines are ready.")
193
+ sleep(2)
194
+ loading_logs.success("🎉 Everything is ready!")
195
+ st.session_state["init_models"] = False
196
 
197
+ if "inference_timers" not in st.session_state:
198
+ st.session_state["inference_timers"] = {}
 
199
 
200
+ exp_number = st.slider("The number of experiments per model.", min_value=100, max_value=300, value=150)
201
+ get_only_mean = st.checkbox("Get only the mean of the inference time for each model.", value=False)
202
+ input_text = st.text_area(
203
+ "Enter text to classify",
204
+ "there is a shortage of capital, and we need extra financing; growth is strong and we have plenty of liquidity; there are doubts about our finances; profits are flat"
205
+ )
206
+ run_inference = st.button("🚀 Run inference")
207
 
208
+ if run_inference:
209
+ st.text("🔎 Running inference...")
210
+ sentences = input_text.split(";")
211
+ st.session_state["inference_timers"] = get_timers(samples=sentences, exp_number=exp_number, only_mean=get_only_mean)
212
+ st.plotly_chart(get_plot(st.session_state["inference_timers"]), use_container_width=True)