jeremiebasso's picture
initial commit
8fe5582
raw
history blame
No virus
1.7 kB
"""Utils"""
from __future__ import annotations
import json
from pathlib import Path
from typing import Literal
from loguru import logger
def download_model(
model_name: str,
model_stage: Literal["staging", "production"],
model_dir: str | Path = "model",
) -> Path:
"""Download model from mlflow"""
import mlflow.artifacts
import mlflow.models
from mlflow.client import MlflowClient
logger.info(f"Looking for model {model_name}/{model_stage}")
if isinstance(model_dir, str):
model_dir = Path(model_dir)
client = MlflowClient()
model_versions = client.get_latest_versions(model_name, stages=[model_stage])
if len(model_versions) != 1:
raise ValueError(f"No model version for {model_name}/{model_stage}")
artifact_uri = model_versions[0].source
model_version = model_versions[0].version
logger.info(f"Found version {model_version} for {model_name}/{model_stage}")
model_path = model_dir / artifact_uri.split("/")[-1] # type: ignore
if model_path.exists():
logger.info(f"Found model in {model_path}, skipping download")
return model_path
logger.info(f"Downloading artifacts {artifact_uri} to {model_dir}")
model_path = mlflow.artifacts.download_artifacts(artifact_uri, dst_path=str(model_dir))
logger.info(f"Succesfully downloaded {model_name}")
model_info = mlflow.models.get_model_info(model_path)
metadata = model_info.metadata
metadata_path = Path(model_path) / "metadata.json"
logger.info(f"Saving metadata to {metadata_path}")
with open(metadata_path, "w", encoding="utf-8") as file:
json.dump(metadata, file)
return Path(model_path)