Spaces:
Sleeping
Sleeping
File size: 2,844 Bytes
33d4721 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import requests
from autotrain.backends.base import BaseBackend
ENDPOINTS_URL = "https://api.endpoints.huggingface.cloud/v2/endpoint/"
class EndpointsRunner(BaseBackend):
"""
EndpointsRunner is responsible for creating and managing endpoint instances.
Methods
-------
create():
Creates an endpoint instance with the specified hardware and model parameters.
create() Method
---------------
Creates an endpoint instance with the specified hardware and model parameters.
Parameters
----------
None
Returns
-------
str
The name of the created endpoint instance.
Raises
------
requests.exceptions.RequestException
If there is an issue with the HTTP request.
"""
def create(self):
hardware = self.available_hardware[self.backend]
accelerator = hardware.split("_")[2]
instance_size = hardware.split("_")[3]
region = hardware.split("_")[1]
vendor = hardware.split("_")[0]
instance_type = hardware.split("_")[4]
payload = {
"accountId": self.username,
"compute": {
"accelerator": accelerator,
"instanceSize": instance_size,
"instanceType": instance_type,
"scaling": {"maxReplica": 1, "minReplica": 1},
},
"model": {
"framework": "custom",
"image": {
"custom": {
"env": {
"HF_TOKEN": self.params.token,
"AUTOTRAIN_USERNAME": self.username,
"PROJECT_NAME": self.params.project_name,
"PARAMS": self.params.model_dump_json(),
"DATA_PATH": self.params.data_path,
"TASK_ID": str(self.task_id),
"MODEL": self.params.model,
"ENDPOINT_ID": f"{self.username}/{self.params.project_name}",
},
"health_route": "/",
"port": 7860,
"url": "public.ecr.aws/z4c3o6n6/autotrain-api:latest",
}
},
"repository": "autotrain-projects/autotrain-advanced",
"revision": "main",
"task": "custom",
},
"name": self.params.project_name,
"provider": {"region": region, "vendor": vendor},
"type": "protected",
}
headers = {"Authorization": f"Bearer {self.params.token}"}
r = requests.post(
ENDPOINTS_URL + self.username,
json=payload,
headers=headers,
timeout=120,
)
return r.json()["name"]
|