from neural_compressor.experimental import Quantization, common import functools import evaluate import onnxruntime from optimum.onnxruntime import ORTModelForFeatureExtraction from sklearn.linear_model import LogisticRegression from tqdm import tqdm from setfit.exporters.utils import mean_pooling accuracy = evaluate.load("accuracy") class OnnxSetFitModel: def __init__(self, ort_model, tokenizer, model_head): self.ort_model = ort_model self.tokenizer = tokenizer self.model_head = model_head def predict(self, inputs): encoded_inputs = self.tokenizer( inputs, padding=True, truncation=True, return_tensors="pt" ) outputs = self.ort_model(**encoded_inputs) embeddings = mean_pooling( outputs["last_hidden_state"], encoded_inputs["attention_mask"] ) return self.model_head.predict(embeddings) def __call__(self, inputs): return self.predict(inputs) class myquantizer: def __init__(self,onnx_path,model_head,tokenizer, test_dataset): self.onnx_path = onnx_path self.head = model_head self.tokenizer = tokenizer self.test_dataset = test_dataset def eval_func(self, model): print(self.onnx_path) ort_model = ORTModelForFeatureExtraction.from_pretrained(self.onnx_path) ort_model.model = onnxruntime.InferenceSession(model.SerializeToString(), None) onnx_setfit_model = OnnxSetFitModel(ort_model, self.tokenizer, self.head) preds = [] chunk_size = 100 for i in tqdm(range(0, len(self.test_dataset["text"]), chunk_size)): preds.extend( onnx_setfit_model.predict(self.test_dataset["text"][i : i + chunk_size]) ) labels = self.test_dataset["label"] accuracy_calc = accuracy.compute(predictions=preds, references=labels) return accuracy_calc["accuracy"] def build_dynamic_quant_yaml(self): yaml = """ model: name: bert framework: onnxrt_integerops device: cpu quantization: approach: post_training_dynamic_quant tuning: accuracy_criterion: relative: 0.01 exit_policy: timeout: 0 random_seed: 9527 """ with open("build.yaml", "w", encoding="utf-8") as f: f.write(yaml) def quantizer_model(self): self.build_dynamic_quant_yaml() onnx_output_path = "onnx/model_quantized.onnx" quantizer = Quantization("build.yaml") model_is_at = str(self.onnx_path / "model.onnx") quantizer.model = common.Model(model_is_at) quantizer.eval_func = functools.partial(self.eval_func) quantized_model = quantizer() quantized_model.save(onnx_output_path)