ashishraics commited on
Commit
cca4ece
1 Parent(s): 9111b95

optimizing app

Browse files
Files changed (1) hide show
  1. app.py +58 -91
app.py CHANGED
@@ -14,6 +14,25 @@ _plotly_config={'displayModeBar': False}
14
  from sentiment_clf_helper import classify_sentiment,create_onnx_model_sentiment,classify_sentiment_onnx
15
  from zeroshot_clf_helper import zero_shot_classification,create_onnx_model_zs,zero_shot_classification_onnx
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  st.set_page_config( # Alternate names: setup_page, page, layout
19
  layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc.
@@ -97,24 +116,27 @@ if select_task=='README':
97
 
98
  ############### Pre-Download & instantiate objects for sentiment analysis *********************** START **********************
99
 
100
- sent_chkpt = "distilbert-base-uncased-finetuned-sst-2-english"
101
- sent_model_dir="sentiment_model_dir"
102
- #create model/token dir for sentiment classification
103
- create_model_dir(chkpt=sent_chkpt, model_dir=sent_model_dir)
104
 
105
 
106
  @st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
107
- def sentiment_task_selected(task,sent_model_dir=sent_model_dir):
 
 
 
 
 
108
  #model & tokenizer initialization for normal sentiment classification
109
- model_sentiment=AutoModelForSequenceClassification.from_pretrained(sent_model_dir)
110
- tokenizer_sentiment=AutoTokenizer.from_pretrained(sent_model_dir)
111
 
112
  # create onnx model for sentiment classification
113
  create_onnx_model_sentiment(_model=model_sentiment, _tokenizer=tokenizer_sentiment)
114
 
115
  #create inference session
116
- sentiment_session = ort.InferenceSession("sent_clf_onnx_dir/sentiment_classifier_onnx.onnx")
117
- sentiment_session_quant = ort.InferenceSession("sent_clf_onnx_dir/sentiment_classifier_onnx_quant.onnx")
118
 
119
  return model_sentiment,tokenizer_sentiment,sentiment_session,sentiment_session_quant
120
 
@@ -123,26 +145,31 @@ def sentiment_task_selected(task,sent_model_dir=sent_model_dir):
123
 
124
  ############### Pre-Download & instantiate objects for Zero shot clf *********************** START **********************
125
 
126
- zs_chkpt = "valhalla/distilbart-mnli-12-1"
127
- zs_model_dir = "zs_model_dir"
128
- # create model/token dir for zeroshot clf
129
- create_model_dir(chkpt=zs_chkpt, model_dir=zs_model_dir)
130
 
131
 
132
  @st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
133
- def zs_task_selected(task, zs_model_dir=zs_model_dir,onnx_dir='zeroshot_onnx_dir'):
134
- #model & tokenizer initialization for normal ZS classification
135
- model_zs=AutoModelForSequenceClassification.from_pretrained(zs_model_dir)
136
- tokenizer_zs=AutoTokenizer.from_pretrained(zs_model_dir)
137
-
138
- # ceate onnx model for zeroshot
 
 
 
 
 
 
 
139
  create_onnx_model_zs()
140
 
141
  #create inference session from onnx model
142
- zs_session = ort.InferenceSession(f"{onnx_dir}/model.onnx")
143
- zs_session_quant = ort.InferenceSession(f"{onnx_dir}/model_quant.onnx")
144
 
145
- return model_zs,tokenizer_zs,zs_session,zs_session_quant
146
 
147
  ############## Pre-Download & instantiate objects for Zero shot analysis ********************* END **********************************
148
 
@@ -256,7 +283,7 @@ if select_task == 'Detect Sentiment':
256
  if select_task=='Zero Shot Classification':
257
 
258
  t1=time.time()
259
- model_zs,tokenizer_zs,zs_session,zs_session_quant = zs_task_selected(task=select_task)
260
  t2 = time.time()
261
  st.write(f"Total time to load Model is {(t2-t1)*1000:.1f} ms")
262
 
@@ -267,29 +294,16 @@ if select_task=='Zero Shot Classification':
267
  c1,c2,c3,c4=st.columns(4)
268
 
269
  with c1:
270
- response1=st.button("Normal runtime")
271
  with c2:
272
- response2=st.button("ONNX runtime")
273
- with c3:
274
- response3=st.button("ONNX runtime with Quantization")
275
- with c4:
276
- response4 = st.button("Simulate 10 runs each runtime")
277
 
278
- if any([response1,response2,response3,response4]):
279
  if response1:
280
- start=time.time()
281
- df_output = zero_shot_classification(input_texts, input_lables,model=model_zs,tokenizer=tokenizer_zs)
282
- end=time.time()
283
- st.write("")
284
- st.write(f"Time taken for computation {(end-start)*1000:.1f} ms")
285
- fig = px.bar(x='Probability',
286
- y='labels',
287
- text='Probability',
288
- data_frame=df_output,
289
- title='Zero Shot Normalized Probabilities')
290
-
291
- st.plotly_chart(fig, config=_plotly_config)
292
- elif response2:
293
  start = time.time()
294
  df_output=zero_shot_classification_onnx(premise=input_texts,labels=input_lables,_session=zs_session,_tokenizer=tokenizer_zs)
295
  end=time.time()
@@ -303,7 +317,7 @@ if select_task=='Zero Shot Classification':
303
  title='Zero Shot Normalized Probabilities')
