import asyncio import os import time from typing import Optional from uuid import uuid4 from fastapi import FastAPI, Form, Header, HTTPException, Request, BackgroundTasks from fastapi.responses import HTMLResponse from fastapi.templating import Jinja2Templates from huggingface_hub import create_discussion, comment_discussion from build_map import load_dataset_and_metadata, upload_dataset_to_atlas from models import WebhookPayload # WEBHOOK_SECRET = os.environ.get("WEBHOOK_SECRET") HUGGINGFACE_ACCESS_TOKEN = os.environ.get("HUGGINGFACE_ACCESS_TOKEN") app = FastAPI() # TODO: use task management queue tasks = {} templates = Jinja2Templates(directory="templates") def upload_atlas_task(task_id: str, dataset_name: str, atlas_api_token: str, webhook_payload: WebhookPayload = None, webhook_notify: bool = False): dataset_dict = load_dataset_and_metadata(dataset_name) map_url = upload_dataset_to_atlas(dataset_dict, atlas_api_token) tasks[task_id]['status'] = 'done' tasks[task_id]['url'] = map_url tasks[task_id]['finish_time'] = time.time() if webhook_notify: discussion = create_discussion( repo_id=webhook_payload.repo.name, title="Atlas Maps", token=HUGGINGFACE_ACCESS_TOKEN, repo_type="dataset" ) comment_discussion( repo_id=webhook_payload.repo.name, discussion_num=discussion.num, comment="Atlas Map: " + map_url, token=HUGGINGFACE_ACCESS_TOKEN, repo_type="dataset" ) @app.on_event("startup") async def startup_event(): asyncio.create_task(cleanup_tasks()) async def cleanup_tasks(): while True: current_time = time.time() tasks_to_delete = [] for task_id, task in tasks.items(): if task['status'] == 'done' and current_time - task.get('finish_time', current_time) > 1800: # 30 minutes tasks_to_delete.append(task_id) for task_id in tasks_to_delete: del tasks[task_id] await asyncio.sleep(1800) # Wait for 30 minutes @app.get("/", response_class=HTMLResponse) async def read_form(request: Request): # Render the form.html template return templates.TemplateResponse("form.html", {"request": request}) @app.post("/submit_form") async def form_post(background_tasks: BackgroundTasks, dataset_name: str = Form(...), atlas_api_token: str = Form(...)): task_id = str(uuid4()) tasks[task_id] = {'status': 'running'} #form_data = DatasetForm(dataset_name=dataset_name) background_tasks.add_task(upload_atlas_task, task_id, dataset_name, atlas_api_token) return {'task_id': task_id} @app.get("/status/{task_id}") async def read_task(task_id: str): if task_id not in tasks: return {'status': 'not found'} else: return tasks[task_id] @app.post("/webhook") async def post_webhook(background_tasks: BackgroundTasks, payload: WebhookPayload, x_webhook_secret: Optional[str] = Header(default=None)): if x_webhook_secret is None: raise HTTPException(401) # if x_webhook_secret != WEBHOOK_SECRET: # raise HTTPException(403) if not ( payload.event.action == "update" and payload.event.scope.startswith("repo.content") and payload.repo.type == "dataset" ): return {"processed": False} else: task_id = str(uuid4()) tasks[task_id] = {'status': 'running'} #form_data = DatasetForm(dataset_name=dataset_name) background_tasks.add_task(upload_atlas_task, task_id, payload.repo.name, x_webhook_secret, payload, True) return {'task_id': task_id}