Spaces:
Paused
Paused
File size: 5,029 Bytes
55d6386 a075ab3 27ace8d 55d6386 a7c7f70 55d6386 4f0a423 55d6386 a075ab3 a7c7f70 55d6386 a7c7f70 55d6386 a7c7f70 d3a3927 a7c7f70 55d6386 a7c7f70 34a0b0f a7c7f70 0d91645 27ace8d 34a0b0f 27ace8d 4bf9627 27ace8d 1307227 a7c7f70 d3a3927 4cd954d a7c7f70 7794767 1535793 7794767 1a11765 7794767 1a11765 7794767 1a11765 7794767 1a11765 7794767 1a11765 55d6386 a075ab3 7794767 a075ab3 a7c7f70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import os
import requests
from typing import Optional
import uvicorn
from subprocess import Popen
import yaml
import datetime
from fastapi import FastAPI, Header, BackgroundTasks
from fastapi.responses import FileResponse
from huggingface_hub.hf_api import HfApi
from src.models import config, WebhookPayload
app = FastAPI()
WEBHOOK_SECRET = os.getenv("WEBHOOK_SECRET")
HF_ACCESS_TOKEN = os.getenv("HF_ACCESS_TOKEN")
@app.get("/")
async def home():
return FileResponse("home.html")
@app.post("/webhook")
async def post_webhook(
payload: WebhookPayload,
task_queue: BackgroundTasks,
x_webhook_secret: Optional[str] = Header(default=None),
):
# if x_webhook_secret is None:
# raise HTTPException(401)
# if x_webhook_secret != WEBHOOK_SECRET:
# raise HTTPException(403)
# if not (
# payload.event.action == "update"
# and payload.event.scope.startswith("repo.content")
# and payload.repo.name == config.input_dataset
# and payload.repo.type == "dataset"
# ):
# # no-op
# return {"processed": False}
# schedule_retrain(payload=payload)
task_queue.add_task(
schedule_retrain,
payload
)
return {"processed": True}
def schedule_retrain(payload: WebhookPayload):
# Create the autotrain project
id = str(int(datetime.datetime.now().timestamp()))
try:
yaml_path = os.path.join(os.getcwd(), "src/config.yaml")
with open(yaml_path) as f:
list_doc = yaml.safe_load(f)
list_doc['project_name'] = id
with open(yaml_path, "w") as f:
yaml.dump(list_doc, f, default_flow_style=False)
result = Popen(['autotrain', '--config', yaml_path])
result.wait()
# project = AutoTrain.create_project(payload)
# AutoTrain.add_data(project_id=project["id"])
# AutoTrain.start_processing(project_id=project["id"])
except requests.HTTPError as err:
print("ERROR while requesting AutoTrain API:")
print(f" code: {err.response.status_code}")
print(f" {err.response.json()}")
raise
# Notify in the community tab
notify_success(id)
deploy_model(id="1726082187")
print(result.returncode)
return {"processed": True}
def notify_success(project_id: str):
message = NOTIFICATION_TEMPLATE.format(
input_model=config.input_model,
input_dataset=config.input_dataset,
project_id=project_id,
)
return HfApi(token=HF_ACCESS_TOKEN).create_discussion(
repo_id=config.input_dataset,
repo_type="dataset",
title="✨ Retraining started!",
description=message,
token=HF_ACCESS_TOKEN,
)
def notify_url(url: str):
message = URL_TEMPLATE.format(
url=url,
)
return HfApi(token=HF_ACCESS_TOKEN).create_discussion(
repo_id='Platma/platma-retrain',
repo_type="space",
title="✨ Endpoint is ready!",
description=message,
token=HF_ACCESS_TOKEN,
)
def deploy_model(id: str):
api = HfApi(token=HF_ACCESS_TOKEN)
url = "https://api.endpoints.huggingface.cloud/v2/endpoint/Platma"
data = {"compute": {"accelerator": "gpu", "instanceSize": "x1", "instanceType": "nvidia-l4",
"scaling": {"maxReplica": 1, "minReplica": 1, "scaleToZeroTimeout": 15}},
"model": {"framework": "pytorch", "image": {
"custom": {"health_route": "/health",
"url": "ghcr.io/huggingface/text-generation-inference:sha-f852190",
"env": {"MAX_BATCH_PREFILL_TOKENS": "2048", "MAX_INPUT_LENGTH": "2048",
"MAX_TOTAL_TOKENS": "2512",
"MODEL_ID": "/repository"}}},
"repository": f"Platma/{id}",
"secrets": {},
"task": "text-generation"},
"name": f"platma-{id}", "provider": {"region": "us-east-1", "vendor": "aws"}, "type": "protected"}
headers = {"Authorization": f"Bearer {HF_ACCESS_TOKEN}", "Content-Type": "application/json"}
r = requests.post(url, json=data, headers=headers)
print(r)
r = api.get_inference_endpoint(name=f"platma-{id}")
while True:
print("Fetching url")
if r.status == 'running':
print(r)
notify_url(r.url)
break
else:
if r.status == 'error':
break
time.sleep(10)
r = api.get_inference_endpoint(name=f"platma-{id}")
print(r)
NOTIFICATION_TEMPLATE = """\
🌸 Hello there!
Following an update of [{input_dataset}](https://huggingface.co/datasets/{input_dataset}), an automatic re-training of [{input_model}](https://huggingface.co/{input_model}) has been scheduled on AutoTrain!
(This is an automated message)
"""
URL_TEMPLATE = """\
Here is your endpoint: {url}
(This is an automated message)
"""
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
|