File size: 8,479 Bytes
a2b0f6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
# Ultralytics YOLO πŸš€, AGPL-3.0 license
import signal
import sys
from pathlib import Path
from time import sleep

import requests

from ultralytics.hub.utils import HUB_API_ROOT, PREFIX, smart_request
from ultralytics.yolo.utils import LOGGER, __version__, checks, emojis, is_colab, threaded
from ultralytics.yolo.utils.errors import HUBModelError

AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local'


class HUBTrainingSession:
    """
    HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing.

    Args:
        url (str): Model identifier used to initialize the HUB training session.

    Attributes:
        agent_id (str): Identifier for the instance communicating with the server.
        model_id (str): Identifier for the YOLOv5 model being trained.
        model_url (str): URL for the model in Ultralytics HUB.
        api_url (str): API URL for the model in Ultralytics HUB.
        auth_header (Dict): Authentication header for the Ultralytics HUB API requests.
        rate_limits (Dict): Rate limits for different API calls (in seconds).
        timers (Dict): Timers for rate limiting.
        metrics_queue (Dict): Queue for the model's metrics.
        model (Dict): Model data fetched from Ultralytics HUB.
        alive (bool): Indicates if the heartbeat loop is active.
    """

    def __init__(self, url):
        """
        Initialize the HUBTrainingSession with the provided model identifier.

        Args:
            url (str): Model identifier used to initialize the HUB training session.
                         It can be a URL string or a model key with specific format.

        Raises:
            ValueError: If the provided model identifier is invalid.
            ConnectionError: If connecting with global API key is not supported.
        """

        from ultralytics.hub.auth import Auth

        # Parse input
        if url.startswith('https://hub.ultralytics.com/models/'):
            url = url.split('https://hub.ultralytics.com/models/')[-1]
        if [len(x) for x in url.split('_')] == [42, 20]:
            key, model_id = url.split('_')
        elif len(url) == 20:
            key, model_id = '', url
        else:
            raise HUBModelError(f"model='{url}' not found. Check format is correct, i.e. "
                                f"model='https://hub.ultralytics.com/models/MODEL_ID' and try again.")

        # Authorize
        auth = Auth(key)
        self.agent_id = None  # identifies which instance is communicating with server
        self.model_id = model_id
        self.model_url = f'https://hub.ultralytics.com/models/{model_id}'
        self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}'
        self.auth_header = auth.get_auth_header()
        self.rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0}  # rate limits (seconds)
        self.timers = {}  # rate limit timers (seconds)
        self.metrics_queue = {}  # metrics queue
        self.model = self._get_model()
        self.alive = True
        self._start_heartbeat()  # start heartbeats
        self._register_signal_handlers()
        LOGGER.info(f'{PREFIX}View model at {self.model_url} πŸš€')

    def _register_signal_handlers(self):
        """Register signal handlers for SIGTERM and SIGINT signals to gracefully handle termination."""
        signal.signal(signal.SIGTERM, self._handle_signal)
        signal.signal(signal.SIGINT, self._handle_signal)

    def _handle_signal(self, signum, frame):
        """
        Handle kill signals and prevent heartbeats from being sent on Colab after termination.
        This method does not use frame, it is included as it is passed by signal.
        """
        if self.alive is True:
            LOGGER.info(f'{PREFIX}Kill signal received! ❌')
            self._stop_heartbeat()
            sys.exit(signum)

    def _stop_heartbeat(self):
        """Terminate the heartbeat loop."""
        self.alive = False

    def upload_metrics(self):
        """Upload model metrics to Ultralytics HUB."""
        payload = {'metrics': self.metrics_queue.copy(), 'type': 'metrics'}
        smart_request('post', self.api_url, json=payload, headers=self.auth_header, code=2)

    def _get_model(self):
        """Fetch and return model data from Ultralytics HUB."""
        api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}'

        try:
            response = smart_request('get', api_url, headers=self.auth_header, thread=False, code=0)
            data = response.json().get('data', None)

            if data.get('status', None) == 'trained':
                raise ValueError(emojis(f'Model is already trained and uploaded to {self.model_url} πŸš€'))

            if not data.get('data', None):
                raise ValueError('Dataset may still be processing. Please wait a minute and try again.')  # RF fix
            self.model_id = data['id']

            if data['status'] == 'new':  # new model to start training
                self.train_args = {
                    # TODO: deprecate 'batch_size' key for 'batch' in 3Q23
                    'batch': data['batch' if ('batch' in data) else 'batch_size'],
                    'epochs': data['epochs'],
                    'imgsz': data['imgsz'],
                    'patience': data['patience'],
                    'device': data['device'],
                    'cache': data['cache'],
                    'data': data['data']}
                self.model_file = data.get('cfg') or data.get('weights')  # cfg for pretrained=False
                self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False)  # YOLOv5->YOLOv5u
            elif data['status'] == 'training':  # existing model to resume training
                self.train_args = {'data': data['data'], 'resume': True}
                self.model_file = data['resume']

            return data
        except requests.exceptions.ConnectionError as e:
            raise ConnectionRefusedError('ERROR: The HUB server is not online. Please try again later.') from e
        except Exception:
            raise

    def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
        """
        Upload a model checkpoint to Ultralytics HUB.

        Args:
            epoch (int): The current training epoch.
            weights (str): Path to the model weights file.
            is_best (bool): Indicates if the current model is the best one so far.
            map (float): Mean average precision of the model.
            final (bool): Indicates if the model is the final model after training.
        """
        if Path(weights).is_file():
            with open(weights, 'rb') as f:
                file = f.read()
        else:
            LOGGER.warning(f'{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.')
            file = None
        url = f'{self.api_url}/upload'
        # url = 'http://httpbin.org/post'  # for debug
        data = {'epoch': epoch}
        if final:
            data.update({'type': 'final', 'map': map})
            smart_request('post',
                          url,
                          data=data,
                          files={'best.pt': file},
                          headers=self.auth_header,
                          retry=10,
                          timeout=3600,
                          thread=False,
                          progress=True,
                          code=4)
        else:
            data.update({'type': 'epoch', 'isBest': bool(is_best)})
            smart_request('post', url, data=data, files={'last.pt': file}, headers=self.auth_header, code=3)

    @threaded
    def _start_heartbeat(self):
        """Begin a threaded heartbeat loop to report the agent's status to Ultralytics HUB."""
        while self.alive:
            r = smart_request('post',
                              f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}',
                              json={
                                  'agent': AGENT_NAME,
                                  'agentId': self.agent_id},
                              headers=self.auth_header,
                              retry=0,
                              code=5,
                              thread=False)  # already in a thread
            self.agent_id = r.json().get('data', {}).get('agentId', None)
            sleep(self.rate_limits['heartbeat'])