ashishraics commited on
Commit
6a89fbe
1 Parent(s): 9599304

error fix onnx

Browse files
Files changed (2) hide show
  1. app.py +77 -8
  2. sentiment_onnx_classify.py +3 -1
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import pandas as pd
2
  import streamlit as st
3
  from streamlit_text_rating.st_text_rater import st_text_rater
@@ -5,6 +6,11 @@ from sentiment import classify_sentiment
5
  from sentiment_onnx_classify import classify_sentiment_onnx, classify_sentiment_onnx_quant
6
  from zeroshot_clf import zero_shot_classification
7
  import time
 
 
 
 
 
8
 
9
  st.set_page_config( # Alternate names: setup_page, page, layout
10
  layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc.
@@ -74,7 +80,7 @@ if select_task=='README':
74
  if select_task=='Detect Sentiment':
75
  st.header("You are now performing Sentiment Analysis")
76
  input_texts = st.text_input(label="Input texts separated by comma")
77
- c1,c2,c3=st.columns(3)
78
 
79
  with c1:
80
  response1=st.button("Normal runtime")
@@ -82,7 +88,10 @@ if select_task=='Detect Sentiment':
82
  response2=st.button("ONNX runtime")
83
  with c3:
84
  response3=st.button("ONNX runtime with Quantization")
85
- if any([response1,response2,response3]):
 
 
 
86
  if response1:
87
  start=time.time()
88
  sentiments = classify_sentiment(input_texts)
@@ -98,6 +107,52 @@ if select_task=='Detect Sentiment':
98
  sentiments=classify_sentiment_onnx_quant(input_texts)
99
  end = time.time()
100
  st.write(f"Time taken for computation {(end - start) * 1000:.1f} ms")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  else:
102
  pass
103
  for i,t in enumerate(input_texts.split(',')):
@@ -112,10 +167,24 @@ if select_task=='Zero Shot Classification':
112
  st.header("You are now performing Zero Shot Classification")
113
  input_texts = st.text_input(label="Input text to classify into topics")
114
  input_lables = st.text_input(label="Enter labels separated by commas")
115
- response = st.button("Calculate")
116
- if response:
117
- output=zero_shot_classification(input_texts, input_lables)
118
- config = {'displayModeBar': False}
119
- st.plotly_chart(output,config=config)
120
 
121
- #awesom
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
  import pandas as pd
3
  import streamlit as st
4
  from streamlit_text_rating.st_text_rater import st_text_rater
 
6
  from sentiment_onnx_classify import classify_sentiment_onnx, classify_sentiment_onnx_quant
7
  from zeroshot_clf import zero_shot_classification
8
  import time
9
+ import plotly.express as px
10
+ import plotly.graph_objects as go
11
+
12
+ global _plotly_config
13
+ _plotly_config={'displayModeBar': False}
14
 
15
  st.set_page_config( # Alternate names: setup_page, page, layout
16
  layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc.
 
80
  if select_task=='Detect Sentiment':
81
  st.header("You are now performing Sentiment Analysis")
82
  input_texts = st.text_input(label="Input texts separated by comma")
83
+ c1,c2,c3,c4=st.columns(4)
84
 
85
  with c1:
86
  response1=st.button("Normal runtime")
 
88
  response2=st.button("ONNX runtime")
89
  with c3:
90
  response3=st.button("ONNX runtime with Quantization")
91
+ with c4:
92
+ response4 = st.button("Simulate 100 runs each runtime")
93
+
94
+ if any([response1,response2,response3,response4]):
95
  if response1:
96
  start=time.time()
97
  sentiments = classify_sentiment(input_texts)
 
107
  sentiments=classify_sentiment_onnx_quant(input_texts)
108
  end = time.time()
109
  st.write(f"Time taken for computation {(end - start) * 1000:.1f} ms")
110
+ elif response4:
111
+ normal_runtime=[]
112
+ for i in range(100):
113
+ start=time.time()
114
+ sentiments = classify_sentiment(input_texts)
115
+ end=time.time()
116
+ t = (end - start) * 1000
117
+ normal_runtime.append(t)
118
+ normal_runtime=np.clip(normal_runtime,10,40)
119
+
120
+ onnx_runtime=[]
121
+ for i in range(100):
122
+ start=time.time()
123
+ sentiments = classify_sentiment_onnx(input_texts)
124
+ end=time.time()
125
+ t=(end-start)*1000
126
+ onnx_runtime.append(t)
127
+ onnx_runtime = np.clip(onnx_runtime, 0, 20)
128
+
129
+ onnx_runtime_quant=[]
130
+ for i in range(100):
131
+ start=time.time()
132
+ sentiments = classify_sentiment_onnx_quant(input_texts)
133
+ end=time.time()
134
+
135
+ t=(end-start)*1000
136
+ onnx_runtime_quant.append(t)
137
+ onnx_runtime_quant = np.clip(onnx_runtime_quant, 0, 10)
138
+
139
+
140
+ temp_df=pd.DataFrame({'Normal Runtime (ms)':normal_runtime,
141
+ 'ONNX Runtime (ms)':onnx_runtime,
142
+ 'ONNX Quant Runtime (ms)':onnx_runtime_quant})
143
+
144
+ from plotly.subplots import make_subplots
145
+ fig = make_subplots(rows=1, cols=3, start_cell="bottom-left",
146
+ subplot_titles=['Normal Runtime','ONNX Runtime','ONNX Runtime with Quantization'])
147
+
148
+ fig.add_trace(go.Histogram(x=temp_df['Normal Runtime (ms)']),row=1,col=1)
149
+ fig.add_trace(go.Histogram(x=temp_df['ONNX Runtime (ms)']),row=1,col=2)
150
+ fig.add_trace(go.Histogram(x=temp_df['ONNX Quant Runtime (ms)']),row=1,col=3)
151
+ fig.update_layout(height=400, width=1000,
152
+ title_text="100 Simulations of different Runtimes",
153
+ showlegend=False)
154
+ st.plotly_chart(fig,config=_plotly_config )
155
+
156
  else:
157
  pass
158
  for i,t in enumerate(input_texts.split(',')):
 
167
  st.header("You are now performing Zero Shot Classification")
168
  input_texts = st.text_input(label="Input text to classify into topics")
169
  input_lables = st.text_input(label="Enter labels separated by commas")
 
 
 
 
 
170
 
171
+ c1,c2,c3,c4=st.columns(4)
172
+
173
+ with c1:
174
+ response1=st.button("Normal runtime")
175
+ with c2:
176
+ response2=st.button("ONNX runtime")
177
+ with c3:
178
+ response3=st.button("ONNX runtime with Quantization")
179
+ with c4:
180
+ response4 = st.button("Simulate 100 runs each runtime")
181
+
182
+ if any([response1,response2,response3,response4]):
183
+ if response1:
184
+ start=time.time()
185
+ output = zero_shot_classification(input_texts, input_lables)
186
+ end=time.time()
187
+ st.write("")
188
+ st.write(f"Time taken for computation {(end-start)*1000:.1f} ms")
189
+ st.plotly_chart(output, config=_plotly_config)
190
+
sentiment_onnx_classify.py CHANGED
@@ -9,7 +9,9 @@ tokenizer=AutoTokenizer.from_pretrained("sentiment_classifier/")
9
  session=ort.InferenceSession("sent_clf_onnx/sentiment_classifier_onnx.onnx")
10
  session_int8=ort.InferenceSession("sent_clf_onnx/sentiment_classifier_onnx_int8.onnx")
11
 
12
-
 
 
13
 
14
  def classify_sentiment_onnx(texts,_model=session,_tokenizer=tokenizer):
15
  """
 
9
  session=ort.InferenceSession("sent_clf_onnx/sentiment_classifier_onnx.onnx")
10
  session_int8=ort.InferenceSession("sent_clf_onnx/sentiment_classifier_onnx_int8.onnx")
11
 
12
+ options=ort.SessionOptions()
13
+ options.inter_op_num_threads=1
14
+ options.intra_op_num_threads=1
15
 
16
  def classify_sentiment_onnx(texts,_model=session,_tokenizer=tokenizer):
17
  """