304
 
305
  st.plotly_chart(fig,config=_plotly_config)
306
- elif response3:
307
  start = time.time()
308
  df_output = zero_shot_classification_onnx(premise=input_texts, labels=input_lables, _session=zs_session_quant,
309
  _tokenizer=tokenizer_zs)
@@ -317,53 +331,6 @@ if select_task=='Zero Shot Classification':
317
  title='Zero Shot Normalized Probabilities')
318
 
319
  st.plotly_chart(fig, config=_plotly_config)
320
- elif response4:
321
- normal_runtime = []
322
- for i in range(100):
323
- start = time.time()
324
- _ = zero_shot_classification(input_texts, input_lables,model=model_zs,tokenizer=tokenizer_zs)
325
- end = time.time()
326
- t = (end - start) * 1000
327
- normal_runtime.append(t)
328
- normal_runtime = np.clip(normal_runtime, 50, 400)
329
-
330
- onnx_runtime = []
331
- for i in range(100):
332
- start = time.time()
333
- _ = zero_shot_classification_onnx(premise=input_texts, labels=input_lables, _session=zs_session,
334
- _tokenizer=tokenizer_zs)
335
- end = time.time()
336
- t = (end - start) * 1000
337
- onnx_runtime.append(t)
338
- onnx_runtime = np.clip(onnx_runtime, 50, 200)
339
-
340
- onnx_runtime_quant = []
341
- for i in range(100):
342
- start = time.time()
343
- _ = zero_shot_classification_onnx(premise=input_texts, labels=input_lables, _session=zs_session_quant,
344
- _tokenizer=tokenizer_zs)
345
- end = time.time()
346
-
347
- t = (end - start) * 1000
348
- onnx_runtime_quant.append(t)
349
- onnx_runtime_quant = np.clip(onnx_runtime_quant, 50, 200)
350
-
351
- temp_df = pd.DataFrame({'Normal Runtime (ms)': normal_runtime,
352
- 'ONNX Runtime (ms)': onnx_runtime,
353
- 'ONNX Quant Runtime (ms)': onnx_runtime_quant})
354
-
355
- from plotly.subplots import make_subplots
356
-
357
- fig = make_subplots(rows=1, cols=3, start_cell="bottom-left",
358
- subplot_titles=['Normal Runtime', 'ONNX Runtime', 'ONNX Runtime with Quantization'])
359
-
360
- fig.add_trace(go.Histogram(x=temp_df['Normal Runtime (ms)']), row=1, col=1)
361
- fig.add_trace(go.Histogram(x=temp_df['ONNX Runtime (ms)']), row=1, col=2)
362
- fig.add_trace(go.Histogram(x=temp_df['ONNX Quant Runtime (ms)']), row=1, col=3)
363
- fig.update_layout(height=400, width=1000,
364
- title_text="10 Simulations of different Runtimes",
365
- showlegend=False)
366
- st.plotly_chart(fig, config=_plotly_config)
367
  else:
368
  pass
369
 
 
14
  from sentiment_clf_helper import classify_sentiment,create_onnx_model_sentiment,classify_sentiment_onnx
15
  from zeroshot_clf_helper import zero_shot_classification,create_onnx_model_zs,zero_shot_classification_onnx
16
 
17
+ import yaml
18
+ def read_yaml(file_path):
19
+ with open(file_path, "r") as f:
20
+ return yaml.safe_load(f)
21
+
22
+ config = read_yaml('config.yaml')
23
+
24
+ sent_chkpt=config['SENTIMENT_CLF']['sent_chkpt']
25
+ sent_mdl_dir=config['SENTIMENT_CLF']['sent_mdl_dir']
26
+ sent_onnx_mdl_dir=config['SENTIMENT_CLF']['sent_onnx_mdl_dir']
27
+ sent_onnx_mdl_name=config['SENTIMENT_CLF']['sent_onnx_mdl_name']
28
+ sent_onnx_quant_mdl_name=config['SENTIMENT_CLF']['sent_onnx_quant_mdl_name']
29
+
30
+ zs_chkpt=config['ZEROSHOT_CLF']['zs_chkpt']
31
+ zs_mdl_dir=config['ZEROSHOT_CLF']['zs_mdl_dir']
32
+ zs_onnx_mdl_dir=config['ZEROSHOT_CLF']['zs_onnx_mdl_dir']
33
+ zs_onnx_mdl_name=config['ZEROSHOT_CLF']['zs_onnx_mdl_name']
34
+ zs_onnx_quant_mdl_name=config['ZEROSHOT_CLF']['zs_onnx_quant_mdl_name']
35
+
36
 
