|
from optimum.onnxruntime import ORTOptimizer, ORTQuantizer, ORTModelForSeq2SeqLM |
|
from optimum.onnxruntime.configuration import OptimizationConfig, AutoQuantizationConfig |
|
from transformers import AutoTokenizer, pipeline |
|
import sys |
|
import os |
|
|
|
save_dir = "." |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
|
|
model = ORTModelForSeq2SeqLM.from_pretrained(path, from_transformers=True) |
|
|
|
optimizer = ORTOptimizer.from_pretrained(model) |
|
|
|
optimization_config = OptimizationConfig( |
|
optimization_level=2, |
|
optimize_with_onnxruntime_only=False, |
|
optimize_for_gpu=False, |
|
) |
|
|
|
optimizer.optimize(save_dir=save_dir, optimization_config=optimization_config) |
|
|
|
|
|
encoder_quantizer = ORTQuantizer.from_pretrained(save_dir, file_name="encoder_model_optimized.onnx") |
|
|
|
|
|
decoder_quantizer = ORTQuantizer.from_pretrained(save_dir, file_name="decoder_model_optimized.onnx") |
|
|
|
|
|
decoder_wp_quantizer = ORTQuantizer.from_pretrained(save_dir, file_name="decoder_with_past_model_optimized.onnx") |
|
|
|
|
|
quantizer = [encoder_quantizer, decoder_quantizer, decoder_wp_quantizer] |
|
|
|
|
|
dqconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False) |
|
|
|
|
|
[q.quantize(save_dir=save_dir, quantization_config=dqconfig) for q in quantizer] |
|
|
|
|
|
[sys.stderr.write(x) for x in os.listdir(save_dir)] |
|
|
|
|
|
optimized_model = ORTModelForSeq2SeqLM.from_pretrained( |
|
save_dir, |
|
encoder_file_name="encoder_model_optimized_quantized.onnx", |
|
decoder_file_name="decoder_model_optimized_quantized.onnx", |
|
decoder_with_past_file_name="decoder_with_past_model_optimized_quantized.onnx", |
|
) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(path) |
|
|
|
|
|
self.pipeline = pipeline("summarization", model=optimized_model, tokenizer=tokenizer) |
|
|
|
def __call__(self, data): |
|
inputs = data.pop("inputs", data) |
|
parameters = data.pop("parameters", None) |
|
if parameters is not None: |
|
summary = self.pipeline(inputs, **parameters) |
|
else: |
|
summary = self.pipeline(inputs) |
|
return summary |
|
|