Spaces:
Sleeping
Sleeping
File size: 15,224 Bytes
8528842 2a74e89 8528842 2a74e89 8528842 2a74e89 8528842 2a74e89 8528842 e586297 8528842 e586297 8528842 e586297 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 |
import logging
import mlflow
from datetime import datetime
from typing import Dict, List, Optional, Literal
from mlflow.tracking import MlflowClient
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def _format_timestamp(ts: int) -> str:
"""Convert MLflow timestamp (milliseconds since epoch) to readable string."""
dt = datetime.fromtimestamp(ts / 1000.0)
return dt.strftime("%Y-%m-%d %H:%M:%S")
def set_tracking_uri(uri: str) -> Dict:
"""Set MLflow tracking URI and verify connection."""
if not uri:
return {"error": True, "message": "URI cannot be empty"}
try:
logger.info(f"Setting MLflow tracking URI to {uri}")
mlflow.set_tracking_uri(uri)
return get_system_info()
except Exception as e:
return {"error": True, "message": f"Failed to set URI: {str(e)}"}
def get_system_info() -> Dict:
"""Get MLflow system information."""
try:
client = MlflowClient()
return {
"mlflow_version": mlflow.__version__,
"tracking_uri": mlflow.get_tracking_uri(),
"registry_uri": mlflow.get_registry_uri(),
"artifact_uri": mlflow.get_artifact_uri(),
"python_version": mlflow.__version__,
"server_time": _format_timestamp(int(datetime.now().timestamp() * 1000)),
"experiment_count": len(mlflow.search_experiments()),
"model_count": len(client.search_registered_models())
}
except Exception as e:
return {"error": True, "message": f"Failed to fetch system info: {str(e)}"}
def list_experiments(name_contains: Optional[str] = "", max_results: Optional[int] = 100) -> Dict:
"""List all experiments in the MLflow tracking server, with optional filtering. Includes run count for each experiment."""
"""
Args:
name_contains: Optional filter to only include experiments whose names contain this string (case-insensitive).
max_results: Maximum number of results to return (default: 100). None means no limit after filtering.
A negative value will result in an empty list.
Returns:
A dictionary containing the total count of returned experiments and a list of their details.
Format: {"total_experiments": count, "experiments": [exp_details, ...]}
Returns {"error": True, "message": ...} on failure.
"""
logger.info(f"Fetching experiments (filter: '{name_contains}', max_results: {max_results})")
try:
client = MlflowClient()
all_mlflow_experiments: List[mlflow.entities.Experiment] = client.search_experiments()
filtered_experiments: List[mlflow.entities.Experiment]
processed_name_filter = name_contains.strip().lower() if name_contains else ""
if processed_name_filter:
filtered_experiments = [
exp for exp in all_mlflow_experiments
if processed_name_filter in exp.name.lower()
]
else:
filtered_experiments = all_mlflow_experiments
# Apply max_results limit
limited_experiments: List[mlflow.entities.Experiment]
if max_results is not None:
if max_results < 0:
limited_experiments = []
else:
limited_experiments = filtered_experiments[:max_results]
else: # max_results is None, return all filtered experiments
limited_experiments = filtered_experiments
experiments_info = []
for exp in limited_experiments:
creation_time_str = None
if hasattr(exp, "creation_time") and exp.creation_time is not None:
creation_time_str = _format_timestamp(exp.creation_time)
tags_dict = {}
if hasattr(exp, "tags") and exp.tags:
tags_dict = dict(exp.tags) # exp.tags is already a dict {key: value}
exp_detail = {
"experiment_id": exp.experiment_id,
"name": exp.name,
"artifact_location": exp.artifact_location,
"lifecycle_stage": exp.lifecycle_stage,
"creation_time": creation_time_str,
"tags": tags_dict
}
run_count_val: any # Can be int or str
try:
# Check if any runs exist for this experiment (counts active and deleted)
probe_runs = client.search_runs(
experiment_ids=[exp.experiment_id],
max_results=1,
run_view_type=mlflow.entities.ViewType.ALL
)
if probe_runs:
# If runs exist, get a more accurate count up to a practical limit
all_runs_for_count = client.search_runs(
experiment_ids=[exp.experiment_id],
max_results=50000, # Practical limit for counting
run_view_type=mlflow.entities.ViewType.ALL
)
run_count_val = len(all_runs_for_count)
else:
run_count_val = 0
except Exception as e_runs:
logger.warning(f"Error getting run count for experiment '{exp.name}' (ID: {exp.experiment_id}): {str(e_runs)}")
run_count_val = "Error getting count"
exp_detail["run_count"] = run_count_val
experiments_info.append(exp_detail)
result = {
"total_experiments": len(experiments_info),
"experiments": experiments_info
}
return result
except Exception as e:
error_msg = f"Error listing experiments: {str(e)}"
logger.error(error_msg, exc_info=True)
return {"error": True, "message": error_msg}
def create_experiment(name: str, tags: Optional[Dict[str, str]] = None) -> Dict:
"""Create a new experiment. Given the name and tags"""
if not name:
return {"error": True, "message": "Experiment name cannot be empty"}
try:
experiment_id = mlflow.create_experiment(name=name, tags=tags or {})
return {
"experiment_id": experiment_id,
"message": "Created experiment"
}
except Exception as e:
return {"error": True, "message": f"Failed to create experiment: {str(e)}"}
def search_runs(
experiment_id: str,
filter_string: str,
order_string: Optional[str] = None,
max_results: int = 100
) -> Dict:
"""Search runs in a given experiment, with filtering and ordering."""
"""
Args:
experiment_id: The ID of the experiment to search runs in.
filter_string: A filter query string used to search for runs.
It follows the MLflow search filter syntax.
Examples:
- "metrics.accuracy > 0.95"
- "params.learning_rate = '0.001'"
- "tags.environment = 'production'"
- "attributes.status = 'FINISHED'"
- "metrics.loss < 0.2 AND params.optimizer = 'Adam'"
If an empty string is provided, no filtering is applied by this string.
Multiple conditions can be combined using 'AND' or 'OR'.
order_string: An optional string to define the order of the results.
It should be a single string composed of a metric, parameter, or attribute
followed by 'ASC' (ascending) or 'DESC' (descending).
Examples:
- "metrics.validation_loss ASC"
- "params.num_epochs DESC"
- "attributes.start_time DESC"
If None or an empty string, results are ordered by MLflow's default (usually start_time DESC).
max_results: Maximum number of runs to return (default: 100).
Returns:
A dictionary containing a list of runs matching the criteria or an error message.
Format: {"runs": [run_details, ...]} or {"error": True, "message": ...}
"""
# Validate experiment_id (must be non-empty)
if not experiment_id:
return {"error": True, "message": "Experiment ID cannot be empty"}
# Validate max_results
if max_results <= 0:
return {"error": True, "message": "max_results must be a positive integer"}
# Ensure filter_string is not None, default to empty if it is (for mlflow.search_runs)
current_filter_string = filter_string if filter_string is not None else ""
found_runs: List[mlflow.entities.Run] # Type hint for the list of Run objects
try:
logger.info(f"Searching runs in experiment '{experiment_id}' with filter '{current_filter_string}', order by '{order_string}', max_results '{max_results}'")
order_by_list = [order_string] if order_string and order_string.strip() else None
found_runs = mlflow.search_runs(
experiment_ids=[str(experiment_id)], # Ensure experiment_id is a string
filter_string=current_filter_string,
max_results=max_results,
order_by=order_by_list,
output_format="list" # Get a list of Run objects instead of DataFrame
)
except Exception as e_search:
logger.error(f"MLflow search_runs API call failed for experiment_id '{experiment_id}': {str(e_search)}", exc_info=True)
return {"error": True, "message": f"MLflow search_runs API call failed: {str(e_search)}"}
processed_runs_info = []
if not found_runs:
logger.info(f"No runs found for experiment_id '{experiment_id}' with the given criteria.")
return {"runs": []}
for run_obj in found_runs:
run_id_for_log = run_obj.info.run_id if run_obj.info else "N/A"
try:
start_time_ms = run_obj.info.start_time
end_time_ms = run_obj.info.end_time
run_details = {
"run_id": run_obj.info.run_id,
"status": run_obj.info.status,
"start_time": _format_timestamp(start_time_ms) if start_time_ms is not None else None,
"end_time": _format_timestamp(end_time_ms) if end_time_ms is not None else None,
"params": dict(run_obj.data.params),
"metrics": dict(run_obj.data.metrics),
"tags": dict(run_obj.data.tags)
}
processed_runs_info.append(run_details)
except Exception as e_process_run:
logger.warning(
f"Failed to process data for run_id '{run_id_for_log}' in experiment '{experiment_id}'. Error: {str(e_process_run)}. Skipping this run.",
exc_info=True
)
continue # Skip to the next run
return {"runs": processed_runs_info}
def list_registered_models() -> Dict:
"""List all registered models."""
try:
logger.info("Listing registered models")
client = MlflowClient()
models = client.search_registered_models()
return {
"models": [
{
"name": model.name,
"creation_timestamp": _format_timestamp(model.creation_timestamp),
"last_updated_timestamp": _format_timestamp(model.last_updated_timestamp),
"description": model.description or "",
"tags": {tag.key: tag.value for tag in model.tags} if hasattr(model, "tags") else {},
"latest_versions": [mv.version for mv in model.latest_versions]
}
for model in models
]
}
except Exception as e:
return {"error": True, "message": f"Failed to list registered models: {str(e)}"}
def get_model_info(model_name: str) -> Dict:
"""Get detailed information about a registered model."""
if not model_name:
return {"error": True, "message": "Model name cannot be empty"}
try:
logger.info(f"Fetching info for model '{model_name}'")
client = MlflowClient()
model = client.get_registered_model(name=model_name)
model_info = {
"name": model.name,
"creation_timestamp": _format_timestamp(model.creation_timestamp),
"last_updated_timestamp": _format_timestamp(model.last_updated_timestamp),
"description": model.description or "",
"tags": {tag.key: tag.value for tag in model.tags} if hasattr(model, "tags") else {},
"versions": []
}
for mv in model.latest_versions:
run_id = mv.run_id
version_dict = {
"version": mv.version,
"current_stage": mv.current_stage,
"creation_timestamp": _format_timestamp(mv.creation_timestamp),
"last_updated_timestamp": _format_timestamp(mv.last_updated_timestamp),
"run": {}
}
run = client.get_run(run_id)
version_dict["run"] = {
"status": run.info.status,
"start_time": _format_timestamp(run.info.start_time),
"end_time": _format_timestamp(run.info.end_time) if run.info.end_time else None,
"metrics": run.data.metrics
}
model_info["versions"].append(version_dict)
return {"model": model_info}
except Exception as e:
return {"error": True, "message": f"Failed to fetch model info: {str(e)}"}
def register_model(
run_id: str,
model_name: str,
description: Optional[str] = None,
tags: Optional[Dict[str, str]] = None
) -> Dict:
"""
Register a model from a run, with optional description and tags.
"""
if not all([run_id, model_name]):
return {"error": True, "message": "Run ID and model name must be non-empty"}
# Prepare description and tags
final_description = (description or "") + " Model registered by LLM through MCP service."
final_tags = {
"registered_by": "mcp-llm-service",
"registration_timestamp": datetime.now().isoformat()
}
if tags:
final_tags.update(tags)
try:
logger.info(f"Registering model '{model_name}' from run '{run_id}/model' with description and tags.")
model_uri = f"runs:/{run_id}/model"
result = mlflow.register_model(
model_uri=model_uri,
name=model_name,
tags=final_tags
)
client = MlflowClient()
client.update_model_version(
name=model_name,
version=result.version,
description=final_description
)
return {
"model_name": model_name,
"version": result.version,
"description": final_description,
"tags": final_tags,
"message": "Registered successfully"
}
except Exception as e:
return {"error": True, "message": f"Registration failed: {str(e)}"} |