|
from fastapi import APIRouter, HTTPException |
|
from fastapi.responses import StreamingResponse, JSONResponse |
|
from huggingface_hub import hf_hub_download, snapshot_download |
|
import os |
|
import tempfile |
|
import shutil |
|
import zipfile |
|
import io |
|
import logging |
|
import json |
|
from datasets import load_dataset |
|
|
|
router = APIRouter(tags=["download"]) |
|
|
|
@router.get("/download-dataset/{session_id}") |
|
async def download_dataset(session_id: str): |
|
""" |
|
Downloads the HuggingFace dataset associated with a session and returns it to the client |
|
|
|
Args: |
|
session_id: Session identifier |
|
|
|
Returns: |
|
ZIP file containing the dataset |
|
""" |
|
try: |
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
|
repo_id = f"yourbench/yourbench_{session_id}" |
|
|
|
try: |
|
|
|
logging.info(f"Downloading dataset {repo_id}") |
|
snapshot_path = snapshot_download( |
|
repo_id=repo_id, |
|
repo_type="dataset", |
|
local_dir=temp_dir, |
|
token=os.environ.get("HF_TOKEN") |
|
) |
|
|
|
logging.info(f"Dataset downloaded to {snapshot_path}") |
|
|
|
|
|
zip_io = io.BytesIO() |
|
with zipfile.ZipFile(zip_io, 'w', zipfile.ZIP_DEFLATED) as zip_file: |
|
|
|
for root, _, files in os.walk(snapshot_path): |
|
for file in files: |
|
file_path = os.path.join(root, file) |
|
arc_name = os.path.relpath(file_path, snapshot_path) |
|
zip_file.write(file_path, arcname=arc_name) |
|
|
|
|
|
zip_io.seek(0) |
|
|
|
|
|
filename = f"yourbench_{session_id}_dataset.zip" |
|
return StreamingResponse( |
|
zip_io, |
|
media_type="application/zip", |
|
headers={"Content-Disposition": f"attachment; filename={filename}"} |
|
) |
|
|
|
except Exception as e: |
|
logging.error(f"Error while downloading the dataset: {str(e)}") |
|
raise HTTPException( |
|
status_code=500, |
|
detail=f"Error while downloading the dataset: {str(e)}" |
|
) |
|
except Exception as e: |
|
logging.error(f"General error: {str(e)}") |
|
raise HTTPException( |
|
status_code=500, |
|
detail=f"Error during download: {str(e)}" |
|
) |
|
|
|
@router.get("/download-questions/{session_id}") |
|
async def download_questions(session_id: str): |
|
""" |
|
Downloads the questions generated for a session in JSON format |
|
|
|
Args: |
|
session_id: Session identifier |
|
|
|
Returns: |
|
JSON file containing only the list of generated questions |
|
""" |
|
try: |
|
|
|
dataset_repo_id = f"yourbench/yourbench_{session_id}" |
|
|
|
|
|
all_questions = [] |
|
|
|
|
|
try: |
|
single_dataset = load_dataset(dataset_repo_id, 'single_shot_questions') |
|
if single_dataset and len(single_dataset['train']) > 0: |
|
for idx in range(len(single_dataset['train'])): |
|
all_questions.append({ |
|
"id": str(idx), |
|
"question": single_dataset['train'][idx].get("question", ""), |
|
"answer": single_dataset['train'][idx].get("self_answer", "No answer available"), |
|
"type": "single_shot" |
|
}) |
|
logging.info(f"Loaded {len(all_questions)} single-shot questions") |
|
except Exception as e: |
|
logging.error(f"Error loading single-shot questions: {str(e)}") |
|
|
|
|
|
try: |
|
multi_dataset = load_dataset(dataset_repo_id, 'multi_hop_questions') |
|
if multi_dataset and len(multi_dataset['train']) > 0: |
|
start_idx = len(all_questions) |
|
for idx in range(len(multi_dataset['train'])): |
|
all_questions.append({ |
|
"id": str(start_idx + idx), |
|
"question": multi_dataset['train'][idx].get("question", ""), |
|
"answer": multi_dataset['train'][idx].get("self_answer", "No answer available"), |
|
"type": "multi_hop" |
|
}) |
|
logging.info(f"Loaded {len(multi_dataset['train'])} multi-hop questions") |
|
except Exception as e: |
|
logging.error(f"Error loading multi-hop questions: {str(e)}") |
|
|
|
|
|
if len(all_questions) == 0: |
|
raise HTTPException(status_code=404, detail="No questions found for this session") |
|
|
|
|
|
questions_json = json.dumps(all_questions, ensure_ascii=False, indent=2) |
|
|
|
|
|
json_bytes = io.BytesIO(questions_json.encode('utf-8')) |
|
json_bytes.seek(0) |
|
|
|
|
|
filename = f"yourbench_{session_id}_questions.json" |
|
return StreamingResponse( |
|
json_bytes, |
|
media_type="application/json", |
|
headers={"Content-Disposition": f"attachment; filename={filename}"} |
|
) |
|
|
|
except HTTPException: |
|
|
|
raise |
|
except Exception as e: |
|
logging.error(f"Error retrieving questions: {str(e)}") |
|
raise HTTPException( |
|
status_code=500, |
|
detail=f"Error downloading questions: {str(e)}" |
|
) |