jeremiebasso's picture
initial commit
8fe5582
raw
history blame
2.68 kB
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
import onnxruntime as ort
from loguru import logger
from onnxruntime.transformers.io_binding_helper import TypeHelper
@dataclass
class ModelInfo:
base_model: str
@classmethod
def from_dir(cls, model_dir: Path):
with open(model_dir / "metadata.json", "r", encoding="utf-8") as file:
data = json.load(file)
return ModelInfo(base_model=data["bert_type"])
class ONNXModel:
def __init__(self, model: ort.InferenceSession, model_info: ModelInfo) -> None:
self.model = model
self.model_info = model_info
self.model_path = Path(model._model_path) # type: ignore
self.model_name = self.model_path.name
self.providers = model.get_providers()
if self.providers[0] in ["CUDAExecutionProvider", "TensorrtExecutionProvider"]:
self.device = "cuda"
else:
self.device = "cpu"
self.io_types = TypeHelper.get_io_numpy_type_map(model)
self.input_names = [el.name for el in model.get_inputs()]
self.output_name = model.get_outputs()[0].name
@staticmethod
def load_session(
path: str | Path,
provider: str = "CPUExecutionProvider",
session_options: ort.SessionOptions | None = None,
provider_options: dict[str, Any] | None = None,
) -> ort.InferenceSession:
providers = [provider]
if provider == "TensorrtExecutionProvider":
providers.append("CUDAExecutionProvider")
elif provider == "CUDAExecutionProvider":
providers.append("CPUExecutionProvider")
if not isinstance(path, str):
path = Path(path) / "model.onnx"
providers_options = None
if provider_options is not None:
providers_options = [provider_options] + [{} for _ in range(len(providers) - 1)]
session = ort.InferenceSession(
str(path),
providers=providers,
sess_options=session_options,
provider_options=providers_options,
)
logger.info("Session loaded")
return session
@classmethod
def from_dir(cls, model_dir: str | Path) -> ONNXModel:
return ONNXModel(ONNXModel.load_session(model_dir), ModelInfo.from_dir(model_dir))
def __call__(self, **model_inputs: np.ndarray):
model_inputs = {
input_name: tensor.astype(self.io_types[input_name]) for input_name, tensor in model_inputs.items()
}
return self.model.run([self.output_name], model_inputs)[0]