|
from fastapi import APIRouter, HTTPException |
|
from typing import Dict, Any |
|
import os |
|
import time |
|
from tasks.create_bench_config_file import CreateBenchConfigTask |
|
from tasks.create_bench import CreateBenchTask |
|
|
|
router = APIRouter(tags=["benchmark"]) |
|
|
|
|
|
active_tasks = {} |
|
|
|
|
|
|
|
session_files = {} |
|
|
|
@router.post("/generate-benchmark") |
|
async def generate_benchmark(data: Dict[str, Any]): |
|
""" |
|
Generate a benchmark configuration and run the ingestion process |
|
|
|
Args: |
|
data: Dictionary containing session_id |
|
|
|
Returns: |
|
Dictionary with logs and status |
|
""" |
|
session_id = data.get("session_id") |
|
|
|
|
|
print(f"DEBUG: Session ID received: {session_id}") |
|
print(f"DEBUG: Available session files: {list(router.session_files.keys())}") |
|
|
|
if not session_id or session_id not in router.session_files: |
|
return {"error": "Invalid or missing session ID"} |
|
|
|
|
|
if session_id in active_tasks: |
|
task = active_tasks[session_id] |
|
|
|
if task.is_task_completed(): |
|
return { |
|
"status": "already_completed", |
|
"logs": task.get_logs(), |
|
"is_completed": True |
|
} |
|
|
|
else: |
|
return { |
|
"status": "already_running", |
|
"logs": task.get_logs(), |
|
"is_completed": False |
|
} |
|
|
|
file_path = router.session_files[session_id] |
|
all_logs = [] |
|
|
|
try: |
|
|
|
task = UnifiedBenchmarkTask(session_uid=session_id) |
|
|
|
|
|
active_tasks[session_id] = task |
|
|
|
|
|
task.run(file_path) |
|
|
|
|
|
all_logs = task.get_logs() |
|
|
|
return { |
|
"status": "running", |
|
"logs": all_logs |
|
} |
|
except Exception as e: |
|
return { |
|
"status": "error", |
|
"error": str(e), |
|
"logs": all_logs |
|
} |
|
|
|
@router.get("/benchmark-progress/{session_id}") |
|
async def get_benchmark_progress(session_id: str): |
|
""" |
|
Get the logs and status for a running benchmark task |
|
|
|
Args: |
|
session_id: Session ID for the task |
|
|
|
Returns: |
|
Dictionary with logs and completion status |
|
""" |
|
if session_id not in active_tasks: |
|
raise HTTPException(status_code=404, detail="Benchmark task not found") |
|
|
|
task = active_tasks[session_id] |
|
logs = task.get_logs() |
|
is_completed = task.is_task_completed() |
|
|
|
return { |
|
"logs": logs, |
|
"is_completed": is_completed |
|
} |
|
|
|
|
|
class UnifiedBenchmarkTask: |
|
""" |
|
Task that handles the entire benchmark process from configuration to completion |
|
""" |
|
|
|
def __init__(self, session_uid: str): |
|
""" |
|
Initialize the unified benchmark task |
|
|
|
Args: |
|
session_uid: Session ID for this task |
|
""" |
|
self.session_uid = session_uid |
|
self.logs = [] |
|
self.is_completed = False |
|
self.config_task = None |
|
self.bench_task = None |
|
|
|
self._add_log("[INFO] Initializing benchmark task") |
|
|
|
def _add_log(self, message: str): |
|
""" |
|
Add a log message |
|
|
|
Args: |
|
message: Log message to add |
|
""" |
|
if message not in self.logs: |
|
self.logs.append(message) |
|
|
|
self.logs = self.logs.copy() |
|
print(f"[{self.session_uid}] {message}") |
|
|
|
def get_logs(self): |
|
""" |
|
Get all logs |
|
|
|
Returns: |
|
List of log messages |
|
""" |
|
return self.logs.copy() |
|
|
|
def is_task_completed(self): |
|
""" |
|
Check if the task is completed |
|
|
|
Returns: |
|
True if completed, False otherwise |
|
""" |
|
return self.is_completed |
|
|
|
def run(self, file_path: str): |
|
""" |
|
Run the benchmark process |
|
|
|
Args: |
|
file_path: Path to the uploaded file |
|
""" |
|
|
|
import threading |
|
thread = threading.Thread(target=self._run_process, args=(file_path,)) |
|
thread.daemon = True |
|
thread.start() |
|
|
|
def _run_process(self, file_path: str): |
|
""" |
|
Internal method to run the process |
|
|
|
Args: |
|
file_path: Path to the uploaded file |
|
""" |
|
try: |
|
|
|
self._add_log("[INFO] Starting configuration process") |
|
|
|
from config.models_config import DEFAULT_BENCHMARK_TIMEOUT |
|
self.config_task = CreateBenchConfigTask(session_uid=self.session_uid, timeout=DEFAULT_BENCHMARK_TIMEOUT) |
|
|
|
|
|
try: |
|
config_path = self.config_task.run(file_path=file_path) |
|
|
|
|
|
config_logs = self.config_task.get_logs() |
|
for log in config_logs: |
|
self._add_log(log) |
|
|
|
|
|
if "[SUCCESS] Stage completed: config_generation" not in self.logs: |
|
self._add_log("[SUCCESS] Stage completed: configuration") |
|
|
|
|
|
self._add_log("[INFO] Starting benchmark process") |
|
self.bench_task = CreateBenchTask(session_uid=self.session_uid, config_path=config_path) |
|
|
|
|
|
self.bench_task.run() |
|
|
|
|
|
while not self.bench_task.is_task_completed(): |
|
|
|
bench_logs = self.bench_task.get_logs() |
|
for log in bench_logs: |
|
self._add_log(log) |
|
time.sleep(1) |
|
|
|
|
|
final_logs = self.bench_task.get_logs() |
|
for log in final_logs: |
|
self._add_log(log) |
|
|
|
|
|
self.is_completed = True |
|
|
|
|
|
|
|
has_error = any("[ERROR]" in log and not ("JSONDecodeError" in log or |
|
"Error processing QA pair" in log or |
|
"'str' object has no attribute 'get'" in log) |
|
for log in final_logs) |
|
benchmark_terminated_with_error = any("Benchmark process terminated with error code" in log for log in final_logs) |
|
benchmark_already_marked_success = any("Benchmark process completed successfully" in log for log in final_logs) |
|
|
|
|
|
json_errors_only = any(("JSONDecodeError" in log or |
|
"Error processing QA pair" in log or |
|
"'str' object has no attribute 'get'" in log) |
|
for log in final_logs) and not has_error |
|
|
|
if json_errors_only: |
|
self._add_log("[INFO] Benchmark completed with minor JSON parsing warnings, considered successful") |
|
|
|
|
|
if (not has_error and not benchmark_terminated_with_error and not benchmark_already_marked_success) or json_errors_only: |
|
self._add_log("[SUCCESS] Benchmark process completed successfully") |
|
|
|
except Exception as config_error: |
|
error_msg = str(config_error) |
|
|
|
self._add_log(f"[ERROR] Configuration failed: {error_msg}") |
|
|
|
|
|
if "Required models not available" in error_msg: |
|
self._add_log("[ERROR] Some required models are not available at the moment. Please try again later.") |
|
|
|
|
|
self.is_completed = True |
|
|
|
except Exception as e: |
|
self._add_log(f"[ERROR] Benchmark process failed: {str(e)}") |
|
self.is_completed = True |