chainyo commited on
Commit
1770058
1 Parent(s): d015acd

fix pipeline referencement

Browse files
Files changed (1) hide show
  1. main.py +4 -4
main.py CHANGED
@@ -19,7 +19,7 @@ from typing import Dict, List, Union
19
  from optimum.onnxruntime import ORTModelForSequenceClassification, ORTOptimizer, ORTQuantizer
20
  from optimum.onnxruntime.configuration import OptimizationConfig, AutoQuantizationConfig
21
  from optimum.pipelines import pipeline as ort_pipeline
22
- from transformers import BertTokenizer, BertForSequenceClassification, pipeline
23
 
24
  from utils import calculate_inference_time
25
 
@@ -105,7 +105,7 @@ def load_pipeline(pipeline_name: str) -> None:
105
  """
106
  if pipeline_name == "pt_pipeline":
107
  model = BertForSequenceClassification.from_pretrained(HUB_MODEL_PATH, num_labels=3)
108
- pipeline = pipeline("sentiment-analysis", tokenizer=st.session_state["tokenizer"], model=model)
109
  elif pipeline_name == "ort_pipeline":
110
  model = ORTModelForSequenceClassification.from_pretrained(HUB_MODEL_PATH, from_transformers=True)
111
  if not ONNX_MODEL_PATH.exists():
@@ -120,7 +120,7 @@ def load_pipeline(pipeline_name: str) -> None:
120
  model = ORTModelForSequenceClassification.from_pretrained(
121
  OPTIMIZED_BASE_PATH, file_name=OPTIMIZED_MODEL_PATH.name
122
  )
123
- pipeline = pipeline("text-classification", tokenizer=st.session_state["tokenizer"], model=model)
124
  elif pipeline_name == "ort_quantized_pipeline":
125
  if not QUANTIZED_MODEL_PATH.exists():
126
  quantization_config = AutoQuantizationConfig.arm64(is_static=False, per_channel=False)
@@ -130,7 +130,7 @@ def load_pipeline(pipeline_name: str) -> None:
130
  model = ORTModelForSequenceClassification.from_pretrained(
131
  QUANTIZED_BASE_PATH, file_name=QUANTIZED_MODEL_PATH.name
132
  )
133
- pipeline = pipeline("text-classification", tokenizer=st.session_state["tokenizer"], model=model)
134
  print(type(pipeline))
135
  return pipeline
136
 
 
19
  from optimum.onnxruntime import ORTModelForSequenceClassification, ORTOptimizer, ORTQuantizer
20
  from optimum.onnxruntime.configuration import OptimizationConfig, AutoQuantizationConfig
21
  from optimum.pipelines import pipeline as ort_pipeline
22
+ from transformers import BertTokenizer, BertForSequenceClassification, pt_pipeline
23
 
24
  from utils import calculate_inference_time
25
 
 
105
  """
106
  if pipeline_name == "pt_pipeline":
107
  model = BertForSequenceClassification.from_pretrained(HUB_MODEL_PATH, num_labels=3)
108
+ pipeline = pt_pipeline("sentiment-analysis", tokenizer=st.session_state["tokenizer"], model=model)
109
  elif pipeline_name == "ort_pipeline":
110
  model = ORTModelForSequenceClassification.from_pretrained(HUB_MODEL_PATH, from_transformers=True)
111
  if not ONNX_MODEL_PATH.exists():
 
120
  model = ORTModelForSequenceClassification.from_pretrained(
121
  OPTIMIZED_BASE_PATH, file_name=OPTIMIZED_MODEL_PATH.name
122
  )
123
+ pipeline = ort_pipeline("text-classification", tokenizer=st.session_state["tokenizer"], model=model)
124
  elif pipeline_name == "ort_quantized_pipeline":
125
  if not QUANTIZED_MODEL_PATH.exists():
126
  quantization_config = AutoQuantizationConfig.arm64(is_static=False, per_channel=False)
 
130
  model = ORTModelForSequenceClassification.from_pretrained(
131
  QUANTIZED_BASE_PATH, file_name=QUANTIZED_MODEL_PATH.name
132
  )
133
+ pipeline = ort_pipeline("text-classification", tokenizer=st.session_state["tokenizer"], model=model)
134
  print(type(pipeline))
135
  return pipeline
136