ashishraics commited on
Commit
84fa2e9
1 Parent(s): 8bb7965

change threading options for onnx inference

Browse files
Files changed (1) hide show
  1. app.py +11 -2
app.py CHANGED
@@ -87,6 +87,10 @@ hide_streamlit_style = """
87
  """
88
  st.markdown(hide_streamlit_style, unsafe_allow_html=True)
89
 
 
 
 
 
90
 
91
  @st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
92
  def create_model_dir(chkpt, model_dir):
@@ -180,6 +184,9 @@ if select_task=='README':
180
  if select_task == 'Detect Sentiment':
181
  t1=time.time()
182
  tokenizer_sentiment,sentiment_session = sentiment_task_selected(task=select_task)
 
 
 
183
  t2 = time.time()
184
  st.write(f"Total time to load Model is {(t2-t1)*1000:.1f} ms")
185
 
@@ -210,7 +217,9 @@ if select_task == 'Detect Sentiment':
210
 
211
  if select_task=='Zero Shot Classification':
212
  t1=time.time()
213
- tokenizer_zs,zs_session = zs_task_selected(task=select_task)
 
 
214
  t2 = time.time()
215
  st.write(f"Total time to load Model is {(t2-t1)*1000:.1f} ms")
216
 
@@ -225,7 +234,7 @@ if select_task=='Zero Shot Classification':
225
 
226
  if response1:
227
  start = time.time()
228
- df_output = zero_shot_classification_onnx(premise=input_texts, labels=input_lables, _session=zs_session,
229
  _tokenizer=tokenizer_zs)
230
  end = time.time()
231
  st.write("")
 
87
  """
88
  st.markdown(hide_streamlit_style, unsafe_allow_html=True)
89
 
90
+ options = ort.SessionOptions()
91
+ options.intra_op_num_threads=1
92
+ options.inter_op_num_threads=1
93
+
94
 
95
  @st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
96
  def create_model_dir(chkpt, model_dir):
 
184
  if select_task == 'Detect Sentiment':
185
  t1=time.time()
186
  tokenizer_sentiment,sentiment_session = sentiment_task_selected(task=select_task)
187
+ ##below 2 steps are slower as caching is not enabled
188
+ # tokenizer_sentiment = AutoTokenizer.from_pretrained(sent_mdl_dir)
189
+ # sentiment_session = ort.InferenceSession(f"{sent_onnx_mdl_dir}/{sent_onnx_mdl_name}")
190
  t2 = time.time()
191
  st.write(f"Total time to load Model is {(t2-t1)*1000:.1f} ms")
192
 
 
217
 
218
  if select_task=='Zero Shot Classification':
219
  t1=time.time()
220
+ tokenizer_zs,session_zs = zs_task_selected(task=select_task)
221
+ # tokenizer_zs= AutoTokenizer.from_pretrained(zs_mdl_dir)
222
+ # session_zs = ort.InferenceSession(f"{zs_onnx_mdl_dir}/{zs_onnx_mdl_name}")
223
  t2 = time.time()
224
  st.write(f"Total time to load Model is {(t2-t1)*1000:.1f} ms")
225
 
 
234
 
235
  if response1:
236
  start = time.time()
237
+ df_output = zero_shot_classification_onnx(premise=input_texts, labels=input_lables, _session=session_zs,
238
  _tokenizer=tokenizer_zs)
239
  end = time.time()
240
  st.write("")