Spaces:
Runtime error
Runtime error
"""⭐ Text Classification with Optimum and ONNXRuntime | |
Streamlit application to classify text using multiple models. | |
Author: | |
- @ChainYo - https://github.com/ChainYo | |
""" | |
import plotly | |
import numpy as np | |
import pandas as pd | |
import streamlit as st | |
from pathlib import Path | |
from time import sleep | |
from typing import Dict, List, Union | |
from optimum.onnxruntime import ORTModelForSequenceClassification, ORTOptimizer, ORTQuantizer | |
from optimum.onnxruntime.configuration import OptimizationConfig, AutoQuantizationConfig | |
from optimum.onnxruntime.model import ORTModel | |
from optimum.pipelines import pipeline as ort_pipeline | |
from transformers import BertTokenizer, BertForSequenceClassification, pipeline | |
from utils import calculate_inference_time | |
HUB_MODEL_PATH = "yiyanghkust/finbert-tone" | |
BASE_PATH = Path("models") | |
ONNX_MODEL_PATH = BASE_PATH.joinpath("model.onnx") | |
OPTIMIZED_BASE_PATH = BASE_PATH.joinpath("optimized") | |
OPTIMIZED_MODEL_PATH = OPTIMIZED_BASE_PATH.joinpath("model-optimized.onnx") | |
QUANTIZED_BASE_PATH = BASE_PATH.joinpath("quantized") | |
QUANTIZED_MODEL_PATH = QUANTIZED_BASE_PATH.joinpath("model-quantized.onnx") | |
VAR2LABEL = { | |
"pt_pipeline": "PyTorch", | |
"ort_pipeline": "ONNXRuntime", | |
"ort_optimized_pipeline": "ONNXRuntime (Optimized)", | |
"ort_quantized_pipeline": "ONNXRuntime (Quantized)", | |
} | |
def get_timers( | |
samples: Union[List[str], str], exp_number: int, only_mean: bool = False | |
) -> Dict[str, float]: | |
""" | |
Calculate inference time for each model for a given sample or list of samples. | |
Parameters | |
---------- | |
samples : Union[List[str], str] | |
Sample or list of samples to calculate inference time for. | |
exp_number : int | |
Number of experiments to run. | |
Returns | |
------- | |
Dict[str, float] | |
Dictionary of inference times for each model for the given samples. | |
""" | |
if isinstance(samples, str): | |
samples = [samples] | |
timers: Dict[str, float] = {} | |
for model in VAR2LABEL.keys(): | |
time_buffer = [] | |
for _ in range(exp_number): | |
with calculate_inference_time(time_buffer): | |
st.session_state[model](samples) | |
timers[VAR2LABEL[model]] = np.mean(time_buffer) if only_mean else time_buffer | |
return timers | |
def get_plot(timers: Dict[str, Union[float, List[float]]]) -> plotly.graph_objs._figure.Figure: | |
""" | |
Plot the inference time for each model. | |
Parameters | |
---------- | |
timers : Dict[str, Union[float, List[float]]] | |
Dictionary of inference times for each model. | |
""" | |
data = pd.DataFrame.from_dict(timers, orient="columns") | |
colors = ["#140f0d", "#2b2c4f", "#615aa2", "#a991fa"] | |
fig = plotly.figure_factory.create_distplot( | |
[data[col] for col in data.columns], data.columns, bin_size=0.2, colors=colors | |
) | |
fig.update_layout(title_text="Inference Time", xaxis_title="Inference Time (s)", yaxis_title="Number of Samples") | |
return fig | |
st.set_page_config(page_title="Optimum Text Classification", page_icon="⭐") | |
st.title("⭐ Optimum Text Classification") | |
st.subheader("Classify financial news tone with 🤗 Optimum and ONNXRuntime") | |
st.markdown(""" | |
[![GitHub](https://img.shields.io/badge/-%23121011.svg?style=for-the-badge&logo=github&logoColor=white)](https://github.com/ChainYo) | |
[![HuggingFace](https://img.shields.io/badge/-yellow.svg?style=for-the-badge&logo=data:image/svg+xml;base64,PD94bWwgdmVyc2lvbj0iMS4wIiBzdGFuZGFsb25lPSJubyI/Pgo8IURPQ1RZUEUgc3ZnIFBVQkxJQyAiLS8vVzNDLy9EVEQgU1ZHIDIwMDEwOTA0Ly9FTiIKICJodHRwOi8vd3d3LnczLm9yZy9UUi8yMDAxL1JFQy1TVkctMjAwMTA5MDQvRFREL3N2ZzEwLmR0ZCI+CjxzdmcgdmVyc2lvbj0iMS4wIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciCiB3aWR0aD0iMTc1LjAwMDAwMHB0IiBoZWlnaHQ9IjE3NS4wMDAwMDBwdCIgdmlld0JveD0iMCAwIDE3NS4wMDAwMDAgMTc1LjAwMDAwMCIKIHByZXNlcnZlQXNwZWN0UmF0aW89InhNaWRZTWlkIG1lZXQiPgoKPGcgdHJhbnNmb3JtPSJ0cmFuc2xhdGUoMC4wMDAwMDAsMTc1LjAwMDAwMCkgc2NhbGUoMC4xMDAwMDAsLTAuMTAwMDAwKSIKZmlsbD0iIzAwMDAwMCIgc3Ryb2tlPSJub25lIj4KPHBhdGggZD0iTTU2MyAxMjM2IGMtMjkgLTEzIC02MyAtNTkgLTYzIC04NiAwIC0yNiAzMyAtODAgNTIgLTg2IDE1IC00IDI2IDEKNDMgMjEgMjAgMjYgMjQgMjcgNTMgMTcgMjggLTkgMzMgLTggNDIgOCAxNyAzMiAxMSA2OSAtMTcgOTkgLTM0IDM3IC02OCA0NQotMTEwIDI3eiIvPgo8cGF0aCBkPSJNMTA2NCAxMjQwIGMtNTAgLTIwIC03NyAtODYgLTU0IC0xMzAgOSAtMTYgMTQgLTE3IDQyIC04IDI5IDEwIDMzIDkKNTUgLTE3IDIxIC0yNCAyNyAtMjYgNDggLTE3IDMxIDE0IDUxIDc2IDM2IDExNCAtMTcgNDYgLTg0IDc2IC0xMjcgNTh6Ii8+CjxwYXRoIGQ9Ik02MDAgODg4IGMwIC00OSAxNiAtOTggNTAgLTE1MSA4NSAtMTM0IDMyNSAtMTM0IDQxMCAwIDUxIDgwIDY5IDE4MwozMSAxODMgLTEwIDAgLTUwIC0xNSAtODcgLTMyIC02MCAtMjkgLTc5IC0zMyAtMTQ5IC0zMyAtNzAgMCAtODkgNCAtMTQ5IDMzCi0zNyAxNyAtNzcgMzIgLTg3IDMyIC0xNSAwIC0xOSAtNyAtMTkgLTMyeiIvPgo8L2c+Cjwvc3ZnPgo=)](https://huggingface.co/ChainYo) | |
[![LinkedIn](https://img.shields.io/badge/-%230077B5.svg?style=for-the-badge&logo=linkedin&logoColor=white)](https://www.linkedin.com/in/thomas-chaigneau-dev/) | |
[![Discord](https://img.shields.io/badge/Chainyo%233610-%237289DA.svg?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/) | |
""") | |
with st.expander("⭐ Details", expanded=True): | |
st.markdown( | |
""" | |
This app is a **demo** of the [🤗 Optimum Text Classification](https://huggingface.co/docs/optimum/onnxruntime/modeling_ort#optimum-inference-with-onnx-runtime) pipeline. | |
We aim to compare the original pipeline with the ONNXRuntime pipeline. | |
We use the [Finbert-Tone](https://huggingface.co/yiyanghkust/finbert-tone) model to classify financial news tone for the demo. | |
You can enter multiple sentences to classify them by separating them with a `; (semicolon)`. | |
""" | |
) | |
if "init_models" not in st.session_state: | |
st.session_state["init_models"] = True | |
if st.session_state["init_models"]: | |
with st.spinner(text="Loading files and models..."): | |
loading_logs = st.empty() | |
with loading_logs.container(): | |
BASE_PATH.mkdir(exist_ok=True) | |
QUANTIZED_BASE_PATH.mkdir(exist_ok=True) | |
OPTIMIZED_BASE_PATH.mkdir(exist_ok=True) | |
if "tokenizer" not in st.session_state: | |
tokenizer = BertTokenizer.from_pretrained(HUB_MODEL_PATH) | |
st.session_state["tokenizer"] = tokenizer | |
st.text("✅ Tokenizer loaded.") | |
if "pt_model" not in st.session_state: | |
pt_model = BertForSequenceClassification.from_pretrained(HUB_MODEL_PATH, num_labels=3) | |
st.session_state["pt_model"] = pt_model | |
st.text("✅ PyTorch model loaded.") | |
if "ort_model" not in st.session_state: | |
ort_model = ORTModelForSequenceClassification.from_pretrained(HUB_MODEL_PATH, from_transformers=True) | |
# if not ONNX_MODEL_PATH.exists(): | |
# ort_model.save_pretrained(ONNX_MODEL_PATH) | |
st.session_state["ort_model"] = ort_model | |
st.text("✅ ONNX Model loaded.") | |
if "optimized_model" not in st.session_state: | |
optimization_config = OptimizationConfig(optimization_level=99) | |
optimizer = ORTOptimizer.from_pretrained(HUB_MODEL_PATH, feature="sequence-classification") | |
if not OPTIMIZED_MODEL_PATH.exists(): | |
optimizer.export(ONNX_MODEL_PATH, OPTIMIZED_MODEL_PATH, optimization_config=optimization_config) | |
optimizer.model.config.save_pretrained(OPTIMIZED_BASE_PATH) | |
optimized_model = ORTModelForSequenceClassification.from_pretrained( | |
OPTIMIZED_BASE_PATH, file_name=OPTIMIZED_MODEL_PATH.name | |
) | |
st.session_state["optimized_model"] = optimized_model | |
st.text("✅ Optimized ONNX model loaded.") | |
if "quantized_model" not in st.session_state: | |
quantization_config = AutoQuantizationConfig.arm64(is_static=False, per_channel=False) | |
quantizer = ORTQuantizer.from_pretrained(HUB_MODEL_PATH, feature="sequence-classification") | |
if not QUANTIZED_MODEL_PATH.exists(): | |
quantizer.export(ONNX_MODEL_PATH, QUANTIZED_MODEL_PATH, quantization_config=quantization_config) | |
quantizer.model.config.save_pretrained(QUANTIZED_BASE_PATH) | |
quantized_model = ORTModelForSequenceClassification.from_pretrained( | |
QUANTIZED_BASE_PATH, file_name=QUANTIZED_MODEL_PATH.name | |
) | |
st.session_state["quantized_model"] = quantized_model | |
st.text("✅ Quantized ONNX model loaded.") | |
if "pt_pipeline" not in st.session_state: | |
pt_pipeline = pipeline( | |
"sentiment-analysis", tokenizer=st.session_state["tokenizer"], model=st.session_state["pt_model"] | |
) | |
st.session_state["pt_pipeline"] = pt_pipeline | |
if "ort_pipeline" not in st.session_state: | |
ort_pipeline = ort_pipeline( | |
"text-classification", tokenizer=st.session_state["tokenizer"], model=st.session_state["ort_model"] | |
) | |
st.session_state["ort_pipeline"] = ort_pipeline | |
if "ort_optimized_pipeline" not in st.session_state: | |
ort_optimized_pipeline = pipeline( | |
"text-classification", | |
tokenizer=st.session_state["tokenizer"], | |
model=st.session_state["optimized_model"], | |
) | |
st.session_state["ort_optimized_pipeline"] = ort_optimized_pipeline | |
if "ort_quantized_pipeline" not in st.session_state: | |
ort_quantized_pipeline = pipeline( | |
"text-classification", | |
tokenizer=st.session_state["tokenizer"], | |
model=st.session_state["quantized_model"], | |
) | |
st.session_state["ort_quantized_pipeline"] = ort_quantized_pipeline | |
st.text("✅ All pipelines are ready.") | |
sleep(2) | |
loading_logs.success("🎉 Everything is ready!") | |
st.session_state["init_models"] = False | |
if "inference_timers" not in st.session_state: | |
st.session_state["inference_timers"] = {} | |
exp_number = st.slider("The number of experiments per model.", min_value=100, max_value=300, value=150) | |
get_only_mean = st.checkbox("Get only the mean of the inference time for each model.", value=False) | |
input_text = st.text_area( | |
"Enter text to classify", | |
"there is a shortage of capital, and we need extra financing; growth is strong and we have plenty of liquidity; there are doubts about our finances; profits are flat" | |
) | |
run_inference = st.button("🚀 Run inference") | |
if run_inference: | |
st.text("🔎 Running inference...") | |
sentences = input_text.split(";") | |
st.session_state["inference_timers"] = get_timers(samples=sentences, exp_number=exp_number, only_mean=get_only_mean) | |
st.plotly_chart(get_plot(st.session_state["inference_timers"]), use_container_width=True) | |