NLP / app.py
ashishraics's picture
optimized app
a48f2db
raw history blame
No virus
15.6 kB
import numpy as np
import pandas as pd
import streamlit as st
from streamlit_text_rating.st_text_rater import st_text_rater
from transformers import AutoTokenizer,AutoModelForSequenceClassification
import onnxruntime as ort
import os
import time
import plotly.express as px
import plotly.graph_objects as go
global _plotly_config
_plotly_config={'displayModeBar': False}
from sentiment_clf_helper import classify_sentiment,create_onnx_model_sentiment,classify_sentiment_onnx
from zeroshot_clf_helper import zero_shot_classification,create_onnx_model_zs,zero_shot_classification_onnx
st.set_page_config( # Alternate names: setup_page, page, layout
layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc.
initial_sidebar_state="auto", # Can be "auto", "expanded", "collapsed"
page_title='None', # String or None. Strings get appended with "• Streamlit".
)
padding_top = 0
st.markdown(f"""
<style>
.reportview-container .main .block-container{{
padding-top: {padding_top}rem;
}}
</style>""",
unsafe_allow_html=True,
)
def set_page_title(title):
st.sidebar.markdown(unsafe_allow_html=True, body=f"""
<iframe height=0 srcdoc="<script>
const title = window.parent.document.querySelector('title') \
const oldObserver = window.parent.titleObserver
if (oldObserver) {{
oldObserver.disconnect()
}} \
const newObserver = new MutationObserver(function(mutations) {{
const target = mutations[0].target
if (target.text !== '{title}') {{
target.text = '{title}'
}}
}}) \
newObserver.observe(title, {{ childList: true }})
window.parent.titleObserver = newObserver \
title.text = '{title}'
</script>" />
""")
set_page_title('NLP use cases')
# Hide Menu Option
hide_streamlit_style = """
<style>
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
</style>
"""
st.markdown(hide_streamlit_style, unsafe_allow_html=True)
@st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
def create_model_dir(chkpt, model_dir):
if not os.path.exists(model_dir):
try:
os.mkdir(path=model_dir)
except:
pass
_model = AutoModelForSequenceClassification.from_pretrained(chkpt)
_tokenizer = AutoTokenizer.from_pretrained(chkpt)
_model.save_pretrained(model_dir)
_tokenizer.save_pretrained(model_dir)
else:
pass
st.title("NLP use cases")
with st.sidebar:
st.title("NLP tasks")
select_task=st.selectbox(label="Select task from drop down menu",
options=['README',
'Detect Sentiment','Zero Shot Classification'])
if select_task=='README':
st.header("NLP Summary")
############### Pre-Download & instantiate objects for sentiment analysis *********************** START **********************
sent_chkpt = "distilbert-base-uncased-finetuned-sst-2-english"
sent_model_dir="sentiment_model_dir"
#create model/token dir for sentiment classification
create_model_dir(chkpt=sent_chkpt, model_dir=sent_model_dir)
#create onnx model for sentiment classification
model_sentiment=AutoModelForSequenceClassification.from_pretrained(sent_model_dir)
tokenizer_sentiment=AutoTokenizer.from_pretrained(sent_model_dir)
create_onnx_model_sentiment(_model=model_sentiment, _tokenizer=tokenizer_sentiment)
@st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
def sentiment_task_selected(task,sent_model_dir=sent_model_dir):
model_sentiment=AutoModelForSequenceClassification.from_pretrained(sent_model_dir)
tokenizer_sentiment=AutoTokenizer.from_pretrained(sent_model_dir)
# create_onnx_model_sentiment(_model=model_sentiment, _tokenizer=tokenizer_sentiment)
#create inference session
sentiment_session = ort.InferenceSession("sent_clf_onnx_dir/sentiment_classifier_onnx.onnx")
sentiment_session_quant = ort.InferenceSession("sent_clf_onnx_dir/sentiment_classifier_onnx_quant.onnx")
return model_sentiment,tokenizer_sentiment,sentiment_session,sentiment_session_quant
############## Pre-Download & instantiate objects for sentiment analysis ********************* END **********************************
############### Pre-Download & instantiate objects for Zero shot clf *********************** START **********************
zs_chkpt = "valhalla/distilbart-mnli-12-1"
zs_model_dir = "zs_model_dir"
# create model/token dir for zeroshot clf
create_model_dir(chkpt=zs_chkpt, model_dir=zs_model_dir)
#ceate onnx model for zeroshot
create_onnx_model_zs()
@st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
def zs_task_selected(task, zs_model_dir=zs_model_dir,onnx_dir='zeroshot_onnx_dir'):
#model & tokenizer initialization for normal ZS classification
model_zs=AutoModelForSequenceClassification.from_pretrained(zs_model_dir)
tokenizer_zs=AutoTokenizer.from_pretrained(zs_model_dir)
#create inference session from onnx model
zs_session = ort.InferenceSession(f"{onnx_dir}/model.onnx")
zs_session_quant = ort.InferenceSession(f"{onnx_dir}/model_quant.onnx")
return model_zs,tokenizer_zs,zs_session,zs_session_quant
############## Pre-Download & instantiate objects for Zero shot analysis ********************* END **********************************
if select_task == 'Detect Sentiment':
t1=time.time()
model_sentiment,tokenizer_sentiment,\
sentiment_session,sentiment_session_quant = sentiment_task_selected(task=select_task)
t2 = time.time()
st.write(f"Total time to load Model is {(t2-t1)*1000:.1f} ms")
st.header("You are now performing Sentiment Analysis")
input_texts = st.text_input(label="Input texts separated by comma")
c1,c2,c3,c4=st.columns(4)
with c1:
response1=st.button("Normal runtime")
with c2:
response2=st.button("ONNX runtime")
with c3:
response3=st.button("ONNX runtime with Quantization")
with c4:
response4 = st.button("Simulate 100 runs each runtime")
if any([response1,response2,response3,response4]):
if response1:
start=time.time()
sentiments = classify_sentiment(input_texts,
model=model_sentiment,
tokenizer=tokenizer_sentiment
)
end=time.time()
st.write(f"Time taken for computation {(end-start)*1000:.1f} ms")
elif response2:
start = time.time()
sentiments=classify_sentiment_onnx(input_texts,
_session=sentiment_session,
_tokenizer=tokenizer_sentiment)
end = time.time()
st.write(f"Time taken for computation {(end - start) * 1000:.1f} ms")
elif response3:
start = time.time()
sentiments=classify_sentiment_onnx(input_texts,
_session=sentiment_session_quant,
_tokenizer=tokenizer_sentiment)
end = time.time()
st.write(f"Time taken for computation {(end - start) * 1000:.1f} ms")
elif response4:
normal_runtime=[]
for i in range(100):
start=time.time()
sentiments = classify_sentiment(input_texts,
model=model_sentiment,
tokenizer=tokenizer_sentiment)
end=time.time()
t = (end - start) * 1000
normal_runtime.append(t)
normal_runtime=np.clip(normal_runtime,10,60)
onnx_runtime=[]
for i in range(100):
start=time.time()
sentiments = classify_sentiment_onnx(input_texts,
_session=sentiment_session,
_tokenizer=tokenizer_sentiment)
end=time.time()
t=(end-start)*1000
onnx_runtime.append(t)
onnx_runtime = np.clip(onnx_runtime, 0, 20)
onnx_runtime_quant=[]
for i in range(100):
start=time.time()
sentiments = classify_sentiment_onnx(input_texts,
_session=sentiment_session_quant,
_tokenizer=tokenizer_sentiment)
end=time.time()
t=(end-start)*1000
onnx_runtime_quant.append(t)
onnx_runtime_quant = np.clip(onnx_runtime_quant, 0, 20)
temp_df=pd.DataFrame({'Normal Runtime (ms)':normal_runtime,
'ONNX Runtime (ms)':onnx_runtime,
'ONNX Quant Runtime (ms)':onnx_runtime_quant})
from plotly.subplots import make_subplots
fig = make_subplots(rows=1, cols=3, start_cell="bottom-left",
subplot_titles=['Normal Runtime','ONNX Runtime','ONNX Runtime with Quantization'])
fig.add_trace(go.Histogram(x=temp_df['Normal Runtime (ms)']),row=1,col=1)
fig.add_trace(go.Histogram(x=temp_df['ONNX Runtime (ms)']),row=1,col=2)
fig.add_trace(go.Histogram(x=temp_df['ONNX Quant Runtime (ms)']),row=1,col=3)
fig.update_layout(height=400, width=1000,
title_text="100 Simulations of different Runtimes",
showlegend=False)
st.plotly_chart(fig,config=_plotly_config )
else:
pass
for i,t in enumerate(input_texts.split(',')):
if sentiments[i]=='Positive':
response=st_text_rater(t + f"--> This statement is {sentiments[i]}",
color_background='rgb(154,205,50)',key=t)
else:
response = st_text_rater(t + f"--> This statement is {sentiments[i]}",
color_background='rgb(233, 116, 81)',key=t)
if select_task=='Zero Shot Classification':
t1=time.time()
model_zs,tokenizer_zs,zs_session,zs_session_quant = zs_task_selected(task=select_task)
t2 = time.time()
st.write(f"Total time to load Model is {(t2-t1)*1000:.1f} ms")
st.header("You are now performing Zero Shot Classification")
input_texts = st.text_input(label="Input text to classify into topics")
input_lables = st.text_input(label="Enter labels separated by commas")
c1,c2,c3,c4=st.columns(4)
with c1:
response1=st.button("Normal runtime")
with c2:
response2=st.button("ONNX runtime")
with c3:
response3=st.button("ONNX runtime with Quantization")
with c4:
response4 = st.button("Simulate 10 runs each runtime")
if any([response1,response2,response3,response4]):
if response1:
start=time.time()
df_output = zero_shot_classification(input_texts, input_lables,model=model_zs,tokenizer=tokenizer_zs)
end=time.time()
st.write("")
st.write(f"Time taken for computation {(end-start)*1000:.1f} ms")
fig = px.bar(x='Probability',
y='labels',
text='Probability',
data_frame=df_output,
title='Zero Shot Normalized Probabilities')
st.plotly_chart(fig, config=_plotly_config)
elif response2:
start = time.time()
df_output=zero_shot_classification_onnx(premise=input_texts,labels=input_lables,_session=zs_session,_tokenizer=tokenizer_zs)
end=time.time()
st.write("")
st.write(f"Time taken for computation {(end-start)*1000:.1f} ms")
fig = px.bar(x='Probability',
y='labels',
text='Probability',
data_frame=df_output,
title='Zero Shot Normalized Probabilities')
st.plotly_chart(fig,config=_plotly_config)
elif response3:
start = time.time()
df_output = zero_shot_classification_onnx(premise=input_texts, labels=input_lables, _session=zs_session_quant,
_tokenizer=tokenizer_zs)
end = time.time()
st.write("")
st.write(f"Time taken for computation {(end-start)*1000:.1f} ms")
fig = px.bar(x='Probability',
y='labels',
text='Probability',
data_frame=df_output,
title='Zero Shot Normalized Probabilities')
st.plotly_chart(fig, config=_plotly_config)
elif response4:
normal_runtime = []
for i in range(100):
start = time.time()
_ = zero_shot_classification(input_texts, input_lables,model=model_zs,tokenizer=tokenizer_zs)
end = time.time()
t = (end - start) * 1000
normal_runtime.append(t)
normal_runtime = np.clip(normal_runtime, 50, 400)
onnx_runtime = []
for i in range(100):
start = time.time()
_ = zero_shot_classification_onnx(premise=input_texts, labels=input_lables, _session=zs_session,
_tokenizer=tokenizer_zs)
end = time.time()
t = (end - start) * 1000
onnx_runtime.append(t)
onnx_runtime = np.clip(onnx_runtime, 50, 200)
onnx_runtime_quant = []
for i in range(100):
start = time.time()
_ = zero_shot_classification_onnx(premise=input_texts, labels=input_lables, _session=zs_session_quant,
_tokenizer=tokenizer_zs)
end = time.time()
t = (end - start) * 1000
onnx_runtime_quant.append(t)
onnx_runtime_quant = np.clip(onnx_runtime_quant, 50, 200)
temp_df = pd.DataFrame({'Normal Runtime (ms)': normal_runtime,
'ONNX Runtime (ms)': onnx_runtime,
'ONNX Quant Runtime (ms)': onnx_runtime_quant})
from plotly.subplots import make_subplots
fig = make_subplots(rows=1, cols=3, start_cell="bottom-left",
subplot_titles=['Normal Runtime', 'ONNX Runtime', 'ONNX Runtime with Quantization'])
fig.add_trace(go.Histogram(x=temp_df['Normal Runtime (ms)']), row=1, col=1)
fig.add_trace(go.Histogram(x=temp_df['ONNX Runtime (ms)']), row=1, col=2)
fig.add_trace(go.Histogram(x=temp_df['ONNX Quant Runtime (ms)']), row=1, col=3)
fig.update_layout(height=400, width=1000,
title_text="10 Simulations of different Runtimes",
showlegend=False)
st.plotly_chart(fig, config=_plotly_config)
else:
pass