import threading import time import warnings from datetime import datetime, timezone import huggingface_hub from gradio_client import Client, handle_file from trackio import utils from trackio.histogram import Histogram from trackio.media import TrackioMedia from trackio.sqlite_storage import SQLiteStorage from trackio.table import Table from trackio.typehints import LogEntry, UploadEntry BATCH_SEND_INTERVAL = 0.5 class Run: def __init__( self, url: str, project: str, client: Client | None, name: str | None = None, group: str | None = None, config: dict | None = None, space_id: str | None = None, ): self.url = url self.project = project self._client_lock = threading.Lock() self._client_thread = None self._client = client self._space_id = space_id self.name = name or utils.generate_readable_name( SQLiteStorage.get_runs(project), space_id ) self.group = group self.config = utils.to_json_safe(config or {}) if isinstance(self.config, dict): for key in self.config: if key.startswith("_"): raise ValueError( f"Config key '{key}' is reserved (keys starting with '_' are reserved for internal use)" ) self.config["_Username"] = self._get_username() self.config["_Created"] = datetime.now(timezone.utc).isoformat() self.config["_Group"] = self.group self._queued_logs: list[LogEntry] = [] self._queued_uploads: list[UploadEntry] = [] self._stop_flag = threading.Event() self._config_logged = False self._client_thread = threading.Thread(target=self._init_client_background) self._client_thread.daemon = True self._client_thread.start() def _get_username(self) -> str | None: """Get the current HuggingFace username if logged in, otherwise None.""" try: who = huggingface_hub.whoami() return who["name"] if who else None except Exception: return None def _batch_sender(self): """Send batched logs every BATCH_SEND_INTERVAL.""" while not self._stop_flag.is_set() or len(self._queued_logs) > 0: if not self._stop_flag.is_set(): time.sleep(BATCH_SEND_INTERVAL) with self._client_lock: if self._client is None: return if self._queued_logs: logs_to_send = self._queued_logs.copy() self._queued_logs.clear() self._client.predict( api_name="/bulk_log", logs=logs_to_send, hf_token=huggingface_hub.utils.get_token(), ) if self._queued_uploads: uploads_to_send = self._queued_uploads.copy() self._queued_uploads.clear() self._client.predict( api_name="/bulk_upload_media", uploads=uploads_to_send, hf_token=huggingface_hub.utils.get_token(), ) def _init_client_background(self): if self._client is None: fib = utils.fibo() for sleep_coefficient in fib: try: client = Client(self.url, verbose=False) with self._client_lock: self._client = client break except Exception: pass if sleep_coefficient is not None: time.sleep(0.1 * sleep_coefficient) self._batch_sender() def _queue_upload(self, file_path, step: int | None): """Queue a media file for upload to space.""" upload_entry: UploadEntry = { "project": self.project, "run": self.name, "step": step, "uploaded_file": handle_file(file_path), } with self._client_lock: self._queued_uploads.append(upload_entry) def _process_media(self, value: TrackioMedia, step: int | None) -> dict: """ Serialize media in metrics and upload to space if needed. """ value._save(self.project, self.name, step) if self._space_id: self._queue_upload(value._get_absolute_file_path(), step) return value._to_dict() def _scan_and_queue_media_uploads(self, table_dict: dict, step: int | None): """ Scan a serialized table for media objects and queue them for upload to space. """ if not self._space_id: return table_data = table_dict.get("_value", []) for row in table_data: for value in row.values(): if isinstance(value, dict) and value.get("_type") in [ "trackio.image", "trackio.video", "trackio.audio", ]: file_path = value.get("file_path") if file_path: from trackio.utils import MEDIA_DIR absolute_path = MEDIA_DIR / file_path self._queue_upload(absolute_path, step) elif isinstance(value, list): for item in value: if isinstance(item, dict) and item.get("_type") in [ "trackio.image", "trackio.video", "trackio.audio", ]: file_path = item.get("file_path") if file_path: from trackio.utils import MEDIA_DIR absolute_path = MEDIA_DIR / file_path self._queue_upload(absolute_path, step) def log(self, metrics: dict, step: int | None = None): renamed_keys = [] new_metrics = {} for k, v in metrics.items(): if k in utils.RESERVED_KEYS or k.startswith("__"): new_key = f"__{k}" renamed_keys.append(k) new_metrics[new_key] = v else: new_metrics[k] = v if renamed_keys: warnings.warn(f"Reserved keys renamed: {renamed_keys} → '__{{key}}'") metrics = new_metrics for key, value in metrics.items(): if isinstance(value, Table): metrics[key] = value._to_dict( project=self.project, run=self.name, step=step ) self._scan_and_queue_media_uploads(metrics[key], step) elif isinstance(value, Histogram): metrics[key] = value._to_dict() elif isinstance(value, TrackioMedia): metrics[key] = self._process_media(value, step) metrics = utils.serialize_values(metrics) config_to_log = None if not self._config_logged and self.config: config_to_log = utils.to_json_safe(self.config) self._config_logged = True log_entry: LogEntry = { "project": self.project, "run": self.name, "metrics": metrics, "step": step, "config": config_to_log, } with self._client_lock: self._queued_logs.append(log_entry) def finish(self): """Cleanup when run is finished.""" self._stop_flag.set() time.sleep(2 * BATCH_SEND_INTERVAL) if self._client_thread is not None: print("* Run finished. Uploading logs to Trackio (please wait...)") self._client_thread.join()