Spaces:
Runtime error
Runtime error
import numpy as np | |
import pandas as pd | |
import streamlit as st | |
from streamlit_text_rating.st_text_rater import st_text_rater | |
from transformers import AutoTokenizer,AutoModelForSequenceClassification | |
import onnxruntime as ort | |
import os | |
import time | |
import plotly.express as px | |
import plotly.graph_objects as go | |
global _plotly_config | |
_plotly_config={'displayModeBar': False} | |
from sentiment_clf_helper import classify_sentiment,create_onnx_model_sentiment,classify_sentiment_onnx | |
from zeroshot_clf_helper import zero_shot_classification,create_onnx_model_zs,zero_shot_classification_onnx | |
st.set_page_config( # Alternate names: setup_page, page, layout | |
layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc. | |
initial_sidebar_state="auto", # Can be "auto", "expanded", "collapsed" | |
page_title='None', # String or None. Strings get appended with "• Streamlit". | |
) | |
padding_top = 0 | |
st.markdown(f""" | |
<style> | |
.reportview-container .main .block-container{{ | |
padding-top: {padding_top}rem; | |
}} | |
</style>""", | |
unsafe_allow_html=True, | |
) | |
def set_page_title(title): | |
st.sidebar.markdown(unsafe_allow_html=True, body=f""" | |
<iframe height=0 srcdoc="<script> | |
const title = window.parent.document.querySelector('title') \ | |
const oldObserver = window.parent.titleObserver | |
if (oldObserver) {{ | |
oldObserver.disconnect() | |
}} \ | |
const newObserver = new MutationObserver(function(mutations) {{ | |
const target = mutations[0].target | |
if (target.text !== '{title}') {{ | |
target.text = '{title}' | |
}} | |
}}) \ | |
newObserver.observe(title, {{ childList: true }}) | |
window.parent.titleObserver = newObserver \ | |
title.text = '{title}' | |
</script>" /> | |
""") | |
set_page_title('NLP use cases') | |
# Hide Menu Option | |
hide_streamlit_style = """ | |
<style> | |
#MainMenu {visibility: hidden;} | |
footer {visibility: hidden;} | |
</style> | |
""" | |
st.markdown(hide_streamlit_style, unsafe_allow_html=True) | |
def create_model_dir(chkpt, model_dir): | |
if not os.path.exists(model_dir): | |
try: | |
os.mkdir(path=model_dir) | |
except: | |
pass | |
_model = AutoModelForSequenceClassification.from_pretrained(chkpt) | |
_tokenizer = AutoTokenizer.from_pretrained(chkpt) | |
_model.save_pretrained(model_dir) | |
_tokenizer.save_pretrained(model_dir) | |
else: | |
pass | |
st.title("NLP use cases") | |
with st.sidebar: | |
st.title("NLP tasks") | |
select_task=st.selectbox(label="Select task from drop down menu", | |
options=['README', | |
'Detect Sentiment','Zero Shot Classification']) | |
if select_task=='README': | |
st.header("NLP Summary") | |
############### Pre-Download & instantiate objects for sentiment analysis *********************** START ********************** | |
sent_chkpt = "distilbert-base-uncased-finetuned-sst-2-english" | |
sent_model_dir="sentiment_model_dir" | |
#create model/token dir for sentiment classification | |
create_model_dir(chkpt=sent_chkpt, model_dir=sent_model_dir) | |
#create onnx model for sentiment classification | |
model_sentiment=AutoModelForSequenceClassification.from_pretrained(sent_model_dir) | |
tokenizer_sentiment=AutoTokenizer.from_pretrained(sent_model_dir) | |
create_onnx_model_sentiment(_model=model_sentiment, _tokenizer=tokenizer_sentiment) | |
def sentiment_task_selected(task,sent_model_dir=sent_model_dir): | |
model_sentiment=AutoModelForSequenceClassification.from_pretrained(sent_model_dir) | |
tokenizer_sentiment=AutoTokenizer.from_pretrained(sent_model_dir) | |
# create_onnx_model_sentiment(_model=model_sentiment, _tokenizer=tokenizer_sentiment) | |
#create inference session | |
sentiment_session = ort.InferenceSession("sent_clf_onnx_dir/sentiment_classifier_onnx.onnx") | |
sentiment_session_quant = ort.InferenceSession("sent_clf_onnx_dir/sentiment_classifier_onnx_quant.onnx") | |
return model_sentiment,tokenizer_sentiment,sentiment_session,sentiment_session_quant | |
############## Pre-Download & instantiate objects for sentiment analysis ********************* END ********************************** | |
############### Pre-Download & instantiate objects for Zero shot clf *********************** START ********************** | |
zs_chkpt = "valhalla/distilbart-mnli-12-1" | |
zs_model_dir = "zs_model_dir" | |
# create model/token dir for zeroshot clf | |
create_model_dir(chkpt=zs_chkpt, model_dir=zs_model_dir) | |
#ceate onnx model for zeroshot | |
create_onnx_model_zs() | |
def zs_task_selected(task, zs_model_dir=zs_model_dir,onnx_dir='zeroshot_onnx_dir'): | |
#model & tokenizer initialization for normal ZS classification | |
model_zs=AutoModelForSequenceClassification.from_pretrained(zs_model_dir) | |
tokenizer_zs=AutoTokenizer.from_pretrained(zs_model_dir) | |
#create inference session from onnx model | |
zs_session = ort.InferenceSession(f"{onnx_dir}/model.onnx") | |
zs_session_quant = ort.InferenceSession(f"{onnx_dir}/model_quant.onnx") | |
return model_zs,tokenizer_zs,zs_session,zs_session_quant | |
############## Pre-Download & instantiate objects for Zero shot analysis ********************* END ********************************** | |
if select_task == 'Detect Sentiment': | |
t1=time.time() | |
model_sentiment,tokenizer_sentiment,\ | |
sentiment_session,sentiment_session_quant = sentiment_task_selected(task=select_task) | |
t2 = time.time() | |
st.write(f"Total time to load Model is {(t2-t1)*1000:.1f} ms") | |
st.header("You are now performing Sentiment Analysis") | |
input_texts = st.text_input(label="Input texts separated by comma") | |
c1,c2,c3,c4=st.columns(4) | |
with c1: | |
response1=st.button("Normal runtime") | |
with c2: | |
response2=st.button("ONNX runtime") | |
with c3: | |
response3=st.button("ONNX runtime with Quantization") | |
with c4: | |
response4 = st.button("Simulate 100 runs each runtime") | |
if any([response1,response2,response3,response4]): | |
if response1: | |
start=time.time() | |
sentiments = classify_sentiment(input_texts, | |
model=model_sentiment, | |
tokenizer=tokenizer_sentiment | |
) | |
end=time.time() | |
st.write(f"Time taken for computation {(end-start)*1000:.1f} ms") | |
elif response2: | |
start = time.time() | |
sentiments=classify_sentiment_onnx(input_texts, | |
_session=sentiment_session, | |
_tokenizer=tokenizer_sentiment) | |
end = time.time() | |
st.write(f"Time taken for computation {(end - start) * 1000:.1f} ms") | |
elif response3: | |
start = time.time() | |
sentiments=classify_sentiment_onnx(input_texts, | |
_session=sentiment_session_quant, | |
_tokenizer=tokenizer_sentiment) | |
end = time.time() | |
st.write(f"Time taken for computation {(end - start) * 1000:.1f} ms") | |
elif response4: | |
normal_runtime=[] | |
for i in range(100): | |
start=time.time() | |
sentiments = classify_sentiment(input_texts, | |
model=model_sentiment, | |
tokenizer=tokenizer_sentiment) | |
end=time.time() | |
t = (end - start) * 1000 | |
normal_runtime.append(t) | |
normal_runtime=np.clip(normal_runtime,10,60) | |
onnx_runtime=[] | |
for i in range(100): | |
start=time.time() | |
sentiments = classify_sentiment_onnx(input_texts, | |
_session=sentiment_session, | |
_tokenizer=tokenizer_sentiment) | |
end=time.time() | |
t=(end-start)*1000 | |
onnx_runtime.append(t) | |
onnx_runtime = np.clip(onnx_runtime, 0, 20) | |
onnx_runtime_quant=[] | |
for i in range(100): | |
start=time.time() | |
sentiments = classify_sentiment_onnx(input_texts, | |
_session=sentiment_session_quant, | |
_tokenizer=tokenizer_sentiment) | |
end=time.time() | |
t=(end-start)*1000 | |
onnx_runtime_quant.append(t) | |
onnx_runtime_quant = np.clip(onnx_runtime_quant, 0, 20) | |
temp_df=pd.DataFrame({'Normal Runtime (ms)':normal_runtime, | |
'ONNX Runtime (ms)':onnx_runtime, | |
'ONNX Quant Runtime (ms)':onnx_runtime_quant}) | |
from plotly.subplots import make_subplots | |
fig = make_subplots(rows=1, cols=3, start_cell="bottom-left", | |
subplot_titles=['Normal Runtime','ONNX Runtime','ONNX Runtime with Quantization']) | |
fig.add_trace(go.Histogram(x=temp_df['Normal Runtime (ms)']),row=1,col=1) | |
fig.add_trace(go.Histogram(x=temp_df['ONNX Runtime (ms)']),row=1,col=2) | |
fig.add_trace(go.Histogram(x=temp_df['ONNX Quant Runtime (ms)']),row=1,col=3) | |
fig.update_layout(height=400, width=1000, | |
title_text="100 Simulations of different Runtimes", | |
showlegend=False) | |
st.plotly_chart(fig,config=_plotly_config ) | |
else: | |
pass | |
for i,t in enumerate(input_texts.split(',')): | |
if sentiments[i]=='Positive': | |
response=st_text_rater(t + f"--> This statement is {sentiments[i]}", | |
color_background='rgb(154,205,50)',key=t) | |
else: | |
response = st_text_rater(t + f"--> This statement is {sentiments[i]}", | |
color_background='rgb(233, 116, 81)',key=t) | |
if select_task=='Zero Shot Classification': | |
t1=time.time() | |
model_zs,tokenizer_zs,zs_session,zs_session_quant = zs_task_selected(task=select_task) | |
t2 = time.time() | |
st.write(f"Total time to load Model is {(t2-t1)*1000:.1f} ms") | |
st.header("You are now performing Zero Shot Classification") | |
input_texts = st.text_input(label="Input text to classify into topics") | |
input_lables = st.text_input(label="Enter labels separated by commas") | |
c1,c2,c3,c4=st.columns(4) | |
with c1: | |
response1=st.button("Normal runtime") | |
with c2: | |
response2=st.button("ONNX runtime") | |
with c3: | |
response3=st.button("ONNX runtime with Quantization") | |
with c4: | |
response4 = st.button("Simulate 10 runs each runtime") | |
if any([response1,response2,response3,response4]): | |
if response1: | |
start=time.time() | |
df_output = zero_shot_classification(input_texts, input_lables,model=model_zs,tokenizer=tokenizer_zs) | |
end=time.time() | |
st.write("") | |
st.write(f"Time taken for computation {(end-start)*1000:.1f} ms") | |
fig = px.bar(x='Probability', | |
y='labels', | |
text='Probability', | |
data_frame=df_output, | |
title='Zero Shot Normalized Probabilities') | |
st.plotly_chart(fig, config=_plotly_config) | |
elif response2: | |
start = time.time() | |
df_output=zero_shot_classification_onnx(premise=input_texts,labels=input_lables,_session=zs_session,_tokenizer=tokenizer_zs) | |
end=time.time() | |
st.write("") | |
st.write(f"Time taken for computation {(end-start)*1000:.1f} ms") | |
fig = px.bar(x='Probability', | |
y='labels', | |
text='Probability', | |
data_frame=df_output, | |
title='Zero Shot Normalized Probabilities') | |
st.plotly_chart(fig,config=_plotly_config) | |
elif response3: | |
start = time.time() | |
df_output = zero_shot_classification_onnx(premise=input_texts, labels=input_lables, _session=zs_session_quant, | |
_tokenizer=tokenizer_zs) | |
end = time.time() | |
st.write("") | |
st.write(f"Time taken for computation {(end-start)*1000:.1f} ms") | |
fig = px.bar(x='Probability', | |
y='labels', | |
text='Probability', | |
data_frame=df_output, | |
title='Zero Shot Normalized Probabilities') | |
st.plotly_chart(fig, config=_plotly_config) | |
elif response4: | |
normal_runtime = [] | |
for i in range(100): | |
start = time.time() | |
_ = zero_shot_classification(input_texts, input_lables,model=model_zs,tokenizer=tokenizer_zs) | |
end = time.time() | |
t = (end - start) * 1000 | |
normal_runtime.append(t) | |
normal_runtime = np.clip(normal_runtime, 50, 400) | |
onnx_runtime = [] | |
for i in range(100): | |
start = time.time() | |
_ = zero_shot_classification_onnx(premise=input_texts, labels=input_lables, _session=zs_session, | |
_tokenizer=tokenizer_zs) | |
end = time.time() | |
t = (end - start) * 1000 | |
onnx_runtime.append(t) | |
onnx_runtime = np.clip(onnx_runtime, 50, 200) | |
onnx_runtime_quant = [] | |
for i in range(100): | |
start = time.time() | |
_ = zero_shot_classification_onnx(premise=input_texts, labels=input_lables, _session=zs_session_quant, | |
_tokenizer=tokenizer_zs) | |
end = time.time() | |
t = (end - start) * 1000 | |
onnx_runtime_quant.append(t) | |
onnx_runtime_quant = np.clip(onnx_runtime_quant, 50, 200) | |
temp_df = pd.DataFrame({'Normal Runtime (ms)': normal_runtime, | |
'ONNX Runtime (ms)': onnx_runtime, | |
'ONNX Quant Runtime (ms)': onnx_runtime_quant}) | |
from plotly.subplots import make_subplots | |
fig = make_subplots(rows=1, cols=3, start_cell="bottom-left", | |
subplot_titles=['Normal Runtime', 'ONNX Runtime', 'ONNX Runtime with Quantization']) | |
fig.add_trace(go.Histogram(x=temp_df['Normal Runtime (ms)']), row=1, col=1) | |
fig.add_trace(go.Histogram(x=temp_df['ONNX Runtime (ms)']), row=1, col=2) | |
fig.add_trace(go.Histogram(x=temp_df['ONNX Quant Runtime (ms)']), row=1, col=3) | |
fig.update_layout(height=400, width=1000, | |
title_text="10 Simulations of different Runtimes", | |
showlegend=False) | |
st.plotly_chart(fig, config=_plotly_config) | |
else: | |
pass | |