| | """Controller Layer - Business logic for prediction operations. |
| | |
| | This module implements the business logic layer following the MVC pattern. |
| | It acts as an intermediary between the API endpoints (views) and the |
| | ML models (models layer), handling: |
| | - Model lifecycle management (loading/unloading) |
| | - Request validation and preprocessing |
| | - Response formatting and label mapping |
| | - Error handling and logging |
| | |
| | The controller is designed to be thread-safe for concurrent access. |
| | """ |
| |
|
| | import logging |
| | from typing import Any, Dict, List |
| |
|
| | import numpy as np |
| |
|
| | from nygaardcodecommentclassification import config |
| | from nygaardcodecommentclassification.api.models import ModelPredictor, ModelRegistry |
| |
|
| | |
| | logger = logging.getLogger("controllers") |
| |
|
| |
|
| | class PredictionController: |
| | """Manages prediction logic, model lifecycle, and response formatting. |
| | |
| | This controller orchestrates the ML prediction pipeline, including: |
| | - Loading and managing ML models via ModelRegistry |
| | - Validating prediction requests against supported languages/models |
| | - Executing predictions through ModelPredictor |
| | - Mapping numeric predictions to human-readable labels |
| | |
| | Attributes: |
| | registry: ModelRegistry instance for model storage |
| | predictor: ModelPredictor instance for inference |
| | |
| | Example: |
| | ```python |
| | controller = PredictionController() |
| | controller.startup() # Load models from MLflow |
| | |
| | results = controller.predict( |
| | texts=["# Calculate sum"], |
| | class_names=["Utils"], |
| | language="python", |
| | model_type="catboost" |
| | ) |
| | # results: [{"text": "# Calculate sum", "class_name": "Utils", "labels": ["summary"]}] |
| | |
| | controller.shutdown() # Release resources |
| | ``` |
| | """ |
| |
|
| | def __init__(self) -> None: |
| | """Initialize the prediction controller.""" |
| | self.registry = ModelRegistry() |
| | self.predictor = ModelPredictor(self.registry) |
| |
|
| | def startup(self) -> None: |
| | """Load all ML models into memory from MLflow. |
| | |
| | This method should be called during application startup. |
| | It connects to the MLflow tracking server and loads all available |
| | models into the registry for fast inference. |
| | |
| | Note: |
| | This operation may take several seconds depending on |
| | the number and size of models. |
| | """ |
| | logger.info("Loading models from MLflow...") |
| | self.registry.load_all_models() |
| | logger.info("Models loaded successfully") |
| |
|
| | def shutdown(self) -> None: |
| | """Release all model resources. |
| | |
| | Clears the model registry and frees GPU memory if applicable. |
| | This should be called during application shutdown. |
| | """ |
| | self.registry.clear() |
| | logger.info("Models cleared and resources released") |
| |
|
| | def get_models_info(self) -> Dict[str, List[str]]: |
| | """Return available models grouped by programming language. |
| | |
| | Returns: |
| | Dict mapping language codes to lists of available model types. |
| | Example: {"java": ["catboost"], "python": ["catboost"], "pharo": ["catboost"]} |
| | """ |
| | info: Dict[str, List[str]] = {} |
| | for lang in config.LANGUAGES: |
| | |
| | info[lang] = ["catboost"] |
| | return info |
| |
|
| | def predict( |
| | self, texts: List[str], class_names: List[str], language: str, model_type: str |
| | ) -> List[Dict[str, Any]]: |
| | """Execute multi-label classification on code comments. |
| | |
| | This method validates the request, runs ML inference, and formats |
| | the results with human-readable labels. |
| | |
| | Args: |
| | texts: List of code comment strings |
| | class_names: List of class names corresponding to each comment |
| | language: Programming language context ("java", "python", "pharo") |
| | model_type: Type of model to use ("catboost") |
| | |
| | Returns: |
| | List of dicts with classification results. Each dict contains: |
| | - "text": The original input text |
| | - "class_name": The class name corresponding to the input text |
| | - "labels": List of predicted category labels (strings) |
| | |
| | Raises: |
| | ValueError: If language is not supported or model type unavailable |
| | RuntimeError: If prediction fails or labels configuration is missing |
| | |
| | Example: |
| | ```python |
| | results = controller.predict( |
| | texts=["This calculates fibonacci", "TODO: optimize"], |
| | class_names=["MathUtils", "Calculator"], |
| | language="python", |
| | model_type="catboost" |
| | ) |
| | # Returns: |
| | # [ |
| | # {"text": "This calculates fibonacci", "class_name": "MathUtils", "labels": ["summary"]}, |
| | # {"text": "TODO: optimize", "class_name": "Calculator", "labels": ["expand"]} |
| | # ] |
| | ``` |
| | """ |
| | |
| | if language not in config.LANGUAGES: |
| | raise ValueError(f"Language '{language}' not supported. Available: {config.LANGUAGES}") |
| |
|
| | if len(texts) != len(class_names): |
| | raise ValueError(f"Mismatch: {len(texts)} texts but {len(class_names)} class names") |
| |
|
| | available_types = ["catboost"] |
| | if model_type not in available_types: |
| | raise ValueError( |
| | f"Model '{model_type}' unavailable for {language}. Available: {available_types}" |
| | ) |
| |
|
| | combined_texts = [f"{text} | {class_name}" for text, class_name in zip(texts, class_names)] |
| |
|
| | |
| | try: |
| | y_pred, embeddings = self.predictor.predict(combined_texts, language, model_type) |
| | except Exception as e: |
| | logger.error("Prediction failed for %s/%s: %s", language, model_type, e) |
| | raise RuntimeError(f"Internal model error: {e}") from e |
| |
|
| | |
| | |
| | try: |
| | labels_map = config.LABELS[language] |
| | except KeyError as e: |
| | raise RuntimeError(f"Configuration error: Labels map missing for {language}") from e |
| |
|
| | |
| | results: List[Dict[str, Any]] = [] |
| | for i, text_input in enumerate(texts): |
| | row_pred = y_pred[i] |
| |
|
| | |
| | predicted_indices = np.where(row_pred == 1)[0] |
| |
|
| | |
| | predicted_labels = [labels_map[idx] for idx in predicted_indices] |
| |
|
| | results.append({"text": text_input, "labels": predicted_labels}) |
| |
|
| | return results |
| |
|