Spaces:
Runtime error
Runtime error
Commit
·
cca4ece
1
Parent(s):
9111b95
optimizing app
Browse files
app.py
CHANGED
|
@@ -14,6 +14,25 @@ _plotly_config={'displayModeBar': False}
|
|
| 14 |
from sentiment_clf_helper import classify_sentiment,create_onnx_model_sentiment,classify_sentiment_onnx
|
| 15 |
from zeroshot_clf_helper import zero_shot_classification,create_onnx_model_zs,zero_shot_classification_onnx
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
st.set_page_config( # Alternate names: setup_page, page, layout
|
| 19 |
layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc.
|
|
@@ -97,24 +116,27 @@ if select_task=='README':
|
|
| 97 |
|
| 98 |
############### Pre-Download & instantiate objects for sentiment analysis *********************** START **********************
|
| 99 |
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
#create model/token dir for sentiment classification
|
| 103 |
-
create_model_dir(chkpt=sent_chkpt, model_dir=sent_model_dir)
|
| 104 |
|
| 105 |
|
| 106 |
@st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
|
| 107 |
-
def sentiment_task_selected(task,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
#model & tokenizer initialization for normal sentiment classification
|
| 109 |
-
model_sentiment=AutoModelForSequenceClassification.from_pretrained(
|
| 110 |
-
tokenizer_sentiment=AutoTokenizer.from_pretrained(
|
| 111 |
|
| 112 |
# create onnx model for sentiment classification
|
| 113 |
create_onnx_model_sentiment(_model=model_sentiment, _tokenizer=tokenizer_sentiment)
|
| 114 |
|
| 115 |
#create inference session
|
| 116 |
-
sentiment_session = ort.InferenceSession("
|
| 117 |
-
sentiment_session_quant = ort.InferenceSession("
|
| 118 |
|
| 119 |
return model_sentiment,tokenizer_sentiment,sentiment_session,sentiment_session_quant
|
| 120 |
|
|
@@ -123,26 +145,31 @@ def sentiment_task_selected(task,sent_model_dir=sent_model_dir):
|
|
| 123 |
|
| 124 |
############### Pre-Download & instantiate objects for Zero shot clf *********************** START **********************
|
| 125 |
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
# create model/token dir for zeroshot clf
|
| 129 |
-
create_model_dir(chkpt=zs_chkpt, model_dir=zs_model_dir)
|
| 130 |
|
| 131 |
|
| 132 |
@st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
|
| 133 |
-
def zs_task_selected(task,
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
create_onnx_model_zs()
|
| 140 |
|
| 141 |
#create inference session from onnx model
|
| 142 |
-
zs_session = ort.InferenceSession(f"{
|
| 143 |
-
zs_session_quant = ort.InferenceSession(f"{
|
| 144 |
|
| 145 |
-
return
|
| 146 |
|
| 147 |
############## Pre-Download & instantiate objects for Zero shot analysis ********************* END **********************************
|
| 148 |
|
|
@@ -256,7 +283,7 @@ if select_task == 'Detect Sentiment':
|
|
| 256 |
if select_task=='Zero Shot Classification':
|
| 257 |
|
| 258 |
t1=time.time()
|
| 259 |
-
|
| 260 |
t2 = time.time()
|
| 261 |
st.write(f"Total time to load Model is {(t2-t1)*1000:.1f} ms")
|
| 262 |
|
|
@@ -267,29 +294,16 @@ if select_task=='Zero Shot Classification':
|
|
| 267 |
c1,c2,c3,c4=st.columns(4)
|
| 268 |
|
| 269 |
with c1:
|
| 270 |
-
response1=st.button("
|
| 271 |
with c2:
|
| 272 |
-
response2=st.button("ONNX runtime")
|
| 273 |
-
with c3:
|
| 274 |
-
|
| 275 |
-
with c4:
|
| 276 |
-
|
| 277 |
|
| 278 |
-
if any([response1,response2
|
| 279 |
if response1:
|
| 280 |
-
start=time.time()
|
| 281 |
-
df_output = zero_shot_classification(input_texts, input_lables,model=model_zs,tokenizer=tokenizer_zs)
|
| 282 |
-
end=time.time()
|
| 283 |
-
st.write("")
|
| 284 |
-
st.write(f"Time taken for computation {(end-start)*1000:.1f} ms")
|
| 285 |
-
fig = px.bar(x='Probability',
|
| 286 |
-
y='labels',
|
| 287 |
-
text='Probability',
|
| 288 |
-
data_frame=df_output,
|
| 289 |
-
title='Zero Shot Normalized Probabilities')
|
| 290 |
-
|
| 291 |
-
st.plotly_chart(fig, config=_plotly_config)
|
| 292 |
-
elif response2:
|
| 293 |
start = time.time()
|
| 294 |
df_output=zero_shot_classification_onnx(premise=input_texts,labels=input_lables,_session=zs_session,_tokenizer=tokenizer_zs)
|
| 295 |
end=time.time()
|
|
@@ -303,7 +317,7 @@ if select_task=='Zero Shot Classification':
|
|
| 303 |
title='Zero Shot Normalized Probabilities')
|
| 304 |
|
| 305 |
st.plotly_chart(fig,config=_plotly_config)
|
| 306 |
-
elif
|
| 307 |
start = time.time()
|
| 308 |
df_output = zero_shot_classification_onnx(premise=input_texts, labels=input_lables, _session=zs_session_quant,
|
| 309 |
_tokenizer=tokenizer_zs)
|
|
@@ -317,53 +331,6 @@ if select_task=='Zero Shot Classification':
|
|
| 317 |
title='Zero Shot Normalized Probabilities')
|
| 318 |
|
| 319 |
st.plotly_chart(fig, config=_plotly_config)
|
| 320 |
-
elif response4:
|
| 321 |
-
normal_runtime = []
|
| 322 |
-
for i in range(100):
|
| 323 |
-
start = time.time()
|
| 324 |
-
_ = zero_shot_classification(input_texts, input_lables,model=model_zs,tokenizer=tokenizer_zs)
|
| 325 |
-
end = time.time()
|
| 326 |
-
t = (end - start) * 1000
|
| 327 |
-
normal_runtime.append(t)
|
| 328 |
-
normal_runtime = np.clip(normal_runtime, 50, 400)
|
| 329 |
-
|
| 330 |
-
onnx_runtime = []
|
| 331 |
-
for i in range(100):
|
| 332 |
-
start = time.time()
|
| 333 |
-
_ = zero_shot_classification_onnx(premise=input_texts, labels=input_lables, _session=zs_session,
|
| 334 |
-
_tokenizer=tokenizer_zs)
|
| 335 |
-
end = time.time()
|
| 336 |
-
t = (end - start) * 1000
|
| 337 |
-
onnx_runtime.append(t)
|
| 338 |
-
onnx_runtime = np.clip(onnx_runtime, 50, 200)
|
| 339 |
-
|
| 340 |
-
onnx_runtime_quant = []
|
| 341 |
-
for i in range(100):
|
| 342 |
-
start = time.time()
|
| 343 |
-
_ = zero_shot_classification_onnx(premise=input_texts, labels=input_lables, _session=zs_session_quant,
|
| 344 |
-
_tokenizer=tokenizer_zs)
|
| 345 |
-
end = time.time()
|
| 346 |
-
|
| 347 |
-
t = (end - start) * 1000
|
| 348 |
-
onnx_runtime_quant.append(t)
|
| 349 |
-
onnx_runtime_quant = np.clip(onnx_runtime_quant, 50, 200)
|
| 350 |
-
|
| 351 |
-
temp_df = pd.DataFrame({'Normal Runtime (ms)': normal_runtime,
|
| 352 |
-
'ONNX Runtime (ms)': onnx_runtime,
|
| 353 |
-
'ONNX Quant Runtime (ms)': onnx_runtime_quant})
|
| 354 |
-
|
| 355 |
-
from plotly.subplots import make_subplots
|
| 356 |
-
|
| 357 |
-
fig = make_subplots(rows=1, cols=3, start_cell="bottom-left",
|
| 358 |
-
subplot_titles=['Normal Runtime', 'ONNX Runtime', 'ONNX Runtime with Quantization'])
|
| 359 |
-
|
| 360 |
-
fig.add_trace(go.Histogram(x=temp_df['Normal Runtime (ms)']), row=1, col=1)
|
| 361 |
-
fig.add_trace(go.Histogram(x=temp_df['ONNX Runtime (ms)']), row=1, col=2)
|
| 362 |
-
fig.add_trace(go.Histogram(x=temp_df['ONNX Quant Runtime (ms)']), row=1, col=3)
|
| 363 |
-
fig.update_layout(height=400, width=1000,
|
| 364 |
-
title_text="10 Simulations of different Runtimes",
|
| 365 |
-
showlegend=False)
|
| 366 |
-
st.plotly_chart(fig, config=_plotly_config)
|
| 367 |
else:
|
| 368 |
pass
|
| 369 |
|
|
|
|
| 14 |
from sentiment_clf_helper import classify_sentiment,create_onnx_model_sentiment,classify_sentiment_onnx
|
| 15 |
from zeroshot_clf_helper import zero_shot_classification,create_onnx_model_zs,zero_shot_classification_onnx
|
| 16 |
|
| 17 |
+
import yaml
|
| 18 |
+
def read_yaml(file_path):
|
| 19 |
+
with open(file_path, "r") as f:
|
| 20 |
+
return yaml.safe_load(f)
|
| 21 |
+
|
| 22 |
+
config = read_yaml('config.yaml')
|
| 23 |
+
|
| 24 |
+
sent_chkpt=config['SENTIMENT_CLF']['sent_chkpt']
|
| 25 |
+
sent_mdl_dir=config['SENTIMENT_CLF']['sent_mdl_dir']
|
| 26 |
+
sent_onnx_mdl_dir=config['SENTIMENT_CLF']['sent_onnx_mdl_dir']
|
| 27 |
+
sent_onnx_mdl_name=config['SENTIMENT_CLF']['sent_onnx_mdl_name']
|
| 28 |
+
sent_onnx_quant_mdl_name=config['SENTIMENT_CLF']['sent_onnx_quant_mdl_name']
|
| 29 |
+
|
| 30 |
+
zs_chkpt=config['ZEROSHOT_CLF']['zs_chkpt']
|
| 31 |
+
zs_mdl_dir=config['ZEROSHOT_CLF']['zs_mdl_dir']
|
| 32 |
+
zs_onnx_mdl_dir=config['ZEROSHOT_CLF']['zs_onnx_mdl_dir']
|
| 33 |
+
zs_onnx_mdl_name=config['ZEROSHOT_CLF']['zs_onnx_mdl_name']
|
| 34 |
+
zs_onnx_quant_mdl_name=config['ZEROSHOT_CLF']['zs_onnx_quant_mdl_name']
|
| 35 |
+
|
| 36 |
|
| 37 |
st.set_page_config( # Alternate names: setup_page, page, layout
|
| 38 |
layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc.
|
|
|
|
| 116 |
|
| 117 |
############### Pre-Download & instantiate objects for sentiment analysis *********************** START **********************
|
| 118 |
|
| 119 |
+
# #create model/token dir for sentiment classification for faster inference
|
| 120 |
+
# create_model_dir(chkpt=sent_chkpt, model_dir=sent_mdl_dir)
|
|
|
|
|
|
|
| 121 |
|
| 122 |
|
| 123 |
@st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
|
| 124 |
+
def sentiment_task_selected(task,
|
| 125 |
+
sent_chkpt=sent_chkpt,
|
| 126 |
+
sent_mdl_dir=sent_mdl_dir,
|
| 127 |
+
sent_onnx_mdl_dir=sent_onnx_mdl_dir,
|
| 128 |
+
sent_onnx_mdl_name=sent_onnx_mdl_name,
|
| 129 |
+
sent_onnx_quant_mdl_name=sent_onnx_quant_mdl_name):
|
| 130 |
#model & tokenizer initialization for normal sentiment classification
|
| 131 |
+
model_sentiment=AutoModelForSequenceClassification.from_pretrained(sent_chkpt)
|
| 132 |
+
tokenizer_sentiment=AutoTokenizer.from_pretrained(sent_chkpt)
|
| 133 |
|
| 134 |
# create onnx model for sentiment classification
|
| 135 |
create_onnx_model_sentiment(_model=model_sentiment, _tokenizer=tokenizer_sentiment)
|
| 136 |
|
| 137 |
#create inference session
|
| 138 |
+
sentiment_session = ort.InferenceSession(f"{sent_onnx_mdl_dir}/{sent_onnx_mdl_name}")
|
| 139 |
+
sentiment_session_quant = ort.InferenceSession(f"{sent_onnx_mdl_dir}/{sent_onnx_quant_mdl_name}")
|
| 140 |
|
| 141 |
return model_sentiment,tokenizer_sentiment,sentiment_session,sentiment_session_quant
|
| 142 |
|
|
|
|
| 145 |
|
| 146 |
############### Pre-Download & instantiate objects for Zero shot clf *********************** START **********************
|
| 147 |
|
| 148 |
+
# # create model/token dir for zeroshot clf
|
| 149 |
+
# create_model_dir(chkpt=zs_chkpt, model_dir=zs_mdl_dir)
|
|
|
|
|
|
|
| 150 |
|
| 151 |
|
| 152 |
@st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
|
| 153 |
+
def zs_task_selected(task,
|
| 154 |
+
zs_chkpt=zs_chkpt ,
|
| 155 |
+
zs_mdl_dir=zs_mdl_dir,
|
| 156 |
+
zs_onnx_mdl_dir=zs_onnx_mdl_dir,
|
| 157 |
+
zs_onnx_mdl_name=zs_onnx_mdl_name,
|
| 158 |
+
zs_onnx_quant_mdl_name=zs_onnx_quant_mdl_name):
|
| 159 |
+
|
| 160 |
+
##model & tokenizer initialization for normal ZS classification
|
| 161 |
+
# model_zs=AutoModelForSequenceClassification.from_pretrained(zs_chkpt)
|
| 162 |
+
# we just need tokenizer for inference and not model since onnx model is already saved
|
| 163 |
+
tokenizer_zs=AutoTokenizer.from_pretrained(zs_chkpt)
|
| 164 |
+
|
| 165 |
+
# create onnx model for zeroshot
|
| 166 |
create_onnx_model_zs()
|
| 167 |
|
| 168 |
#create inference session from onnx model
|
| 169 |
+
zs_session = ort.InferenceSession(f"{zs_onnx_mdl_dir}/{zs_onnx_mdl_name}")
|
| 170 |
+
zs_session_quant = ort.InferenceSession(f"{zs_onnx_mdl_dir}/{zs_onnx_quant_mdl_name}")
|
| 171 |
|
| 172 |
+
return tokenizer_zs,zs_session,zs_session_quant
|
| 173 |
|
| 174 |
############## Pre-Download & instantiate objects for Zero shot analysis ********************* END **********************************
|
| 175 |
|
|
|
|
| 283 |
if select_task=='Zero Shot Classification':
|
| 284 |
|
| 285 |
t1=time.time()
|
| 286 |
+
tokenizer_zs,zs_session,zs_session_quant = zs_task_selected(task=select_task)
|
| 287 |
t2 = time.time()
|
| 288 |
st.write(f"Total time to load Model is {(t2-t1)*1000:.1f} ms")
|
| 289 |
|
|
|
|
| 294 |
c1,c2,c3,c4=st.columns(4)
|
| 295 |
|
| 296 |
with c1:
|
| 297 |
+
response1=st.button("ONNX runtime")
|
| 298 |
with c2:
|
| 299 |
+
response2=st.button("ONNX runtime Quantized")
|
| 300 |
+
# with c3:
|
| 301 |
+
# response3=st.button("ONNX runtime with Quantization")
|
| 302 |
+
# with c4:
|
| 303 |
+
# response4 = st.button("Simulate 10 runs each runtime")
|
| 304 |
|
| 305 |
+
if any([response1,response2]):
|
| 306 |
if response1:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
start = time.time()
|
| 308 |
df_output=zero_shot_classification_onnx(premise=input_texts,labels=input_lables,_session=zs_session,_tokenizer=tokenizer_zs)
|
| 309 |
end=time.time()
|
|
|
|
| 317 |
title='Zero Shot Normalized Probabilities')
|
| 318 |
|
| 319 |
st.plotly_chart(fig,config=_plotly_config)
|
| 320 |
+
elif response2:
|
| 321 |
start = time.time()
|
| 322 |
df_output = zero_shot_classification_onnx(premise=input_texts, labels=input_lables, _session=zs_session_quant,
|
| 323 |
_tokenizer=tokenizer_zs)
|
|
|
|
| 331 |
title='Zero Shot Normalized Probabilities')
|
| 332 |
|
| 333 |
st.plotly_chart(fig, config=_plotly_config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
else:
|
| 335 |
pass
|
| 336 |
|