37
  st.set_page_config( # Alternate names: setup_page, page, layout
38
  layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc.
 
116
 
117
  ############### Pre-Download & instantiate objects for sentiment analysis *********************** START **********************
118
 
119
+ # #create model/token dir for sentiment classification for faster inference
120
+ # create_model_dir(chkpt=sent_chkpt, model_dir=sent_mdl_dir)
 
 
121
 
122
 
123
  @st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
124
+ def sentiment_task_selected(task,
125
+ sent_chkpt=sent_chkpt,
126
+ sent_mdl_dir=sent_mdl_dir,
127
+ sent_onnx_mdl_dir=sent_onnx_mdl_dir,
128
+ sent_onnx_mdl_name=sent_onnx_mdl_name,
129
+ sent_onnx_quant_mdl_name=sent_onnx_quant_mdl_name):
130
  #model & tokenizer initialization for normal sentiment classification
131
+ model_sentiment=AutoModelForSequenceClassification.from_pretrained(sent_chkpt)
132
+ tokenizer_sentiment=AutoTokenizer.from_pretrained(sent_chkpt)
133
 
134
  # create onnx model for sentiment classification
135
  create_onnx_model_sentiment(_model=model_sentiment, _tokenizer=tokenizer_sentiment)
136
 
137
  #create inference session
138
+ sentiment_session = ort.InferenceSession(f"{sent_onnx_mdl_dir}/{sent_onnx_mdl_name}")
139
+ sentiment_session_quant = ort.InferenceSession(f"{sent_onnx_mdl_dir}/{sent_onnx_quant_mdl_name}")
140
 
141
  return model_sentiment,tokenizer_sentiment,sentiment_session,sentiment_session_quant
142
 
 
145
 
146
  ############### Pre-Download & instantiate objects for Zero shot clf *********************** START **********************
147
 
148
+ # # create model/token dir for zeroshot clf
149
+ # create_model_dir(chkpt=zs_chkpt, model_dir=zs_mdl_dir)
 
 
150
 
151
 
152
  @st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
153
+ def zs_task_selected(task,
154
+ zs_chkpt=zs_chkpt ,
155
+ zs_mdl_dir=zs_mdl_dir,
156
+ zs_onnx_mdl_dir=zs_onnx_mdl_dir,
157
+ zs_onnx_mdl_name=zs_onnx_mdl_name,
158
+ zs_onnx_quant_mdl_name=zs_onnx_quant_mdl_name):
159
+
160
+ ##model & tokenizer initialization for normal ZS classification
161
+ # model_zs=AutoModelForSequenceClassification.from_pretrained(zs_chkpt)
162
+ # we just need tokenizer for inference and not model since onnx model is already saved
163
+ tokenizer_zs=AutoTokenizer.from_pretrained(zs_chkpt)
164
+
165
+ # create onnx model for zeroshot
166
  create_onnx_model_zs()
167
 
168
  #create inference session from onnx model
169
+ zs_session = ort.InferenceSession(f"{zs_onnx_mdl_dir}/{zs_onnx_mdl_name}")
170
+ zs_session_quant = ort.InferenceSession(f"{zs_onnx_mdl_dir}/{zs_onnx_quant_mdl_name}")
171
 
172
+ return tokenizer_zs,zs_session,zs_session_quant
173
 
174
  ############## Pre-Download & instantiate objects for Zero shot analysis ********************* END **********************************
175
 
 
283
  if select_task=='Zero Shot Classification':
284
 
285
  t1=time.time()
286
+ tokenizer_zs,zs_session,zs_session_quant = zs_task_selected(task=select_task)
287
  t2 = time.time()
288
  st.write(f"Total time to load Model is {(t2-t1)*1000:.1f} ms")
289
 
 
294
  c1,c2,c3,c4=st.columns(4)
295
 
296
  with c1:
297
+ response1=st.button("ONNX runtime")
298
  with c2:
299
+ response2=st.button("ONNX runtime Quantized")
300
+ # with c3:
301
+ # response3=st.button("ONNX runtime with Quantization")
302
+ # with c4:
303
+ # response4 = st.button("Simulate 10 runs each runtime")
304
 
305
+ if any([response1,response2]):
306
  if response1:
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  start = time.time()
308
  df_output=zero_shot_classification_onnx(premise=input_texts,labels=input_lables,_session=zs_session,_tokenizer=tokenizer_zs)
309
  end=time.time()
 
317
  title='Zero Shot Normalized Probabilities')
318
 
319
  st.plotly_chart(fig,config=_plotly_config)
320
+ elif response2:
321
  start = time.time()
322
  df_output = zero_shot_classification_onnx(premise=input_texts, labels=input_lables, _session=zs_session_quant,
323
  _tokenizer=tokenizer_zs)
 
331
  title='Zero Shot Normalized Probabilities')
332
 
333
  st.plotly_chart(fig, config=_plotly_config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  else:
335
  pass
336