ashishraics commited on
Commit
a48f2db
1 Parent(s): 6886461

optimized app

Browse files
.gitignore CHANGED
@@ -1,3 +1,6 @@
1
  venv/
2
  sent_clf_onnx/
3
- sentiment_model_dir/
 
 
 
1
  venv/
2
  sent_clf_onnx/
3
+ sentiment_model_dir/
4
+ zs_model_dir/
5
+ zeroshot_onnx_dir/
6
+ sent_clf_onnx_dir/
__pycache__/sentiment.cpython-39.pyc DELETED
Binary file (939 Bytes)
__pycache__/zeroshot_clf.cpython-39.pyc DELETED
Binary file (1.61 kB)
app.py CHANGED
@@ -2,21 +2,19 @@ import numpy as np
2
  import pandas as pd
3
  import streamlit as st
4
  from streamlit_text_rating.st_text_rater import st_text_rater
5
- from sentiment import classify_sentiment
6
- from sentiment_onnx_classify import classify_sentiment_onnx, classify_sentiment_onnx_quant,create_onnx_model
7
- from zeroshot_clf import zero_shot_classification
8
  from transformers import AutoTokenizer,AutoModelForSequenceClassification
9
- from onnxruntime.quantization import quantize_dynamic,QuantType
10
- import transformers.convert_graph_to_onnx as onnx_convert
11
- from pathlib import Path
12
  import os
13
  import time
14
  import plotly.express as px
15
  import plotly.graph_objects as go
16
- import onnxruntime as ort
17
  global _plotly_config
18
  _plotly_config={'displayModeBar': False}
19
 
 
 
 
 
20
  st.set_page_config( # Alternate names: setup_page, page, layout
21
  layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc.
22
  initial_sidebar_state="auto", # Can be "auto", "expanded", "collapsed"
@@ -73,7 +71,7 @@ st.markdown(hide_streamlit_style, unsafe_allow_html=True)
73
 
74
  @st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
75
  def create_model_dir(chkpt, model_dir):
76
- if not os.path.exists(chkpt):
77
  try:
78
  os.mkdir(path=model_dir)
79
  except:
@@ -101,30 +99,57 @@ if select_task=='README':
101
 
102
  sent_chkpt = "distilbert-base-uncased-finetuned-sst-2-english"
103
  sent_model_dir="sentiment_model_dir"
104
- #create model/token dir
105
  create_model_dir(chkpt=sent_chkpt, model_dir=sent_model_dir)
 
 
 
 
106
 
107
  @st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
108
- def task_selected(task,sent_model_dir=sent_model_dir):
109
  model_sentiment=AutoModelForSequenceClassification.from_pretrained(sent_model_dir)
110
  tokenizer_sentiment=AutoTokenizer.from_pretrained(sent_model_dir)
 
 
 
 
111
 
112
- create_onnx_model(_model=model_sentiment, _tokenizer=tokenizer_sentiment)
113
 
114
- #create inference session
115
- sentiment_session = ort.InferenceSession("sent_clf_onnx/sentiment_classifier_onnx.onnx")
116
- sentiment_session_int8 = ort.InferenceSession("sent_clf_onnx/sentiment_classifier_onnx_int8.onnx")
117
 
118
- return model_sentiment,tokenizer_sentiment,sentiment_session,sentiment_session_int8
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- ############## Pre-Download & instantiate objects for sentiment analysis ********************* END **********************************
122
 
123
  if select_task == 'Detect Sentiment':
124
 
125
  t1=time.time()
126
  model_sentiment,tokenizer_sentiment,\
127
- sentiment_session,sentiment_session_int8 = task_selected(task=select_task)
128
  t2 = time.time()
129
  st.write(f"Total time to load Model is {(t2-t1)*1000:.1f} ms")
130
 
@@ -159,8 +184,8 @@ if select_task == 'Detect Sentiment':
159
  st.write(f"Time taken for computation {(end - start) * 1000:.1f} ms")
160
  elif response3:
161
  start = time.time()
162
- sentiments=classify_sentiment_onnx_quant(input_texts,
163
- _session=sentiment_session_int8,
164
  _tokenizer=tokenizer_sentiment)
165
  end = time.time()
166
  st.write(f"Time taken for computation {(end - start) * 1000:.1f} ms")
@@ -190,8 +215,8 @@ if select_task == 'Detect Sentiment':
190
  onnx_runtime_quant=[]
191
  for i in range(100):
192
  start=time.time()
193
- sentiments = classify_sentiment_onnx_quant(input_texts,
194
- _session=sentiment_session,
195
  _tokenizer=tokenizer_sentiment)
196
  end=time.time()
197
 
@@ -227,6 +252,12 @@ if select_task == 'Detect Sentiment':
227
  color_background='rgb(233, 116, 81)',key=t)
