# Copyright 2024 The HuggingFace Inc. team. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import shutil from pathlib import Path from typing import Optional, Union import numpy as np from huggingface_hub import hf_hub_download from huggingface_hub.utils import validate_hf_hub_args from ..utils import ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, is_onnx_available, logging if is_onnx_available(): import onnxruntime as ort logger = logging.get_logger(__name__) ORT_TO_NP_TYPE = { "tensor(bool)": np.bool_, "tensor(int8)": np.int8, "tensor(uint8)": np.uint8, "tensor(int16)": np.int16, "tensor(uint16)": np.uint16, "tensor(int32)": np.int32, "tensor(uint32)": np.uint32, "tensor(int64)": np.int64, "tensor(uint64)": np.uint64, "tensor(float16)": np.float16, "tensor(float)": np.float32, "tensor(double)": np.float64, } class OnnxRuntimeModel: def __init__(self, model=None, **kwargs): logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.") self.model = model self.model_save_dir = kwargs.get("model_save_dir", None) self.latest_model_name = kwargs.get("latest_model_name", ONNX_WEIGHTS_NAME) def __call__(self, **kwargs): inputs = {k: np.array(v) for k, v in kwargs.items()} return self.model.run(None, inputs) @staticmethod def load_model(path: Union[str, Path], provider=None, sess_options=None): """ Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider` Arguments: path (`str` or `Path`): Directory from which to load provider(`str`, *optional*): Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider` """ if provider is None: logger.info("No onnxruntime provider specified, using CPUExecutionProvider") provider = "CPUExecutionProvider" return ort.InferenceSession(path, providers=[provider], sess_options=sess_options) def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs): """ Save a model and its configuration file to a directory, so that it can be re-loaded using the [`~optimum.onnxruntime.modeling_ort.ORTModel.from_pretrained`] class method. It will always save the latest_model_name. Arguments: save_directory (`str` or `Path`): Directory where to save the model file. file_name(`str`, *optional*): Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to save the model with a different name. """ model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME src_path = self.model_save_dir.joinpath(self.latest_model_name) dst_path = Path(save_directory).joinpath(model_file_name) try: shutil.copyfile(src_path, dst_path) except shutil.SameFileError: pass # copy external weights (for models >2GB) src_path = self.model_save_dir.joinpath(ONNX_EXTERNAL_WEIGHTS_NAME) if src_path.exists(): dst_path = Path(save_directory).joinpath(ONNX_EXTERNAL_WEIGHTS_NAME) try: shutil.copyfile(src_path, dst_path) except shutil.SameFileError: pass def save_pretrained( self, save_directory: Union[str, os.PathLike], **kwargs, ): """ Save a model to a directory, so that it can be re-loaded using the [`~OnnxModel.from_pretrained`] class method.: Arguments: save_directory (`str` or `os.PathLike`): Directory to which to save. Will be created if it doesn't exist. """ if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return os.makedirs(save_directory, exist_ok=True) # saving model weights/files self._save_pretrained(save_directory, **kwargs) @classmethod @validate_hf_hub_args def _from_pretrained( cls, model_id: Union[str, Path], token: Optional[Union[bool, str, None]] = None, revision: Optional[Union[str, None]] = None, force_download: bool = False, cache_dir: Optional[str] = None, file_name: Optional[str] = None, provider: Optional[str] = None, sess_options: Optional["ort.SessionOptions"] = None, **kwargs, ): """ Load a model from a directory or the HF Hub. Arguments: model_id (`str` or `Path`): Directory from which to load token (`str` or `bool`): Is needed to load models from a private or gated repository revision (`str`): Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id cache_dir (`Union[str, Path]`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. file_name(`str`): Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to load different model files from the same repository or directory. provider(`str`): The ONNX runtime provider, e.g. `CPUExecutionProvider` or `CUDAExecutionProvider`. kwargs (`Dict`, *optional*): kwargs will be passed to the model during initialization """ model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME # load model from local directory if os.path.isdir(model_id): model = OnnxRuntimeModel.load_model( Path(model_id, model_file_name).as_posix(), provider=provider, sess_options=sess_options ) kwargs["model_save_dir"] = Path(model_id) # load model from hub else: # download model model_cache_path = hf_hub_download( repo_id=model_id, filename=model_file_name, token=token, revision=revision, cache_dir=cache_dir, force_download=force_download, ) kwargs["model_save_dir"] = Path(model_cache_path).parent kwargs["latest_model_name"] = Path(model_cache_path).name model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider, sess_options=sess_options) return cls(model=model, **kwargs) @classmethod @validate_hf_hub_args def from_pretrained( cls, model_id: Union[str, Path], force_download: bool = True, token: Optional[str] = None, cache_dir: Optional[str] = None, **model_kwargs, ): revision = None if len(str(model_id).split("@")) == 2: model_id, revision = model_id.split("@") return cls._from_pretrained( model_id=model_id, revision=revision, cache_dir=cache_dir, force_download=force_download, token=token, **model_kwargs, )