|
from fastapi import APIRouter, HTTPException |
|
from typing import Dict, Any |
|
import os |
|
from tasks.evaluation_task import EvaluationTask |
|
from huggingface_hub import hf_hub_download |
|
import json |
|
from datetime import datetime |
|
import asyncio |
|
|
|
router = APIRouter(tags=["evaluation"]) |
|
|
|
|
|
active_evaluation_tasks = {} |
|
|
|
@router.post("/evaluate-benchmark") |
|
async def evaluate_benchmark(data: Dict[str, Any]): |
|
""" |
|
Lancer l'évaluation d'un benchmark pour une session donnée |
|
|
|
Args: |
|
data: Dictionary contenant session_id |
|
|
|
Returns: |
|
Dictionary avec statut et logs initiaux |
|
""" |
|
session_id = data.get("session_id") |
|
|
|
if not session_id: |
|
return {"error": "Session ID manquant ou invalide"} |
|
|
|
|
|
if session_id in active_evaluation_tasks: |
|
evaluation_task = active_evaluation_tasks[session_id] |
|
|
|
if evaluation_task.is_task_completed(): |
|
|
|
del active_evaluation_tasks[session_id] |
|
else: |
|
|
|
return { |
|
"status": "already_running", |
|
"message": "Une évaluation est déjà en cours pour cette session", |
|
"logs": evaluation_task.get_logs() |
|
} |
|
|
|
try: |
|
|
|
dataset_name = f"yourbench/yourbench_{session_id}" |
|
|
|
|
|
evaluation_task = EvaluationTask(session_uid=session_id, dataset_name=dataset_name) |
|
active_evaluation_tasks[session_id] = evaluation_task |
|
|
|
|
|
asyncio.create_task(evaluation_task.run()) |
|
|
|
|
|
initial_logs = evaluation_task.get_logs() |
|
|
|
return { |
|
"status": "started", |
|
"message": f"Évaluation démarrée pour le benchmark {dataset_name}", |
|
"logs": initial_logs |
|
} |
|
except Exception as e: |
|
return { |
|
"status": "error", |
|
"error": str(e), |
|
"message": f"Erreur lors du démarrage de l'évaluation: {str(e)}" |
|
} |
|
|
|
@router.get("/evaluation-logs/{session_id}") |
|
async def get_evaluation_logs(session_id: str): |
|
""" |
|
Récupérer les logs d'une évaluation en cours |
|
|
|
Args: |
|
session_id: ID de la session pour laquelle récupérer les logs |
|
|
|
Returns: |
|
Dictionary avec logs et statut de complétion |
|
""" |
|
if session_id not in active_evaluation_tasks: |
|
raise HTTPException(status_code=404, detail="Tâche d'évaluation non trouvée") |
|
|
|
evaluation_task = active_evaluation_tasks[session_id] |
|
logs = evaluation_task.get_logs() |
|
is_completed = evaluation_task.is_task_completed() |
|
|
|
|
|
results = None |
|
if is_completed and hasattr(evaluation_task, 'results') and evaluation_task.results: |
|
results = evaluation_task.results |
|
|
|
|
|
progress = evaluation_task.get_progress() |
|
|
|
return { |
|
"logs": logs, |
|
"is_completed": is_completed, |
|
"results": results, |
|
"current_step": progress["current_step"], |
|
"completed_steps": progress["completed_steps"] |
|
} |
|
|
|
@router.get("/evaluation-results/{session_id}") |
|
async def get_evaluation_results(session_id: str): |
|
""" |
|
Retrieve results of a completed evaluation |
|
|
|
Args: |
|
session_id: Session ID to retrieve results for |
|
|
|
Returns: |
|
Dictionary with evaluation results |
|
""" |
|
try: |
|
|
|
organization = os.getenv("HF_ORGANIZATION", "yourbench") |
|
dataset_name = f"{organization}/yourbench_{session_id}" |
|
|
|
|
|
try: |
|
results_file = hf_hub_download( |
|
repo_id=dataset_name, |
|
repo_type="dataset", |
|
filename="lighteval_results.json" |
|
) |
|
|
|
with open(results_file) as f: |
|
results_data = json.load(f) |
|
|
|
|
|
if "results" in results_data and isinstance(results_data["results"], list): |
|
|
|
results_list = results_data["results"] |
|
metadata = results_data.get("metadata", {}) |
|
else: |
|
|
|
results_list = results_data |
|
metadata = {} |
|
|
|
|
|
formatted_results = { |
|
"metadata": { |
|
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
|
"session_id": metadata.get("session_id", session_id), |
|
"total_models_tested": len(results_list), |
|
"successful_tests": len([r for r in results_list if r.get("status") == "success"]) |
|
}, |
|
"models_comparison": [ |
|
{ |
|
"model_name": result["model"], |
|
"provider": result["provider"], |
|
"success": result.get("status") == "success", |
|
"accuracy": result["accuracy"], |
|
"evaluation_time": result["execution_time"], |
|
"error": result.get("status") if result.get("status") != "success" else None |
|
} |
|
for result in results_list |
|
] |
|
} |
|
|
|
return { |
|
"success": True, |
|
"results": formatted_results |
|
} |
|
except Exception as e: |
|
return { |
|
"success": False, |
|
"message": f"Failed to load results from Hub: {str(e)}" |
|
} |
|
|
|
except Exception as e: |
|
return { |
|
"success": False, |
|
"message": f"Error retrieving evaluation results: {str(e)}" |
|
} |