Spaces:
Sleeping
Sleeping
| import threading | |
| import time | |
| import huggingface_hub | |
| from gradio_client import Client | |
| from trackio.sqlite_storage import SQLiteStorage | |
| from trackio.typehints import LogEntry | |
| from trackio.utils import RESERVED_KEYS, fibo, generate_readable_name | |
| class Run: | |
| def __init__( | |
| self, | |
| url: str, | |
| project: str, | |
| client: Client | None, | |
| name: str | None = None, | |
| config: dict | None = None, | |
| ): | |
| self.url = url | |
| self.project = project | |
| self._client_lock = threading.Lock() | |
| self._client_thread = None | |
| self._client = client | |
| self.name = name or generate_readable_name(SQLiteStorage.get_runs(project)) | |
| self.config = config or {} | |
| self._queued_logs: list[LogEntry] = [] | |
| self._stop_flag = threading.Event() | |
| self._client_thread = threading.Thread(target=self._init_client_background) | |
| self._client_thread.daemon = True | |
| self._client_thread.start() | |
| def _batch_sender(self): | |
| """Send batched logs every 500ms.""" | |
| while not self._stop_flag.is_set(): | |
| time.sleep(0.5) | |
| with self._client_lock: | |
| if self._queued_logs and self._client is not None: | |
| logs_to_send = self._queued_logs.copy() | |
| self._queued_logs.clear() | |
| self._client.predict( | |
| api_name="/bulk_log", | |
| logs=logs_to_send, | |
| hf_token=huggingface_hub.utils.get_token(), | |
| ) | |
| def _init_client_background(self): | |
| if self._client is None: | |
| fib = fibo() | |
| for sleep_coefficient in fib: | |
| try: | |
| client = Client(self.url, verbose=False) | |
| with self._client_lock: | |
| self._client = client | |
| break | |
| except Exception: | |
| pass | |
| if sleep_coefficient is not None: | |
| time.sleep(0.1 * sleep_coefficient) | |
| self._batch_sender() | |
| def log(self, metrics: dict, step: int | None = None): | |
| for k in metrics.keys(): | |
| if k in RESERVED_KEYS or k.startswith("__"): | |
| raise ValueError( | |
| f"Please do not use this reserved key as a metric: {k}" | |
| ) | |
| log_entry: LogEntry = { | |
| "project": self.project, | |
| "run": self.name, | |
| "metrics": metrics, | |
| "step": step, | |
| } | |
| with self._client_lock: | |
| self._queued_logs.append(log_entry) | |
| def finish(self): | |
| """Cleanup when run is finished.""" | |
| self._stop_flag.set() | |
| with self._client_lock: | |
| if self._queued_logs and self._client is not None: | |
| logs_to_send = self._queued_logs.copy() | |
| self._queued_logs.clear() | |
| self._client.predict( | |
| api_name="/bulk_log", | |
| logs=logs_to_send, | |
| hf_token=huggingface_hub.utils.get_token(), | |
| ) | |
| if self._client_thread is not None: | |
| print(f"* Uploading logs to Trackio Space: {self.url} (please wait...)") | |
| self._client_thread.join(timeout=30) | |