File size: 2,987 Bytes
55d6386
 
 
a075ab3
 
bb9ba1d
 
55d6386
a7c7f70
55d6386
 
 
4f0a423
55d6386
 
 
a075ab3
 
a7c7f70
 
55d6386
 
a7c7f70
 
55d6386
 
 
a7c7f70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55d6386
 
 
a7c7f70
 
bb9ba1d
 
 
 
 
 
 
 
 
a7c7f70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55d6386
 
 
 
 
 
 
 
a075ab3
 
a7c7f70
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import requests
from typing import Optional
import uvicorn
from subprocess import Popen
import yaml
import datetime

from fastapi import FastAPI, Header, BackgroundTasks
from fastapi.responses import FileResponse
from huggingface_hub.hf_api import HfApi

from src.models import config, WebhookPayload

app = FastAPI()

WEBHOOK_SECRET = os.getenv("WEBHOOK_SECRET")
HF_ACCESS_TOKEN = os.getenv("HF_ACCESS_TOKEN")


@app.get("/")
async def home():
    return FileResponse("home.html")


@app.post("/webhook")
async def post_webhook(
        payload: WebhookPayload,
        task_queue: BackgroundTasks,
        x_webhook_secret: Optional[str] = Header(default=None),
):
    # if x_webhook_secret is None:
    # 	raise HTTPException(401)
    # if x_webhook_secret != WEBHOOK_SECRET:
    # 	raise HTTPException(403)
    # if not (
    # 	payload.event.action == "update"
    # 	and payload.event.scope.startswith("repo.content")
    # 	and payload.repo.name == config.input_dataset
    # 	and payload.repo.type == "dataset"
    # ):
    # 	# no-op
    # 	return {"processed": False}
    schedule_retrain(payload=payload)
    # task_queue.add_task(
    # 	schedule_retrain,
    # 	payload
    # )

    return {"processed": True}


def schedule_retrain(payload: WebhookPayload):
    # Create the autotrain project
    try:
        yaml_path = os.path.join(os.getcwd(), "config.yaml")
        with open(yaml_path) as f:
            list_doc = yaml.safe_load(f)
            list_doc['project_name'] = datetime.datetime.now().isoformat()

        with open(yaml_path) as f:
            yaml.dump(list_doc, f, default_flow_style=False)

        result = Popen(['autotrain', '--config', yaml_path])
    # project = AutoTrain.create_project(payload)
    # AutoTrain.add_data(project_id=project["id"])
    # AutoTrain.start_processing(project_id=project["id"])
    except requests.HTTPError as err:
        print("ERROR while requesting AutoTrain API:")
        print(f"  code: {err.response.status_code}")
        print(f"  {err.response.json()}")
        raise
    # Notify in the community tab
    notify_success('vicuna')
    print(result.returncode)
    return {"processed": True}


def notify_success(project_id: str):
    message = NOTIFICATION_TEMPLATE.format(
        input_model=config.input_model,
        input_dataset=config.input_dataset,
        project_id=project_id,
    )
    return HfApi(token=HF_ACCESS_TOKEN).create_discussion(
        repo_id=config.input_dataset,
        repo_type="dataset",
        title="✨ Retraining started!",
        description=message,
        token=HF_ACCESS_TOKEN,
    )


NOTIFICATION_TEMPLATE = """\
🌸 Hello there!

Following an update of [{input_dataset}](https://huggingface.co/datasets/{input_dataset}), an automatic re-training of [{input_model}](https://huggingface.co/{input_model}) has been scheduled on AutoTrain!

(This is an automated message)
"""

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)