Richard Guo commited on
Commit
81aaa4e
1 Parent(s): f47c911

huggingface cli requirement and webhook route

Browse files
Files changed (2) hide show
  1. main.py +38 -11
  2. requirements.txt +1 -0
main.py CHANGED
@@ -1,25 +1,48 @@
1
- from fastapi import FastAPI, Form, Request, BackgroundTasks
 
 
 
 
 
 
2
  from fastapi.responses import HTMLResponse
3
  from fastapi.templating import Jinja2Templates
4
-
5
- from uuid import uuid4
6
- import time
7
- import asyncio
8
 
9
  from build_map import load_dataset_and_metadata, upload_dataset_to_atlas
10
- from models import WebhookPayload, WebhookPayloadRepo, WebhookPayloadEvent
11
 
 
 
12
 
13
  app = FastAPI()
14
  # TODO: use task management queue
15
  tasks = {}
16
  templates = Jinja2Templates(directory="templates")
17
 
18
- def upload_atlas_task(task_id, dataset_name):
 
 
 
19
  dataset_dict = load_dataset_and_metadata(dataset_name)
20
- map_url = upload_dataset_to_atlas(dataset_dict, project_name="atlas-space-test")
21
  tasks[task_id]['status'] = 'done'
22
  tasks[task_id]['url'] = map_url
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  @app.on_event("startup")
25
  async def startup_event():
@@ -47,7 +70,6 @@ async def form_post(background_tasks: BackgroundTasks, dataset_name: str = Form(
47
  tasks[task_id] = {'status': 'running'}
48
  #form_data = DatasetForm(dataset_name=dataset_name)
49
  background_tasks.add_task(upload_atlas_task, task_id, dataset_name)
50
-
51
  return {'task_id': task_id}
52
 
53
  @app.get("/status/{task_id}")
@@ -58,7 +80,12 @@ async def read_task(task_id: str):
58
  return tasks[task_id]
59
 
60
  @app.post("/webhook")
61
- async def post_webhook(background_tasks: BackgroundTasks, payload: WebhookPayload):
 
 
 
 
 
62
  if not (
63
  payload.event.action == "update"
64
  and payload.event.scope.startswith("repo.content")
@@ -69,5 +96,5 @@ async def post_webhook(background_tasks: BackgroundTasks, payload: WebhookPayloa
69
  task_id = str(uuid4())
70
  tasks[task_id] = {'status': 'running'}
71
  #form_data = DatasetForm(dataset_name=dataset_name)
72
- background_tasks.add_task(upload_atlas_task, task_id, payload.repo.name)
73
  return {'task_id': task_id}
 
1
+ import asyncio
2
+ import os
3
+ import time
4
+ from typing import Optional
5
+ from uuid import uuid4
6
+
7
+ from fastapi import FastAPI, Form, Header, HTTPException, Request, BackgroundTasks
8
  from fastapi.responses import HTMLResponse
9
  from fastapi.templating import Jinja2Templates
10
+ from huggingface_hub import create_discussion, comment_discussion
 
 
 
11
 
12
  from build_map import load_dataset_and_metadata, upload_dataset_to_atlas
13
+ from models import WebhookPayload
14
 
15
+ WEBHOOK_SECRET = os.environ.get("WEBHOOK_SECRET")
16
+ HUGGINGFACE_ACCESS_TOKEN = os.environ.get("HUGGINGFACE_ACCESS_TOKEN")
17
 
18
  app = FastAPI()
19
  # TODO: use task management queue
20
  tasks = {}
21
  templates = Jinja2Templates(directory="templates")
22
 
23
+ def upload_atlas_task(task_id,
24
+ dataset_name,
25
+ webhook_payload: WebhookPayload = None,
26
+ webhook_notify: bool = False):
27
  dataset_dict = load_dataset_and_metadata(dataset_name)
28
+ map_url = upload_dataset_to_atlas(dataset_dict)
29
  tasks[task_id]['status'] = 'done'
30
  tasks[task_id]['url'] = map_url
31
+ tasks[task_id]['finish_time'] = time.time()
32
+
33
+ if webhook_notify:
34
+ discussion = create_discussion(
35
+ repo_id=webhook_payload.repo.id,
36
+ title="Atlas Maps",
37
+ token=HUGGINGFACE_ACCESS_TOKEN,
38
+ )
39
+ comment_discussion(
40
+ repo_id=webhook_payload.repo.id,
41
+ discussion_num=discussion.num,
42
+ comment="Atlas Map: " + map_url,
43
+ token=HUGGINGFACE_ACCESS_TOKEN
44
+ )
45
+
46
 
47
  @app.on_event("startup")
48
  async def startup_event():
 
70
  tasks[task_id] = {'status': 'running'}
71
  #form_data = DatasetForm(dataset_name=dataset_name)
72
  background_tasks.add_task(upload_atlas_task, task_id, dataset_name)
 
73
  return {'task_id': task_id}
74
 
75
  @app.get("/status/{task_id}")
 
80
  return tasks[task_id]
81
 
82
  @app.post("/webhook")
83
+ async def post_webhook(background_tasks: BackgroundTasks, payload: WebhookPayload, x_webhook_secret: Optional[str] = Header(default=None)):
84
+ if x_webhook_secret is None:
85
+ raise HTTPException(401)
86
+ if x_webhook_secret != WEBHOOK_SECRET:
87
+ raise HTTPException(403)
88
+
89
  if not (
90
  payload.event.action == "update"
91
  and payload.event.scope.startswith("repo.content")
 
96
  task_id = str(uuid4())
97
  tasks[task_id] = {'status': 'running'}
98
  #form_data = DatasetForm(dataset_name=dataset_name)
99
+ background_tasks.add_task(upload_atlas_task, task_id, payload.repo.name, payload, True)
100
  return {'task_id': task_id}
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
  datasets==2.13.0
2
  fastapi[all]
 
3
  nomic==2.0.3
4
  pandas==1.5.3
5
  pyarrow==12.0.1
 
1
  datasets==2.13.0
2
  fastapi[all]
3
+ huggingface-hub==0.16.4
4
  nomic==2.0.3
5
  pandas==1.5.3
6
  pyarrow==12.0.1