228
 
229
  if select_task=='Zero Shot Classification':
 
 
 
 
 
 
230
  st.header("You are now performing Zero Shot Classification")
231
  input_texts = st.text_input(label="Input text to classify into topics")
232
  input_lables = st.text_input(label="Enter labels separated by commas")
@@ -240,14 +271,97 @@ if select_task=='Zero Shot Classification':
240
  with c3:
241
  response3=st.button("ONNX runtime with Quantization")
242
  with c4:
243
- response4 = st.button("Simulate 100 runs each runtime")
244
 
245
  if any([response1,response2,response3,response4]):
246
  if response1:
247
  start=time.time()
248
- output = zero_shot_classification(input_texts, input_lables)
249
  end=time.time()
250
  st.write("")
251
  st.write(f"Time taken for computation {(end-start)*1000:.1f} ms")
252
- st.plotly_chart(output, config=_plotly_config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
2
  import pandas as pd
3
  import streamlit as st
4
  from streamlit_text_rating.st_text_rater import st_text_rater
 
 
 
5
  from transformers import AutoTokenizer,AutoModelForSequenceClassification
6
+ import onnxruntime as ort
 
 
7
  import os
8
  import time
9
  import plotly.express as px
10
  import plotly.graph_objects as go
 
11
  global _plotly_config
12
  _plotly_config={'displayModeBar': False}
13
 
14
+ from sentiment_clf_helper import classify_sentiment,create_onnx_model_sentiment,classify_sentiment_onnx
15
+ from zeroshot_clf_helper import zero_shot_classification,create_onnx_model_zs,zero_shot_classification_onnx
16
+
17
+
18
  st.set_page_config( # Alternate names: setup_page, page, layout
19
  layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc.
20
  initial_sidebar_state="auto", # Can be "auto", "expanded", "collapsed"
71
 
72
  @st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
73
  def create_model_dir(chkpt, model_dir):
74
+ if not os.path.exists(model_dir):
75
  try:
76
  os.mkdir(path=model_dir)
77
  except:
99
 
100
  sent_chkpt = "distilbert-base-uncased-finetuned-sst-2-english"
101
  sent_model_dir="sentiment_model_dir"
102
+ #create model/token dir for sentiment classification
103
  create_model_dir(chkpt=sent_chkpt, model_dir=sent_model_dir)
104
+ #create onnx model for sentiment classification
105
+ model_sentiment=AutoModelForSequenceClassification.from_pretrained(sent_model_dir)
106
+ tokenizer_sentiment=AutoTokenizer.from_pretrained(sent_model_dir)
107
+ create_onnx_model_sentiment(_model=model_sentiment, _tokenizer=tokenizer_sentiment)
108
 
109
  @st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
110
+ def sentiment_task_selected(task,sent_model_dir=sent_model_dir):
111
  model_sentiment=AutoModelForSequenceClassification.from_pretrained(sent_model_dir)
112
  tokenizer_sentiment=AutoTokenizer.from_pretrained(sent_model_dir)
113
+ # create_onnx_model_sentiment(_model=model_sentiment, _tokenizer=tokenizer_sentiment)
114
+ #create inference session
115
+ sentiment_session = ort.InferenceSession("sent_clf_onnx_dir/sentiment_classifier_onnx.onnx")
116
+ sentiment_session_quant = ort.InferenceSession("sent_clf_onnx_dir/sentiment_classifier_onnx_quant.onnx")
117
 
118
+ return model_sentiment,tokenizer_sentiment,sentiment_session,sentiment_session_quant
119
 
120
+ ############## Pre-Download & instantiate objects for sentiment analysis ********************* END **********************************
 
 
121
 
 
122
 
123
+ ############### Pre-Download & instantiate objects for Zero shot clf *********************** START **********************
124
+
125
+ zs_chkpt = "valhalla/distilbart-mnli-12-1"
126
+ zs_model_dir = "zs_model_dir"
127
+ # create model/token dir for zeroshot clf
128
+ create_model_dir(chkpt=zs_chkpt, model_dir=zs_model_dir)
129
+ #ceate onnx model for zeroshot
130
+ create_onnx_model_zs()
131
+
132
+ @st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
133
+ def zs_task_selected(task, zs_model_dir=zs_model_dir,onnx_dir='zeroshot_onnx_dir'):
134
+
135
+ #model & tokenizer initialization for normal ZS classification
136
+ model_zs=AutoModelForSequenceClassification.from_pretrained(zs_model_dir)
137
+ tokenizer_zs=AutoTokenizer.from_pretrained(zs_model_dir)
138
+
139
+ #create inference session from onnx model
140
+ zs_session = ort.InferenceSession(f"{onnx_dir}/model.onnx")
141
+ zs_session_quant = ort.InferenceSession(f"{onnx_dir}/model_quant.onnx")
142
+
143
+ return model_zs,tokenizer_zs,zs_session,zs_session_quant
144
+
145
+ ############## Pre-Download & instantiate objects for Zero shot analysis ********************* END **********************************
146
 
 
147
 
148
  if select_task == 'Detect Sentiment':
149
 
150
  t1=time.time()
151
  model_sentiment,tokenizer_sentiment,\
152
+ sentiment_session,sentiment_session_quant = sentiment_task_selected(task=select_task)
153
  t2 = time.time()
154
  st.write(f"Total time to load Model is {(t2-t1)*1000:.1f} ms")
155
 
184
  st.write(f"Time taken for computation {(end - start) * 1000:.1f} ms")
185
  elif response3:
186
  start = time.time()
187
+ sentiments=classify_sentiment_onnx(input_texts,
188
+ _session=sentiment_session_quant,
189
  _tokenizer=tokenizer_sentiment)
190
  end = time.time()
191
  st.write(f"Time taken for computation {(end - start) * 1000:.1f} ms")
215
  onnx_runtime_quant=[]
216
  for i in range(100):
217
  start=time.time()
218
+ sentiments = classify_sentiment_onnx(input_texts,
219
+ _session=sentiment_session_quant,
220
  _tokenizer=tokenizer_sentiment)
221
  end=time.time()
222
 
252
  color_background='rgb(233, 116, 81)',key=t)
253
 
254
  if select_task=='Zero Shot Classification':
255
+
256
+ t1=time.time()
257
+ model_zs,tokenizer_zs,zs_session,zs_session_quant = zs_task_selected(task=select_task)
258
+ t2 = time.time()
259
+ st.write(f"Total time to load Model is {(t2-t1)*1000:.1f} ms")
260
+
261
  st.header("You are now performing Zero Shot Classification")
262
  input_texts = st.text_input(label="Input text to classify into topics")
263
  input_lables = st.text_input(label="Enter labels separated by commas")
271
  with c3:
272
  response3=st.button("ONNX runtime with Quantization")
273
  with c4:
274
+ response4 = st.button("Simulate 10 runs each runtime")
275
 
276
  if any([response1,response2,response3,response4]):
277
  if response1:
278
  start=time.time()
279
+ df_output = zero_shot_classification(input_texts, input_lables,model=model_zs,tokenizer=tokenizer_zs)
280
  end=time.time()
281
  st.write("")
282
  st.write(f"Time taken for computation {(end-start)*1000:.1f} ms")
283
+ fig = px.bar(x='Probability',
284
+ y='labels',
285
+ text='Probability',
286
+ data_frame=df_output,
287
+ title='Zero Shot Normalized Probabilities')
288
+
289
+ st.plotly_chart(fig, config=_plotly_config)
290
+ elif response2:
291
+ start = time.time()
292
+ df_output=zero_shot_classification_onnx(premise=input_texts,labels=input_lables,_session=zs_session,_tokenizer=tokenizer_zs)
293
+ end=time.time()
294
+ st.write("")
295
+ st.write(f"Time taken for computation {(end-start)*1000:.1f} ms")
296
+
297
+ fig = px.bar(x='Probability',
298
+ y='labels',
299
+ text='Probability',
300
+ data_frame=df_output,
301
+ title='Zero Shot Normalized Probabilities')
302
+
303
+ st.plotly_chart(fig,config=_plotly_config)
304
+ elif response3:
305
+ start = time.time()
306
+ df_output = zero_shot_classification_onnx(premise=input_texts, labels=input_lables, _session=zs_session_quant,
307
+ _tokenizer=tokenizer_zs)
308
+ end = time.time()
309
+ st.write("")
310
+ st.write(f"Time taken for computation {(end-start)*1000:.1f} ms")
311
+ fig = px.bar(x='Probability',
312
+ y='labels',
313
+ text='Probability',
314
+ data_frame=df_output,
315
+ title='Zero Shot Normalized Probabilities')
316
+
317
+ st.plotly_chart(fig, config=_plotly_config)
318
+ elif response4:
319
+ normal_runtime = []
320
+ for i in range(100):
321
+ start = time.time()
322
+ _ = zero_shot_classification(input_texts, input_lables,model=model_zs,tokenizer=tokenizer_zs)
323
+ end = time.time()
324
+ t = (end - start) * 1000
325
+ normal_runtime.append(t)
326
+ normal_runtime = np.clip(normal_runtime, 50, 400)
327
+
328
+ onnx_runtime = []
329
+ for i in range(100):
330
+ start = time.time()
331
+ _ = zero_shot_classification_onnx(premise=input_texts, labels=input_lables, _session=zs_session,
332
+ _tokenizer=tokenizer_zs)
333
+ end = time.time()
334
+ t = (end - start) * 1000
335
+ onnx_runtime.append(t)
336
+ onnx_runtime = np.clip(onnx_runtime, 50, 200)
337
+
338
+ onnx_runtime_quant = []
339
+ for i in range(100):
340
+ start = time.time()
341
+ _ = zero_shot_classification_onnx(premise=input_texts, labels=input_lables, _session=zs_session_quant,
342
+ _tokenizer=tokenizer_zs)
343
+ end = time.time()
344
+
345
+ t = (end - start) * 1000
346
+ onnx_runtime_quant.append(t)
347
+ onnx_runtime_quant = np.clip(onnx_runtime_quant, 50, 200)
348
+
349
+ temp_df = pd.DataFrame({'Normal Runtime (ms)': normal_runtime,
350
+ 'ONNX Runtime (ms)': onnx_runtime,
351
+ 'ONNX Quant Runtime (ms)': onnx_runtime_quant})
352
+
353
+ from plotly.subplots import make_subplots
354
+
355
+ fig = make_subplots(rows=1, cols=3, start_cell="bottom-left",
356
+ subplot_titles=['Normal Runtime', 'ONNX Runtime', 'ONNX Runtime with Quantization'])
357
+
358
+ fig.add_trace(go.Histogram(x=temp_df['Normal Runtime (ms)']), row=1, col=1)
359
+ fig.add_trace(go.Histogram(x=temp_df['ONNX Runtime (ms)']), row=1, col=2)
360
+ fig.add_trace(go.Histogram(x=temp_df['ONNX Quant Runtime (ms)']), row=1, col=3)
361
+ fig.update_layout(height=400, width=1000,
362
+ title_text="10 Simulations of different Runtimes",
363
+ showlegend=False)
364
+ st.plotly_chart(fig, config=_plotly_config)
365
+ else:
366
+ pass
367
 
config.yaml ADDED
File without changes
sentiment.py DELETED
@@ -1,23 +0,0 @@
1
- import torch
2
- from transformers import AutoModelForSequenceClassification,AutoTokenizer
3
-
4
- chkpt='distilbert-base-uncased-finetuned-sst-2-english'
5
- model=AutoModelForSequenceClassification.from_pretrained(chkpt)
6
- tokenizer=AutoTokenizer.from_pretrained(chkpt)
7
- # tokenizer=AutoTokenizer.from_pretrained('sentiment_classifier/')
8
-
9
- def classify_sentiment(texts,model,tokenizer):
10
- """
11
- user will pass texts separated by comma
12
- """
13
- try:
14
- texts=texts.split(',')
15
- except:
16
- pass
17
-
18
- input = tokenizer(texts, padding=True, truncation=True,
19
- return_tensors="pt")
20
- logits = model(**input)['logits'].softmax(dim=1)
21
- logits = torch.argmax(logits, dim=1)
22
- output = ['Positive' if i == 1 else 'Negative' for i in logits]
23
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sentiment_onnx_classify.py → sentiment_clf_helper.py RENAMED
@@ -1,6 +1,3 @@
1
- import onnxruntime as ort
2
- import torch
3
- from transformers import AutoTokenizer,AutoModelForSequenceClassification
4
  import numpy as np
5
  import transformers
6
  from onnxruntime.quantization import quantize_dynamic,QuantType
@@ -8,11 +5,32 @@ import transformers.convert_graph_to_onnx as onnx_convert
8
  from pathlib import Path
9
  import os
10
 
11
- # chkpt='distilbert-base-uncased-finetuned-sst-2-english'
12
- # model= AutoModelForSequenceClassification.from_pretrained(chkpt)
13
- # tokenizer= AutoTokenizer.from_pretrained(chkpt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- def create_onnx_model(_model, _tokenizer):
16
  """
17
 
18
  Args:
@@ -23,44 +41,25 @@ def create_onnx_model(_model, _tokenizer):
23
  Creates a simple ONNX model & int8 Quantized Model in the directory "sent_clf_onnx/" if directory not present
24
 
25
  """
26
- if not os.path.exists('sent_clf_onnx'):
27
  try:
28
- os.mkdir('sent_clf_onnx')
29
  except:
30
  pass
31
- """
32
- Making ONNX model object
33
- """
34
  pipeline=transformers.pipeline("text-classification", model=_model, tokenizer=_tokenizer)
35
 
36
- """
37
- convert pipeline to onnx object
38
- """
39
  onnx_convert.convert_pytorch(pipeline,
40
  opset=11,
41
- output=Path("sent_clf_onnx/sentiment_classifier_onnx.onnx"),
42
  use_external_format=False
43
  )
44
 
45
- """
46
- convert onnx object to another onnx object with int8 quantization
47
- """
48
- quantize_dynamic("sent_clf_onnx/sentiment_classifier_onnx.onnx","sent_clf_onnx/sentiment_classifier_onnx_int8.onnx",
49
  weight_type=QuantType.QUInt8)
50
  else:
51
  pass
52
 
53
 
54
-
55
-
56
- # #create onnx & onnx_int_8 sessions
57
- # session = ort.InferenceSession("sent_clf_onnx/sentiment_classifier_onnx.onnx")
58
- # session_int8 = ort.InferenceSession("sent_clf_onnx/sentiment_classifier_onnx_int8.onnx")
59
-
60
- # options=ort.SessionOptions()
61
- # options.inter_op_num_threads=1
62
- # options.intra_op_num_threads=1
63
-
64
  def classify_sentiment_onnx(texts, _session, _tokenizer):
65
  """
66
 
@@ -92,41 +91,3 @@ def classify_sentiment_onnx(texts, _session, _tokenizer):
92
  output = ['Positive' if i == 1 else 'Negative' for i in output]
93
  return output
94
 
95
- def classify_sentiment_onnx_quant(texts, _session, _tokenizer):
96
- """
97
- Args:
98
- texts: input texts from user
99
- _session: pass ONNX runtime session
100
- _tokenizer: Relevant Tokenizer e.g. AutoTokenizer.from_pretrained("same checkpoint as the model")
101
-
102
- Returns:
103
- list of Positve and Negative texts
104
-
105
- """
106
- try:
107
- texts=texts.split(',')
108
- except:
109
- pass
110
-
111
- _inputs = _tokenizer(texts, padding=True, truncation=True,
112
- return_tensors="np")
113
-
114
-
115
- input_feed={
116
- "input_ids":np.array(_inputs['input_ids']),
117
- "attention_mask":np.array((_inputs['attention_mask']))
118
- }
119
-
120
- output = _session.run(input_feed=input_feed, output_names=['output_0'])[0]
121
-
122
- output=np.argmax(output,axis=1)
123
- output = ['Positive' if i == 1 else 'Negative' for i in output]
124
-
125
- return output
126
-
127
-
128
-
129
-
130
-
131
-
132
-
 
 
 
1
  import numpy as np
2
  import transformers
3
  from onnxruntime.quantization import quantize_dynamic,QuantType
5
  from pathlib import Path
6
  import os
7
 
8
+ import torch
9
+ from transformers import AutoModelForSequenceClassification,AutoTokenizer
10
+
11
+ chkpt='distilbert-base-uncased-finetuned-sst-2-english'
12
+ model=AutoModelForSequenceClassification.from_pretrained(chkpt)
13
+ tokenizer=AutoTokenizer.from_pretrained(chkpt)
14
+ # tokenizer=AutoTokenizer.from_pretrained('sentiment_classifier/')
15
+
16
+ def classify_sentiment(texts,model,tokenizer):
17
+ """
18
+ user will pass texts separated by comma
19
+ """
20
+ try:
21
+ texts=texts.split(',')
22
+ except:
23
+ pass
24
+
25
+ input = tokenizer(texts, padding=True, truncation=True,
26
+ return_tensors="pt")
27
+ logits = model(**input)['logits'].softmax(dim=1)
28
+ logits = torch.argmax(logits, dim=1)
29
+ output = ['Positive' if i == 1 else 'Negative' for i in logits]
30
+ return output
31
+
32
 
33
+ def create_onnx_model_sentiment(_model, _tokenizer):
34
  """
35
 
36
  Args:
41
  Creates a simple ONNX model & int8 Quantized Model in the directory "sent_clf_onnx/" if directory not present
42
 
43
  """
44
+ if not os.path.exists('sent_clf_onnx_dir'):
45
  try:
46
+ os.mkdir('sent_clf_onnx_dir')
47
  except:
48
  pass
 
 
 
49
  pipeline=transformers.pipeline("text-classification", model=_model, tokenizer=_tokenizer)
50
 
 
 
 
51
  onnx_convert.convert_pytorch(pipeline,
52
  opset=11,
53
+ output=Path("sent_clf_onnx_dir/sentiment_classifier_onnx.onnx"),
54
  use_external_format=False
55
  )
56
 
57
+ quantize_dynamic("sent_clf_onnx_dir/sentiment_classifier_onnx.onnx","sent_clf_onnx_dir/sentiment_classifier_onnx_quant.onnx",
 
 
 
58
  weight_type=QuantType.QUInt8)
59
  else:
60
  pass
61
 
62
 
 
 
 
 
 
 
 
 
 
 
63
  def classify_sentiment_onnx(texts, _session, _tokenizer):
64
  """
65
 
91
  output = ['Positive' if i == 1 else 'Negative' for i in output]
92
  return output
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sentiment_onnx.py DELETED
@@ -1,41 +0,0 @@
1
- from transformers import AutoTokenizer,AutoModelForSequenceClassification
2
- import transformers.convert_graph_to_onnx as onnx_convert
3
- from pathlib import Path
4
- import transformers
5
- from onnxruntime.quantization import quantize_dynamic,QuantType
6
- import onnx
7
- import onnxruntime as ort
8
-
9
- """
10
- type in cmd to create onnx model of hugging face chkpt
11
- python3 -m transformers.onnx --model= distilbert-base-uncased-finetuned-sst-2-english sentiment_onnx/
12
- """
13
- chkpt='distilbert-base-uncased-finetuned-sst-2-english'
14
- model= AutoModelForSequenceClassification.from_pretrained(chkpt)
15
- tokenizer= AutoTokenizer.from_pretrained(chkpt)
16
-
17
- """
18
- or download the model directly from hub --
19
- chkpt='distilbert-base-uncased-finetuned-sst-2-english'
20
- model= AutoModelForSequenceClassification.from_pretrained(chkpt)
21
- tokenizer= AutoTokenizer.from_pretrained(chkpt)
22
- """
23
-
24
-
25
- pipeline=transformers.pipeline("text-classification",model=model,tokenizer=tokenizer)
26
-
27
- """ convert pipeline to onnx object"""
28
- onnx_convert.convert_pytorch(pipeline,
29
- opset=11,
30
- output=Path("sent_clf_onnx/sentiment_classifier_onnx.onnx"),
31
- use_external_format=False
32
- )
33
-
34
- """ convert onnx object to another onnx object with int8 quantization """
35
- quantize_dynamic("sent_clf_onnx/sentiment_classifier_onnx.onnx","sent_clf_onnx/sentiment_classifier_onnx_int8.onnx",
36
- weight_type=QuantType.QUInt8)
37
-
38
- print(ort.__version__)
39
- print(onnx.__version__)
40
-
41
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer
2
+ from onnxruntime import InferenceSession
3
+ import numpy as np
4
+ import subprocess
5
+ import os
6
+
7
+
8
+ #create onnx model using
9
+ if not os.path.exists("zs_model_onnx"):
10
+ try:
11
+ subprocess.run(['python3','-m','transformers.onnx',
12
+ '--model=facebook/bart-large-mnli',
13
+ '--feature=sequence-classification',
14
+ 'zs_model_onnx/'])
15
+ except:
16
+ pass
17
+
18
+ #create session of saved onnx model
19
+ session = InferenceSession("zs_model_onnx/model.onnx")
20
+
21
+ #tokenizer for the chkpt
22
+ tokenizer=AutoTokenizer.from_pretrained('zs_model_dir')
23
+
24
+ # ONNX Runtime expects NumPy arrays as input
25
+ inputs = tokenizer("Using DistilBERT with ONNX Runtime!","you know how", return_tensors="np")
26
+ input_feed = {
27
+ "input_ids": np.array(inputs['input_ids']),
28
+ "attention_mask": np.array((inputs['attention_mask']))
29
+ }
30
+
31
+ #output
32
+ outputs = session.run(output_names=["logits"], input_feed=dict(input_feed))
33
+
34
+ print(outputs)
zeroshot_clf.py → zeroshot_clf_helper.py RENAMED
@@ -1,13 +1,10 @@
1
- import pandas as pd
2
- import streamlit
3
  import torch
4
- from transformers import AutoModelForSequenceClassification,AutoTokenizer
 
 
5
  import numpy as np
6
- import plotly.express as px
7
 
8
- # chkpt='valhalla/distilbart-mnli-12-1'
9
- # model=AutoModelForSequenceClassification.from_pretrained(chkpt)
10
- # tokenizer=AutoTokenizer.from_pretrained('zero_shot_clf/')
11
 
12
  def zero_shot_classification(premise: str, labels: str, model, tokenizer):
13
  try:
@@ -35,14 +32,66 @@ def zero_shot_classification(premise: str, labels: str, model, tokenizer):
35
 
36
  df=pd.DataFrame({'labels':labels,
37
  'Probability':labels_prob_norm})
38
- fig=px.bar(x='Probability',
39
- y='labels',
40
- text='Probability',
41
- data_frame=df,
42
- title='Zero Shot Normalized Probabilities')
43
- return fig
44
 
 
 
 
45
  # zero_shot_classification(premise='Tiny worms and breath analyzers could screen for \disease while it’s early and treatable',
46
  # labels='science, sports, museum')
47
 
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ from onnxruntime.quantization import quantize_dynamic,QuantType
3
+ import os
4
+ import subprocess
5
  import numpy as np
6
+ import pandas as pd
7
 
 
 
 
8
 
9
  def zero_shot_classification(premise: str, labels: str, model, tokenizer):
10
  try:
32
 
33
  df=pd.DataFrame({'labels':labels,
34
  'Probability':labels_prob_norm})
 
 
 
 
 
 
35
 
36
+ return df
37
+
38
+ ##example
39
  # zero_shot_classification(premise='Tiny worms and breath analyzers could screen for \disease while it’s early and treatable',
40
  # labels='science, sports, museum')
41
 
42
 
43
+ def create_onnx_model_zs(art_path='zeroshot_onnx_dir'):
44
+
45
+ # create onnx model using
46
+ if not os.path.exists(art_path):
47
+ try:
48
+ subprocess.run(['python3', '-m', 'transformers.onnx',
49
+ '--model=facebook/bart-large-mnli',
50
+ '--feature=sequence-classification',
51
+ art_path])
52
+ except:
53
+ pass
54
+
55
+ #create quanitzed model from vanila onnx
56
+ quantize_dynamic(f"{art_path}/model.onnx",f"{art_path}/model_quant.onnx",weight_type=QuantType.QUInt8)
57
+ else:
58
+ pass
59
+
60
+ def zero_shot_classification_onnx(premise,labels,_session,_tokenizer):
61
+ try:
62
+ labels=labels.split(',')
63
+ labels=[l.lower() for l in labels]
64
+ except:
65
+ raise Exception("please pass atleast 2 labels to classify")
66
+
67
+ premise=premise.lower()
68
+
69
+ labels_prob=[]
70
+
71
+ for l in labels:
72
+
73
+ hypothesis= f'this is an example of {l}'
74
+
75
+ inputs = _tokenizer(premise,hypothesis,
76
+ return_tensors='pt',
77
+ truncation_strategy='only_first')
78
+
79
+ input_feed = {
80
+ "input_ids": np.array(inputs['input_ids']),
81
+ "attention_mask": np.array((inputs['attention_mask']))
82
+ }
83
+
84
+ output = _session.run(output_names=["logits"],input_feed=dict(input_feed))[0] #returns logits as array
85
+ output=torch.from_numpy(output)
86
+ entail_contra_prob = output[:,[0,2]].softmax(dim=1)[:,1].item() #only normalizing entail & contradict probabilties
87
+ labels_prob.append(entail_contra_prob)
88
+
89
+ labels_prob_norm=[np.round(100*c/np.sum(labels_prob),1) for c in labels_prob]
90
+
91
+ df=pd.DataFrame({'labels':labels,
92
+ 'Probability':labels_prob_norm})
93
+
94
+ return df
95
+
96
+
97
+