Spaces:
Sleeping
Sleeping
import logging | |
import mlflow | |
from datetime import datetime | |
from typing import Dict, List, Optional, Literal | |
from mlflow.tracking import MlflowClient | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def _format_timestamp(ts: int) -> str: | |
"""Convert MLflow timestamp (milliseconds since epoch) to readable string.""" | |
dt = datetime.fromtimestamp(ts / 1000.0) | |
return dt.strftime("%Y-%m-%d %H:%M:%S") | |
def set_tracking_uri(uri: str) -> Dict: | |
"""Set MLflow tracking URI and verify connection.""" | |
if not uri: | |
return {"error": True, "message": "URI cannot be empty"} | |
try: | |
logger.info(f"Setting MLflow tracking URI to {uri}") | |
mlflow.set_tracking_uri(uri) | |
return get_system_info() | |
except Exception as e: | |
return {"error": True, "message": f"Failed to set URI: {str(e)}"} | |
def get_system_info() -> Dict: | |
"""Get MLflow system information.""" | |
try: | |
client = MlflowClient() | |
return { | |
"mlflow_version": mlflow.__version__, | |
"tracking_uri": mlflow.get_tracking_uri(), | |
"registry_uri": mlflow.get_registry_uri(), | |
"artifact_uri": mlflow.get_artifact_uri(), | |
"python_version": mlflow.__version__, | |
"server_time": _format_timestamp(int(datetime.now().timestamp() * 1000)), | |
"experiment_count": len(mlflow.search_experiments()), | |
"model_count": len(client.search_registered_models()) | |
} | |
except Exception as e: | |
return {"error": True, "message": f"Failed to fetch system info: {str(e)}"} | |
def list_experiments(name_contains: Optional[str] = "", max_results: Optional[int] = 100) -> Dict: | |
"""List all experiments in the MLflow tracking server, with optional filtering. Includes run count for each experiment.""" | |
""" | |
Args: | |
name_contains: Optional filter to only include experiments whose names contain this string (case-insensitive). | |
max_results: Maximum number of results to return (default: 100). None means no limit after filtering. | |
A negative value will result in an empty list. | |
Returns: | |
A dictionary containing the total count of returned experiments and a list of their details. | |
Format: {"total_experiments": count, "experiments": [exp_details, ...]} | |
Returns {"error": True, "message": ...} on failure. | |
""" | |
logger.info(f"Fetching experiments (filter: '{name_contains}', max_results: {max_results})") | |
try: | |
client = MlflowClient() | |
all_mlflow_experiments: List[mlflow.entities.Experiment] = client.search_experiments() | |
filtered_experiments: List[mlflow.entities.Experiment] | |
processed_name_filter = name_contains.strip().lower() if name_contains else "" | |
if processed_name_filter: | |
filtered_experiments = [ | |
exp for exp in all_mlflow_experiments | |
if processed_name_filter in exp.name.lower() | |
] | |
else: | |
filtered_experiments = all_mlflow_experiments | |
# Apply max_results limit | |
limited_experiments: List[mlflow.entities.Experiment] | |
if max_results is not None: | |
if max_results < 0: | |
limited_experiments = [] | |
else: | |
limited_experiments = filtered_experiments[:max_results] | |
else: # max_results is None, return all filtered experiments | |
limited_experiments = filtered_experiments | |
experiments_info = [] | |
for exp in limited_experiments: | |
creation_time_str = None | |
if hasattr(exp, "creation_time") and exp.creation_time is not None: | |
creation_time_str = _format_timestamp(exp.creation_time) | |
tags_dict = {} | |
if hasattr(exp, "tags") and exp.tags: | |
tags_dict = dict(exp.tags) # exp.tags is already a dict {key: value} | |
exp_detail = { | |
"experiment_id": exp.experiment_id, | |
"name": exp.name, | |
"artifact_location": exp.artifact_location, | |
"lifecycle_stage": exp.lifecycle_stage, | |
"creation_time": creation_time_str, | |
"tags": tags_dict | |
} | |
run_count_val: any # Can be int or str | |
try: | |
# Check if any runs exist for this experiment (counts active and deleted) | |
probe_runs = client.search_runs( | |
experiment_ids=[exp.experiment_id], | |
max_results=1, | |
run_view_type=mlflow.entities.ViewType.ALL | |
) | |
if probe_runs: | |
# If runs exist, get a more accurate count up to a practical limit | |
all_runs_for_count = client.search_runs( | |
experiment_ids=[exp.experiment_id], | |
max_results=50000, # Practical limit for counting | |
run_view_type=mlflow.entities.ViewType.ALL | |
) | |
run_count_val = len(all_runs_for_count) | |
else: | |
run_count_val = 0 | |
except Exception as e_runs: | |
logger.warning(f"Error getting run count for experiment '{exp.name}' (ID: {exp.experiment_id}): {str(e_runs)}") | |
run_count_val = "Error getting count" | |
exp_detail["run_count"] = run_count_val | |
experiments_info.append(exp_detail) | |
result = { | |
"total_experiments": len(experiments_info), | |
"experiments": experiments_info | |
} | |
return result | |
except Exception as e: | |
error_msg = f"Error listing experiments: {str(e)}" | |
logger.error(error_msg, exc_info=True) | |
return {"error": True, "message": error_msg} | |
def create_experiment(name: str, tags: Optional[Dict[str, str]] = None) -> Dict: | |
"""Create a new experiment. Given the name and tags""" | |
if not name: | |
return {"error": True, "message": "Experiment name cannot be empty"} | |
try: | |
experiment_id = mlflow.create_experiment(name=name, tags=tags or {}) | |
return { | |
"experiment_id": experiment_id, | |
"message": "Created experiment" | |
} | |
except Exception as e: | |
return {"error": True, "message": f"Failed to create experiment: {str(e)}"} | |
def search_runs( | |
experiment_id: str, | |
filter_string: str, | |
order_string: Optional[str] = None, | |
max_results: int = 100 | |
) -> Dict: | |
"""Search runs in a given experiment, with filtering and ordering.""" | |
""" | |
Args: | |
experiment_id: The ID of the experiment to search runs in. | |
filter_string: A filter query string used to search for runs. | |
It follows the MLflow search filter syntax. | |
Examples: | |
- "metrics.accuracy > 0.95" | |
- "params.learning_rate = '0.001'" | |
- "tags.environment = 'production'" | |
- "attributes.status = 'FINISHED'" | |
- "metrics.loss < 0.2 AND params.optimizer = 'Adam'" | |
If an empty string is provided, no filtering is applied by this string. | |
Multiple conditions can be combined using 'AND' or 'OR'. | |
order_string: An optional string to define the order of the results. | |
It should be a single string composed of a metric, parameter, or attribute | |
followed by 'ASC' (ascending) or 'DESC' (descending). | |
Examples: | |
- "metrics.validation_loss ASC" | |
- "params.num_epochs DESC" | |
- "attributes.start_time DESC" | |
If None or an empty string, results are ordered by MLflow's default (usually start_time DESC). | |
max_results: Maximum number of runs to return (default: 100). | |
Returns: | |
A dictionary containing a list of runs matching the criteria or an error message. | |
Format: {"runs": [run_details, ...]} or {"error": True, "message": ...} | |
""" | |
# Validate experiment_id (must be non-empty) | |
if not experiment_id: | |
return {"error": True, "message": "Experiment ID cannot be empty"} | |
# Validate max_results | |
if max_results <= 0: | |
return {"error": True, "message": "max_results must be a positive integer"} | |
# Ensure filter_string is not None, default to empty if it is (for mlflow.search_runs) | |
current_filter_string = filter_string if filter_string is not None else "" | |
found_runs: List[mlflow.entities.Run] # Type hint for the list of Run objects | |
try: | |
logger.info(f"Searching runs in experiment '{experiment_id}' with filter '{current_filter_string}', order by '{order_string}', max_results '{max_results}'") | |
order_by_list = [order_string] if order_string and order_string.strip() else None | |
found_runs = mlflow.search_runs( | |
experiment_ids=[str(experiment_id)], # Ensure experiment_id is a string | |
filter_string=current_filter_string, | |
max_results=max_results, | |
order_by=order_by_list, | |
output_format="list" # Get a list of Run objects instead of DataFrame | |
) | |
except Exception as e_search: | |
logger.error(f"MLflow search_runs API call failed for experiment_id '{experiment_id}': {str(e_search)}", exc_info=True) | |
return {"error": True, "message": f"MLflow search_runs API call failed: {str(e_search)}"} | |
processed_runs_info = [] | |
if not found_runs: | |
logger.info(f"No runs found for experiment_id '{experiment_id}' with the given criteria.") | |
return {"runs": []} | |
for run_obj in found_runs: | |
run_id_for_log = run_obj.info.run_id if run_obj.info else "N/A" | |
try: | |
start_time_ms = run_obj.info.start_time | |
end_time_ms = run_obj.info.end_time | |
run_details = { | |
"run_id": run_obj.info.run_id, | |
"status": run_obj.info.status, | |
"start_time": _format_timestamp(start_time_ms) if start_time_ms is not None else None, | |
"end_time": _format_timestamp(end_time_ms) if end_time_ms is not None else None, | |
"params": dict(run_obj.data.params), | |
"metrics": dict(run_obj.data.metrics), | |
"tags": dict(run_obj.data.tags) | |
} | |
processed_runs_info.append(run_details) | |
except Exception as e_process_run: | |
logger.warning( | |
f"Failed to process data for run_id '{run_id_for_log}' in experiment '{experiment_id}'. Error: {str(e_process_run)}. Skipping this run.", | |
exc_info=True | |
) | |
continue # Skip to the next run | |
return {"runs": processed_runs_info} | |
def list_registered_models() -> Dict: | |
"""List all registered models.""" | |
try: | |
logger.info("Listing registered models") | |
client = MlflowClient() | |
models = client.search_registered_models() | |
return { | |
"models": [ | |
{ | |
"name": model.name, | |
"creation_timestamp": _format_timestamp(model.creation_timestamp), | |
"last_updated_timestamp": _format_timestamp(model.last_updated_timestamp), | |
"description": model.description or "", | |
"tags": {tag.key: tag.value for tag in model.tags} if hasattr(model, "tags") else {}, | |
"latest_versions": [mv.version for mv in model.latest_versions] | |
} | |
for model in models | |
] | |
} | |
except Exception as e: | |
return {"error": True, "message": f"Failed to list registered models: {str(e)}"} | |
def get_model_info(model_name: str) -> Dict: | |
"""Get detailed information about a registered model.""" | |
if not model_name: | |
return {"error": True, "message": "Model name cannot be empty"} | |
try: | |
logger.info(f"Fetching info for model '{model_name}'") | |
client = MlflowClient() | |
model = client.get_registered_model(name=model_name) | |
model_info = { | |
"name": model.name, | |
"creation_timestamp": _format_timestamp(model.creation_timestamp), | |
"last_updated_timestamp": _format_timestamp(model.last_updated_timestamp), | |
"description": model.description or "", | |
"tags": {tag.key: tag.value for tag in model.tags} if hasattr(model, "tags") else {}, | |
"versions": [] | |
} | |
for mv in model.latest_versions: | |
run_id = mv.run_id | |
version_dict = { | |
"version": mv.version, | |
"current_stage": mv.current_stage, | |
"creation_timestamp": _format_timestamp(mv.creation_timestamp), | |
"last_updated_timestamp": _format_timestamp(mv.last_updated_timestamp), | |
"run": {} | |
} | |
run = client.get_run(run_id) | |
version_dict["run"] = { | |
"status": run.info.status, | |
"start_time": _format_timestamp(run.info.start_time), | |
"end_time": _format_timestamp(run.info.end_time) if run.info.end_time else None, | |
"metrics": run.data.metrics | |
} | |
model_info["versions"].append(version_dict) | |
return {"model": model_info} | |
except Exception as e: | |
return {"error": True, "message": f"Failed to fetch model info: {str(e)}"} | |
def register_model( | |
run_id: str, | |
model_name: str, | |
description: Optional[str] = None, | |
tags: Optional[Dict[str, str]] = None | |
) -> Dict: | |
""" | |
Register a model from a run, with optional description and tags. | |
""" | |
if not all([run_id, model_name]): | |
return {"error": True, "message": "Run ID and model name must be non-empty"} | |
# Prepare description and tags | |
final_description = (description or "") + " Model registered by LLM through MCP service." | |
final_tags = { | |
"registered_by": "mcp-llm-service", | |
"registration_timestamp": datetime.now().isoformat() | |
} | |
if tags: | |
final_tags.update(tags) | |
try: | |
logger.info(f"Registering model '{model_name}' from run '{run_id}/model' with description and tags.") | |
model_uri = f"runs:/{run_id}/model" | |
result = mlflow.register_model( | |
model_uri=model_uri, | |
name=model_name, | |
tags=final_tags | |
) | |
client = MlflowClient() | |
client.update_model_version( | |
name=model_name, | |
version=result.version, | |
description=final_description | |
) | |
return { | |
"model_name": model_name, | |
"version": result.version, | |
"description": final_description, | |
"tags": final_tags, | |
"message": "Registered successfully" | |
} | |
except Exception as e: | |
return {"error": True, "message": f"Registration failed: {str(e)}"} |