import logging import mlflow from datetime import datetime from typing import Dict, List, Optional, Literal from mlflow.tracking import MlflowClient # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def _format_timestamp(ts: int) -> str: """Convert MLflow timestamp (milliseconds since epoch) to readable string.""" dt = datetime.fromtimestamp(ts / 1000.0) return dt.strftime("%Y-%m-%d %H:%M:%S") def set_tracking_uri(uri: str) -> Dict: """Set MLflow tracking URI and verify connection.""" if not uri: return {"error": True, "message": "URI cannot be empty"} try: logger.info(f"Setting MLflow tracking URI to {uri}") mlflow.set_tracking_uri(uri) return get_system_info() except Exception as e: return {"error": True, "message": f"Failed to set URI: {str(e)}"} def get_system_info() -> Dict: """Get MLflow system information.""" try: client = MlflowClient() return { "mlflow_version": mlflow.__version__, "tracking_uri": mlflow.get_tracking_uri(), "registry_uri": mlflow.get_registry_uri(), "artifact_uri": mlflow.get_artifact_uri(), "python_version": mlflow.__version__, "server_time": _format_timestamp(int(datetime.now().timestamp() * 1000)), "experiment_count": len(mlflow.search_experiments()), "model_count": len(client.search_registered_models()) } except Exception as e: return {"error": True, "message": f"Failed to fetch system info: {str(e)}"} def list_experiments(name_contains: Optional[str] = "", max_results: Optional[int] = 100) -> Dict: """List all experiments in the MLflow tracking server, with optional filtering. Includes run count for each experiment.""" """ Args: name_contains: Optional filter to only include experiments whose names contain this string (case-insensitive). max_results: Maximum number of results to return (default: 100). None means no limit after filtering. A negative value will result in an empty list. Returns: A dictionary containing the total count of returned experiments and a list of their details. Format: {"total_experiments": count, "experiments": [exp_details, ...]} Returns {"error": True, "message": ...} on failure. """ logger.info(f"Fetching experiments (filter: '{name_contains}', max_results: {max_results})") try: client = MlflowClient() all_mlflow_experiments: List[mlflow.entities.Experiment] = client.search_experiments() filtered_experiments: List[mlflow.entities.Experiment] processed_name_filter = name_contains.strip().lower() if name_contains else "" if processed_name_filter: filtered_experiments = [ exp for exp in all_mlflow_experiments if processed_name_filter in exp.name.lower() ] else: filtered_experiments = all_mlflow_experiments # Apply max_results limit limited_experiments: List[mlflow.entities.Experiment] if max_results is not None: if max_results < 0: limited_experiments = [] else: limited_experiments = filtered_experiments[:max_results] else: # max_results is None, return all filtered experiments limited_experiments = filtered_experiments experiments_info = [] for exp in limited_experiments: creation_time_str = None if hasattr(exp, "creation_time") and exp.creation_time is not None: creation_time_str = _format_timestamp(exp.creation_time) tags_dict = {} if hasattr(exp, "tags") and exp.tags: tags_dict = dict(exp.tags) # exp.tags is already a dict {key: value} exp_detail = { "experiment_id": exp.experiment_id, "name": exp.name, "artifact_location": exp.artifact_location, "lifecycle_stage": exp.lifecycle_stage, "creation_time": creation_time_str, "tags": tags_dict } run_count_val: any # Can be int or str try: # Check if any runs exist for this experiment (counts active and deleted) probe_runs = client.search_runs( experiment_ids=[exp.experiment_id], max_results=1, run_view_type=mlflow.entities.ViewType.ALL ) if probe_runs: # If runs exist, get a more accurate count up to a practical limit all_runs_for_count = client.search_runs( experiment_ids=[exp.experiment_id], max_results=50000, # Practical limit for counting run_view_type=mlflow.entities.ViewType.ALL ) run_count_val = len(all_runs_for_count) else: run_count_val = 0 except Exception as e_runs: logger.warning(f"Error getting run count for experiment '{exp.name}' (ID: {exp.experiment_id}): {str(e_runs)}") run_count_val = "Error getting count" exp_detail["run_count"] = run_count_val experiments_info.append(exp_detail) result = { "total_experiments": len(experiments_info), "experiments": experiments_info } return result except Exception as e: error_msg = f"Error listing experiments: {str(e)}" logger.error(error_msg, exc_info=True) return {"error": True, "message": error_msg} def create_experiment(name: str, tags: Optional[Dict[str, str]] = None) -> Dict: """Create a new experiment. Given the name and tags""" if not name: return {"error": True, "message": "Experiment name cannot be empty"} try: experiment_id = mlflow.create_experiment(name=name, tags=tags or {}) return { "experiment_id": experiment_id, "message": "Created experiment" } except Exception as e: return {"error": True, "message": f"Failed to create experiment: {str(e)}"} def search_runs( experiment_id: str, filter_string: str, order_string: Optional[str] = None, max_results: int = 100 ) -> Dict: """Search runs in a given experiment, with filtering and ordering.""" """ Args: experiment_id: The ID of the experiment to search runs in. filter_string: A filter query string used to search for runs. It follows the MLflow search filter syntax. Examples: - "metrics.accuracy > 0.95" - "params.learning_rate = '0.001'" - "tags.environment = 'production'" - "attributes.status = 'FINISHED'" - "metrics.loss < 0.2 AND params.optimizer = 'Adam'" If an empty string is provided, no filtering is applied by this string. Multiple conditions can be combined using 'AND' or 'OR'. order_string: An optional string to define the order of the results. It should be a single string composed of a metric, parameter, or attribute followed by 'ASC' (ascending) or 'DESC' (descending). Examples: - "metrics.validation_loss ASC" - "params.num_epochs DESC" - "attributes.start_time DESC" If None or an empty string, results are ordered by MLflow's default (usually start_time DESC). max_results: Maximum number of runs to return (default: 100). Returns: A dictionary containing a list of runs matching the criteria or an error message. Format: {"runs": [run_details, ...]} or {"error": True, "message": ...} """ # Validate experiment_id (must be non-empty) if not experiment_id: return {"error": True, "message": "Experiment ID cannot be empty"} # Validate max_results if max_results <= 0: return {"error": True, "message": "max_results must be a positive integer"} # Ensure filter_string is not None, default to empty if it is (for mlflow.search_runs) current_filter_string = filter_string if filter_string is not None else "" found_runs: List[mlflow.entities.Run] # Type hint for the list of Run objects try: logger.info(f"Searching runs in experiment '{experiment_id}' with filter '{current_filter_string}', order by '{order_string}', max_results '{max_results}'") order_by_list = [order_string] if order_string and order_string.strip() else None found_runs = mlflow.search_runs( experiment_ids=[str(experiment_id)], # Ensure experiment_id is a string filter_string=current_filter_string, max_results=max_results, order_by=order_by_list, output_format="list" # Get a list of Run objects instead of DataFrame ) except Exception as e_search: logger.error(f"MLflow search_runs API call failed for experiment_id '{experiment_id}': {str(e_search)}", exc_info=True) return {"error": True, "message": f"MLflow search_runs API call failed: {str(e_search)}"} processed_runs_info = [] if not found_runs: logger.info(f"No runs found for experiment_id '{experiment_id}' with the given criteria.") return {"runs": []} for run_obj in found_runs: run_id_for_log = run_obj.info.run_id if run_obj.info else "N/A" try: start_time_ms = run_obj.info.start_time end_time_ms = run_obj.info.end_time run_details = { "run_id": run_obj.info.run_id, "status": run_obj.info.status, "start_time": _format_timestamp(start_time_ms) if start_time_ms is not None else None, "end_time": _format_timestamp(end_time_ms) if end_time_ms is not None else None, "params": dict(run_obj.data.params), "metrics": dict(run_obj.data.metrics), "tags": dict(run_obj.data.tags) } processed_runs_info.append(run_details) except Exception as e_process_run: logger.warning( f"Failed to process data for run_id '{run_id_for_log}' in experiment '{experiment_id}'. Error: {str(e_process_run)}. Skipping this run.", exc_info=True ) continue # Skip to the next run return {"runs": processed_runs_info} def list_registered_models() -> Dict: """List all registered models.""" try: logger.info("Listing registered models") client = MlflowClient() models = client.search_registered_models() return { "models": [ { "name": model.name, "creation_timestamp": _format_timestamp(model.creation_timestamp), "last_updated_timestamp": _format_timestamp(model.last_updated_timestamp), "description": model.description or "", "tags": {tag.key: tag.value for tag in model.tags} if hasattr(model, "tags") else {}, "latest_versions": [mv.version for mv in model.latest_versions] } for model in models ] } except Exception as e: return {"error": True, "message": f"Failed to list registered models: {str(e)}"} def get_model_info(model_name: str) -> Dict: """Get detailed information about a registered model.""" if not model_name: return {"error": True, "message": "Model name cannot be empty"} try: logger.info(f"Fetching info for model '{model_name}'") client = MlflowClient() model = client.get_registered_model(name=model_name) model_info = { "name": model.name, "creation_timestamp": _format_timestamp(model.creation_timestamp), "last_updated_timestamp": _format_timestamp(model.last_updated_timestamp), "description": model.description or "", "tags": {tag.key: tag.value for tag in model.tags} if hasattr(model, "tags") else {}, "versions": [] } for mv in model.latest_versions: run_id = mv.run_id version_dict = { "version": mv.version, "current_stage": mv.current_stage, "creation_timestamp": _format_timestamp(mv.creation_timestamp), "last_updated_timestamp": _format_timestamp(mv.last_updated_timestamp), "run": {} } run = client.get_run(run_id) version_dict["run"] = { "status": run.info.status, "start_time": _format_timestamp(run.info.start_time), "end_time": _format_timestamp(run.info.end_time) if run.info.end_time else None, "metrics": run.data.metrics } model_info["versions"].append(version_dict) return {"model": model_info} except Exception as e: return {"error": True, "message": f"Failed to fetch model info: {str(e)}"} def register_model( run_id: str, model_name: str, description: Optional[str] = None, tags: Optional[Dict[str, str]] = None ) -> Dict: """ Register a model from a run, with optional description and tags. """ if not all([run_id, model_name]): return {"error": True, "message": "Run ID and model name must be non-empty"} # Prepare description and tags final_description = (description or "") + " Model registered by LLM through MCP service." final_tags = { "registered_by": "mcp-llm-service", "registration_timestamp": datetime.now().isoformat() } if tags: final_tags.update(tags) try: logger.info(f"Registering model '{model_name}' from run '{run_id}/model' with description and tags.") model_uri = f"runs:/{run_id}/model" result = mlflow.register_model( model_uri=model_uri, name=model_name, tags=final_tags ) client = MlflowClient() client.update_model_version( name=model_name, version=result.version, description=final_description ) return { "model_name": model_name, "version": result.version, "description": final_description, "tags": final_tags, "message": "Registered successfully" } except Exception as e: return {"error": True, "message": f"Registration failed: {str(e)}"}