File size: 2,902 Bytes
df6c67d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from typing import Literal, Optional

import requests
from pydantic import BaseModel

import docker
from inference.core.devices.utils import GLOBAL_DEVICE_ID
from inference.core.env import API_BASE_URL, API_KEY
from inference.core.logger import logger
from inference.core.utils.url_utils import wrap_url
from inference.enterprise.device_manager.container_service import get_container_by_id


class Command(BaseModel):
    id: str
    containerId: str
    command: Literal["restart", "stop", "ping", "snapshot", "update_version"]
    deviceId: str
    requested_on: Optional[int] = None


def fetch_commands():
    url = wrap_url(
        f"{API_BASE_URL}/devices/{GLOBAL_DEVICE_ID}/commands?api_key={API_KEY}"
    )
    resp = requests.get(url).json()
    for cmd in resp.get("data", []):
        handle_command(cmd)


def handle_command(cmd_payload: dict):
    was_processed = False
    container_id = cmd_payload.get("containerId")
    container = get_container_by_id(container_id)
    if not container:
        logger.warn(f"Container with id {container_id} not found")
        ack_command(cmd_payload.get("id"), was_processed)
        return
    cmd = cmd_payload.get("command")
    data = None
    match cmd:
        case "restart":
            was_processed, data = container.restart()
        case "stop":
            was_processed, data = container.stop()
        case "ping":
            was_processed, data = container.ping()
        case "snapshot":
            was_processed, data = container.snapshot()
        case "start":
            was_processed, data = container.start()
        case "update_version":
            was_processed, data = handle_version_update(container)
        case _:
            logger.error("Unknown command: {}".format(cmd))
    return ack_command(cmd_payload.get("id"), was_processed, data=data)


def ack_command(command_id, was_processed, data=None):
    post_body = dict()
    post_body["api_key"] = API_KEY
    post_body["commandId"] = command_id
    post_body["wasProcessed"] = was_processed
    if data:
        post_body["data"] = data
    url = wrap_url(f"{API_BASE_URL}/devices/{GLOBAL_DEVICE_ID}/commands/ack")
    requests.post(url, json=post_body)


def handle_version_update(container):
    try:
        config = container.get_startup_config()
        image_name = config["image"].split(":")[0]
        container.kill()
        client = docker.from_env()
        new_container = client.containers.run(
            image=f"{image_name}:latest",
            detach=config["detach"],
            privileged=config["privileged"],
            labels=config["labels"],
            ports=config["port_bindings"],
            environment=config["env"],
            network="host",
        )
        logger.info(f"New container started {new_container}")
        return True, None
    except Exception as e:
        logger.error(e)
        return False, None