Spaces:
Runtime error
Runtime error
| """Model Layer - ML model management and inference. | |
| This module handles the low-level ML operations including: | |
| - Model loading and storage via ModelRegistry | |
| - Inference execution via ModelPredictor | |
| Architecture: | |
| - ModelRegistry: Central storage for loaded models with lazy loading | |
| - ModelPredictor: Executes inference using registered models | |
| """ | |
| import logging | |
| import os | |
| from pathlib import Path | |
| import sys | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import dagshub | |
| import mlflow | |
| import numpy as np | |
| import torch | |
| from nygaardcodecommentclassification import config | |
| # Patch torch.load to use CPU mapping by default if CUDA is not available | |
| # This prevents "Attempting to deserialize object on a CUDA device" errors | |
| _original_torch_load = torch.load | |
| def _patched_torch_load(f, map_location=None, *args, **kwargs): | |
| """Wrapper around torch.load that uses CPU mapping if CUDA unavailable.""" | |
| if map_location is None and not torch.cuda.is_available(): | |
| map_location = torch.device('cpu') | |
| return _original_torch_load(f, map_location=map_location, *args, **kwargs) | |
| torch.load = _patched_torch_load | |
| # Configure module logger with explicit handler to ensure visibility | |
| logger = logging.getLogger("nygaard.models") | |
| logger.setLevel(logging.DEBUG) | |
| # Avoid duplicate handlers if module is reloaded | |
| if not logger.handlers: | |
| handler = logging.StreamHandler(sys.stdout) | |
| handler.setLevel(logging.DEBUG) | |
| formatter = logging.Formatter("[%(levelname)s] %(name)s: %(message)s") | |
| handler.setFormatter(formatter) | |
| logger.addHandler(handler) | |
| class ModelRegistry: | |
| """Central registry for ML models loaded in memory. | |
| This class manages the lifecycle of ML models, providing: | |
| - Automatic discovery and loading of models from the filesystem | |
| - Organized storage by language and model type | |
| - Memory management with explicit cleanup | |
| Attributes: | |
| _registry: Internal dictionary storing loaded models | |
| Example: | |
| ```python | |
| registry = ModelRegistry() | |
| registry.load_all_models(Path("./models")) | |
| # Access a loaded model | |
| model_entry = registry.get_model("python", "catboost") | |
| if model_entry: | |
| model = model_entry["model"] | |
| embedder = model_entry.get("embedder") | |
| ``` | |
| """ | |
| def __init__(self) -> None: | |
| """Initialize an empty model registry.""" | |
| self._registry: Dict[str, Dict[str, Any]] = {} | |
| def load_all_models(self) -> None: | |
| """Load all ML models from MLflow tracking server. | |
| This method connects to the MLflow tracking server (DagsHub) and loads | |
| CatBoost classifiers and sentence transformer embedders for all | |
| configured languages. | |
| Environment Variables: | |
| DAGSHUB_USER_TOKEN: Authentication token for DagsHub/MLflow | |
| Note: | |
| - Continues loading other models if one fails | |
| - Logs all loading activity for debugging | |
| """ | |
| logger.info("Starting to load all models from MLflow") | |
| # Initialize MLflow with DagsHub - uses DAGSHUB_USER_TOKEN env var for auth | |
| # Set DAGSHUB_USER_TOKEN in your environment to avoid interactive login | |
| dagshub_token = os.environ.get("DAGSHUB_USER_TOKEN") | |
| if dagshub_token: | |
| os.environ["MLFLOW_TRACKING_USERNAME"] = dagshub_token | |
| os.environ["MLFLOW_TRACKING_PASSWORD"] = dagshub_token | |
| logger.info("Using DAGSHUB_USER_TOKEN for authentication") | |
| else: | |
| logger.warning("DAGSHUB_USER_TOKEN not set - may require interactive login") | |
| dagshub.init(repo_owner="se4ai2526-uniba", repo_name="Nygaard", mlflow=True) | |
| mlflow.set_experiment("evaluating") | |
| # Load models for all configured languages directly from MLflow | |
| # No need for local directory structure | |
| for lang in config.LANGUAGES: | |
| logger.info("Loading models for language: %s", lang) | |
| if lang not in self._registry: | |
| self._registry[lang] = {} | |
| self._load_catboost_models(lang) | |
| logger.info("Finished loading all models from MLflow") | |
| def _load_catboost_models(self, lang: str) -> None: | |
| """Load CatBoost models for a specific language from MLflow. | |
| Downloads and loads the CatBoost classifier and sentence transformer | |
| embedder directly from MLflow tracking server. | |
| Args: | |
| lang: The programming language code (e.g., "python", "java") | |
| """ | |
| # Find the CatBoost run | |
| catboost_runs = mlflow.search_runs( | |
| experiment_names=["evaluating"], filter_string="tags.model = 'catboost'" | |
| ).sort_values(by="metrics.final_score", ascending=False) | |
| if catboost_runs.empty: | |
| logger.error("No CatBoost run found in 'evaluating' experiment") | |
| return | |
| catboost_run = catboost_runs.iloc[0] | |
| catboost_run_id = catboost_run.run_id | |
| catboost_run_name = catboost_run.get("tags.mlflow.runName", "unknown") | |
| catboost_git_commit = catboost_run.get("tags.mlflow.source.git.commit") | |
| logger.info( | |
| "Found CatBoost run: '%s' (ID: %s, commit: %s)", | |
| catboost_run_name, | |
| catboost_run_id, | |
| catboost_git_commit, | |
| ) | |
| # Find the embedder run with same git commit and source file | |
| embedder_run = None | |
| embedder_run_id = None | |
| embedder_run_name = None | |
| if catboost_git_commit: | |
| # Search for sentence transformer with same git commit | |
| logger.info( | |
| "[%s] Searching for embedder with git commit: %s", | |
| lang.upper(), | |
| catboost_git_commit, | |
| ) | |
| embedder_runs = mlflow.search_runs( | |
| experiment_names=["evaluating"], | |
| filter_string=f"tags.`mlflow.source.git.commit` = '{catboost_git_commit}' and run_name LIKE 'sentence_transformer%'", | |
| ) | |
| if not embedder_runs.empty: | |
| embedder_run = embedder_runs.iloc[0] | |
| embedder_run_id = embedder_run.run_id | |
| embedder_run_name = embedder_run.get("tags.mlflow.runName", "unknown") | |
| logger.info( | |
| "[%s] Found embedder with matching git commit: '%s' (ID: %s)", | |
| lang.upper(), | |
| embedder_run_name, | |
| embedder_run_id, | |
| ) | |
| # Fallback: search by default name if git commit search failed | |
| if not embedder_run_id: | |
| logger.info( | |
| "[%s] Falling back to default embedder search", | |
| lang.upper(), | |
| ) | |
| embedder_runs = mlflow.search_runs( | |
| experiment_names=["evaluating"], | |
| filter_string="run_name = 'sentence_transformer_paraphrase-MiniLM-L6-v2'", | |
| ) | |
| if embedder_runs.empty: | |
| logger.error( | |
| "No embedder run found for 'sentence_transformer_paraphrase-MiniLM-L6-v2'" | |
| ) | |
| return | |
| embedder_run = embedder_runs.iloc[0] | |
| embedder_run_id = embedder_run.run_id | |
| embedder_run_name = embedder_run.get("tags.mlflow.runName", "unknown") | |
| logger.info( | |
| "Found Embedder run: '%s' (ID: %s)", | |
| embedder_run_name, | |
| embedder_run_id, | |
| ) | |
| try: | |
| # Load the CatBoost model from MLflow | |
| model_uri = f"runs:/{catboost_run_id}/model_{lang}" | |
| logger.info( | |
| "[%s] Loading CatBoost classifier from run '%s' (ID: %s)...", | |
| lang.upper(), | |
| catboost_run_name, | |
| catboost_run_id, | |
| ) | |
| model = mlflow.sklearn.load_model(model_uri) | |
| # Load the sentence transformer embedder from MLflow | |
| embedder_uri = f"runs:/{embedder_run_id}/model_{lang}" | |
| logger.info( | |
| "[%s] Loading sentence transformer from run '%s' (ID: %s)...", | |
| lang.upper(), | |
| embedder_run_name, | |
| embedder_run_id, | |
| ) | |
| embedder = mlflow.sklearn.load_model(embedder_uri) | |
| # Register the model with its metadata | |
| self._registry[lang]["catboost"] = { | |
| "model": model, | |
| "feature_type": "embeddings", | |
| "embedder": embedder, | |
| } | |
| logger.info( | |
| "[%s] ✓ Ready: CatBoost + %s embeddings", | |
| lang.upper(), | |
| embedder_run_name.replace("sentence_transformer_", ""), | |
| ) | |
| except Exception as e: | |
| logger.error("[%s] Error loading models: %s", lang.upper(), e) | |
| def get_model(self, language: str, model_type: str) -> Optional[Dict[str, Any]]: | |
| """Retrieve a loaded model entry by language and type. | |
| Args: | |
| language: The programming language code | |
| model_type: The type of model | |
| Returns: | |
| Dict containing the model and metadata, or None if not found. | |
| The dict contains: | |
| - "model": The loaded ML model object | |
| - "feature_type": Type of features used | |
| - "embedder": Optional sentence transformer for embedding generation | |
| """ | |
| return self._registry.get(language, {}).get(model_type) | |
| def clear(self) -> None: | |
| """Clear all models from the registry and free memory. | |
| This method should be called during application shutdown to | |
| release GPU memory and other resources. | |
| """ | |
| self._registry.clear() | |
| # Clear CUDA cache if GPU was used | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.info("CUDA cache cleared") | |
| class ModelPredictor: | |
| """Handles low-level prediction logic. | |
| Attributes: | |
| registry: Reference to the ModelRegistry for model access | |
| Example: | |
| ```python | |
| registry = ModelRegistry() | |
| registry.load_all_models(Path("./models")) | |
| predictor = ModelPredictor(registry) | |
| predictions = predictor.predict( | |
| texts=["# Calculate sum of list"], | |
| language="python", | |
| model_type="catboost" | |
| ) | |
| # predictions: np.ndarray with shape (1, num_labels) | |
| ``` | |
| """ | |
| def __init__(self, model_registry: ModelRegistry) -> None: | |
| """Initialize the predictor with a model registry. | |
| Args: | |
| model_registry: The ModelRegistry instance containing loaded models | |
| """ | |
| self.registry = model_registry | |
| def predict( | |
| self, texts: List[str], language: str, model_type: str | |
| ) -> Tuple[np.ndarray, Optional[np.ndarray]]: | |
| """Execute prediction on a list of texts. | |
| This method handles the full inference pipeline: | |
| 1. Retrieve the appropriate model from the registry | |
| 2. Extract features (e.g., generate embeddings) | |
| 3. Run model inference | |
| 4. Return raw predictions | |
| Args: | |
| texts: List of code comment strings to classify | |
| language: Programming language context for model selection | |
| model_type: Type of model to use | |
| Returns: | |
| Tuple containing: | |
| - numpy array of predictions with shape (n_samples, n_labels). | |
| - numpy array of embeddings (if available, else None). | |
| Raises: | |
| ValueError: If the requested model is not available or | |
| if an unsupported feature/model type is specified | |
| """ | |
| # Retrieve model entry from registry | |
| model_entry = self.registry.get_model(language, model_type) | |
| if not model_entry or "model" not in model_entry: | |
| raise ValueError(f"Model {model_type} not available for {language}") | |
| model = model_entry["model"] | |
| # Handle CatBoost models | |
| if model_type == "catboost": | |
| if model_entry.get("feature_type") == "embeddings": | |
| # Generate embeddings using the SetFit sentence transformer | |
| embedder = model_entry.get("embedder") | |
| if embedder is None: | |
| raise ValueError(f"Embedder not loaded for {language}") | |
| # Encode texts to dense embeddings (no progress bar for API use) | |
| embeddings = embedder.encode(texts, show_progress_bar=False) | |
| # Run CatBoost prediction on embeddings | |
| return model.predict(embeddings), embeddings | |
| raise ValueError("Unsupported feature type for CatBoost") | |
| raise ValueError(f"Unknown model type: {model_type}") | |