import yaml from typing import List, Dict, Any from loguru import logger from .base import RerankerModel from .cross_encoder import SentenceTransformersReranker, QwenReranker class ModelManager: """ Manager for reranking models with preloading and configuration. This class loads model configurations from a YAML file (default: config.yaml), instantiates and manages multiple reranker models, and provides methods to preload, retrieve, and list the available models. Supports a default model if model_id is not provided. Attributes: models (Dict[str, RerankerModel]): Dictionary of loaded model instances keyed by model ID. model_configs (Dict[str, Dict[str, Any]]): Model configuration loaded from YAML file. default_model_id (str): The default model ID to use if none is provided. """ def __init__(self, config_path: str = 'config.yaml'): """ Initialize the ModelManager and load model configurations from a YAML file. Args: config_path (str): Path to the YAML configuration file. Defaults to 'config.yaml'. Side Effects: Loads model configuration into self.model_configs. Initializes an empty dictionary for loaded models. Sets the default model ID from config. """ self.models: Dict[str, RerankerModel] = {} try: with open(config_path, 'r') as f: config_data = yaml.safe_load(f) self.model_configs = config_data.get('models', {}) self.default_model_id = config_data.get('default_model') logger.info(f"Loaded model configs from {config_path}") except Exception as e: logger.error(f"Failed to load config.yaml: {e}") self.model_configs = {} self.default_model_id = None async def preload_all_models(self): """ Preload all models defined in the configuration file. Iterates through all model configurations, instantiates the appropriate reranker class (SentenceTransformersReranker or QwenReranker), loads the model, and stores it in self.models. Logs the status of each model load and a summary at the end. Raises: Exception: If a model fails to load, logs the error and continues with the next model. """ logger.info(f"Starting preload of {len(self.model_configs)} reranking models...") for model_id, config in self.model_configs.items(): try: logger.info(f"Loading {model_id}...") if config["model_type"] == "sentence_transformers": model = SentenceTransformersReranker( model_id=model_id, model_name=config["model_name"], model_type=config["model_type"] ) elif config["model_type"] == "qwen": model = QwenReranker( model_id=model_id, model_name=config["model_name"], model_type=config["model_type"] ) else: logger.error(f"Unknown model type: {config['model_type']}") continue model.load() self.models[model_id] = model logger.success(f"Successfully preloaded {model_id}") except Exception as e: logger.error(f"Failed to preload {model_id}: {e}") loaded_count = len([m for m in self.models.values() if m.loaded]) logger.success(f"Preloaded {loaded_count}/{len(self.model_configs)} models successfully") def get_model(self, model_id: str = None) -> RerankerModel: """ Retrieve a loaded model instance by its ID, or use the default model if not specified. Args: model_id (str, optional): The unique identifier of the model to retrieve. If None, uses the default model. Returns: RerankerModel: The loaded reranker model instance. Raises: ValueError: If the model is not found or not loaded. """ if model_id is None: if not self.default_model_id: raise ValueError("No model_id provided and no default_model set in config.yaml") model_id = self.default_model_id if model_id not in self.models: raise ValueError(f"Model {model_id} not found") model = self.models[model_id] if not model.loaded: raise ValueError(f"Model {model_id} not loaded") return model def list_models(self) -> List[Dict[str, Any]]: """ List all available models with their configuration and load status. Returns: List[Dict[str, Any]]: A list of dictionaries, each containing model ID, name, type, description, and loaded status. """ models_info = [] for model_id, config in self.model_configs.items(): model = self.models.get(model_id) info = { "id": model_id, "name": config.get("model_name"), "type": config.get("model_type"), "language": config.get("languages"), "description": config.get("description"), "repository": config.get("repository"), "loaded": model.loaded if model else False } models_info.append(info) return models_info