Spaces:
Build error
Build error
File size: 2,987 Bytes
55d6386 a075ab3 bb9ba1d 55d6386 a7c7f70 55d6386 4f0a423 55d6386 a075ab3 a7c7f70 55d6386 a7c7f70 55d6386 a7c7f70 55d6386 a7c7f70 bb9ba1d a7c7f70 55d6386 a075ab3 a7c7f70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import os
import requests
from typing import Optional
import uvicorn
from subprocess import Popen
import yaml
import datetime
from fastapi import FastAPI, Header, BackgroundTasks
from fastapi.responses import FileResponse
from huggingface_hub.hf_api import HfApi
from src.models import config, WebhookPayload
app = FastAPI()
WEBHOOK_SECRET = os.getenv("WEBHOOK_SECRET")
HF_ACCESS_TOKEN = os.getenv("HF_ACCESS_TOKEN")
@app.get("/")
async def home():
return FileResponse("home.html")
@app.post("/webhook")
async def post_webhook(
payload: WebhookPayload,
task_queue: BackgroundTasks,
x_webhook_secret: Optional[str] = Header(default=None),
):
# if x_webhook_secret is None:
# raise HTTPException(401)
# if x_webhook_secret != WEBHOOK_SECRET:
# raise HTTPException(403)
# if not (
# payload.event.action == "update"
# and payload.event.scope.startswith("repo.content")
# and payload.repo.name == config.input_dataset
# and payload.repo.type == "dataset"
# ):
# # no-op
# return {"processed": False}
schedule_retrain(payload=payload)
# task_queue.add_task(
# schedule_retrain,
# payload
# )
return {"processed": True}
def schedule_retrain(payload: WebhookPayload):
# Create the autotrain project
try:
yaml_path = os.path.join(os.getcwd(), "config.yaml")
with open(yaml_path) as f:
list_doc = yaml.safe_load(f)
list_doc['project_name'] = datetime.datetime.now().isoformat()
with open(yaml_path) as f:
yaml.dump(list_doc, f, default_flow_style=False)
result = Popen(['autotrain', '--config', yaml_path])
# project = AutoTrain.create_project(payload)
# AutoTrain.add_data(project_id=project["id"])
# AutoTrain.start_processing(project_id=project["id"])
except requests.HTTPError as err:
print("ERROR while requesting AutoTrain API:")
print(f" code: {err.response.status_code}")
print(f" {err.response.json()}")
raise
# Notify in the community tab
notify_success('vicuna')
print(result.returncode)
return {"processed": True}
def notify_success(project_id: str):
message = NOTIFICATION_TEMPLATE.format(
input_model=config.input_model,
input_dataset=config.input_dataset,
project_id=project_id,
)
return HfApi(token=HF_ACCESS_TOKEN).create_discussion(
repo_id=config.input_dataset,
repo_type="dataset",
title="✨ Retraining started!",
description=message,
token=HF_ACCESS_TOKEN,
)
NOTIFICATION_TEMPLATE = """\
🌸 Hello there!
Following an update of [{input_dataset}](https://huggingface.co/datasets/{input_dataset}), an automatic re-training of [{input_model}](https://huggingface.co/{input_model}) has been scheduled on AutoTrain!
(This is an automated message)
"""
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
|