MLflow_mcp / mcp_mlflow_tools.py
tuntun's picture
update
9f3c9f2
import logging
import mlflow
from datetime import datetime
from typing import Dict, List, Optional, Literal
from mlflow.tracking import MlflowClient
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def _format_timestamp(ts: int) -> str:
"""Convert MLflow timestamp (milliseconds since epoch) to readable string."""
dt = datetime.fromtimestamp(ts / 1000.0)
return dt.strftime("%Y-%m-%d %H:%M:%S")
def set_tracking_uri(uri: str) -> Dict:
"""Set MLflow tracking URI and verify connection."""
if not uri:
return {"error": True, "message": "URI cannot be empty"}
try:
logger.info(f"Setting MLflow tracking URI to {uri}")
mlflow.set_tracking_uri(uri)
return get_system_info()
except Exception as e:
return {"error": True, "message": f"Failed to set URI: {str(e)}"}
def get_system_info() -> Dict:
"""Get MLflow system information."""
try:
client = MlflowClient()
return {
"mlflow_version": mlflow.__version__,
"tracking_uri": mlflow.get_tracking_uri(),
"registry_uri": mlflow.get_registry_uri(),
"artifact_uri": mlflow.get_artifact_uri(),
"python_version": mlflow.__version__,
"server_time": _format_timestamp(int(datetime.now().timestamp() * 1000)),
"experiment_count": len(mlflow.search_experiments()),
"model_count": len(client.search_registered_models())
}
except Exception as e:
return {"error": True, "message": f"Failed to fetch system info: {str(e)}"}
def list_experiments(name_contains: Optional[str] = "", max_results: Optional[int] = 100) -> Dict:
"""List all experiments in the MLflow tracking server, with optional filtering. Includes run count for each experiment."""
"""
Args:
name_contains: Optional filter to only include experiments whose names contain this string (case-insensitive).
max_results: Maximum number of results to return (default: 100). None means no limit after filtering.
A negative value will result in an empty list.
Returns:
A dictionary containing the total count of returned experiments and a list of their details.
Format: {"total_experiments": count, "experiments": [exp_details, ...]}
Returns {"error": True, "message": ...} on failure.
"""
logger.info(f"Fetching experiments (filter: '{name_contains}', max_results: {max_results})")
try:
client = MlflowClient()
all_mlflow_experiments: List[mlflow.entities.Experiment] = client.search_experiments()
filtered_experiments: List[mlflow.entities.Experiment]
processed_name_filter = name_contains.strip().lower() if name_contains else ""
if processed_name_filter:
filtered_experiments = [
exp for exp in all_mlflow_experiments
if processed_name_filter in exp.name.lower()
]
else:
filtered_experiments = all_mlflow_experiments
# Apply max_results limit
limited_experiments: List[mlflow.entities.Experiment]
if max_results is not None:
if max_results < 0:
limited_experiments = []
else:
limited_experiments = filtered_experiments[:max_results]
else: # max_results is None, return all filtered experiments
limited_experiments = filtered_experiments
experiments_info = []
for exp in limited_experiments:
creation_time_str = None
if hasattr(exp, "creation_time") and exp.creation_time is not None:
creation_time_str = _format_timestamp(exp.creation_time)
tags_dict = {}
if hasattr(exp, "tags") and exp.tags:
tags_dict = dict(exp.tags) # exp.tags is already a dict {key: value}
exp_detail = {
"experiment_id": exp.experiment_id,
"name": exp.name,
"artifact_location": exp.artifact_location,
"lifecycle_stage": exp.lifecycle_stage,
"creation_time": creation_time_str,
"tags": tags_dict
}
run_count_val: any # Can be int or str
try:
# Check if any runs exist for this experiment (counts active and deleted)
probe_runs = client.search_runs(
experiment_ids=[exp.experiment_id],
max_results=1,
run_view_type=mlflow.entities.ViewType.ALL
)
if probe_runs:
# If runs exist, get a more accurate count up to a practical limit
all_runs_for_count = client.search_runs(
experiment_ids=[exp.experiment_id],
max_results=50000, # Practical limit for counting
run_view_type=mlflow.entities.ViewType.ALL
)
run_count_val = len(all_runs_for_count)
else:
run_count_val = 0
except Exception as e_runs:
logger.warning(f"Error getting run count for experiment '{exp.name}' (ID: {exp.experiment_id}): {str(e_runs)}")
run_count_val = "Error getting count"
exp_detail["run_count"] = run_count_val
experiments_info.append(exp_detail)
result = {
"total_experiments": len(experiments_info),
"experiments": experiments_info
}
return result
except Exception as e:
error_msg = f"Error listing experiments: {str(e)}"
logger.error(error_msg, exc_info=True)
return {"error": True, "message": error_msg}
def create_experiment(name: str, tags: Optional[Dict[str, str]] = None) -> Dict:
"""Create a new experiment. Given the name and tags"""
if not name:
return {"error": True, "message": "Experiment name cannot be empty"}
try:
experiment_id = mlflow.create_experiment(name=name, tags=tags or {})
return {
"experiment_id": experiment_id,
"message": "Created experiment"
}
except Exception as e:
return {"error": True, "message": f"Failed to create experiment: {str(e)}"}
def search_runs(
experiment_id: str,
filter_string: str,
order_string: Optional[str] = None,
max_results: int = 100
) -> Dict:
"""Search runs in a given experiment, with filtering and ordering."""
"""
Args:
experiment_id: The ID of the experiment to search runs in.
filter_string: A filter query string used to search for runs.
It follows the MLflow search filter syntax.
Examples:
- "metrics.accuracy > 0.95"
- "params.learning_rate = '0.001'"
- "tags.environment = 'production'"
- "attributes.status = 'FINISHED'"
- "metrics.loss < 0.2 AND params.optimizer = 'Adam'"
If an empty string is provided, no filtering is applied by this string.
Multiple conditions can be combined using 'AND' or 'OR'.
order_string: An optional string to define the order of the results.
It should be a single string composed of a metric, parameter, or attribute
followed by 'ASC' (ascending) or 'DESC' (descending).
Examples:
- "metrics.validation_loss ASC"
- "params.num_epochs DESC"
- "attributes.start_time DESC"
If None or an empty string, results are ordered by MLflow's default (usually start_time DESC).
max_results: Maximum number of runs to return (default: 100).
Returns:
A dictionary containing a list of runs matching the criteria or an error message.
Format: {"runs": [run_details, ...]} or {"error": True, "message": ...}
"""
# Validate experiment_id (must be non-empty)
if not experiment_id:
return {"error": True, "message": "Experiment ID cannot be empty"}
# Validate max_results
if max_results <= 0:
return {"error": True, "message": "max_results must be a positive integer"}
# Ensure filter_string is not None, default to empty if it is (for mlflow.search_runs)
current_filter_string = filter_string if filter_string is not None else ""
found_runs: List[mlflow.entities.Run] # Type hint for the list of Run objects
try:
logger.info(f"Searching runs in experiment '{experiment_id}' with filter '{current_filter_string}', order by '{order_string}', max_results '{max_results}'")
order_by_list = [order_string] if order_string and order_string.strip() else None
found_runs = mlflow.search_runs(
experiment_ids=[str(experiment_id)], # Ensure experiment_id is a string
filter_string=current_filter_string,
max_results=max_results,
order_by=order_by_list,
output_format="list" # Get a list of Run objects instead of DataFrame
)
except Exception as e_search:
logger.error(f"MLflow search_runs API call failed for experiment_id '{experiment_id}': {str(e_search)}", exc_info=True)
return {"error": True, "message": f"MLflow search_runs API call failed: {str(e_search)}"}
processed_runs_info = []
if not found_runs:
logger.info(f"No runs found for experiment_id '{experiment_id}' with the given criteria.")
return {"runs": []}
for run_obj in found_runs:
run_id_for_log = run_obj.info.run_id if run_obj.info else "N/A"
try:
start_time_ms = run_obj.info.start_time
end_time_ms = run_obj.info.end_time
run_details = {
"run_id": run_obj.info.run_id,
"status": run_obj.info.status,
"start_time": _format_timestamp(start_time_ms) if start_time_ms is not None else None,
"end_time": _format_timestamp(end_time_ms) if end_time_ms is not None else None,
"params": dict(run_obj.data.params),
"metrics": dict(run_obj.data.metrics),
"tags": dict(run_obj.data.tags)
}
processed_runs_info.append(run_details)
except Exception as e_process_run:
logger.warning(
f"Failed to process data for run_id '{run_id_for_log}' in experiment '{experiment_id}'. Error: {str(e_process_run)}. Skipping this run.",
exc_info=True
)
continue # Skip to the next run
return {"runs": processed_runs_info}
def list_registered_models() -> Dict:
"""List all registered models."""
try:
logger.info("Listing registered models")
client = MlflowClient()
models = client.search_registered_models()
return {
"models": [
{
"name": model.name,
"creation_timestamp": _format_timestamp(model.creation_timestamp),
"last_updated_timestamp": _format_timestamp(model.last_updated_timestamp),
"description": model.description or "",
"tags": {tag.key: tag.value for tag in model.tags} if hasattr(model, "tags") else {},
"latest_versions": [mv.version for mv in model.latest_versions]
}
for model in models
]
}
except Exception as e:
return {"error": True, "message": f"Failed to list registered models: {str(e)}"}
def get_model_info(model_name: str) -> Dict:
"""Get detailed information about a registered model."""
if not model_name:
return {"error": True, "message": "Model name cannot be empty"}
try:
logger.info(f"Fetching info for model '{model_name}'")
client = MlflowClient()
model = client.get_registered_model(name=model_name)
model_info = {
"name": model.name,
"creation_timestamp": _format_timestamp(model.creation_timestamp),
"last_updated_timestamp": _format_timestamp(model.last_updated_timestamp),
"description": model.description or "",
"tags": {tag.key: tag.value for tag in model.tags} if hasattr(model, "tags") else {},
"versions": []
}
for mv in model.latest_versions:
run_id = mv.run_id
version_dict = {
"version": mv.version,
"current_stage": mv.current_stage,
"creation_timestamp": _format_timestamp(mv.creation_timestamp),
"last_updated_timestamp": _format_timestamp(mv.last_updated_timestamp),
"run": {}
}
run = client.get_run(run_id)
version_dict["run"] = {
"status": run.info.status,
"start_time": _format_timestamp(run.info.start_time),
"end_time": _format_timestamp(run.info.end_time) if run.info.end_time else None,
"metrics": run.data.metrics
}
model_info["versions"].append(version_dict)
return {"model": model_info}
except Exception as e:
return {"error": True, "message": f"Failed to fetch model info: {str(e)}"}
def register_model(
run_id: str,
model_name: str,
description: Optional[str] = None,
tags: Optional[Dict[str, str]] = None
) -> Dict:
"""
Register a model from a run, with optional description and tags.
"""
if not all([run_id, model_name]):
return {"error": True, "message": "Run ID and model name must be non-empty"}
# Prepare description and tags
final_description = (description or "") + " Model registered by LLM through MCP service."
final_tags = {
"registered_by": "mcp-llm-service",
"registration_timestamp": datetime.now().isoformat()
}
if tags:
final_tags.update(tags)
try:
logger.info(f"Registering model '{model_name}' from run '{run_id}/model' with description and tags.")
model_uri = f"runs:/{run_id}/model"
result = mlflow.register_model(
model_uri=model_uri,
name=model_name,
tags=final_tags
)
client = MlflowClient()
client.update_model_version(
name=model_name,
version=result.version,
description=final_description
)
return {
"model_name": model_name,
"version": result.version,
"description": final_description,
"tags": final_tags,
"message": "Registered successfully"
}
except Exception as e:
return {"error": True, "message": f"Registration failed: {str(e)}"}