personabot-api / app /core /mlflow_tracker.py
GitHub Actions
Deploy 5a96418
bbe01fe
import json
import logging
import mlflow
from typing import Optional
logger = logging.getLogger(__name__)
class MLflowTracker:
def __init__(self, tracking_uri: str, experiment_name: str):
self.tracking_uri = tracking_uri
self.experiment_name = experiment_name
try:
mlflow.set_tracking_uri(self.tracking_uri)
mlflow.set_experiment(self.experiment_name)
except Exception as e:
logger.warning(f"Failed to initialize MLflow Tracking at {tracking_uri}: {e}")
def log_interaction(self, query: str, answer: str, chunks: list[dict], latency_ms: int, ragas_scores: dict) -> None:
"""
Logs to the current active run. Each call is a new child run under the parent experiment.
"""
try:
with mlflow.start_run(nested=True, run_name="Query_Interaction"):
mlflow.log_param("query", query)
# We can log answer as an artifact or param (param limited to 250 chars)
# It's better to log extensive text into artifacts
interaction_data = {
"query": query,
"answer": answer,
"chunks_used": [c.get("metadata", {}).get("doc_id", "unknown") for c in chunks]
}
# MLflow metrics
mlflow.log_metric("latency_ms", latency_ms)
if ragas_scores:
for metric_name, score in ragas_scores.items():
mlflow.log_metric(f"ragas_{metric_name}", score)
# Save interaction JSON locally and log as artifact
import tempfile
with tempfile.NamedTemporaryFile("w+", suffix=".json", delete=False) as f:
json.dump(interaction_data, f)
temp_path = f.name
mlflow.log_artifact(temp_path, "interaction_details")
import os
os.unlink(temp_path)
except Exception as e:
logger.error(f"Failed to log interaction to MLflow: {e}")
def log_eval_suite(self, results: dict, filepath: str) -> None:
"""
Logs full eval suite results as a run with metric history.
Saves eval JSON as artifact.
"""
try:
with mlflow.start_run(run_name="Evaluation_Suite"):
# Log top level metrics
ragas = results.get("ragas", {})
for k, v in ragas.items():
mlflow.log_metric(f"suite_{k}", float(v))
custom = results.get("custom", {})
for k, v in custom.items():
mlflow.log_metric(f"suite_{k}", float(v))
# Log the artifact file directly
mlflow.log_artifact(filepath, "evaluation_reports")
logger.info("Evaluation Suite saved successfully into MLflow.")
except Exception as e:
logger.error(f"Failed to log eval suite to MLflow: {e}")
def compare_reranker_runs(self, run_id_old: str, run_id_new: str) -> bool:
"""
Returns True if new run's MAP@10 > old run by > 0.02.
Queries the MLflow API for run records.
"""
try:
client = mlflow.tracking.MlflowClient(self.tracking_uri)
old_run = client.get_run(run_id_old)
new_run = client.get_run(run_id_new)
old_map = old_run.data.metrics.get("map_at_10", 0.0)
new_map = new_run.data.metrics.get("map_at_10", 0.0)
return new_map > old_map + 0.02
except Exception as e:
logger.error(f"Failed comparing MLflow runs {run_id_old} / {run_id_new}: {e}")
return False