atlas-map / main.py
Richard Guo
nomic login
1779f92
import asyncio
import os
import time
from typing import Optional
from uuid import uuid4
from fastapi import FastAPI, Form, Header, HTTPException, Request, BackgroundTasks
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from huggingface_hub import create_discussion, comment_discussion
from build_map import load_dataset_and_metadata, upload_dataset_to_atlas
from models import WebhookPayload
# WEBHOOK_SECRET = os.environ.get("WEBHOOK_SECRET")
HUGGINGFACE_ACCESS_TOKEN = os.environ.get("HUGGINGFACE_ACCESS_TOKEN")
app = FastAPI()
# TODO: use task management queue
tasks = {}
templates = Jinja2Templates(directory="templates")
def upload_atlas_task(task_id: str,
dataset_name: str,
atlas_api_token: str,
webhook_payload: WebhookPayload = None,
webhook_notify: bool = False):
dataset_dict = load_dataset_and_metadata(dataset_name)
map_url = upload_dataset_to_atlas(dataset_dict, atlas_api_token)
tasks[task_id]['status'] = 'done'
tasks[task_id]['url'] = map_url
tasks[task_id]['finish_time'] = time.time()
if webhook_notify:
discussion = create_discussion(
repo_id=webhook_payload.repo.name,
title="Atlas Maps",
token=HUGGINGFACE_ACCESS_TOKEN,
repo_type="dataset"
)
comment_discussion(
repo_id=webhook_payload.repo.name,
discussion_num=discussion.num,
comment="Atlas Map: " + map_url,
token=HUGGINGFACE_ACCESS_TOKEN,
repo_type="dataset"
)
@app.on_event("startup")
async def startup_event():
asyncio.create_task(cleanup_tasks())
async def cleanup_tasks():
while True:
current_time = time.time()
tasks_to_delete = []
for task_id, task in tasks.items():
if task['status'] == 'done' and current_time - task.get('finish_time', current_time) > 1800: # 30 minutes
tasks_to_delete.append(task_id)
for task_id in tasks_to_delete:
del tasks[task_id]
await asyncio.sleep(1800) # Wait for 30 minutes
@app.get("/", response_class=HTMLResponse)
async def read_form(request: Request):
# Render the form.html template
return templates.TemplateResponse("form.html", {"request": request})
@app.post("/submit_form")
async def form_post(background_tasks: BackgroundTasks, dataset_name: str = Form(...), atlas_api_token: str = Form(...)):
task_id = str(uuid4())
tasks[task_id] = {'status': 'running'}
#form_data = DatasetForm(dataset_name=dataset_name)
background_tasks.add_task(upload_atlas_task, task_id, dataset_name, atlas_api_token)
return {'task_id': task_id}
@app.get("/status/{task_id}")
async def read_task(task_id: str):
if task_id not in tasks:
return {'status': 'not found'}
else:
return tasks[task_id]
@app.post("/webhook")
async def post_webhook(background_tasks: BackgroundTasks, payload: WebhookPayload, x_webhook_secret: Optional[str] = Header(default=None)):
if x_webhook_secret is None:
raise HTTPException(401)
# if x_webhook_secret != WEBHOOK_SECRET:
# raise HTTPException(403)
if not (
payload.event.action == "update"
and payload.event.scope.startswith("repo.content")
and payload.repo.type == "dataset"
):
return {"processed": False}
else:
task_id = str(uuid4())
tasks[task_id] = {'status': 'running'}
#form_data = DatasetForm(dataset_name=dataset_name)
background_tasks.add_task(upload_atlas_task, task_id, payload.repo.name, x_webhook_secret, payload, True)
return {'task_id': task_id}