ashishraics commited on
Commit
b47aba9
1 Parent(s): 884971a
Files changed (1) hide show
  1. app.py +10 -8
app.py CHANGED
@@ -101,16 +101,17 @@ sent_chkpt = "distilbert-base-uncased-finetuned-sst-2-english"
101
  sent_model_dir="sentiment_model_dir"
102
  #create model/token dir for sentiment classification
103
  create_model_dir(chkpt=sent_chkpt, model_dir=sent_model_dir)
104
- #create onnx model for sentiment classification
105
- model_sentiment=AutoModelForSequenceClassification.from_pretrained(sent_model_dir)
106
- tokenizer_sentiment=AutoTokenizer.from_pretrained(sent_model_dir)
107
- create_onnx_model_sentiment(_model=model_sentiment, _tokenizer=tokenizer_sentiment)
108
 
109
  @st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
110
  def sentiment_task_selected(task,sent_model_dir=sent_model_dir):
 
111
  model_sentiment=AutoModelForSequenceClassification.from_pretrained(sent_model_dir)
112
  tokenizer_sentiment=AutoTokenizer.from_pretrained(sent_model_dir)
113
- # create_onnx_model_sentiment(_model=model_sentiment, _tokenizer=tokenizer_sentiment)
 
 
 
114
  #create inference session
115
  sentiment_session = ort.InferenceSession("sent_clf_onnx_dir/sentiment_classifier_onnx.onnx")
116
  sentiment_session_quant = ort.InferenceSession("sent_clf_onnx_dir/sentiment_classifier_onnx_quant.onnx")
@@ -126,16 +127,17 @@ zs_chkpt = "valhalla/distilbart-mnli-12-1"
126
  zs_model_dir = "zs_model_dir"
127
  # create model/token dir for zeroshot clf
128
  create_model_dir(chkpt=zs_chkpt, model_dir=zs_model_dir)
129
- #ceate onnx model for zeroshot
130
- create_onnx_model_zs()
131
 
132
  @st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
133
  def zs_task_selected(task, zs_model_dir=zs_model_dir,onnx_dir='zeroshot_onnx_dir'):
134
-
135
  #model & tokenizer initialization for normal ZS classification
136
  model_zs=AutoModelForSequenceClassification.from_pretrained(zs_model_dir)
137
  tokenizer_zs=AutoTokenizer.from_pretrained(zs_model_dir)
138
 
 
 
 
139
  #create inference session from onnx model
140
  zs_session = ort.InferenceSession(f"{onnx_dir}/model.onnx")
141
  zs_session_quant = ort.InferenceSession(f"{onnx_dir}/model_quant.onnx")
 
101
  sent_model_dir="sentiment_model_dir"
102
  #create model/token dir for sentiment classification
103
  create_model_dir(chkpt=sent_chkpt, model_dir=sent_model_dir)
104
+
 
 
 
105
 
106
  @st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
107
  def sentiment_task_selected(task,sent_model_dir=sent_model_dir):
108
+ #model & tokenizer initialization for normal sentiment classification
109
  model_sentiment=AutoModelForSequenceClassification.from_pretrained(sent_model_dir)
110
  tokenizer_sentiment=AutoTokenizer.from_pretrained(sent_model_dir)
111
+
112
+ # create onnx model for sentiment classification
113
+ create_onnx_model_sentiment(_model=model_sentiment, _tokenizer=tokenizer_sentiment)
114
+
115
  #create inference session
116
  sentiment_session = ort.InferenceSession("sent_clf_onnx_dir/sentiment_classifier_onnx.onnx")
117
  sentiment_session_quant = ort.InferenceSession("sent_clf_onnx_dir/sentiment_classifier_onnx_quant.onnx")
 
127
  zs_model_dir = "zs_model_dir"
128
  # create model/token dir for zeroshot clf
129
  create_model_dir(chkpt=zs_chkpt, model_dir=zs_model_dir)
130
+
 
131
 
132
  @st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
133
  def zs_task_selected(task, zs_model_dir=zs_model_dir,onnx_dir='zeroshot_onnx_dir'):
 
134
  #model & tokenizer initialization for normal ZS classification
135
  model_zs=AutoModelForSequenceClassification.from_pretrained(zs_model_dir)
136
  tokenizer_zs=AutoTokenizer.from_pretrained(zs_model_dir)
137
 
138
+ # ceate onnx model for zeroshot
139
+ create_onnx_model_zs()
140
+
141
  #create inference session from onnx model
142
  zs_session = ort.InferenceSession(f"{onnx_dir}/model.onnx")
143
  zs_session_quant = ort.InferenceSession(f"{onnx_dir}/model_quant.onnx")