"""Utils""" from __future__ import annotations import json from pathlib import Path from typing import Literal from loguru import logger def download_model( model_name: str, model_stage: Literal["staging", "production"], model_dir: str | Path = "model", ) -> Path: """Download model from mlflow""" import mlflow.artifacts import mlflow.models from mlflow.client import MlflowClient logger.info(f"Looking for model {model_name}/{model_stage}") if isinstance(model_dir, str): model_dir = Path(model_dir) client = MlflowClient() model_versions = client.get_latest_versions(model_name, stages=[model_stage]) if len(model_versions) != 1: raise ValueError(f"No model version for {model_name}/{model_stage}") artifact_uri = model_versions[0].source model_version = model_versions[0].version logger.info(f"Found version {model_version} for {model_name}/{model_stage}") model_path = model_dir / artifact_uri.split("/")[-1] # type: ignore if model_path.exists(): logger.info(f"Found model in {model_path}, skipping download") return model_path logger.info(f"Downloading artifacts {artifact_uri} to {model_dir}") model_path = mlflow.artifacts.download_artifacts(artifact_uri, dst_path=str(model_dir)) logger.info(f"Succesfully downloaded {model_name}") model_info = mlflow.models.get_model_info(model_path) metadata = model_info.metadata metadata_path = Path(model_path) / "metadata.json" logger.info(f"Saving metadata to {metadata_path}") with open(metadata_path, "w", encoding="utf-8") as file: json.dump(metadata, file) return Path(model_path)