jbaselga commited on
Commit
0385645
·
verified ·
1 Parent(s): 401aee4

Create api_server.py

Browse files
Files changed (1) hide show
  1. api_server.py +59 -0
api_server.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from typing import List, Dict, Any, Union
4
+ from datasets import load_dataset
5
+ import random
6
+ import os
7
+
8
+ app = FastAPI()
9
+
10
+ # Carga y filtra nivel 1 GAIA (validation split)
11
+ ds = load_dataset("gaia-benchmark/GAIA", "2023_level1", split="validation", trust_remote_code=True)
12
+ QUESTIONS = []
13
+ GROUND_TRUTH: Dict[str, str] = {}
14
+
15
+ for item in ds:
16
+ task_id = str(item["task_id"])
17
+ QUESTIONS.append({
18
+ "task_id": task_id,
19
+ "question": item["Question"]
20
+ })
21
+ GROUND_TRUTH[task_id] = str(item["Final answer"])
22
+
23
+ class AnswerItem(BaseModel):
24
+ task_id: str
25
+ submitted_answer: Union[str, int, float]
26
+
27
+ class Submission(BaseModel):
28
+ username: str
29
+ agent_code: str
30
+ answers: List[AnswerItem]
31
+
32
+ class ScoreResponse(BaseModel):
33
+ username: str
34
+ score: float
35
+ correct_count: int
36
+ total_attempted: int
37
+ message: str
38
+
39
+ @app.get("/questions")
40
+ def get_questions():
41
+ # Devuelve las 20 preguntas aleatorias de nivel 1 cada vez
42
+ chosen = random.sample(QUESTIONS, k=min(20, len(QUESTIONS)))
43
+ return chosen
44
+
45
+ @app.post("/submit")
46
+ def submit(sub: Submission):
47
+ correct = sum(
48
+ 1 for ans in sub.answers
49
+ if GROUND_TRUTH.get(ans.task_id, "") == str(ans.submitted_answer).strip()
50
+ )
51
+ total = len(sub.answers)
52
+ score = correct / total * 100 if total > 0 else 0.0
53
+ return ScoreResponse(
54
+ username=sub.username,
55
+ score=score,
56
+ correct_count=correct,
57
+ total_attempted=total,
58
+ message=f"Puntuación: {correct}/{total} = {score:.1f}%"
59
+ )