Saketh-Reddy commited on
Commit
b9721b4
1 Parent(s): 393b323

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +62 -2
main.py CHANGED
@@ -1,7 +1,67 @@
1
- from fastapi import FastAPI
 
 
 
 
 
 
 
2
 
3
  app = FastAPI()
4
 
 
 
 
 
 
5
  @app.get("/")
6
  def read_root():
7
- return {"Hello": "World!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from typing import Optional
4
+ from huggingface_hub import snapshot_download
5
+ from fastapi import FastAPI, Header, HTTPException, BackgroundTasks
6
+ from huggingface_hub.hf_api import HfApi
7
+
8
+ from .models import config, WebhookPayload
9
 
10
  app = FastAPI()
11
 
12
+ api = HfApi(
13
+     endpoint="https://huggingface.co/<organization>",
14
+     token="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"
15
+ )
16
+
17
  @app.get("/")
18
  def read_root():
19
+ return {"Hello": "World!"}
20
+
21
+ @app.post("/webhook")
22
+ async def post_webhook(
23
+ payload: WebhookPayload,
24
+ task_queue: BackgroundTasks,
25
+ x_webhook_secret: Optional[str] = Header(default=None),
26
+ ):
27
+ if x_webhook_secret is None:
28
+ raise HTTPException(401)
29
+ if x_webhook_secret != "webhooksecret":
30
+ raise HTTPException(403)
31
+
32
+ if not (
33
+ payload.event.action == "update"
34
+ and payload.event.scope.startswith("repo.content")
35
+ and payload.repo.name == "SakethTest/ThirdParty"
36
+ and payload.repo.type == "model"
37
+ ):
38
+ # no-op
39
+ return {"processed": False}
40
+
41
+ task_queue.add_task(
42
+ update_cloned_repo,
43
+ payload
44
+ )
45
+
46
+ return {"processed": True}
47
+
48
+
49
+ def update_cloned_repo(payload: WebhookPayload):
50
+ # Create the update_cloned_repo project
51
+ try:
52
+ snapshot_download(repo_id="SakethTest/ThirdParty",local_dir="ThirdParty")
53
+ api.upload_folder(
54
+ folder_path="./ThirdParty",
55
+ repo_id="shellplc/ThirdParty",
56
+ repo_type="model",
57
+ commit_message="uploaded third party model"
58
+ )
59
+ except requests.HTTPError as err:
60
+ print("ERROR while requesting AutoTrain API:")
61
+ print(f" code: {err.response.status_code}")
62
+ print(f" {err.response.json()}")
63
+ raise
64
+ # Notify in the community tab
65
+ notify_success(project["id"])
66
+
67
+ return {"processed": True}