pi-tagger / tagger /model.py
neggles's picture
init
2b6048b
raw
history blame
1.03 kB
from pathlib import Path
from typing import Optional
import onnxruntime as rt
from huggingface_hub import hf_hub_download
def download_onnx(
repo_id: str,
filename: str = "model.onnx",
revision: Optional[str] = None,
token: Optional[str] = None,
) -> Path:
if not filename.endswith(".onnx"):
filename += ".onnx"
model_path = hf_hub_download(repo_id=repo_id, filename=filename, revision=revision, token=token)
return Path(model_path).resolve()
def create_session(
repo_id: str,
revision: Optional[str] = None,
token: Optional[str] = None,
) -> rt.InferenceSession:
model_path = download_onnx(repo_id, revision=revision, token=token)
if not model_path.is_file():
model_path = model_path.joinpath("model.onnx")
if not model_path.is_file():
raise FileNotFoundError(f"Model not found: {model_path}")
model = rt.InferenceSession(
str(model_path),
providers=[("CUDAExecutionProvider", {}), "CPUExecutionProvider"],
)
return model