sergiopaniego HF Staff commited on
Commit
a33f061
·
verified ·
1 Parent(s): 1b5e9a8

Upload folder using huggingface_hub

Browse files
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. __init__.py +284 -0
  3. __pycache__/__init__.cpython-312.pyc +0 -0
  4. __pycache__/cli.cpython-312.pyc +0 -0
  5. __pycache__/commit_scheduler.cpython-312.pyc +0 -0
  6. __pycache__/context_vars.cpython-312.pyc +0 -0
  7. __pycache__/deploy.cpython-312.pyc +0 -0
  8. __pycache__/dummy_commit_scheduler.cpython-312.pyc +0 -0
  9. __pycache__/file_storage.cpython-312.pyc +0 -0
  10. __pycache__/imports.cpython-312.pyc +0 -0
  11. __pycache__/media.cpython-312.pyc +0 -0
  12. __pycache__/run.cpython-312.pyc +0 -0
  13. __pycache__/sqlite_storage.cpython-312.pyc +0 -0
  14. __pycache__/table.cpython-312.pyc +0 -0
  15. __pycache__/typehints.cpython-312.pyc +0 -0
  16. __pycache__/utils.cpython-312.pyc +0 -0
  17. __pycache__/video_writer.cpython-312.pyc +0 -0
  18. assets/trackio_logo_dark.png +0 -0
  19. assets/trackio_logo_light.png +0 -0
  20. assets/trackio_logo_old.png +3 -0
  21. assets/trackio_logo_type_dark.png +0 -0
  22. assets/trackio_logo_type_dark_transparent.png +0 -0
  23. assets/trackio_logo_type_light.png +0 -0
  24. assets/trackio_logo_type_light_transparent.png +0 -0
  25. cli.py +32 -0
  26. commit_scheduler.py +391 -0
  27. context_vars.py +15 -0
  28. deploy.py +224 -0
  29. dummy_commit_scheduler.py +12 -0
  30. file_storage.py +37 -0
  31. imports.py +302 -0
  32. media.py +286 -0
  33. py.typed +0 -0
  34. run.py +182 -0
  35. sqlite_storage.py +559 -0
  36. table.py +55 -0
  37. typehints.py +18 -0
  38. ui/__init__.py +8 -0
  39. ui/__pycache__/__init__.cpython-312.pyc +0 -0
  40. ui/__pycache__/fns.cpython-312.pyc +0 -0
  41. ui/__pycache__/main.cpython-312.pyc +0 -0
  42. ui/__pycache__/run_detail.cpython-312.pyc +0 -0
  43. ui/__pycache__/runs.cpython-312.pyc +0 -0
  44. ui/fns.py +58 -0
  45. ui/main.py +937 -0
  46. ui/run_detail.py +90 -0
  47. ui/runs.py +236 -0
  48. utils.py +733 -0
  49. version.txt +1 -0
  50. video_writer.py +126 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/trackio_logo_old.png filter=lfs diff=lfs merge=lfs -text
__init__.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import secrets
4
+ import warnings
5
+ import webbrowser
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ from gradio.blocks import BUILT_IN_THEMES
10
+ from gradio.themes import Default as DefaultTheme
11
+ from gradio.themes import ThemeClass
12
+ from gradio_client import Client
13
+ from huggingface_hub import SpaceStorage
14
+
15
+ from trackio import context_vars, deploy, utils
16
+ from trackio.imports import import_csv, import_tf_events
17
+ from trackio.media import TrackioImage, TrackioVideo
18
+ from trackio.run import Run
19
+ from trackio.sqlite_storage import SQLiteStorage
20
+ from trackio.table import Table
21
+ from trackio.ui.main import demo
22
+ from trackio.ui.runs import run_page
23
+ from trackio.utils import TRACKIO_DIR, TRACKIO_LOGO_DIR
24
+
25
+ __version__ = Path(__file__).parent.joinpath("version.txt").read_text().strip()
26
+
27
+ __all__ = [
28
+ "init",
29
+ "log",
30
+ "finish",
31
+ "show",
32
+ "import_csv",
33
+ "import_tf_events",
34
+ "Image",
35
+ "Video",
36
+ "Table",
37
+ ]
38
+
39
+ Image = TrackioImage
40
+ Video = TrackioVideo
41
+
42
+
43
+ config = {}
44
+
45
+ DEFAULT_THEME = "citrus"
46
+
47
+
48
+ def init(
49
+ project: str,
50
+ name: str | None = None,
51
+ space_id: str | None = None,
52
+ space_storage: SpaceStorage | None = None,
53
+ dataset_id: str | None = None,
54
+ config: dict | None = None,
55
+ resume: str = "never",
56
+ settings: Any = None,
57
+ private: bool | None = None,
58
+ ) -> Run:
59
+ """
60
+ Creates a new Trackio project and returns a [`Run`] object.
61
+
62
+ Args:
63
+ project (`str`):
64
+ The name of the project (can be an existing project to continue tracking or
65
+ a new project to start tracking from scratch).
66
+ name (`str` or `None`, *optional*, defaults to `None`):
67
+ The name of the run (if not provided, a default name will be generated).
68
+ space_id (`str` or `None`, *optional*, defaults to `None`):
69
+ If provided, the project will be logged to a Hugging Face Space instead of
70
+ a local directory. Should be a complete Space name like
71
+ `"username/reponame"` or `"orgname/reponame"`, or just `"reponame"` in which
72
+ case the Space will be created in the currently-logged-in Hugging Face
73
+ user's namespace. If the Space does not exist, it will be created. If the
74
+ Space already exists, the project will be logged to it.
75
+ space_storage ([`~huggingface_hub.SpaceStorage`] or `None`, *optional*, defaults to `None`):
76
+ Choice of persistent storage tier.
77
+ dataset_id (`str` or `None`, *optional*, defaults to `None`):
78
+ If a `space_id` is provided, a persistent Hugging Face Dataset will be
79
+ created and the metrics will be synced to it every 5 minutes. Specify a
80
+ Dataset with name like `"username/datasetname"` or `"orgname/datasetname"`,
81
+ or `"datasetname"` (uses currently-logged-in Hugging Face user's namespace),
82
+ or `None` (uses the same name as the Space but with the `"_dataset"`
83
+ suffix). If the Dataset does not exist, it will be created. If the Dataset
84
+ already exists, the project will be appended to it.
85
+ config (`dict` or `None`, *optional*, defaults to `None`):
86
+ A dictionary of configuration options. Provided for compatibility with
87
+ `wandb.init()`.
88
+ resume (`str`, *optional*, defaults to `"never"`):
89
+ Controls how to handle resuming a run. Can be one of:
90
+
91
+ - `"must"`: Must resume the run with the given name, raises error if run
92
+ doesn't exist
93
+ - `"allow"`: Resume the run if it exists, otherwise create a new run
94
+ - `"never"`: Never resume a run, always create a new one
95
+ private (`bool` or `None`, *optional*, defaults to `None`):
96
+ Whether to make the Space private. If None (default), the repo will be
97
+ public unless the organization's default is private. This value is ignored
98
+ if the repo already exists.
99
+ settings (`Any`, *optional*, defaults to `None`):
100
+ Not used. Provided for compatibility with `wandb.init()`.
101
+
102
+ Returns:
103
+ `Run`: A [`Run`] object that can be used to log metrics and finish the run.
104
+ """
105
+ if settings is not None:
106
+ warnings.warn(
107
+ "* Warning: settings is not used. Provided for compatibility with wandb.init(). Please create an issue at: https://github.com/gradio-app/trackio/issues if you need a specific feature implemented."
108
+ )
109
+
110
+ if space_id is None and dataset_id is not None:
111
+ raise ValueError("Must provide a `space_id` when `dataset_id` is provided.")
112
+ space_id, dataset_id = utils.preprocess_space_and_dataset_ids(space_id, dataset_id)
113
+ url = context_vars.current_server.get()
114
+
115
+ if url is None:
116
+ if space_id is None:
117
+ _, url, _ = demo.launch(
118
+ show_api=False,
119
+ inline=False,
120
+ quiet=True,
121
+ prevent_thread_lock=True,
122
+ show_error=True,
123
+ )
124
+ else:
125
+ url = space_id
126
+ context_vars.current_server.set(url)
127
+
128
+ if (
129
+ context_vars.current_project.get() is None
130
+ or context_vars.current_project.get() != project
131
+ ):
132
+ print(f"* Trackio project initialized: {project}")
133
+
134
+ if dataset_id is not None:
135
+ os.environ["TRACKIO_DATASET_ID"] = dataset_id
136
+ print(
137
+ f"* Trackio metrics will be synced to Hugging Face Dataset: {dataset_id}"
138
+ )
139
+ if space_id is None:
140
+ print(f"* Trackio metrics logged to: {TRACKIO_DIR}")
141
+ utils.print_dashboard_instructions(project)
142
+ else:
143
+ deploy.create_space_if_not_exists(
144
+ space_id, space_storage, dataset_id, private
145
+ )
146
+ print(
147
+ f"* View dashboard by going to: {deploy.SPACE_URL.format(space_id=space_id)}"
148
+ )
149
+ context_vars.current_project.set(project)
150
+
151
+ client = None
152
+ if not space_id:
153
+ client = Client(url, verbose=False)
154
+
155
+ if resume == "must":
156
+ if name is None:
157
+ raise ValueError("Must provide a run name when resume='must'")
158
+ if name not in SQLiteStorage.get_runs(project):
159
+ raise ValueError(f"Run '{name}' does not exist in project '{project}'")
160
+ resumed = True
161
+ elif resume == "allow":
162
+ resumed = name is not None and name in SQLiteStorage.get_runs(project)
163
+ elif resume == "never":
164
+ if name is not None and name in SQLiteStorage.get_runs(project):
165
+ warnings.warn(
166
+ f"* Warning: resume='never' but a run '{name}' already exists in "
167
+ f"project '{project}'. Generating a new name and instead. If you want "
168
+ "to resume this run, call init() with resume='must' or resume='allow'."
169
+ )
170
+ name = None
171
+ resumed = False
172
+ else:
173
+ raise ValueError("resume must be one of: 'must', 'allow', or 'never'")
174
+
175
+ run = Run(
176
+ url=url,
177
+ project=project,
178
+ client=client,
179
+ name=name,
180
+ config=config,
181
+ space_id=space_id,
182
+ )
183
+
184
+ if resumed:
185
+ print(f"* Resumed existing run: {run.name}")
186
+ else:
187
+ print(f"* Created new run: {run.name}")
188
+
189
+ context_vars.current_run.set(run)
190
+ globals()["config"] = run.config
191
+ return run
192
+
193
+
194
+ def log(metrics: dict, step: int | None = None) -> None:
195
+ """
196
+ Logs metrics to the current run.
197
+
198
+ Args:
199
+ metrics (`dict`):
200
+ A dictionary of metrics to log.
201
+ step (`int` or `None`, *optional*, defaults to `None`):
202
+ The step number. If not provided, the step will be incremented
203
+ automatically.
204
+ """
205
+ run = context_vars.current_run.get()
206
+ if run is None:
207
+ raise RuntimeError("Call trackio.init() before trackio.log().")
208
+ run.log(
209
+ metrics=metrics,
210
+ step=step,
211
+ )
212
+
213
+
214
+ def finish():
215
+ """
216
+ Finishes the current run.
217
+ """
218
+ run = context_vars.current_run.get()
219
+ if run is None:
220
+ raise RuntimeError("Call trackio.init() before trackio.finish().")
221
+ run.finish()
222
+
223
+
224
+ def show(project: str | None = None, theme: str | ThemeClass = DEFAULT_THEME):
225
+ """
226
+ Launches the Trackio dashboard.
227
+
228
+ Args:
229
+ project (`str` or `None`, *optional*, defaults to `None`):
230
+ The name of the project whose runs to show. If not provided, all projects
231
+ will be shown and the user can select one.
232
+ theme (`str` or `ThemeClass`, *optional*, defaults to `"citrus"`):
233
+ A Gradio Theme to use for the dashboard instead of the default `"citrus"`,
234
+ can be a built-in theme (e.g. `'soft'`, `'default'`), a theme from the Hub
235
+ (e.g. `"gstaff/xkcd"`), or a custom Theme class.
236
+ """
237
+ write_token = secrets.token_urlsafe(32)
238
+
239
+ demo.write_token = write_token
240
+ run_page.write_token = write_token
241
+
242
+ if theme != DEFAULT_THEME:
243
+ # TODO: It's a little hacky to reproduce this theme-setting logic from Gradio Blocks,
244
+ # but in Gradio 6.0, the theme will be set in `launch()` instead, which means that we
245
+ # will be able to remove this code.
246
+ if isinstance(theme, str):
247
+ if theme.lower() in BUILT_IN_THEMES:
248
+ theme = BUILT_IN_THEMES[theme.lower()]
249
+ else:
250
+ try:
251
+ theme = ThemeClass.from_hub(theme)
252
+ except Exception as e:
253
+ warnings.warn(f"Cannot load {theme}. Caught Exception: {str(e)}")
254
+ theme = DefaultTheme()
255
+ if not isinstance(theme, ThemeClass):
256
+ warnings.warn("Theme should be a class loaded from gradio.themes")
257
+ theme = DefaultTheme()
258
+ demo.theme: ThemeClass = theme
259
+ demo.theme_css = theme._get_theme_css()
260
+ demo.stylesheets = theme._stylesheets
261
+ theme_hasher = hashlib.sha256()
262
+ theme_hasher.update(demo.theme_css.encode("utf-8"))
263
+ demo.theme_hash = theme_hasher.hexdigest()
264
+
265
+ _, url, share_url = demo.launch(
266
+ show_api=False,
267
+ quiet=True,
268
+ inline=utils.is_in_notebook(),
269
+ prevent_thread_lock=True,
270
+ favicon_path=TRACKIO_LOGO_DIR / "trackio_logo_light.png",
271
+ allowed_paths=[TRACKIO_LOGO_DIR],
272
+ )
273
+
274
+ base_url = share_url + "/" if share_url else url
275
+
276
+ params = [f"write_token={write_token}"]
277
+ if project:
278
+ params.append(f"project={project}")
279
+ dashboard_url = base_url + "?" + "&".join(params)
280
+
281
+ if not utils.is_in_notebook():
282
+ print(f"* Trackio UI launched at: {dashboard_url}")
283
+ webbrowser.open(dashboard_url)
284
+ utils.block_except_in_notebook()
__pycache__/__init__.cpython-312.pyc ADDED
Binary file (12.5 kB). View file
 
__pycache__/cli.cpython-312.pyc ADDED
Binary file (1.44 kB). View file
 
__pycache__/commit_scheduler.cpython-312.pyc ADDED
Binary file (18.8 kB). View file
 
__pycache__/context_vars.cpython-312.pyc ADDED
Binary file (775 Bytes). View file
 
__pycache__/deploy.cpython-312.pyc ADDED
Binary file (8.74 kB). View file
 
__pycache__/dummy_commit_scheduler.cpython-312.pyc ADDED
Binary file (1.03 kB). View file
 
__pycache__/file_storage.cpython-312.pyc ADDED
Binary file (1.65 kB). View file
 
__pycache__/imports.cpython-312.pyc ADDED
Binary file (13.5 kB). View file
 
__pycache__/media.cpython-312.pyc ADDED
Binary file (14.2 kB). View file
 
__pycache__/run.cpython-312.pyc ADDED
Binary file (8.65 kB). View file
 
__pycache__/sqlite_storage.cpython-312.pyc ADDED
Binary file (27.3 kB). View file
 
__pycache__/table.cpython-312.pyc ADDED
Binary file (2.48 kB). View file
 
__pycache__/typehints.cpython-312.pyc ADDED
Binary file (920 Bytes). View file
 
__pycache__/utils.cpython-312.pyc ADDED
Binary file (21.2 kB). View file
 
__pycache__/video_writer.cpython-312.pyc ADDED
Binary file (5.34 kB). View file
 
assets/trackio_logo_dark.png ADDED
assets/trackio_logo_light.png ADDED
assets/trackio_logo_old.png ADDED

Git LFS Details

  • SHA256: 3922c4d1e465270ad4d8abb12023f3beed5d9f7f338528a4c0ac21dcf358a1c8
  • Pointer size: 131 Bytes
  • Size of remote file: 487 kB
assets/trackio_logo_type_dark.png ADDED
assets/trackio_logo_type_dark_transparent.png ADDED
assets/trackio_logo_type_light.png ADDED
assets/trackio_logo_type_light_transparent.png ADDED
cli.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from trackio import show
4
+
5
+
6
+ def main():
7
+ parser = argparse.ArgumentParser(description="Trackio CLI")
8
+ subparsers = parser.add_subparsers(dest="command")
9
+
10
+ ui_parser = subparsers.add_parser(
11
+ "show", help="Show the Trackio dashboard UI for a project"
12
+ )
13
+ ui_parser.add_argument(
14
+ "--project", required=False, help="Project name to show in the dashboard"
15
+ )
16
+ ui_parser.add_argument(
17
+ "--theme",
18
+ required=False,
19
+ default="citrus",
20
+ help="A Gradio Theme to use for the dashboard instead of the default 'citrus', can be a built-in theme (e.g. 'soft', 'default'), a theme from the Hub (e.g. 'gstaff/xkcd').",
21
+ )
22
+
23
+ args = parser.parse_args()
24
+
25
+ if args.command == "show":
26
+ show(args.project, args.theme)
27
+ else:
28
+ parser.print_help()
29
+
30
+
31
+ if __name__ == "__main__":
32
+ main()
commit_scheduler.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Originally copied from https://github.com/huggingface/huggingface_hub/blob/d0a948fc2a32ed6e557042a95ef3e4af97ec4a7c/src/huggingface_hub/_commit_scheduler.py
2
+
3
+ import atexit
4
+ import logging
5
+ import os
6
+ import time
7
+ from concurrent.futures import Future
8
+ from dataclasses import dataclass
9
+ from io import SEEK_END, SEEK_SET, BytesIO
10
+ from pathlib import Path
11
+ from threading import Lock, Thread
12
+ from typing import Callable, Dict, List, Optional, Union
13
+
14
+ from huggingface_hub.hf_api import (
15
+ DEFAULT_IGNORE_PATTERNS,
16
+ CommitInfo,
17
+ CommitOperationAdd,
18
+ HfApi,
19
+ )
20
+ from huggingface_hub.utils import filter_repo_objects
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class _FileToUpload:
27
+ """Temporary dataclass to store info about files to upload. Not meant to be used directly."""
28
+
29
+ local_path: Path
30
+ path_in_repo: str
31
+ size_limit: int
32
+ last_modified: float
33
+
34
+
35
+ class CommitScheduler:
36
+ """
37
+ Scheduler to upload a local folder to the Hub at regular intervals (e.g. push to hub every 5 minutes).
38
+
39
+ The recommended way to use the scheduler is to use it as a context manager. This ensures that the scheduler is
40
+ properly stopped and the last commit is triggered when the script ends. The scheduler can also be stopped manually
41
+ with the `stop` method. Checkout the [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#scheduled-uploads)
42
+ to learn more about how to use it.
43
+
44
+ Args:
45
+ repo_id (`str`):
46
+ The id of the repo to commit to.
47
+ folder_path (`str` or `Path`):
48
+ Path to the local folder to upload regularly.
49
+ every (`int` or `float`, *optional*):
50
+ The number of minutes between each commit. Defaults to 5 minutes.
51
+ path_in_repo (`str`, *optional*):
52
+ Relative path of the directory in the repo, for example: `"checkpoints/"`. Defaults to the root folder
53
+ of the repository.
54
+ repo_type (`str`, *optional*):
55
+ The type of the repo to commit to. Defaults to `model`.
56
+ revision (`str`, *optional*):
57
+ The revision of the repo to commit to. Defaults to `main`.
58
+ private (`bool`, *optional*):
59
+ Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
60
+ token (`str`, *optional*):
61
+ The token to use to commit to the repo. Defaults to the token saved on the machine.
62
+ allow_patterns (`List[str]` or `str`, *optional*):
63
+ If provided, only files matching at least one pattern are uploaded.
64
+ ignore_patterns (`List[str]` or `str`, *optional*):
65
+ If provided, files matching any of the patterns are not uploaded.
66
+ squash_history (`bool`, *optional*):
67
+ Whether to squash the history of the repo after each commit. Defaults to `False`. Squashing commits is
68
+ useful to avoid degraded performances on the repo when it grows too large.
69
+ hf_api (`HfApi`, *optional*):
70
+ The [`HfApi`] client to use to commit to the Hub. Can be set with custom settings (user agent, token,...).
71
+ on_before_commit (`Callable[[], None]`, *optional*):
72
+ If specified, a function that will be called before the CommitScheduler lists files to create a commit.
73
+
74
+ Example:
75
+ ```py
76
+ >>> from pathlib import Path
77
+ >>> from huggingface_hub import CommitScheduler
78
+
79
+ # Scheduler uploads every 10 minutes
80
+ >>> csv_path = Path("watched_folder/data.csv")
81
+ >>> CommitScheduler(repo_id="test_scheduler", repo_type="dataset", folder_path=csv_path.parent, every=10)
82
+
83
+ >>> with csv_path.open("a") as f:
84
+ ... f.write("first line")
85
+
86
+ # Some time later (...)
87
+ >>> with csv_path.open("a") as f:
88
+ ... f.write("second line")
89
+ ```
90
+
91
+ Example using a context manager:
92
+ ```py
93
+ >>> from pathlib import Path
94
+ >>> from huggingface_hub import CommitScheduler
95
+
96
+ >>> with CommitScheduler(repo_id="test_scheduler", repo_type="dataset", folder_path="watched_folder", every=10) as scheduler:
97
+ ... csv_path = Path("watched_folder/data.csv")
98
+ ... with csv_path.open("a") as f:
99
+ ... f.write("first line")
100
+ ... (...)
101
+ ... with csv_path.open("a") as f:
102
+ ... f.write("second line")
103
+
104
+ # Scheduler is now stopped and last commit have been triggered
105
+ ```
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ *,
111
+ repo_id: str,
112
+ folder_path: Union[str, Path],
113
+ every: Union[int, float] = 5,
114
+ path_in_repo: Optional[str] = None,
115
+ repo_type: Optional[str] = None,
116
+ revision: Optional[str] = None,
117
+ private: Optional[bool] = None,
118
+ token: Optional[str] = None,
119
+ allow_patterns: Optional[Union[List[str], str]] = None,
120
+ ignore_patterns: Optional[Union[List[str], str]] = None,
121
+ squash_history: bool = False,
122
+ hf_api: Optional["HfApi"] = None,
123
+ on_before_commit: Optional[Callable[[], None]] = None,
124
+ ) -> None:
125
+ self.api = hf_api or HfApi(token=token)
126
+ self.on_before_commit = on_before_commit
127
+
128
+ # Folder
129
+ self.folder_path = Path(folder_path).expanduser().resolve()
130
+ self.path_in_repo = path_in_repo or ""
131
+ self.allow_patterns = allow_patterns
132
+
133
+ if ignore_patterns is None:
134
+ ignore_patterns = []
135
+ elif isinstance(ignore_patterns, str):
136
+ ignore_patterns = [ignore_patterns]
137
+ self.ignore_patterns = ignore_patterns + DEFAULT_IGNORE_PATTERNS
138
+
139
+ if self.folder_path.is_file():
140
+ raise ValueError(
141
+ f"'folder_path' must be a directory, not a file: '{self.folder_path}'."
142
+ )
143
+ self.folder_path.mkdir(parents=True, exist_ok=True)
144
+
145
+ # Repository
146
+ repo_url = self.api.create_repo(
147
+ repo_id=repo_id, private=private, repo_type=repo_type, exist_ok=True
148
+ )
149
+ self.repo_id = repo_url.repo_id
150
+ self.repo_type = repo_type
151
+ self.revision = revision
152
+ self.token = token
153
+
154
+ self.last_uploaded: Dict[Path, float] = {}
155
+ self.last_push_time: float | None = None
156
+
157
+ if not every > 0:
158
+ raise ValueError(f"'every' must be a positive integer, not '{every}'.")
159
+ self.lock = Lock()
160
+ self.every = every
161
+ self.squash_history = squash_history
162
+
163
+ logger.info(
164
+ f"Scheduled job to push '{self.folder_path}' to '{self.repo_id}' every {self.every} minutes."
165
+ )
166
+ self._scheduler_thread = Thread(target=self._run_scheduler, daemon=True)
167
+ self._scheduler_thread.start()
168
+ atexit.register(self._push_to_hub)
169
+
170
+ self.__stopped = False
171
+
172
+ def stop(self) -> None:
173
+ """Stop the scheduler.
174
+
175
+ A stopped scheduler cannot be restarted. Mostly for tests purposes.
176
+ """
177
+ self.__stopped = True
178
+
179
+ def __enter__(self) -> "CommitScheduler":
180
+ return self
181
+
182
+ def __exit__(self, exc_type, exc_value, traceback) -> None:
183
+ # Upload last changes before exiting
184
+ self.trigger().result()
185
+ self.stop()
186
+ return
187
+
188
+ def _run_scheduler(self) -> None:
189
+ """Dumb thread waiting between each scheduled push to Hub."""
190
+ while True:
191
+ self.last_future = self.trigger()
192
+ time.sleep(self.every * 60)
193
+ if self.__stopped:
194
+ break
195
+
196
+ def trigger(self) -> Future:
197
+ """Trigger a `push_to_hub` and return a future.
198
+
199
+ This method is automatically called every `every` minutes. You can also call it manually to trigger a commit
200
+ immediately, without waiting for the next scheduled commit.
201
+ """
202
+ return self.api.run_as_future(self._push_to_hub)
203
+
204
+ def _push_to_hub(self) -> Optional[CommitInfo]:
205
+ if self.__stopped: # If stopped, already scheduled commits are ignored
206
+ return None
207
+
208
+ logger.info("(Background) scheduled commit triggered.")
209
+ try:
210
+ value = self.push_to_hub()
211
+ if self.squash_history:
212
+ logger.info("(Background) squashing repo history.")
213
+ self.api.super_squash_history(
214
+ repo_id=self.repo_id, repo_type=self.repo_type, branch=self.revision
215
+ )
216
+ return value
217
+ except Exception as e:
218
+ logger.error(
219
+ f"Error while pushing to Hub: {e}"
220
+ ) # Depending on the setup, error might be silenced
221
+ raise
222
+
223
+ def push_to_hub(self) -> Optional[CommitInfo]:
224
+ """
225
+ Push folder to the Hub and return the commit info.
226
+
227
+ <Tip warning={true}>
228
+
229
+ This method is not meant to be called directly. It is run in the background by the scheduler, respecting a
230
+ queue mechanism to avoid concurrent commits. Making a direct call to the method might lead to concurrency
231
+ issues.
232
+
233
+ </Tip>
234
+
235
+ The default behavior of `push_to_hub` is to assume an append-only folder. It lists all files in the folder and
236
+ uploads only changed files. If no changes are found, the method returns without committing anything. If you want
237
+ to change this behavior, you can inherit from [`CommitScheduler`] and override this method. This can be useful
238
+ for example to compress data together in a single file before committing. For more details and examples, check
239
+ out our [integration guide](https://huggingface.co/docs/huggingface_hub/main/en/guides/upload#scheduled-uploads).
240
+ """
241
+ # Check files to upload (with lock)
242
+ with self.lock:
243
+ if self.on_before_commit is not None:
244
+ self.on_before_commit()
245
+
246
+ logger.debug("Listing files to upload for scheduled commit.")
247
+
248
+ # List files from folder (taken from `_prepare_upload_folder_additions`)
249
+ relpath_to_abspath = {
250
+ path.relative_to(self.folder_path).as_posix(): path
251
+ for path in sorted(
252
+ self.folder_path.glob("**/*")
253
+ ) # sorted to be deterministic
254
+ if path.is_file()
255
+ }
256
+ prefix = f"{self.path_in_repo.strip('/')}/" if self.path_in_repo else ""
257
+
258
+ # Filter with pattern + filter out unchanged files + retrieve current file size
259
+ files_to_upload: List[_FileToUpload] = []
260
+ for relpath in filter_repo_objects(
261
+ relpath_to_abspath.keys(),
262
+ allow_patterns=self.allow_patterns,
263
+ ignore_patterns=self.ignore_patterns,
264
+ ):
265
+ local_path = relpath_to_abspath[relpath]
266
+ stat = local_path.stat()
267
+ if (
268
+ self.last_uploaded.get(local_path) is None
269
+ or self.last_uploaded[local_path] != stat.st_mtime
270
+ ):
271
+ files_to_upload.append(
272
+ _FileToUpload(
273
+ local_path=local_path,
274
+ path_in_repo=prefix + relpath,
275
+ size_limit=stat.st_size,
276
+ last_modified=stat.st_mtime,
277
+ )
278
+ )
279
+
280
+ # Return if nothing to upload
281
+ if len(files_to_upload) == 0:
282
+ logger.debug("Dropping schedule commit: no changed file to upload.")
283
+ return None
284
+
285
+ # Convert `_FileToUpload` as `CommitOperationAdd` (=> compute file shas + limit to file size)
286
+ logger.debug("Removing unchanged files since previous scheduled commit.")
287
+ add_operations = [
288
+ CommitOperationAdd(
289
+ # TODO: Cap the file to its current size, even if the user append data to it while a scheduled commit is happening
290
+ # (requires an upstream fix for XET-535: `hf_xet` should support `BinaryIO` for upload)
291
+ path_or_fileobj=file_to_upload.local_path,
292
+ path_in_repo=file_to_upload.path_in_repo,
293
+ )
294
+ for file_to_upload in files_to_upload
295
+ ]
296
+
297
+ # Upload files (append mode expected - no need for lock)
298
+ logger.debug("Uploading files for scheduled commit.")
299
+ commit_info = self.api.create_commit(
300
+ repo_id=self.repo_id,
301
+ repo_type=self.repo_type,
302
+ operations=add_operations,
303
+ commit_message="Scheduled Commit",
304
+ revision=self.revision,
305
+ )
306
+
307
+ for file in files_to_upload:
308
+ self.last_uploaded[file.local_path] = file.last_modified
309
+
310
+ self.last_push_time = time.time()
311
+
312
+ return commit_info
313
+
314
+
315
+ class PartialFileIO(BytesIO):
316
+ """A file-like object that reads only the first part of a file.
317
+
318
+ Useful to upload a file to the Hub when the user might still be appending data to it. Only the first part of the
319
+ file is uploaded (i.e. the part that was available when the filesystem was first scanned).
320
+
321
+ In practice, only used internally by the CommitScheduler to regularly push a folder to the Hub with minimal
322
+ disturbance for the user. The object is passed to `CommitOperationAdd`.
323
+
324
+ Only supports `read`, `tell` and `seek` methods.
325
+
326
+ Args:
327
+ file_path (`str` or `Path`):
328
+ Path to the file to read.
329
+ size_limit (`int`):
330
+ The maximum number of bytes to read from the file. If the file is larger than this, only the first part
331
+ will be read (and uploaded).
332
+ """
333
+
334
+ def __init__(self, file_path: Union[str, Path], size_limit: int) -> None:
335
+ self._file_path = Path(file_path)
336
+ self._file = self._file_path.open("rb")
337
+ self._size_limit = min(size_limit, os.fstat(self._file.fileno()).st_size)
338
+
339
+ def __del__(self) -> None:
340
+ self._file.close()
341
+ return super().__del__()
342
+
343
+ def __repr__(self) -> str:
344
+ return (
345
+ f"<PartialFileIO file_path={self._file_path} size_limit={self._size_limit}>"
346
+ )
347
+
348
+ def __len__(self) -> int:
349
+ return self._size_limit
350
+
351
+ def __getattribute__(self, name: str):
352
+ if name.startswith("_") or name in (
353
+ "read",
354
+ "tell",
355
+ "seek",
356
+ ): # only 3 public methods supported
357
+ return super().__getattribute__(name)
358
+ raise NotImplementedError(f"PartialFileIO does not support '{name}'.")
359
+
360
+ def tell(self) -> int:
361
+ """Return the current file position."""
362
+ return self._file.tell()
363
+
364
+ def seek(self, __offset: int, __whence: int = SEEK_SET) -> int:
365
+ """Change the stream position to the given offset.
366
+
367
+ Behavior is the same as a regular file, except that the position is capped to the size limit.
368
+ """
369
+ if __whence == SEEK_END:
370
+ # SEEK_END => set from the truncated end
371
+ __offset = len(self) + __offset
372
+ __whence = SEEK_SET
373
+
374
+ pos = self._file.seek(__offset, __whence)
375
+ if pos > self._size_limit:
376
+ return self._file.seek(self._size_limit)
377
+ return pos
378
+
379
+ def read(self, __size: Optional[int] = -1) -> bytes:
380
+ """Read at most `__size` bytes from the file.
381
+
382
+ Behavior is the same as a regular file, except that it is capped to the size limit.
383
+ """
384
+ current = self._file.tell()
385
+ if __size is None or __size < 0:
386
+ # Read until file limit
387
+ truncated_size = self._size_limit - current
388
+ else:
389
+ # Read until file limit or __size
390
+ truncated_size = min(__size, self._size_limit - current)
391
+ return self._file.read(truncated_size)
context_vars.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextvars
2
+ from typing import TYPE_CHECKING
3
+
4
+ if TYPE_CHECKING:
5
+ from trackio.run import Run
6
+
7
+ current_run: contextvars.ContextVar["Run | None"] = contextvars.ContextVar(
8
+ "current_run", default=None
9
+ )
10
+ current_project: contextvars.ContextVar[str | None] = contextvars.ContextVar(
11
+ "current_project", default=None
12
+ )
13
+ current_server: contextvars.ContextVar[str | None] = contextvars.ContextVar(
14
+ "current_server", default=None
15
+ )
deploy.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.metadata
2
+ import io
3
+ import os
4
+ import time
5
+ from importlib.resources import files
6
+ from pathlib import Path
7
+
8
+ import gradio
9
+ import huggingface_hub
10
+ from gradio_client import Client, handle_file
11
+ from httpx import ReadTimeout
12
+ from huggingface_hub.errors import RepositoryNotFoundError
13
+ from requests import HTTPError
14
+
15
+ import trackio
16
+ from trackio.sqlite_storage import SQLiteStorage
17
+
18
+ SPACE_URL = "https://huggingface.co/spaces/{space_id}"
19
+
20
+
21
+ def _is_trackio_installed_from_source() -> bool:
22
+ """Check if trackio is installed from source/editable install vs PyPI."""
23
+ try:
24
+ trackio_file = trackio.__file__
25
+ if "site-packages" not in trackio_file:
26
+ return True
27
+
28
+ dist = importlib.metadata.distribution("trackio")
29
+ if dist.files:
30
+ files = list(dist.files)
31
+ has_pth = any(".pth" in str(f) for f in files)
32
+ if has_pth:
33
+ return True
34
+
35
+ return False
36
+ except (
37
+ AttributeError,
38
+ importlib.metadata.PackageNotFoundError,
39
+ importlib.metadata.MetadataError,
40
+ ValueError,
41
+ TypeError,
42
+ ):
43
+ return True
44
+
45
+
46
+ def deploy_as_space(
47
+ space_id: str,
48
+ space_storage: huggingface_hub.SpaceStorage | None = None,
49
+ dataset_id: str | None = None,
50
+ private: bool | None = None,
51
+ ):
52
+ if (
53
+ os.getenv("SYSTEM") == "spaces"
54
+ ): # in case a repo with this function is uploaded to spaces
55
+ return
56
+
57
+ trackio_path = files("trackio")
58
+
59
+ hf_api = huggingface_hub.HfApi()
60
+
61
+ try:
62
+ huggingface_hub.create_repo(
63
+ space_id,
64
+ private=private,
65
+ space_sdk="gradio",
66
+ space_storage=space_storage,
67
+ repo_type="space",
68
+ exist_ok=True,
69
+ )
70
+ except HTTPError as e:
71
+ if e.response.status_code in [401, 403]: # unauthorized or forbidden
72
+ print("Need 'write' access token to create a Spaces repo.")
73
+ huggingface_hub.login(add_to_git_credential=False)
74
+ huggingface_hub.create_repo(
75
+ space_id,
76
+ private=private,
77
+ space_sdk="gradio",
78
+ space_storage=space_storage,
79
+ repo_type="space",
80
+ exist_ok=True,
81
+ )
82
+ else:
83
+ raise ValueError(f"Failed to create Space: {e}")
84
+
85
+ with open(Path(trackio_path, "README.md"), "r") as f:
86
+ readme_content = f.read()
87
+ readme_content = readme_content.replace("{GRADIO_VERSION}", gradio.__version__)
88
+ readme_buffer = io.BytesIO(readme_content.encode("utf-8"))
89
+ hf_api.upload_file(
90
+ path_or_fileobj=readme_buffer,
91
+ path_in_repo="README.md",
92
+ repo_id=space_id,
93
+ repo_type="space",
94
+ )
95
+
96
+ # We can assume pandas, gradio, and huggingface-hub are already installed in a Gradio Space.
97
+ # Make sure necessary dependencies are installed by creating a requirements.txt.
98
+ is_source_install = _is_trackio_installed_from_source()
99
+
100
+ if is_source_install:
101
+ requirements_content = """pyarrow>=21.0"""
102
+ else:
103
+ requirements_content = f"""pyarrow>=21.0
104
+ trackio=={trackio.__version__}"""
105
+
106
+ requirements_buffer = io.BytesIO(requirements_content.encode("utf-8"))
107
+ hf_api.upload_file(
108
+ path_or_fileobj=requirements_buffer,
109
+ path_in_repo="requirements.txt",
110
+ repo_id=space_id,
111
+ repo_type="space",
112
+ )
113
+
114
+ huggingface_hub.utils.disable_progress_bars()
115
+
116
+ if is_source_install:
117
+ hf_api.upload_folder(
118
+ repo_id=space_id,
119
+ repo_type="space",
120
+ folder_path=trackio_path,
121
+ ignore_patterns=["README.md"],
122
+ )
123
+ else:
124
+ app_file_content = """import trackio
125
+ trackio.show()"""
126
+ app_file_buffer = io.BytesIO(app_file_content.encode("utf-8"))
127
+ hf_api.upload_file(
128
+ path_or_fileobj=app_file_buffer,
129
+ path_in_repo="ui/main.py",
130
+ repo_id=space_id,
131
+ repo_type="space",
132
+ )
133
+
134
+ if hf_token := huggingface_hub.utils.get_token():
135
+ huggingface_hub.add_space_secret(space_id, "HF_TOKEN", hf_token)
136
+ if dataset_id is not None:
137
+ huggingface_hub.add_space_variable(space_id, "TRACKIO_DATASET_ID", dataset_id)
138
+
139
+
140
+ def create_space_if_not_exists(
141
+ space_id: str,
142
+ space_storage: huggingface_hub.SpaceStorage | None = None,
143
+ dataset_id: str | None = None,
144
+ private: bool | None = None,
145
+ ) -> None:
146
+ """
147
+ Creates a new Hugging Face Space if it does not exist. If a dataset_id is provided, it will be added as a space variable.
148
+
149
+ Args:
150
+ space_id: The ID of the Space to create.
151
+ dataset_id: The ID of the Dataset to add to the Space.
152
+ private: Whether to make the Space private. If None (default), the repo will be
153
+ public unless the organization's default is private. This value is ignored if
154
+ the repo already exists.
155
+ """
156
+ if "/" not in space_id:
157
+ raise ValueError(
158
+ f"Invalid space ID: {space_id}. Must be in the format: username/reponame or orgname/reponame."
159
+ )
160
+ if dataset_id is not None and "/" not in dataset_id:
161
+ raise ValueError(
162
+ f"Invalid dataset ID: {dataset_id}. Must be in the format: username/datasetname or orgname/datasetname."
163
+ )
164
+ try:
165
+ huggingface_hub.repo_info(space_id, repo_type="space")
166
+ print(f"* Found existing space: {SPACE_URL.format(space_id=space_id)}")
167
+ if dataset_id is not None:
168
+ huggingface_hub.add_space_variable(
169
+ space_id, "TRACKIO_DATASET_ID", dataset_id
170
+ )
171
+ return
172
+ except RepositoryNotFoundError:
173
+ pass
174
+ except HTTPError as e:
175
+ if e.response.status_code in [401, 403]: # unauthorized or forbidden
176
+ print("Need 'write' access token to create a Spaces repo.")
177
+ huggingface_hub.login(add_to_git_credential=False)
178
+ huggingface_hub.add_space_variable(
179
+ space_id, "TRACKIO_DATASET_ID", dataset_id
180
+ )
181
+ else:
182
+ raise ValueError(f"Failed to create Space: {e}")
183
+
184
+ print(f"* Creating new space: {SPACE_URL.format(space_id=space_id)}")
185
+ deploy_as_space(space_id, space_storage, dataset_id, private)
186
+
187
+
188
+ def wait_until_space_exists(
189
+ space_id: str,
190
+ ) -> None:
191
+ """
192
+ Blocks the current thread until the space exists.
193
+ May raise a TimeoutError if this takes quite a while.
194
+
195
+ Args:
196
+ space_id: The ID of the Space to wait for.
197
+ """
198
+ delay = 1
199
+ for _ in range(10):
200
+ try:
201
+ Client(space_id, verbose=False)
202
+ return
203
+ except (ReadTimeout, ValueError):
204
+ time.sleep(delay)
205
+ delay = min(delay * 2, 30)
206
+ raise TimeoutError("Waiting for space to exist took longer than expected")
207
+
208
+
209
+ def upload_db_to_space(project: str, space_id: str) -> None:
210
+ """
211
+ Uploads the database of a local Trackio project to a Hugging Face Space.
212
+
213
+ Args:
214
+ project: The name of the project to upload.
215
+ space_id: The ID of the Space to upload to.
216
+ """
217
+ db_path = SQLiteStorage.get_project_db_path(project)
218
+ client = Client(space_id, verbose=False)
219
+ client.predict(
220
+ api_name="/upload_db_to_space",
221
+ project=project,
222
+ uploaded_db=handle_file(db_path),
223
+ hf_token=huggingface_hub.utils.get_token(),
224
+ )
dummy_commit_scheduler.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A dummy object to fit the interface of huggingface_hub's CommitScheduler
2
+ class DummyCommitSchedulerLock:
3
+ def __enter__(self):
4
+ return None
5
+
6
+ def __exit__(self, exception_type, exception_value, exception_traceback):
7
+ pass
8
+
9
+
10
+ class DummyCommitScheduler:
11
+ def __init__(self):
12
+ self.lock = DummyCommitSchedulerLock()
file_storage.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ try: # absolute imports when installed
4
+ from trackio.utils import MEDIA_DIR
5
+ except ImportError: # relative imports for local execution on Spaces
6
+ from utils import MEDIA_DIR
7
+
8
+
9
+ class FileStorage:
10
+ @staticmethod
11
+ def get_project_media_path(
12
+ project: str,
13
+ run: str | None = None,
14
+ step: int | None = None,
15
+ filename: str | None = None,
16
+ ) -> Path:
17
+ if filename is not None and step is None:
18
+ raise ValueError("filename requires step")
19
+ if step is not None and run is None:
20
+ raise ValueError("step requires run")
21
+
22
+ path = MEDIA_DIR / project
23
+ if run:
24
+ path /= run
25
+ if step is not None:
26
+ path /= str(step)
27
+ if filename:
28
+ path /= filename
29
+ return path
30
+
31
+ @staticmethod
32
+ def init_project_media_path(
33
+ project: str, run: str | None = None, step: int | None = None
34
+ ) -> Path:
35
+ path = FileStorage.get_project_media_path(project, run, step)
36
+ path.mkdir(parents=True, exist_ok=True)
37
+ return path
imports.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import pandas as pd
5
+
6
+ from trackio import deploy, utils
7
+ from trackio.sqlite_storage import SQLiteStorage
8
+
9
+
10
+ def import_csv(
11
+ csv_path: str | Path,
12
+ project: str,
13
+ name: str | None = None,
14
+ space_id: str | None = None,
15
+ dataset_id: str | None = None,
16
+ private: bool | None = None,
17
+ ) -> None:
18
+ """
19
+ Imports a CSV file into a Trackio project. The CSV file must contain a `"step"`
20
+ column, may optionally contain a `"timestamp"` column, and any other columns will be
21
+ treated as metrics. It should also include a header row with the column names.
22
+
23
+ TODO: call init() and return a Run object so that the user can continue to log metrics to it.
24
+
25
+ Args:
26
+ csv_path (`str` or `Path`):
27
+ The str or Path to the CSV file to import.
28
+ project (`str`):
29
+ The name of the project to import the CSV file into. Must not be an existing
30
+ project.
31
+ name (`str` or `None`, *optional*, defaults to `None`):
32
+ The name of the Run to import the CSV file into. If not provided, a default
33
+ name will be generated.
34
+ name (`str` or `None`, *optional*, defaults to `None`):
35
+ The name of the run (if not provided, a default name will be generated).
36
+ space_id (`str` or `None`, *optional*, defaults to `None`):
37
+ If provided, the project will be logged to a Hugging Face Space instead of a
38
+ local directory. Should be a complete Space name like `"username/reponame"`
39
+ or `"orgname/reponame"`, or just `"reponame"` in which case the Space will
40
+ be created in the currently-logged-in Hugging Face user's namespace. If the
41
+ Space does not exist, it will be created. If the Space already exists, the
42
+ project will be logged to it.
43
+ dataset_id (`str` or `None`, *optional*, defaults to `None`):
44
+ If provided, a persistent Hugging Face Dataset will be created and the
45
+ metrics will be synced to it every 5 minutes. Should be a complete Dataset
46
+ name like `"username/datasetname"` or `"orgname/datasetname"`, or just
47
+ `"datasetname"` in which case the Dataset will be created in the
48
+ currently-logged-in Hugging Face user's namespace. If the Dataset does not
49
+ exist, it will be created. If the Dataset already exists, the project will
50
+ be appended to it. If not provided, the metrics will be logged to a local
51
+ SQLite database, unless a `space_id` is provided, in which case a Dataset
52
+ will be automatically created with the same name as the Space but with the
53
+ `"_dataset"` suffix.
54
+ private (`bool` or `None`, *optional*, defaults to `None`):
55
+ Whether to make the Space private. If None (default), the repo will be
56
+ public unless the organization's default is private. This value is ignored
57
+ if the repo already exists.
58
+ """
59
+ if SQLiteStorage.get_runs(project):
60
+ raise ValueError(
61
+ f"Project '{project}' already exists. Cannot import CSV into existing project."
62
+ )
63
+
64
+ csv_path = Path(csv_path)
65
+ if not csv_path.exists():
66
+ raise FileNotFoundError(f"CSV file not found: {csv_path}")
67
+
68
+ df = pd.read_csv(csv_path)
69
+ if df.empty:
70
+ raise ValueError("CSV file is empty")
71
+
72
+ column_mapping = utils.simplify_column_names(df.columns.tolist())
73
+ df = df.rename(columns=column_mapping)
74
+
75
+ step_column = None
76
+ for col in df.columns:
77
+ if col.lower() == "step":
78
+ step_column = col
79
+ break
80
+
81
+ if step_column is None:
82
+ raise ValueError("CSV file must contain a 'step' or 'Step' column")
83
+
84
+ if name is None:
85
+ name = csv_path.stem
86
+
87
+ metrics_list = []
88
+ steps = []
89
+ timestamps = []
90
+
91
+ numeric_columns = []
92
+ for column in df.columns:
93
+ if column == step_column:
94
+ continue
95
+ if column == "timestamp":
96
+ continue
97
+
98
+ try:
99
+ pd.to_numeric(df[column], errors="raise")
100
+ numeric_columns.append(column)
101
+ except (ValueError, TypeError):
102
+ continue
103
+
104
+ for _, row in df.iterrows():
105
+ metrics = {}
106
+ for column in numeric_columns:
107
+ value = row[column]
108
+ if bool(pd.notna(value)):
109
+ metrics[column] = float(value)
110
+
111
+ if metrics:
112
+ metrics_list.append(metrics)
113
+ steps.append(int(row[step_column]))
114
+
115
+ if "timestamp" in df.columns and bool(pd.notna(row["timestamp"])):
116
+ timestamps.append(str(row["timestamp"]))
117
+ else:
118
+ timestamps.append("")
119
+
120
+ if metrics_list:
121
+ SQLiteStorage.bulk_log(
122
+ project=project,
123
+ run=name,
124
+ metrics_list=metrics_list,
125
+ steps=steps,
126
+ timestamps=timestamps,
127
+ )
128
+
129
+ print(
130
+ f"* Imported {len(metrics_list)} rows from {csv_path} into project '{project}' as run '{name}'"
131
+ )
132
+ print(f"* Metrics found: {', '.join(metrics_list[0].keys())}")
133
+
134
+ space_id, dataset_id = utils.preprocess_space_and_dataset_ids(space_id, dataset_id)
135
+ if dataset_id is not None:
136
+ os.environ["TRACKIO_DATASET_ID"] = dataset_id
137
+ print(f"* Trackio metrics will be synced to Hugging Face Dataset: {dataset_id}")
138
+
139
+ if space_id is None:
140
+ utils.print_dashboard_instructions(project)
141
+ else:
142
+ deploy.create_space_if_not_exists(
143
+ space_id=space_id, dataset_id=dataset_id, private=private
144
+ )
145
+ deploy.wait_until_space_exists(space_id=space_id)
146
+ deploy.upload_db_to_space(project=project, space_id=space_id)
147
+ print(
148
+ f"* View dashboard by going to: {deploy.SPACE_URL.format(space_id=space_id)}"
149
+ )
150
+
151
+
152
+ def import_tf_events(
153
+ log_dir: str | Path,
154
+ project: str,
155
+ name: str | None = None,
156
+ space_id: str | None = None,
157
+ dataset_id: str | None = None,
158
+ private: bool | None = None,
159
+ ) -> None:
160
+ """
161
+ Imports TensorFlow Events files from a directory into a Trackio project. Each
162
+ subdirectory in the log directory will be imported as a separate run.
163
+
164
+ Args:
165
+ log_dir (`str` or `Path`):
166
+ The str or Path to the directory containing TensorFlow Events files.
167
+ project (`str`):
168
+ The name of the project to import the TensorFlow Events files into. Must not
169
+ be an existing project.
170
+ name (`str` or `None`, *optional*, defaults to `None`):
171
+ The name prefix for runs (if not provided, will use directory names). Each
172
+ subdirectory will create a separate run.
173
+ space_id (`str` or `None`, *optional*, defaults to `None`):
174
+ If provided, the project will be logged to a Hugging Face Space instead of a
175
+ local directory. Should be a complete Space name like `"username/reponame"`
176
+ or `"orgname/reponame"`, or just `"reponame"` in which case the Space will
177
+ be created in the currently-logged-in Hugging Face user's namespace. If the
178
+ Space does not exist, it will be created. If the Space already exists, the
179
+ project will be logged to it.
180
+ dataset_id (`str` or `None`, *optional*, defaults to `None`):
181
+ If provided, a persistent Hugging Face Dataset will be created and the
182
+ metrics will be synced to it every 5 minutes. Should be a complete Dataset
183
+ name like `"username/datasetname"` or `"orgname/datasetname"`, or just
184
+ `"datasetname"` in which case the Dataset will be created in the
185
+ currently-logged-in Hugging Face user's namespace. If the Dataset does not
186
+ exist, it will be created. If the Dataset already exists, the project will
187
+ be appended to it. If not provided, the metrics will be logged to a local
188
+ SQLite database, unless a `space_id` is provided, in which case a Dataset
189
+ will be automatically created with the same name as the Space but with the
190
+ `"_dataset"` suffix.
191
+ private (`bool` or `None`, *optional*, defaults to `None`):
192
+ Whether to make the Space private. If None (default), the repo will be
193
+ public unless the organization's default is private. This value is ignored
194
+ if the repo already exists.
195
+ """
196
+ try:
197
+ from tbparse import SummaryReader
198
+ except ImportError:
199
+ raise ImportError(
200
+ "The `tbparse` package is not installed but is required for `import_tf_events`. Please install trackio with the `tensorboard` extra: `pip install trackio[tensorboard]`."
201
+ )
202
+
203
+ if SQLiteStorage.get_runs(project):
204
+ raise ValueError(
205
+ f"Project '{project}' already exists. Cannot import TF events into existing project."
206
+ )
207
+
208
+ path = Path(log_dir)
209
+ if not path.exists():
210
+ raise FileNotFoundError(f"TF events directory not found: {path}")
211
+
212
+ # Use tbparse to read all tfevents files in the directory structure
213
+ reader = SummaryReader(str(path), extra_columns={"dir_name"})
214
+ df = reader.scalars
215
+
216
+ if df.empty:
217
+ raise ValueError(f"No TensorFlow events data found in {path}")
218
+
219
+ total_imported = 0
220
+ imported_runs = []
221
+
222
+ # Group by dir_name to create separate runs
223
+ for dir_name, group_df in df.groupby("dir_name"):
224
+ try:
225
+ # Determine run name based on directory name
226
+ if dir_name == "":
227
+ run_name = "main" # For files in the root directory
228
+ else:
229
+ run_name = dir_name # Use directory name
230
+
231
+ if name:
232
+ run_name = f"{name}_{run_name}"
233
+
234
+ if group_df.empty:
235
+ print(f"* Skipping directory {dir_name}: no scalar data found")
236
+ continue
237
+
238
+ metrics_list = []
239
+ steps = []
240
+ timestamps = []
241
+
242
+ for _, row in group_df.iterrows():
243
+ # Convert row values to appropriate types
244
+ tag = str(row["tag"])
245
+ value = float(row["value"])
246
+ step = int(row["step"])
247
+
248
+ metrics = {tag: value}
249
+ metrics_list.append(metrics)
250
+ steps.append(step)
251
+
252
+ # Use wall_time if present, else fallback
253
+ if "wall_time" in group_df.columns and not bool(
254
+ pd.isna(row["wall_time"])
255
+ ):
256
+ timestamps.append(str(row["wall_time"]))
257
+ else:
258
+ timestamps.append("")
259
+
260
+ if metrics_list:
261
+ SQLiteStorage.bulk_log(
262
+ project=project,
263
+ run=str(run_name),
264
+ metrics_list=metrics_list,
265
+ steps=steps,
266
+ timestamps=timestamps,
267
+ )
268
+
269
+ total_imported += len(metrics_list)
270
+ imported_runs.append(run_name)
271
+
272
+ print(
273
+ f"* Imported {len(metrics_list)} scalar events from directory '{dir_name}' as run '{run_name}'"
274
+ )
275
+ print(f"* Metrics in this run: {', '.join(set(group_df['tag']))}")
276
+
277
+ except Exception as e:
278
+ print(f"* Error processing directory {dir_name}: {e}")
279
+ continue
280
+
281
+ if not imported_runs:
282
+ raise ValueError("No valid TensorFlow events data could be imported")
283
+
284
+ print(f"* Total imported events: {total_imported}")
285
+ print(f"* Created runs: {', '.join(imported_runs)}")
286
+
287
+ space_id, dataset_id = utils.preprocess_space_and_dataset_ids(space_id, dataset_id)
288
+ if dataset_id is not None:
289
+ os.environ["TRACKIO_DATASET_ID"] = dataset_id
290
+ print(f"* Trackio metrics will be synced to Hugging Face Dataset: {dataset_id}")
291
+
292
+ if space_id is None:
293
+ utils.print_dashboard_instructions(project)
294
+ else:
295
+ deploy.create_space_if_not_exists(
296
+ space_id, dataset_id=dataset_id, private=private
297
+ )
298
+ deploy.wait_until_space_exists(space_id)
299
+ deploy.upload_db_to_space(project, space_id)
300
+ print(
301
+ f"* View dashboard by going to: {deploy.SPACE_URL.format(space_id=space_id)}"
302
+ )
media.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import uuid
4
+ from abc import ABC, abstractmethod
5
+ from pathlib import Path
6
+ from typing import Literal
7
+
8
+ import numpy as np
9
+ from PIL import Image as PILImage
10
+
11
+ try: # absolute imports when installed
12
+ from trackio.file_storage import FileStorage
13
+ from trackio.utils import MEDIA_DIR
14
+ from trackio.video_writer import write_video
15
+ except ImportError: # relative imports for local execution on Spaces
16
+ from file_storage import FileStorage
17
+ from utils import MEDIA_DIR
18
+ from video_writer import write_video
19
+
20
+
21
+ class TrackioMedia(ABC):
22
+ """
23
+ Abstract base class for Trackio media objects
24
+ Provides shared functionality for file handling and serialization.
25
+ """
26
+
27
+ TYPE: str
28
+
29
+ def __init_subclass__(cls, **kwargs):
30
+ """Ensure subclasses define the TYPE attribute."""
31
+ super().__init_subclass__(**kwargs)
32
+ if not hasattr(cls, "TYPE") or cls.TYPE is None:
33
+ raise TypeError(f"Class {cls.__name__} must define TYPE attribute")
34
+
35
+ def __init__(self, value, caption: str | None = None):
36
+ self.caption = caption
37
+ self._value = value
38
+ self._file_path: Path | None = None
39
+
40
+ # Validate file existence for string/Path inputs
41
+ if isinstance(self._value, str | Path):
42
+ if not os.path.isfile(self._value):
43
+ raise ValueError(f"File not found: {self._value}")
44
+
45
+ def _file_extension(self) -> str:
46
+ if self._file_path:
47
+ return self._file_path.suffix[1:].lower()
48
+ if isinstance(self._value, str | Path):
49
+ path = Path(self._value)
50
+ return path.suffix[1:].lower()
51
+ if hasattr(self, "_format") and self._format:
52
+ return self._format
53
+ return "unknown"
54
+
55
+ def _get_relative_file_path(self) -> Path | None:
56
+ return self._file_path
57
+
58
+ def _get_absolute_file_path(self) -> Path | None:
59
+ if self._file_path:
60
+ return MEDIA_DIR / self._file_path
61
+ return None
62
+
63
+ def _save(self, project: str, run: str, step: int = 0):
64
+ if self._file_path:
65
+ return
66
+
67
+ media_dir = FileStorage.init_project_media_path(project, run, step)
68
+ filename = f"{uuid.uuid4()}.{self._file_extension()}"
69
+ file_path = media_dir / filename
70
+
71
+ # Delegate to subclass-specific save logic
72
+ self._save_media(file_path)
73
+
74
+ self._file_path = file_path.relative_to(MEDIA_DIR)
75
+
76
+ @abstractmethod
77
+ def _save_media(self, file_path: Path):
78
+ """
79
+ Performs the actual media saving logic.
80
+ """
81
+ pass
82
+
83
+ def _to_dict(self) -> dict:
84
+ if not self._file_path:
85
+ raise ValueError("Media must be saved to file before serialization")
86
+ return {
87
+ "_type": self.TYPE,
88
+ "file_path": str(self._get_relative_file_path()),
89
+ "caption": self.caption,
90
+ }
91
+
92
+
93
+ TrackioImageSourceType = str | Path | np.ndarray | PILImage.Image
94
+
95
+
96
+ class TrackioImage(TrackioMedia):
97
+ """
98
+ Initializes an Image object.
99
+
100
+ Example:
101
+ ```python
102
+ import trackio
103
+ import numpy as np
104
+ from PIL import Image
105
+
106
+ # Create an image from numpy array
107
+ image_data = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
108
+ image = trackio.Image(image_data, caption="Random image")
109
+ trackio.log({"my_image": image})
110
+
111
+ # Create an image from PIL Image
112
+ pil_image = Image.new('RGB', (100, 100), color='red')
113
+ image = trackio.Image(pil_image, caption="Red square")
114
+ trackio.log({"red_image": image})
115
+
116
+ # Create an image from file path
117
+ image = trackio.Image("path/to/image.jpg", caption="Photo from file")
118
+ trackio.log({"file_image": image})
119
+ ```
120
+
121
+ Args:
122
+ value (`str`, `Path`, `numpy.ndarray`, or `PIL.Image`, *optional*, defaults to `None`):
123
+ A path to an image, a PIL Image, or a numpy array of shape (height, width, channels).
124
+ caption (`str`, *optional*, defaults to `None`):
125
+ A string caption for the image.
126
+ """
127
+
128
+ TYPE = "trackio.image"
129
+
130
+ def __init__(self, value: TrackioImageSourceType, caption: str | None = None):
131
+ super().__init__(value, caption)
132
+ self._format: str | None = None
133
+
134
+ if (
135
+ isinstance(self._value, np.ndarray | PILImage.Image)
136
+ and self._format is None
137
+ ):
138
+ self._format = "png"
139
+
140
+ def _as_pil(self) -> PILImage.Image | None:
141
+ try:
142
+ if isinstance(self._value, np.ndarray):
143
+ arr = np.asarray(self._value).astype("uint8")
144
+ return PILImage.fromarray(arr).convert("RGBA")
145
+ if isinstance(self._value, PILImage.Image):
146
+ return self._value.convert("RGBA")
147
+ except Exception as e:
148
+ raise ValueError(f"Failed to process image data: {self._value}") from e
149
+ return None
150
+
151
+ def _save_media(self, file_path: Path):
152
+ if pil := self._as_pil():
153
+ pil.save(file_path, format=self._format)
154
+ elif isinstance(self._value, str | Path):
155
+ if os.path.isfile(self._value):
156
+ shutil.copy(self._value, file_path)
157
+ else:
158
+ raise ValueError(f"File not found: {self._value}")
159
+
160
+
161
+ TrackioVideoSourceType = str | Path | np.ndarray
162
+ TrackioVideoFormatType = Literal["gif", "mp4", "webm"]
163
+
164
+
165
+ class TrackioVideo(TrackioMedia):
166
+ """
167
+ Initializes a Video object.
168
+
169
+ Example:
170
+ ```python
171
+ import trackio
172
+ import numpy as np
173
+
174
+ # Create a simple video from numpy array
175
+ frames = np.random.randint(0, 255, (10, 3, 64, 64), dtype=np.uint8)
176
+ video = trackio.Video(frames, caption="Random video", fps=30)
177
+
178
+ # Create a batch of videos
179
+ batch_frames = np.random.randint(0, 255, (3, 10, 3, 64, 64), dtype=np.uint8)
180
+ batch_video = trackio.Video(batch_frames, caption="Batch of videos", fps=15)
181
+
182
+ # Create video from file path
183
+ video = trackio.Video("path/to/video.mp4", caption="Video from file")
184
+ ```
185
+
186
+ Args:
187
+ value (`str`, `Path`, or `numpy.ndarray`, *optional*, defaults to `None`):
188
+ A path to a video file, or a numpy array.
189
+ The array should be of type `np.uint8` with RGB values in the range `[0, 255]`.
190
+ It is expected to have shape of either (frames, channels, height, width) or (batch, frames, channels, height, width).
191
+ For the latter, the videos will be tiled into a grid.
192
+ caption (`str`, *optional*, defaults to `None`):
193
+ A string caption for the video.
194
+ fps (`int`, *optional*, defaults to `None`):
195
+ Frames per second for the video. Only used when value is an ndarray. Default is `24`.
196
+ format (`Literal["gif", "mp4", "webm"]`, *optional*, defaults to `None`):
197
+ Video format ("gif", "mp4", or "webm"). Only used when value is an ndarray. Default is "gif".
198
+ """
199
+
200
+ TYPE = "trackio.video"
201
+
202
+ def __init__(
203
+ self,
204
+ value: TrackioVideoSourceType,
205
+ caption: str | None = None,
206
+ fps: int | None = None,
207
+ format: TrackioVideoFormatType | None = None,
208
+ ):
209
+ super().__init__(value, caption)
210
+ if isinstance(value, np.ndarray):
211
+ if format is None:
212
+ format = "gif"
213
+ if fps is None:
214
+ fps = 24
215
+ self._fps = fps
216
+ self._format = format
217
+
218
+ @property
219
+ def _codec(self) -> str:
220
+ match self._format:
221
+ case "gif":
222
+ return "gif"
223
+ case "mp4":
224
+ return "h264"
225
+ case "webm":
226
+ return "vp9"
227
+ case _:
228
+ raise ValueError(f"Unsupported format: {self._format}")
229
+
230
+ def _save_media(self, file_path: Path):
231
+ if isinstance(self._value, np.ndarray):
232
+ video = TrackioVideo._process_ndarray(self._value)
233
+ write_video(file_path, video, fps=self._fps, codec=self._codec)
234
+ elif isinstance(self._value, str | Path):
235
+ if os.path.isfile(self._value):
236
+ shutil.copy(self._value, file_path)
237
+ else:
238
+ raise ValueError(f"File not found: {self._value}")
239
+
240
+ @staticmethod
241
+ def _process_ndarray(value: np.ndarray) -> np.ndarray:
242
+ # Verify value is either 4D (single video) or 5D array (batched videos).
243
+ # Expected format: (frames, channels, height, width) or (batch, frames, channels, height, width)
244
+ if value.ndim < 4:
245
+ raise ValueError(
246
+ "Video requires at least 4 dimensions (frames, channels, height, width)"
247
+ )
248
+ if value.ndim > 5:
249
+ raise ValueError(
250
+ "Videos can have at most 5 dimensions (batch, frames, channels, height, width)"
251
+ )
252
+ if value.ndim == 4:
253
+ # Reshape to 5D with single batch: (1, frames, channels, height, width)
254
+ value = value[np.newaxis, ...]
255
+
256
+ value = TrackioVideo._tile_batched_videos(value)
257
+ return value
258
+
259
+ @staticmethod
260
+ def _tile_batched_videos(video: np.ndarray) -> np.ndarray:
261
+ """
262
+ Tiles a batch of videos into a grid of videos.
263
+
264
+ Input format: (batch, frames, channels, height, width) - original FCHW format
265
+ Output format: (frames, total_height, total_width, channels)
266
+ """
267
+ batch_size, frames, channels, height, width = video.shape
268
+
269
+ next_pow2 = 1 << (batch_size - 1).bit_length()
270
+ if batch_size != next_pow2:
271
+ pad_len = next_pow2 - batch_size
272
+ pad_shape = (pad_len, frames, channels, height, width)
273
+ padding = np.zeros(pad_shape, dtype=video.dtype)
274
+ video = np.concatenate((video, padding), axis=0)
275
+ batch_size = next_pow2
276
+
277
+ n_rows = 1 << ((batch_size.bit_length() - 1) // 2)
278
+ n_cols = batch_size // n_rows
279
+
280
+ # Reshape to grid layout: (n_rows, n_cols, frames, channels, height, width)
281
+ video = video.reshape(n_rows, n_cols, frames, channels, height, width)
282
+
283
+ # Rearrange dimensions to (frames, total_height, total_width, channels)
284
+ video = video.transpose(2, 0, 4, 1, 5, 3)
285
+ video = video.reshape(frames, n_rows * height, n_cols * width, channels)
286
+ return video
py.typed ADDED
File without changes
run.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import time
3
+ from datetime import datetime, timezone
4
+
5
+ import huggingface_hub
6
+ from gradio_client import Client, handle_file
7
+
8
+ from trackio.media import TrackioMedia
9
+ from trackio.sqlite_storage import SQLiteStorage
10
+ from trackio.table import Table
11
+ from trackio.typehints import LogEntry, UploadEntry
12
+ from trackio.utils import (
13
+ RESERVED_KEYS,
14
+ fibo,
15
+ generate_readable_name,
16
+ serialize_values,
17
+ )
18
+
19
+ BATCH_SEND_INTERVAL = 0.5
20
+
21
+
22
+ class Run:
23
+ def __init__(
24
+ self,
25
+ url: str,
26
+ project: str,
27
+ client: Client | None,
28
+ name: str | None = None,
29
+ config: dict | None = None,
30
+ space_id: str | None = None,
31
+ ):
32
+ self.url = url
33
+ self.project = project
34
+ self._client_lock = threading.Lock()
35
+ self._client_thread = None
36
+ self._client = client
37
+ self._space_id = space_id
38
+ self.name = name or generate_readable_name(
39
+ SQLiteStorage.get_runs(project), space_id
40
+ )
41
+ self.config = config or {}
42
+
43
+ for key in self.config:
44
+ if key.startswith("_"):
45
+ raise ValueError(
46
+ f"Config key '{key}' is reserved (keys starting with '_' are reserved for internal use)"
47
+ )
48
+
49
+ self.config["_Username"] = self._get_username()
50
+ self.config["_Created"] = datetime.now(timezone.utc).isoformat()
51
+ self._queued_logs: list[LogEntry] = []
52
+ self._queued_uploads: list[UploadEntry] = []
53
+ self._stop_flag = threading.Event()
54
+ self._config_logged = False
55
+
56
+ self._client_thread = threading.Thread(target=self._init_client_background)
57
+ self._client_thread.daemon = True
58
+ self._client_thread.start()
59
+
60
+ def _get_username(self) -> str | None:
61
+ """Get the current HuggingFace username if logged in, otherwise None."""
62
+ try:
63
+ who = huggingface_hub.whoami()
64
+ return who["name"] if who else None
65
+ except Exception:
66
+ return None
67
+
68
+ def _batch_sender(self):
69
+ """Send batched logs every BATCH_SEND_INTERVAL."""
70
+ while not self._stop_flag.is_set() or len(self._queued_logs) > 0:
71
+ # If the stop flag has been set, then just quickly send all
72
+ # the logs and exit.
73
+ if not self._stop_flag.is_set():
74
+ time.sleep(BATCH_SEND_INTERVAL)
75
+
76
+ with self._client_lock:
77
+ if self._client is None:
78
+ return
79
+ if self._queued_logs:
80
+ logs_to_send = self._queued_logs.copy()
81
+ self._queued_logs.clear()
82
+ self._client.predict(
83
+ api_name="/bulk_log",
84
+ logs=logs_to_send,
85
+ hf_token=huggingface_hub.utils.get_token(),
86
+ )
87
+ if self._queued_uploads:
88
+ uploads_to_send = self._queued_uploads.copy()
89
+ self._queued_uploads.clear()
90
+ self._client.predict(
91
+ api_name="/bulk_upload_media",
92
+ uploads=uploads_to_send,
93
+ hf_token=huggingface_hub.utils.get_token(),
94
+ )
95
+
96
+ def _init_client_background(self):
97
+ if self._client is None:
98
+ fib = fibo()
99
+ for sleep_coefficient in fib:
100
+ try:
101
+ client = Client(self.url, verbose=False)
102
+
103
+ with self._client_lock:
104
+ self._client = client
105
+ break
106
+ except Exception:
107
+ pass
108
+ if sleep_coefficient is not None:
109
+ time.sleep(0.1 * sleep_coefficient)
110
+
111
+ self._batch_sender()
112
+
113
+ def _process_media(self, metrics, step: int | None) -> dict:
114
+ """
115
+ Serialize media in metrics and upload to space if needed.
116
+ """
117
+ serializable_metrics = {}
118
+ if not step:
119
+ step = 0
120
+ for key, value in metrics.items():
121
+ if isinstance(value, TrackioMedia):
122
+ value._save(self.project, self.name, step)
123
+ serializable_metrics[key] = value._to_dict()
124
+ if self._space_id:
125
+ # Upload local media when deploying to space
126
+ upload_entry: UploadEntry = {
127
+ "project": self.project,
128
+ "run": self.name,
129
+ "step": step,
130
+ "uploaded_file": handle_file(value._get_absolute_file_path()),
131
+ }
132
+ with self._client_lock:
133
+ self._queued_uploads.append(upload_entry)
134
+ else:
135
+ serializable_metrics[key] = value
136
+ return serializable_metrics
137
+
138
+ @staticmethod
139
+ def _replace_tables(metrics):
140
+ for k, v in metrics.items():
141
+ if isinstance(v, Table):
142
+ metrics[k] = v._to_dict()
143
+
144
+ def log(self, metrics: dict, step: int | None = None):
145
+ for k in metrics.keys():
146
+ if k in RESERVED_KEYS or k.startswith("__"):
147
+ raise ValueError(
148
+ f"Please do not use this reserved key as a metric: {k}"
149
+ )
150
+ Run._replace_tables(metrics)
151
+
152
+ metrics = self._process_media(metrics, step)
153
+ metrics = serialize_values(metrics)
154
+
155
+ config_to_log = None
156
+ if not self._config_logged and self.config:
157
+ config_to_log = self.config
158
+ self._config_logged = True
159
+
160
+ log_entry: LogEntry = {
161
+ "project": self.project,
162
+ "run": self.name,
163
+ "metrics": metrics,
164
+ "step": step,
165
+ "config": config_to_log,
166
+ }
167
+
168
+ with self._client_lock:
169
+ self._queued_logs.append(log_entry)
170
+
171
+ def finish(self):
172
+ """Cleanup when run is finished."""
173
+ self._stop_flag.set()
174
+
175
+ # Wait for the batch sender to finish before joining the client thread.
176
+ time.sleep(2 * BATCH_SEND_INTERVAL)
177
+
178
+ if self._client_thread is not None:
179
+ print(
180
+ f"* Run finished. Uploading logs to Trackio Space: {self.url} (please wait...)"
181
+ )
182
+ self._client_thread.join()
sqlite_storage.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fcntl
2
+ import json
3
+ import os
4
+ import sqlite3
5
+ import time
6
+ from datetime import datetime
7
+ from pathlib import Path
8
+ from threading import Lock
9
+
10
+ import huggingface_hub as hf
11
+ import pandas as pd
12
+
13
+ try: # absolute imports when installed
14
+ from trackio.commit_scheduler import CommitScheduler
15
+ from trackio.dummy_commit_scheduler import DummyCommitScheduler
16
+ from trackio.utils import (
17
+ TRACKIO_DIR,
18
+ deserialize_values,
19
+ serialize_values,
20
+ )
21
+ except Exception: # relative imports for local execution on Spaces
22
+ from commit_scheduler import CommitScheduler
23
+ from dummy_commit_scheduler import DummyCommitScheduler
24
+ from utils import TRACKIO_DIR, deserialize_values, serialize_values
25
+
26
+
27
+ class ProcessLock:
28
+ """A simple file-based lock that works across processes."""
29
+
30
+ def __init__(self, lockfile_path: Path):
31
+ self.lockfile_path = lockfile_path
32
+ self.lockfile = None
33
+
34
+ def __enter__(self):
35
+ """Acquire the lock with retry logic."""
36
+ self.lockfile_path.parent.mkdir(parents=True, exist_ok=True)
37
+ self.lockfile = open(self.lockfile_path, "w")
38
+
39
+ max_retries = 100
40
+ for attempt in range(max_retries):
41
+ try:
42
+ fcntl.flock(self.lockfile.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
43
+ return self
44
+ except IOError:
45
+ if attempt < max_retries - 1:
46
+ time.sleep(0.1)
47
+ else:
48
+ raise IOError("Could not acquire database lock after 10 seconds")
49
+
50
+ def __exit__(self, exc_type, exc_val, exc_tb):
51
+ """Release the lock."""
52
+ if self.lockfile:
53
+ fcntl.flock(self.lockfile.fileno(), fcntl.LOCK_UN)
54
+ self.lockfile.close()
55
+
56
+
57
+ class SQLiteStorage:
58
+ _dataset_import_attempted = False
59
+ _current_scheduler: CommitScheduler | DummyCommitScheduler | None = None
60
+ _scheduler_lock = Lock()
61
+
62
+ @staticmethod
63
+ def _get_connection(db_path: Path) -> sqlite3.Connection:
64
+ conn = sqlite3.connect(str(db_path), timeout=30.0)
65
+ conn.execute("PRAGMA journal_mode = WAL")
66
+ conn.row_factory = sqlite3.Row
67
+ return conn
68
+
69
+ @staticmethod
70
+ def _get_process_lock(project: str) -> ProcessLock:
71
+ lockfile_path = TRACKIO_DIR / f"{project}.lock"
72
+ return ProcessLock(lockfile_path)
73
+
74
+ @staticmethod
75
+ def get_project_db_filename(project: str) -> Path:
76
+ """Get the database filename for a specific project."""
77
+ safe_project_name = "".join(
78
+ c for c in project if c.isalnum() or c in ("-", "_")
79
+ ).rstrip()
80
+ if not safe_project_name:
81
+ safe_project_name = "default"
82
+ return f"{safe_project_name}.db"
83
+
84
+ @staticmethod
85
+ def get_project_db_path(project: str) -> Path:
86
+ """Get the database path for a specific project."""
87
+ filename = SQLiteStorage.get_project_db_filename(project)
88
+ return TRACKIO_DIR / filename
89
+
90
+ @staticmethod
91
+ def init_db(project: str) -> Path:
92
+ """
93
+ Initialize the SQLite database with required tables.
94
+ If there is a dataset ID provided, copies from that dataset instead.
95
+ Returns the database path.
96
+ """
97
+ db_path = SQLiteStorage.get_project_db_path(project)
98
+ db_path.parent.mkdir(parents=True, exist_ok=True)
99
+ with SQLiteStorage._get_process_lock(project):
100
+ with sqlite3.connect(db_path, timeout=30.0) as conn:
101
+ conn.execute("PRAGMA journal_mode = WAL")
102
+ cursor = conn.cursor()
103
+ cursor.execute("""
104
+ CREATE TABLE IF NOT EXISTS metrics (
105
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
106
+ timestamp TEXT NOT NULL,
107
+ run_name TEXT NOT NULL,
108
+ step INTEGER NOT NULL,
109
+ metrics TEXT NOT NULL
110
+ )
111
+ """)
112
+ cursor.execute("""
113
+ CREATE TABLE IF NOT EXISTS configs (
114
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
115
+ run_name TEXT NOT NULL,
116
+ config TEXT NOT NULL,
117
+ created_at TEXT NOT NULL,
118
+ UNIQUE(run_name)
119
+ )
120
+ """)
121
+ cursor.execute(
122
+ """
123
+ CREATE INDEX IF NOT EXISTS idx_metrics_run_step
124
+ ON metrics(run_name, step)
125
+ """
126
+ )
127
+ cursor.execute(
128
+ """
129
+ CREATE INDEX IF NOT EXISTS idx_configs_run_name
130
+ ON configs(run_name)
131
+ """
132
+ )
133
+ conn.commit()
134
+ return db_path
135
+
136
+ @staticmethod
137
+ def export_to_parquet():
138
+ """
139
+ Exports all projects' DB files as Parquet under the same path but with extension ".parquet".
140
+ """
141
+ # don't attempt to export (potentially wrong/blank) data before importing for the first time
142
+ if not SQLiteStorage._dataset_import_attempted:
143
+ return
144
+ all_paths = os.listdir(TRACKIO_DIR)
145
+ db_paths = [f for f in all_paths if f.endswith(".db")]
146
+ for db_path in db_paths:
147
+ db_path = TRACKIO_DIR / db_path
148
+ parquet_path = db_path.with_suffix(".parquet")
149
+ if (not parquet_path.exists()) or (
150
+ db_path.stat().st_mtime > parquet_path.stat().st_mtime
151
+ ):
152
+ with sqlite3.connect(db_path) as conn:
153
+ df = pd.read_sql("SELECT * from metrics", conn)
154
+ # break out the single JSON metrics column into individual columns
155
+ metrics = df["metrics"].copy()
156
+ metrics = pd.DataFrame(
157
+ metrics.apply(
158
+ lambda x: deserialize_values(json.loads(x))
159
+ ).values.tolist(),
160
+ index=df.index,
161
+ )
162
+ del df["metrics"]
163
+ for col in metrics.columns:
164
+ df[col] = metrics[col]
165
+ df.to_parquet(parquet_path)
166
+
167
+ @staticmethod
168
+ def import_from_parquet():
169
+ """
170
+ Imports to all DB files that have matching files under the same path but with extension ".parquet".
171
+ """
172
+ all_paths = os.listdir(TRACKIO_DIR)
173
+ parquet_paths = [f for f in all_paths if f.endswith(".parquet")]
174
+ for parquet_path in parquet_paths:
175
+ parquet_path = TRACKIO_DIR / parquet_path
176
+ db_path = parquet_path.with_suffix(".db")
177
+ df = pd.read_parquet(parquet_path)
178
+ with sqlite3.connect(db_path) as conn:
179
+ # fix up df to have a single JSON metrics column
180
+ if "metrics" not in df.columns:
181
+ # separate other columns from metrics
182
+ metrics = df.copy()
183
+ other_cols = ["id", "timestamp", "run_name", "step"]
184
+ df = df[other_cols]
185
+ for col in other_cols:
186
+ del metrics[col]
187
+ # combine them all into a single metrics col
188
+ metrics = json.loads(metrics.to_json(orient="records"))
189
+ df["metrics"] = [
190
+ json.dumps(serialize_values(row)) for row in metrics
191
+ ]
192
+ df.to_sql("metrics", conn, if_exists="replace", index=False)
193
+
194
+ @staticmethod
195
+ def get_scheduler():
196
+ """
197
+ Get the scheduler for the database based on the environment variables.
198
+ This applies to both local and Spaces.
199
+ """
200
+ with SQLiteStorage._scheduler_lock:
201
+ if SQLiteStorage._current_scheduler is not None:
202
+ return SQLiteStorage._current_scheduler
203
+ hf_token = os.environ.get("HF_TOKEN")
204
+ dataset_id = os.environ.get("TRACKIO_DATASET_ID")
205
+ space_repo_name = os.environ.get("SPACE_REPO_NAME")
206
+ if dataset_id is None or space_repo_name is None:
207
+ scheduler = DummyCommitScheduler()
208
+ else:
209
+ scheduler = CommitScheduler(
210
+ repo_id=dataset_id,
211
+ repo_type="dataset",
212
+ folder_path=TRACKIO_DIR,
213
+ private=True,
214
+ allow_patterns=["*.parquet", "media/**/*"],
215
+ squash_history=True,
216
+ token=hf_token,
217
+ on_before_commit=SQLiteStorage.export_to_parquet,
218
+ )
219
+ SQLiteStorage._current_scheduler = scheduler
220
+ return scheduler
221
+
222
+ @staticmethod
223
+ def log(project: str, run: str, metrics: dict, step: int | None = None):
224
+ """
225
+ Safely log metrics to the database. Before logging, this method will ensure the database exists
226
+ and is set up with the correct tables. It also uses a cross-process lock to prevent
227
+ database locking errors when multiple processes access the same database.
228
+
229
+ This method is not used in the latest versions of Trackio (replaced by bulk_log) but
230
+ is kept for backwards compatibility for users who are connecting to a newer version of
231
+ a Trackio Spaces dashboard with an older version of Trackio installed locally.
232
+ """
233
+ db_path = SQLiteStorage.init_db(project)
234
+
235
+ with SQLiteStorage._get_process_lock(project):
236
+ with SQLiteStorage._get_connection(db_path) as conn:
237
+ cursor = conn.cursor()
238
+
239
+ cursor.execute(
240
+ """
241
+ SELECT MAX(step)
242
+ FROM metrics
243
+ WHERE run_name = ?
244
+ """,
245
+ (run,),
246
+ )
247
+ last_step = cursor.fetchone()[0]
248
+ if step is None:
249
+ current_step = 0 if last_step is None else last_step + 1
250
+ else:
251
+ current_step = step
252
+
253
+ current_timestamp = datetime.now().isoformat()
254
+
255
+ cursor.execute(
256
+ """
257
+ INSERT INTO metrics
258
+ (timestamp, run_name, step, metrics)
259
+ VALUES (?, ?, ?, ?)
260
+ """,
261
+ (
262
+ current_timestamp,
263
+ run,
264
+ current_step,
265
+ json.dumps(serialize_values(metrics)),
266
+ ),
267
+ )
268
+ conn.commit()
269
+
270
+ @staticmethod
271
+ def bulk_log(
272
+ project: str,
273
+ run: str,
274
+ metrics_list: list[dict],
275
+ steps: list[int] | None = None,
276
+ timestamps: list[str] | None = None,
277
+ config: dict | None = None,
278
+ ):
279
+ """
280
+ Safely log bulk metrics to the database. Before logging, this method will ensure the database exists
281
+ and is set up with the correct tables. It also uses a cross-process lock to prevent
282
+ database locking errors when multiple processes access the same database.
283
+ """
284
+ if not metrics_list:
285
+ return
286
+
287
+ if timestamps is None:
288
+ timestamps = [datetime.now().isoformat()] * len(metrics_list)
289
+
290
+ db_path = SQLiteStorage.init_db(project)
291
+ with SQLiteStorage._get_process_lock(project):
292
+ with SQLiteStorage._get_connection(db_path) as conn:
293
+ cursor = conn.cursor()
294
+
295
+ if steps is None:
296
+ steps = list(range(len(metrics_list)))
297
+ elif any(s is None for s in steps):
298
+ cursor.execute(
299
+ "SELECT MAX(step) FROM metrics WHERE run_name = ?", (run,)
300
+ )
301
+ last_step = cursor.fetchone()[0]
302
+ current_step = 0 if last_step is None else last_step + 1
303
+
304
+ processed_steps = []
305
+ for step in steps:
306
+ if step is None:
307
+ processed_steps.append(current_step)
308
+ current_step += 1
309
+ else:
310
+ processed_steps.append(step)
311
+ steps = processed_steps
312
+
313
+ if len(metrics_list) != len(steps) or len(metrics_list) != len(
314
+ timestamps
315
+ ):
316
+ raise ValueError(
317
+ "metrics_list, steps, and timestamps must have the same length"
318
+ )
319
+
320
+ data = []
321
+ for i, metrics in enumerate(metrics_list):
322
+ data.append(
323
+ (
324
+ timestamps[i],
325
+ run,
326
+ steps[i],
327
+ json.dumps(serialize_values(metrics)),
328
+ )
329
+ )
330
+
331
+ cursor.executemany(
332
+ """
333
+ INSERT INTO metrics
334
+ (timestamp, run_name, step, metrics)
335
+ VALUES (?, ?, ?, ?)
336
+ """,
337
+ data,
338
+ )
339
+
340
+ if config:
341
+ current_timestamp = datetime.now().isoformat()
342
+ cursor.execute(
343
+ """
344
+ INSERT OR REPLACE INTO configs
345
+ (run_name, config, created_at)
346
+ VALUES (?, ?, ?)
347
+ """,
348
+ (run, json.dumps(serialize_values(config)), current_timestamp),
349
+ )
350
+
351
+ conn.commit()
352
+
353
+ @staticmethod
354
+ def get_logs(project: str, run: str) -> list[dict]:
355
+ """Retrieve logs for a specific run. Logs include the step count (int) and the timestamp (datetime object)."""
356
+ db_path = SQLiteStorage.get_project_db_path(project)
357
+ if not db_path.exists():
358
+ return []
359
+
360
+ with SQLiteStorage._get_connection(db_path) as conn:
361
+ cursor = conn.cursor()
362
+ cursor.execute(
363
+ """
364
+ SELECT timestamp, step, metrics
365
+ FROM metrics
366
+ WHERE run_name = ?
367
+ ORDER BY timestamp
368
+ """,
369
+ (run,),
370
+ )
371
+
372
+ rows = cursor.fetchall()
373
+ results = []
374
+ for row in rows:
375
+ metrics = json.loads(row["metrics"])
376
+ metrics = deserialize_values(metrics)
377
+ metrics["timestamp"] = row["timestamp"]
378
+ metrics["step"] = row["step"]
379
+ results.append(metrics)
380
+ return results
381
+
382
+ @staticmethod
383
+ def load_from_dataset():
384
+ dataset_id = os.environ.get("TRACKIO_DATASET_ID")
385
+ space_repo_name = os.environ.get("SPACE_REPO_NAME")
386
+ if dataset_id is not None and space_repo_name is not None:
387
+ hfapi = hf.HfApi()
388
+ updated = False
389
+ if not TRACKIO_DIR.exists():
390
+ TRACKIO_DIR.mkdir(parents=True, exist_ok=True)
391
+ with SQLiteStorage.get_scheduler().lock:
392
+ try:
393
+ files = hfapi.list_repo_files(dataset_id, repo_type="dataset")
394
+ for file in files:
395
+ # Download parquet and media assets
396
+ if not (file.endswith(".parquet") or file.startswith("media/")):
397
+ continue
398
+ if (TRACKIO_DIR / file).exists():
399
+ continue
400
+ hf.hf_hub_download(
401
+ dataset_id, file, repo_type="dataset", local_dir=TRACKIO_DIR
402
+ )
403
+ updated = True
404
+ except hf.errors.EntryNotFoundError:
405
+ pass
406
+ except hf.errors.RepositoryNotFoundError:
407
+ pass
408
+ if updated:
409
+ SQLiteStorage.import_from_parquet()
410
+ SQLiteStorage._dataset_import_attempted = True
411
+
412
+ @staticmethod
413
+ def get_projects() -> list[str]:
414
+ """
415
+ Get list of all projects by scanning the database files in the trackio directory.
416
+ """
417
+ if not SQLiteStorage._dataset_import_attempted:
418
+ SQLiteStorage.load_from_dataset()
419
+
420
+ projects: set[str] = set()
421
+ if not TRACKIO_DIR.exists():
422
+ return []
423
+
424
+ for db_file in TRACKIO_DIR.glob("*.db"):
425
+ project_name = db_file.stem
426
+ projects.add(project_name)
427
+ return sorted(projects)
428
+
429
+ @staticmethod
430
+ def get_runs(project: str) -> list[str]:
431
+ """Get list of all runs for a project."""
432
+ db_path = SQLiteStorage.get_project_db_path(project)
433
+ if not db_path.exists():
434
+ return []
435
+
436
+ with SQLiteStorage._get_connection(db_path) as conn:
437
+ cursor = conn.cursor()
438
+ cursor.execute(
439
+ "SELECT DISTINCT run_name FROM metrics",
440
+ )
441
+ return [row[0] for row in cursor.fetchall()]
442
+
443
+ @staticmethod
444
+ def get_max_steps_for_runs(project: str) -> dict[str, int]:
445
+ """Get the maximum step for each run in a project."""
446
+ db_path = SQLiteStorage.get_project_db_path(project)
447
+ if not db_path.exists():
448
+ return {}
449
+
450
+ with SQLiteStorage._get_connection(db_path) as conn:
451
+ cursor = conn.cursor()
452
+ cursor.execute(
453
+ """
454
+ SELECT run_name, MAX(step) as max_step
455
+ FROM metrics
456
+ GROUP BY run_name
457
+ """
458
+ )
459
+
460
+ results = {}
461
+ for row in cursor.fetchall():
462
+ results[row["run_name"]] = row["max_step"]
463
+
464
+ return results
465
+
466
+ @staticmethod
467
+ def store_config(project: str, run: str, config: dict) -> None:
468
+ """Store configuration for a run."""
469
+ db_path = SQLiteStorage.init_db(project)
470
+
471
+ with SQLiteStorage._get_process_lock(project):
472
+ with SQLiteStorage._get_connection(db_path) as conn:
473
+ cursor = conn.cursor()
474
+ current_timestamp = datetime.now().isoformat()
475
+
476
+ cursor.execute(
477
+ """
478
+ INSERT OR REPLACE INTO configs
479
+ (run_name, config, created_at)
480
+ VALUES (?, ?, ?)
481
+ """,
482
+ (run, json.dumps(serialize_values(config)), current_timestamp),
483
+ )
484
+ conn.commit()
485
+
486
+ @staticmethod
487
+ def get_run_config(project: str, run: str) -> dict | None:
488
+ """Get configuration for a specific run."""
489
+ db_path = SQLiteStorage.get_project_db_path(project)
490
+ if not db_path.exists():
491
+ return None
492
+
493
+ with SQLiteStorage._get_connection(db_path) as conn:
494
+ cursor = conn.cursor()
495
+ try:
496
+ cursor.execute(
497
+ """
498
+ SELECT config FROM configs WHERE run_name = ?
499
+ """,
500
+ (run,),
501
+ )
502
+
503
+ row = cursor.fetchone()
504
+ if row:
505
+ config = json.loads(row["config"])
506
+ return deserialize_values(config)
507
+ return None
508
+ except sqlite3.OperationalError as e:
509
+ if "no such table: configs" in str(e):
510
+ return None
511
+ raise
512
+
513
+ @staticmethod
514
+ def delete_run(project: str, run: str) -> bool:
515
+ """Delete a run from the database (both metrics and config)."""
516
+ db_path = SQLiteStorage.get_project_db_path(project)
517
+ if not db_path.exists():
518
+ return False
519
+
520
+ with SQLiteStorage._get_process_lock(project):
521
+ with SQLiteStorage._get_connection(db_path) as conn:
522
+ cursor = conn.cursor()
523
+ try:
524
+ cursor.execute("DELETE FROM metrics WHERE run_name = ?", (run,))
525
+ cursor.execute("DELETE FROM configs WHERE run_name = ?", (run,))
526
+ conn.commit()
527
+ return True
528
+ except sqlite3.Error:
529
+ return False
530
+
531
+ @staticmethod
532
+ def get_all_run_configs(project: str) -> dict[str, dict]:
533
+ """Get configurations for all runs in a project."""
534
+ db_path = SQLiteStorage.get_project_db_path(project)
535
+ if not db_path.exists():
536
+ return {}
537
+
538
+ with SQLiteStorage._get_connection(db_path) as conn:
539
+ cursor = conn.cursor()
540
+ try:
541
+ cursor.execute(
542
+ """
543
+ SELECT run_name, config FROM configs
544
+ """
545
+ )
546
+
547
+ results = {}
548
+ for row in cursor.fetchall():
549
+ config = json.loads(row["config"])
550
+ results[row["run_name"]] = deserialize_values(config)
551
+ return results
552
+ except sqlite3.OperationalError as e:
553
+ if "no such table: configs" in str(e):
554
+ return {}
555
+ raise
556
+
557
+ def finish(self):
558
+ """Cleanup when run is finished."""
559
+ pass
table.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Literal, Optional, Union
2
+
3
+ from pandas import DataFrame
4
+
5
+
6
+ class Table:
7
+ """
8
+ Initializes a Table object.
9
+
10
+ Args:
11
+ columns (`list[str]`, *optional*, defaults to `None`):
12
+ Names of the columns in the table. Optional if `data` is provided. Not
13
+ expected if `dataframe` is provided. Currently ignored.
14
+ data (`list[list[Any]]`, *optional*, defaults to `None`):
15
+ 2D row-oriented array of values.
16
+ dataframe (`pandas.`DataFrame``, *optional*, defaults to `None`):
17
+ DataFrame object used to create the table. When set, `data` and `columns`
18
+ arguments are ignored.
19
+ rows (`list[list[any]]`, *optional*, defaults to `None`):
20
+ Currently ignored.
21
+ optional (`bool` or `list[bool]`, *optional*, defaults to `True`):
22
+ Currently ignored.
23
+ allow_mixed_types (`bool`, *optional*, defaults to `False`):
24
+ Currently ignored.
25
+ log_mode: (`Literal["IMMUTABLE", "MUTABLE", "INCREMENTAL"]` or `None`, *optional*, defaults to `"IMMUTABLE"`):
26
+ Currently ignored.
27
+ """
28
+
29
+ TYPE = "trackio.table"
30
+
31
+ def __init__(
32
+ self,
33
+ columns: Optional[list[str]] = None,
34
+ data: Optional[list[list[Any]]] = None,
35
+ dataframe: Optional[DataFrame] = None,
36
+ rows: Optional[list[list[Any]]] = None,
37
+ optional: Union[bool, list[bool]] = True,
38
+ allow_mixed_types: bool = False,
39
+ log_mode: Optional[
40
+ Literal["IMMUTABLE", "MUTABLE", "INCREMENTAL"]
41
+ ] = "IMMUTABLE",
42
+ ):
43
+ # TODO: implement support for columns, dtype, optional, allow_mixed_types, and log_mode.
44
+ # for now (like `rows`) they are included for API compat but don't do anything.
45
+
46
+ if dataframe is None:
47
+ self.data = data
48
+ else:
49
+ self.data = dataframe.to_dict(orient="records")
50
+
51
+ def _to_dict(self):
52
+ return {
53
+ "_type": self.TYPE,
54
+ "_value": self.data,
55
+ }
typehints.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, TypedDict
2
+
3
+ from gradio import FileData
4
+
5
+
6
+ class LogEntry(TypedDict):
7
+ project: str
8
+ run: str
9
+ metrics: dict[str, Any]
10
+ step: int | None
11
+ config: dict[str, Any] | None
12
+
13
+
14
+ class UploadEntry(TypedDict):
15
+ project: str
16
+ run: str
17
+ step: int | None
18
+ uploaded_file: FileData
ui/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from trackio.ui.main import demo
3
+ from trackio.ui.runs import run_page
4
+ except ImportError:
5
+ from ui.main import demo
6
+ from ui.runs import run_page
7
+
8
+ __all__ = ["demo", "run_page"]
ui/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (413 Bytes). View file
 
ui/__pycache__/fns.cpython-312.pyc ADDED
Binary file (3.16 kB). View file
 
ui/__pycache__/main.cpython-312.pyc ADDED
Binary file (35.6 kB). View file
 
ui/__pycache__/run_detail.cpython-312.pyc ADDED
Binary file (4.04 kB). View file
 
ui/__pycache__/runs.cpython-312.pyc ADDED
Binary file (10.6 kB). View file
 
ui/fns.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared functions for the Trackio UI."""
2
+
3
+ import os
4
+
5
+ import gradio as gr
6
+
7
+ try:
8
+ import trackio.utils as utils
9
+ from trackio.sqlite_storage import SQLiteStorage
10
+ except ImportError:
11
+ import utils
12
+ from sqlite_storage import SQLiteStorage
13
+
14
+
15
+ def get_project_info() -> str | None:
16
+ dataset_id = os.environ.get("TRACKIO_DATASET_ID")
17
+ space_id = os.environ.get("SPACE_ID")
18
+ if utils.persistent_storage_enabled():
19
+ return "&#10024; Persistent Storage is enabled, logs are stored directly in this Space."
20
+ if dataset_id:
21
+ sync_status = utils.get_sync_status(SQLiteStorage.get_scheduler())
22
+ upgrade_message = f"New changes are synced every 5 min <span class='info-container'><input type='checkbox' class='info-checkbox' id='upgrade-info'><label for='upgrade-info' class='info-icon'>&#9432;</label><span class='info-expandable'> To avoid losing data between syncs, <a href='https://huggingface.co/spaces/{space_id}/settings' class='accent-link'>click here</a> to open this Space's settings and add Persistent Storage. Make sure data is synced prior to enabling.</span></span>"
23
+ if sync_status is not None:
24
+ info = f"&#x21bb; Backed up {sync_status} min ago to <a href='https://huggingface.co/datasets/{dataset_id}' target='_blank' class='accent-link'>{dataset_id}</a> | {upgrade_message}"
25
+ else:
26
+ info = f"&#x21bb; Not backed up yet to <a href='https://huggingface.co/datasets/{dataset_id}' target='_blank' class='accent-link'>{dataset_id}</a> | {upgrade_message}"
27
+ return info
28
+ return None
29
+
30
+
31
+ def get_projects(request: gr.Request):
32
+ projects = SQLiteStorage.get_projects()
33
+ if project := request.query_params.get("project"):
34
+ interactive = False
35
+ else:
36
+ interactive = True
37
+ if selected_project := request.query_params.get("selected_project"):
38
+ project = selected_project
39
+ else:
40
+ project = projects[0] if projects else None
41
+
42
+ return gr.Dropdown(
43
+ label="Project",
44
+ choices=projects,
45
+ value=project,
46
+ allow_custom_value=True,
47
+ interactive=interactive,
48
+ info=get_project_info(),
49
+ )
50
+
51
+
52
+ def update_navbar_value(project_dd):
53
+ return gr.Navbar(
54
+ value=[
55
+ ("Metrics", f"?selected_project={project_dd}"),
56
+ ("Runs", f"runs?selected_project={project_dd}"),
57
+ ]
58
+ )
ui/main.py ADDED
@@ -0,0 +1,937 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """The main page for the Trackio UI."""
2
+
3
+ import os
4
+ import re
5
+ import shutil
6
+ from dataclasses import dataclass
7
+ from typing import Any
8
+
9
+ import gradio as gr
10
+ import huggingface_hub as hf
11
+ import numpy as np
12
+ import pandas as pd
13
+
14
+ HfApi = hf.HfApi()
15
+
16
+ try:
17
+ import trackio.utils as utils
18
+ from trackio.file_storage import FileStorage
19
+ from trackio.media import TrackioImage, TrackioVideo
20
+ from trackio.sqlite_storage import SQLiteStorage
21
+ from trackio.table import Table
22
+ from trackio.typehints import LogEntry, UploadEntry
23
+ from trackio.ui import fns
24
+ from trackio.ui.run_detail import run_detail_page
25
+ from trackio.ui.runs import run_page
26
+ except ImportError:
27
+ import utils
28
+ from file_storage import FileStorage
29
+ from media import TrackioImage, TrackioVideo
30
+ from sqlite_storage import SQLiteStorage
31
+ from table import Table
32
+ from typehints import LogEntry, UploadEntry
33
+ from ui import fns
34
+ from ui.run_detail import run_detail_page
35
+ from ui.runs import run_page
36
+
37
+
38
+ def get_runs(project) -> list[str]:
39
+ if not project:
40
+ return []
41
+ return SQLiteStorage.get_runs(project)
42
+
43
+
44
+ def get_available_metrics(project: str, runs: list[str]) -> list[str]:
45
+ """Get all available metrics across all runs for x-axis selection."""
46
+ if not project or not runs:
47
+ return ["step", "time"]
48
+
49
+ all_metrics = set()
50
+ for run in runs:
51
+ metrics = SQLiteStorage.get_logs(project, run)
52
+ if metrics:
53
+ df = pd.DataFrame(metrics)
54
+ numeric_cols = df.select_dtypes(include="number").columns
55
+ numeric_cols = [c for c in numeric_cols if c not in utils.RESERVED_KEYS]
56
+ all_metrics.update(numeric_cols)
57
+
58
+ all_metrics.add("step")
59
+ all_metrics.add("time")
60
+
61
+ sorted_metrics = utils.sort_metrics_by_prefix(list(all_metrics))
62
+
63
+ result = ["step", "time"]
64
+ for metric in sorted_metrics:
65
+ if metric not in result:
66
+ result.append(metric)
67
+
68
+ return result
69
+
70
+
71
+ @dataclass
72
+ class MediaData:
73
+ caption: str | None
74
+ file_path: str
75
+
76
+
77
+ def extract_media(logs: list[dict]) -> dict[str, list[MediaData]]:
78
+ media_by_key: dict[str, list[MediaData]] = {}
79
+ logs = sorted(logs, key=lambda x: x.get("step", 0))
80
+ for log in logs:
81
+ for key, value in log.items():
82
+ if isinstance(value, dict):
83
+ type = value.get("_type")
84
+ if type == TrackioImage.TYPE or type == TrackioVideo.TYPE:
85
+ if key not in media_by_key:
86
+ media_by_key[key] = []
87
+ try:
88
+ media_data = MediaData(
89
+ file_path=utils.MEDIA_DIR / value.get("file_path"),
90
+ caption=value.get("caption"),
91
+ )
92
+ media_by_key[key].append(media_data)
93
+ except Exception as e:
94
+ print(f"Media currently unavailable: {key}: {e}")
95
+ return media_by_key
96
+
97
+
98
+ def load_run_data(
99
+ project: str | None,
100
+ run: str | None,
101
+ smoothing_granularity: int,
102
+ x_axis: str,
103
+ log_scale: bool = False,
104
+ ) -> tuple[pd.DataFrame, dict]:
105
+ if not project or not run:
106
+ return None, None
107
+
108
+ logs = SQLiteStorage.get_logs(project, run)
109
+ if not logs:
110
+ return None, None
111
+
112
+ media = extract_media(logs)
113
+ df = pd.DataFrame(logs)
114
+
115
+ if "step" not in df.columns:
116
+ df["step"] = range(len(df))
117
+
118
+ if x_axis == "time" and "timestamp" in df.columns:
119
+ df["timestamp"] = pd.to_datetime(df["timestamp"])
120
+ first_timestamp = df["timestamp"].min()
121
+ df["time"] = (df["timestamp"] - first_timestamp).dt.total_seconds()
122
+ x_column = "time"
123
+ elif x_axis == "step":
124
+ x_column = "step"
125
+ else:
126
+ x_column = x_axis
127
+
128
+ if log_scale and x_column in df.columns:
129
+ x_vals = df[x_column]
130
+ if (x_vals <= 0).any():
131
+ df[x_column] = np.log10(np.maximum(x_vals, 0) + 1)
132
+ else:
133
+ df[x_column] = np.log10(x_vals)
134
+
135
+ if smoothing_granularity > 0:
136
+ numeric_cols = df.select_dtypes(include="number").columns
137
+ numeric_cols = [c for c in numeric_cols if c not in utils.RESERVED_KEYS]
138
+
139
+ df_original = df.copy()
140
+ df_original["run"] = run
141
+ df_original["data_type"] = "original"
142
+
143
+ df_smoothed = df.copy()
144
+ window_size = max(3, min(smoothing_granularity, len(df)))
145
+ df_smoothed[numeric_cols] = (
146
+ df_smoothed[numeric_cols]
147
+ .rolling(window=window_size, center=True, min_periods=1)
148
+ .mean()
149
+ )
150
+ df_smoothed["run"] = f"{run}_smoothed"
151
+ df_smoothed["data_type"] = "smoothed"
152
+
153
+ combined_df = pd.concat([df_original, df_smoothed], ignore_index=True)
154
+ combined_df["x_axis"] = x_column
155
+ return combined_df, media
156
+ else:
157
+ df["run"] = run
158
+ df["data_type"] = "original"
159
+ df["x_axis"] = x_column
160
+ return df, media
161
+
162
+
163
+ def update_runs(
164
+ project, filter_text, user_interacted_with_runs=False, selected_runs_from_url=None
165
+ ):
166
+ if project is None:
167
+ runs = []
168
+ num_runs = 0
169
+ else:
170
+ runs = get_runs(project)
171
+ num_runs = len(runs)
172
+ if filter_text:
173
+ runs = [r for r in runs if filter_text in r]
174
+
175
+ if not user_interacted_with_runs:
176
+ if selected_runs_from_url:
177
+ value = [r for r in runs if r in selected_runs_from_url]
178
+ else:
179
+ value = runs
180
+ return gr.CheckboxGroup(choices=runs, value=value), gr.Textbox(
181
+ label=f"Runs ({num_runs})"
182
+ )
183
+ else:
184
+ return gr.CheckboxGroup(choices=runs), gr.Textbox(label=f"Runs ({num_runs})")
185
+
186
+
187
+ def filter_runs(project, filter_text):
188
+ runs = get_runs(project)
189
+ runs = [r for r in runs if filter_text in r]
190
+ return gr.CheckboxGroup(choices=runs, value=runs)
191
+
192
+
193
+ def update_x_axis_choices(project, runs):
194
+ """Update x-axis dropdown choices based on available metrics."""
195
+ available_metrics = get_available_metrics(project, runs)
196
+ return gr.Dropdown(
197
+ label="X-axis",
198
+ choices=available_metrics,
199
+ value="step",
200
+ )
201
+
202
+
203
+ def toggle_timer(cb_value):
204
+ if cb_value:
205
+ return gr.Timer(active=True)
206
+ else:
207
+ return gr.Timer(active=False)
208
+
209
+
210
+ def check_auth(hf_token: str | None) -> None:
211
+ if os.getenv("SYSTEM") == "spaces": # if we are running in Spaces
212
+ # check auth token passed in
213
+ if hf_token is None:
214
+ raise PermissionError(
215
+ "Expected a HF_TOKEN to be provided when logging to a Space"
216
+ )
217
+ who = HfApi.whoami(hf_token)
218
+ access_token = who["auth"]["accessToken"]
219
+ owner_name = os.getenv("SPACE_AUTHOR_NAME")
220
+ repo_name = os.getenv("SPACE_REPO_NAME")
221
+ # make sure the token user is either the author of the space,
222
+ # or is a member of an org that is the author.
223
+ orgs = [o["name"] for o in who["orgs"]]
224
+ if owner_name != who["name"] and owner_name not in orgs:
225
+ raise PermissionError(
226
+ "Expected the provided hf_token to be the user owner of the space, or be a member of the org owner of the space"
227
+ )
228
+ # reject fine-grained tokens without specific repo access
229
+ if access_token["role"] == "fineGrained":
230
+ matched = False
231
+ for item in access_token["fineGrained"]["scoped"]:
232
+ if (
233
+ item["entity"]["type"] == "space"
234
+ and item["entity"]["name"] == f"{owner_name}/{repo_name}"
235
+ and "repo.write" in item["permissions"]
236
+ ):
237
+ matched = True
238
+ break
239
+ if (
240
+ (
241
+ item["entity"]["type"] == "user"
242
+ or item["entity"]["type"] == "org"
243
+ )
244
+ and item["entity"]["name"] == owner_name
245
+ and "repo.write" in item["permissions"]
246
+ ):
247
+ matched = True
248
+ break
249
+ if not matched:
250
+ raise PermissionError(
251
+ "Expected the provided hf_token with fine grained permissions to provide write access to the space"
252
+ )
253
+ # reject read-only tokens
254
+ elif access_token["role"] != "write":
255
+ raise PermissionError(
256
+ "Expected the provided hf_token to provide write permissions"
257
+ )
258
+
259
+
260
+ def upload_db_to_space(
261
+ project: str, uploaded_db: gr.FileData, hf_token: str | None
262
+ ) -> None:
263
+ check_auth(hf_token)
264
+ db_project_path = SQLiteStorage.get_project_db_path(project)
265
+ if os.path.exists(db_project_path):
266
+ raise gr.Error(
267
+ f"Trackio database file already exists for project {project}, cannot overwrite."
268
+ )
269
+ os.makedirs(os.path.dirname(db_project_path), exist_ok=True)
270
+ shutil.copy(uploaded_db["path"], db_project_path)
271
+
272
+
273
+ def bulk_upload_media(uploads: list[UploadEntry], hf_token: str | None) -> None:
274
+ check_auth(hf_token)
275
+ for upload in uploads:
276
+ media_path = FileStorage.init_project_media_path(
277
+ upload["project"], upload["run"], upload["step"]
278
+ )
279
+ shutil.copy(upload["uploaded_file"]["path"], media_path)
280
+
281
+
282
+ def log(
283
+ project: str,
284
+ run: str,
285
+ metrics: dict[str, Any],
286
+ step: int | None,
287
+ hf_token: str | None,
288
+ ) -> None:
289
+ """
290
+ Note: this method is not used in the latest versions of Trackio (replaced by bulk_log) but
291
+ is kept for backwards compatibility for users who are connecting to a newer version of
292
+ a Trackio Spaces dashboard with an older version of Trackio installed locally.
293
+ """
294
+ check_auth(hf_token)
295
+ SQLiteStorage.log(project=project, run=run, metrics=metrics, step=step)
296
+
297
+
298
+ def bulk_log(
299
+ logs: list[LogEntry],
300
+ hf_token: str | None,
301
+ ) -> None:
302
+ check_auth(hf_token)
303
+
304
+ logs_by_run = {}
305
+ for log_entry in logs:
306
+ key = (log_entry["project"], log_entry["run"])
307
+ if key not in logs_by_run:
308
+ logs_by_run[key] = {"metrics": [], "steps": [], "config": None}
309
+ logs_by_run[key]["metrics"].append(log_entry["metrics"])
310
+ logs_by_run[key]["steps"].append(log_entry.get("step"))
311
+ if log_entry.get("config") and logs_by_run[key]["config"] is None:
312
+ logs_by_run[key]["config"] = log_entry["config"]
313
+
314
+ for (project, run), data in logs_by_run.items():
315
+ SQLiteStorage.bulk_log(
316
+ project=project,
317
+ run=run,
318
+ metrics_list=data["metrics"],
319
+ steps=data["steps"],
320
+ config=data["config"],
321
+ )
322
+
323
+
324
+ def filter_metrics_by_regex(metrics: list[str], filter_pattern: str) -> list[str]:
325
+ """
326
+ Filter metrics using regex pattern.
327
+
328
+ Args:
329
+ metrics: List of metric names to filter
330
+ filter_pattern: Regex pattern to match against metric names
331
+
332
+ Returns:
333
+ List of metric names that match the pattern
334
+ """
335
+ if not filter_pattern.strip():
336
+ return metrics
337
+
338
+ try:
339
+ pattern = re.compile(filter_pattern, re.IGNORECASE)
340
+ return [metric for metric in metrics if pattern.search(metric)]
341
+ except re.error:
342
+ return [
343
+ metric for metric in metrics if filter_pattern.lower() in metric.lower()
344
+ ]
345
+
346
+
347
+ def configure(request: gr.Request):
348
+ sidebar_param = request.query_params.get("sidebar")
349
+ match sidebar_param:
350
+ case "collapsed":
351
+ sidebar = gr.Sidebar(open=False, visible=True)
352
+ case "hidden":
353
+ sidebar = gr.Sidebar(open=False, visible=False)
354
+ case _:
355
+ sidebar = gr.Sidebar(open=True, visible=True)
356
+
357
+ metrics_param = request.query_params.get("metrics", "")
358
+ runs_param = request.query_params.get("runs", "")
359
+ selected_runs = runs_param.split(",") if runs_param else []
360
+ navbar_param = request.query_params.get("navbar")
361
+ match navbar_param:
362
+ case "hidden":
363
+ navbar = gr.Navbar(visible=False)
364
+ case _:
365
+ navbar = gr.Navbar(visible=True)
366
+
367
+ return [], sidebar, metrics_param, selected_runs, navbar
368
+
369
+
370
+ def create_media_section(media_by_run: dict[str, dict[str, list[MediaData]]]):
371
+ with gr.Accordion(label="media"):
372
+ with gr.Group(elem_classes=("media-group")):
373
+ for run, media_by_key in media_by_run.items():
374
+ with gr.Tab(label=run, elem_classes=("media-tab")):
375
+ for key, media_item in media_by_key.items():
376
+ gr.Gallery(
377
+ [(item.file_path, item.caption) for item in media_item],
378
+ label=key,
379
+ columns=6,
380
+ elem_classes=("media-gallery"),
381
+ )
382
+
383
+
384
+ css = """
385
+ #run-cb .wrap { gap: 2px; }
386
+ #run-cb .wrap label {
387
+ line-height: 1;
388
+ padding: 6px;
389
+ }
390
+ .logo-light { display: block; }
391
+ .logo-dark { display: none; }
392
+ .dark .logo-light { display: none; }
393
+ .dark .logo-dark { display: block; }
394
+ .dark .caption-label { color: white; }
395
+
396
+ .info-container {
397
+ position: relative;
398
+ display: inline;
399
+ }
400
+ .info-checkbox {
401
+ position: absolute;
402
+ opacity: 0;
403
+ pointer-events: none;
404
+ }
405
+ .info-icon {
406
+ border-bottom: 1px dotted;
407
+ cursor: pointer;
408
+ user-select: none;
409
+ color: var(--color-accent);
410
+ }
411
+ .info-expandable {
412
+ display: none;
413
+ opacity: 0;
414
+ transition: opacity 0.2s ease-in-out;
415
+ }
416
+ .info-checkbox:checked ~ .info-expandable {
417
+ display: inline;
418
+ opacity: 1;
419
+ }
420
+ .info-icon:hover { opacity: 0.8; }
421
+ .accent-link { font-weight: bold; }
422
+
423
+ .media-gallery .fixed-height { min-height: 275px; }
424
+ .media-group, .media-group > div { background: none; }
425
+ .media-group .tabs { padding: 0.5em; }
426
+ .media-tab { max-height: 500px; overflow-y: scroll; }
427
+ """
428
+
429
+ javascript = """
430
+ <script>
431
+ function setCookie(name, value, days) {
432
+ var expires = "";
433
+ if (days) {
434
+ var date = new Date();
435
+ date.setTime(date.getTime() + (days * 24 * 60 * 60 * 1000));
436
+ expires = "; expires=" + date.toUTCString();
437
+ }
438
+ document.cookie = name + "=" + (value || "") + expires + "; path=/; SameSite=Lax";
439
+ }
440
+
441
+ function getCookie(name) {
442
+ var nameEQ = name + "=";
443
+ var ca = document.cookie.split(';');
444
+ for(var i=0;i < ca.length;i++) {
445
+ var c = ca[i];
446
+ while (c.charAt(0)==' ') c = c.substring(1,c.length);
447
+ if (c.indexOf(nameEQ) == 0) return c.substring(nameEQ.length,c.length);
448
+ }
449
+ return null;
450
+ }
451
+
452
+ (function() {
453
+ const urlParams = new URLSearchParams(window.location.search);
454
+ const writeToken = urlParams.get('write_token');
455
+
456
+ if (writeToken) {
457
+ setCookie('trackio_write_token', writeToken, 7);
458
+
459
+ urlParams.delete('write_token');
460
+ const newUrl = window.location.pathname +
461
+ (urlParams.toString() ? '?' + urlParams.toString() : '') +
462
+ window.location.hash;
463
+ window.history.replaceState({}, document.title, newUrl);
464
+ }
465
+ })();
466
+ </script>
467
+ """
468
+
469
+
470
+ gr.set_static_paths(paths=[utils.MEDIA_DIR])
471
+
472
+ with gr.Blocks(title="Trackio Dashboard", css=css, head=javascript) as demo:
473
+ with gr.Sidebar(open=False) as sidebar:
474
+ logo = gr.Markdown(
475
+ f"""
476
+ <img src='/gradio_api/file={utils.TRACKIO_LOGO_DIR}/trackio_logo_type_light_transparent.png' width='80%' class='logo-light'>
477
+ <img src='/gradio_api/file={utils.TRACKIO_LOGO_DIR}/trackio_logo_type_dark_transparent.png' width='80%' class='logo-dark'>
478
+ """
479
+ )
480
+ project_dd = gr.Dropdown(label="Project", allow_custom_value=True)
481
+
482
+ embed_code = gr.Code(
483
+ label="Embed this view",
484
+ max_lines=2,
485
+ lines=2,
486
+ language="html",
487
+ visible=bool(os.environ.get("SPACE_HOST")),
488
+ )
489
+ run_tb = gr.Textbox(label="Runs", placeholder="Type to filter...")
490
+ run_cb = gr.CheckboxGroup(
491
+ label="Runs",
492
+ choices=[],
493
+ interactive=True,
494
+ elem_id="run-cb",
495
+ show_select_all=True,
496
+ )
497
+ gr.HTML("<hr>")
498
+ realtime_cb = gr.Checkbox(label="Refresh metrics realtime", value=True)
499
+ smoothing_slider = gr.Slider(
500
+ label="Smoothing Factor",
501
+ minimum=0,
502
+ maximum=20,
503
+ value=10,
504
+ step=1,
505
+ info="0 = no smoothing",
506
+ )
507
+ x_axis_dd = gr.Dropdown(
508
+ label="X-axis",
509
+ choices=["step", "time"],
510
+ value="step",
511
+ )
512
+ log_scale_cb = gr.Checkbox(label="Log scale X-axis", value=False)
513
+ metric_filter_tb = gr.Textbox(
514
+ label="Metric Filter (regex)",
515
+ placeholder="e.g., loss|ndcg@10|gpu",
516
+ value="",
517
+ info="Filter metrics using regex patterns. Leave empty to show all metrics.",
518
+ )
519
+
520
+ navbar = gr.Navbar(value=[("Metrics", ""), ("Runs", "/runs")], main_page_name=False)
521
+ timer = gr.Timer(value=1)
522
+ metrics_subset = gr.State([])
523
+ user_interacted_with_run_cb = gr.State(False)
524
+ selected_runs_from_url = gr.State([])
525
+
526
+ gr.on(
527
+ [demo.load],
528
+ fn=configure,
529
+ outputs=[
530
+ metrics_subset,
531
+ sidebar,
532
+ metric_filter_tb,
533
+ selected_runs_from_url,
534
+ navbar,
535
+ ],
536
+ queue=False,
537
+ api_name=False,
538
+ )
539
+ gr.on(
540
+ [demo.load],
541
+ fn=fns.get_projects,
542
+ outputs=project_dd,
543
+ show_progress="hidden",
544
+ queue=False,
545
+ api_name=False,
546
+ )
547
+ gr.on(
548
+ [timer.tick],
549
+ fn=update_runs,
550
+ inputs=[
551
+ project_dd,
552
+ run_tb,
553
+ user_interacted_with_run_cb,
554
+ selected_runs_from_url,
555
+ ],
556
+ outputs=[run_cb, run_tb],
557
+ show_progress="hidden",
558
+ api_name=False,
559
+ )
560
+ gr.on(
561
+ [timer.tick],
562
+ fn=lambda: gr.Dropdown(info=fns.get_project_info()),
563
+ outputs=[project_dd],
564
+ show_progress="hidden",
565
+ api_name=False,
566
+ )
567
+ gr.on(
568
+ [demo.load, project_dd.change],
569
+ fn=update_runs,
570
+ inputs=[project_dd, run_tb, gr.State(False), selected_runs_from_url],
571
+ outputs=[run_cb, run_tb],
572
+ show_progress="hidden",
573
+ queue=False,
574
+ api_name=False,
575
+ ).then(
576
+ fn=update_x_axis_choices,
577
+ inputs=[project_dd, run_cb],
578
+ outputs=x_axis_dd,
579
+ show_progress="hidden",
580
+ queue=False,
581
+ api_name=False,
582
+ ).then(
583
+ fn=utils.generate_embed_code,
584
+ inputs=[project_dd, metric_filter_tb, run_cb],
585
+ outputs=[embed_code],
586
+ show_progress="hidden",
587
+ api_name=False,
588
+ queue=False,
589
+ ).then(
590
+ fns.update_navbar_value,
591
+ inputs=[project_dd],
592
+ outputs=[navbar],
593
+ show_progress="hidden",
594
+ api_name=False,
595
+ queue=False,
596
+ )
597
+
598
+ gr.on(
599
+ [run_cb.input],
600
+ fn=update_x_axis_choices,
601
+ inputs=[project_dd, run_cb],
602
+ outputs=x_axis_dd,
603
+ show_progress="hidden",
604
+ queue=False,
605
+ api_name=False,
606
+ )
607
+ gr.on(
608
+ [metric_filter_tb.change, run_cb.change],
609
+ fn=utils.generate_embed_code,
610
+ inputs=[project_dd, metric_filter_tb, run_cb],
611
+ outputs=embed_code,
612
+ show_progress="hidden",
613
+ api_name=False,
614
+ queue=False,
615
+ )
616
+
617
+ realtime_cb.change(
618
+ fn=toggle_timer,
619
+ inputs=realtime_cb,
620
+ outputs=timer,
621
+ api_name=False,
622
+ queue=False,
623
+ )
624
+ run_cb.input(
625
+ fn=lambda: True,
626
+ outputs=user_interacted_with_run_cb,
627
+ api_name=False,
628
+ queue=False,
629
+ )
630
+ run_tb.input(
631
+ fn=filter_runs,
632
+ inputs=[project_dd, run_tb],
633
+ outputs=run_cb,
634
+ api_name=False,
635
+ queue=False,
636
+ )
637
+
638
+ gr.api(
639
+ fn=upload_db_to_space,
640
+ api_name="upload_db_to_space",
641
+ )
642
+ gr.api(
643
+ fn=bulk_upload_media,
644
+ api_name="bulk_upload_media",
645
+ )
646
+ gr.api(
647
+ fn=log,
648
+ api_name="log",
649
+ )
650
+ gr.api(
651
+ fn=bulk_log,
652
+ api_name="bulk_log",
653
+ )
654
+
655
+ x_lim = gr.State(None)
656
+ last_steps = gr.State({})
657
+
658
+ def update_x_lim(select_data: gr.SelectData):
659
+ return select_data.index
660
+
661
+ def update_last_steps(project):
662
+ """Check the last step for each run to detect when new data is available."""
663
+ if not project:
664
+ return {}
665
+ return SQLiteStorage.get_max_steps_for_runs(project)
666
+
667
+ timer.tick(
668
+ fn=update_last_steps,
669
+ inputs=[project_dd],
670
+ outputs=last_steps,
671
+ show_progress="hidden",
672
+ api_name=False,
673
+ )
674
+
675
+ @gr.render(
676
+ triggers=[
677
+ demo.load,
678
+ run_cb.change,
679
+ last_steps.change,
680
+ smoothing_slider.change,
681
+ x_lim.change,
682
+ x_axis_dd.change,
683
+ log_scale_cb.change,
684
+ metric_filter_tb.change,
685
+ ],
686
+ inputs=[
687
+ project_dd,
688
+ run_cb,
689
+ smoothing_slider,
690
+ metrics_subset,
691
+ x_lim,
692
+ x_axis_dd,
693
+ log_scale_cb,
694
+ metric_filter_tb,
695
+ ],
696
+ show_progress="hidden",
697
+ queue=False,
698
+ )
699
+ def update_dashboard(
700
+ project,
701
+ runs,
702
+ smoothing_granularity,
703
+ metrics_subset,
704
+ x_lim_value,
705
+ x_axis,
706
+ log_scale,
707
+ metric_filter,
708
+ ):
709
+ dfs = []
710
+ images_by_run = {}
711
+ original_runs = runs.copy()
712
+
713
+ for run in runs:
714
+ df, images_by_key = load_run_data(
715
+ project, run, smoothing_granularity, x_axis, log_scale
716
+ )
717
+ if df is not None:
718
+ dfs.append(df)
719
+ images_by_run[run] = images_by_key
720
+
721
+ if dfs:
722
+ if smoothing_granularity > 0:
723
+ original_dfs = []
724
+ smoothed_dfs = []
725
+ for df in dfs:
726
+ original_data = df[df["data_type"] == "original"]
727
+ smoothed_data = df[df["data_type"] == "smoothed"]
728
+ if not original_data.empty:
729
+ original_dfs.append(original_data)
730
+ if not smoothed_data.empty:
731
+ smoothed_dfs.append(smoothed_data)
732
+
733
+ all_dfs = original_dfs + smoothed_dfs
734
+ master_df = (
735
+ pd.concat(all_dfs, ignore_index=True) if all_dfs else pd.DataFrame()
736
+ )
737
+
738
+ else:
739
+ master_df = pd.concat(dfs, ignore_index=True)
740
+ else:
741
+ master_df = pd.DataFrame()
742
+
743
+ if master_df.empty:
744
+ return
745
+
746
+ x_column = "step"
747
+ if dfs and not dfs[0].empty and "x_axis" in dfs[0].columns:
748
+ x_column = dfs[0]["x_axis"].iloc[0]
749
+
750
+ numeric_cols = master_df.select_dtypes(include="number").columns
751
+ numeric_cols = [c for c in numeric_cols if c not in utils.RESERVED_KEYS]
752
+ if x_column and x_column in numeric_cols:
753
+ numeric_cols.remove(x_column)
754
+
755
+ if metrics_subset:
756
+ numeric_cols = [c for c in numeric_cols if c in metrics_subset]
757
+
758
+ if metric_filter and metric_filter.strip():
759
+ numeric_cols = filter_metrics_by_regex(list(numeric_cols), metric_filter)
760
+
761
+ nested_metric_groups = utils.group_metrics_with_subprefixes(list(numeric_cols))
762
+ color_map = utils.get_color_mapping(original_runs, smoothing_granularity > 0)
763
+
764
+ metric_idx = 0
765
+ for group_name in sorted(nested_metric_groups.keys()):
766
+ group_data = nested_metric_groups[group_name]
767
+
768
+ total_plot_count = sum(
769
+ 1
770
+ for m in group_data["direct_metrics"]
771
+ if not master_df.dropna(subset=[m]).empty
772
+ ) + sum(
773
+ sum(1 for m in metrics if not master_df.dropna(subset=[m]).empty)
774
+ for metrics in group_data["subgroups"].values()
775
+ )
776
+ group_label = (
777
+ f"{group_name} ({total_plot_count})"
778
+ if total_plot_count > 0
779
+ else group_name
780
+ )
781
+
782
+ with gr.Accordion(
783
+ label=group_label,
784
+ open=True,
785
+ key=f"accordion-{group_name}",
786
+ preserved_by_key=["value", "open"],
787
+ ):
788
+ if group_data["direct_metrics"]:
789
+ with gr.Draggable(
790
+ key=f"row-{group_name}-direct", orientation="row"
791
+ ):
792
+ for metric_name in group_data["direct_metrics"]:
793
+ metric_df = master_df.dropna(subset=[metric_name])
794
+ color = "run" if "run" in metric_df.columns else None
795
+ if not metric_df.empty:
796
+ plot = gr.LinePlot(
797
+ utils.downsample(
798
+ metric_df,
799
+ x_column,
800
+ metric_name,
801
+ color,
802
+ x_lim_value,
803
+ ),
804
+ x=x_column,
805
+ y=metric_name,
806
+ y_title=metric_name.split("/")[-1],
807
+ color=color,
808
+ color_map=color_map,
809
+ title=metric_name,
810
+ key=f"plot-{metric_idx}",
811
+ preserved_by_key=None,
812
+ x_lim=x_lim_value,
813
+ show_fullscreen_button=True,
814
+ min_width=400,
815
+ )
816
+ plot.select(
817
+ update_x_lim,
818
+ outputs=x_lim,
819
+ key=f"select-{metric_idx}",
820
+ )
821
+ plot.double_click(
822
+ lambda: None,
823
+ outputs=x_lim,
824
+ key=f"double-{metric_idx}",
825
+ )
826
+ metric_idx += 1
827
+
828
+ if group_data["subgroups"]:
829
+ for subgroup_name in sorted(group_data["subgroups"].keys()):
830
+ subgroup_metrics = group_data["subgroups"][subgroup_name]
831
+
832
+ subgroup_plot_count = sum(
833
+ 1
834
+ for m in subgroup_metrics
835
+ if not master_df.dropna(subset=[m]).empty
836
+ )
837
+ subgroup_label = (
838
+ f"{subgroup_name} ({subgroup_plot_count})"
839
+ if subgroup_plot_count > 0
840
+ else subgroup_name
841
+ )
842
+
843
+ with gr.Accordion(
844
+ label=subgroup_label,
845
+ open=True,
846
+ key=f"accordion-{group_name}-{subgroup_name}",
847
+ preserved_by_key=["value", "open"],
848
+ ):
849
+ with gr.Draggable(key=f"row-{group_name}-{subgroup_name}"):
850
+ for metric_name in subgroup_metrics:
851
+ metric_df = master_df.dropna(subset=[metric_name])
852
+ color = (
853
+ "run" if "run" in metric_df.columns else None
854
+ )
855
+ if not metric_df.empty:
856
+ plot = gr.LinePlot(
857
+ utils.downsample(
858
+ metric_df,
859
+ x_column,
860
+ metric_name,
861
+ color,
862
+ x_lim_value,
863
+ ),
864
+ x=x_column,
865
+ y=metric_name,
866
+ y_title=metric_name.split("/")[-1],
867
+ color=color,
868
+ color_map=color_map,
869
+ title=metric_name,
870
+ key=f"plot-{metric_idx}",
871
+ preserved_by_key=None,
872
+ x_lim=x_lim_value,
873
+ show_fullscreen_button=True,
874
+ min_width=400,
875
+ )
876
+ plot.select(
877
+ update_x_lim,
878
+ outputs=x_lim,
879
+ key=f"select-{metric_idx}",
880
+ )
881
+ plot.double_click(
882
+ lambda: None,
883
+ outputs=x_lim,
884
+ key=f"double-{metric_idx}",
885
+ )
886
+ metric_idx += 1
887
+ if images_by_run and any(any(images) for images in images_by_run.values()):
888
+ create_media_section(images_by_run)
889
+
890
+ table_cols = master_df.select_dtypes(include="object").columns
891
+ table_cols = [c for c in table_cols if c not in utils.RESERVED_KEYS]
892
+ if metrics_subset:
893
+ table_cols = [c for c in table_cols if c in metrics_subset]
894
+ if metric_filter and metric_filter.strip():
895
+ table_cols = filter_metrics_by_regex(list(table_cols), metric_filter)
896
+
897
+ actual_table_count = sum(
898
+ 1
899
+ for metric_name in table_cols
900
+ if not (metric_df := master_df.dropna(subset=[metric_name])).empty
901
+ and isinstance(value := metric_df[metric_name].iloc[-1], dict)
902
+ and value.get("_type") == Table.TYPE
903
+ )
904
+
905
+ if actual_table_count > 0:
906
+ with gr.Accordion(f"tables ({actual_table_count})", open=True):
907
+ with gr.Row(key="row"):
908
+ for metric_idx, metric_name in enumerate(table_cols):
909
+ metric_df = master_df.dropna(subset=[metric_name])
910
+ if not metric_df.empty:
911
+ value = metric_df[metric_name].iloc[-1]
912
+ if (
913
+ isinstance(value, dict)
914
+ and "_type" in value
915
+ and value["_type"] == Table.TYPE
916
+ ):
917
+ try:
918
+ df = pd.DataFrame(value["_value"])
919
+ gr.DataFrame(
920
+ df,
921
+ label=f"{metric_name} (latest)",
922
+ key=f"table-{metric_idx}",
923
+ wrap=True,
924
+ )
925
+ except Exception as e:
926
+ gr.Warning(
927
+ f"Column {metric_name} failed to render as a table: {e}"
928
+ )
929
+
930
+
931
+ with demo.route("Runs", show_in_navbar=False):
932
+ run_page.render()
933
+ with demo.route("Run", show_in_navbar=False):
934
+ run_detail_page.render()
935
+
936
+ if __name__ == "__main__":
937
+ demo.launch(allowed_paths=[utils.TRACKIO_LOGO_DIR], show_api=False, show_error=True)
ui/run_detail.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """The Runs page for the Trackio UI."""
2
+
3
+ import gradio as gr
4
+
5
+ try:
6
+ import trackio.utils as utils
7
+ from trackio.sqlite_storage import SQLiteStorage
8
+ from trackio.ui import fns
9
+ except ImportError:
10
+ import utils
11
+ from sqlite_storage import SQLiteStorage
12
+ from ui import fns
13
+
14
+ RUN_DETAILS_TEMPLATE = """
15
+ ## Run Details
16
+ * **Run Name:** `{run_name}`
17
+ * **Created:** {created} by {username}
18
+ """
19
+
20
+ with gr.Blocks() as run_detail_page:
21
+ with gr.Sidebar() as sidebar:
22
+ logo = gr.Markdown(
23
+ f"""
24
+ <img src='/gradio_api/file={utils.TRACKIO_LOGO_DIR}/trackio_logo_type_light_transparent.png' width='80%' class='logo-light'>
25
+ <img src='/gradio_api/file={utils.TRACKIO_LOGO_DIR}/trackio_logo_type_dark_transparent.png' width='80%' class='logo-dark'>
26
+ """
27
+ )
28
+ project_dd = gr.Dropdown(
29
+ label="Project", allow_custom_value=True, interactive=False
30
+ )
31
+ run_dd = gr.Dropdown(label="Run")
32
+
33
+ navbar = gr.Navbar(value=[("Metrics", ""), ("Runs", "/runs")], main_page_name=False)
34
+
35
+ run_details = gr.Markdown(RUN_DETAILS_TEMPLATE)
36
+
37
+ run_config = gr.JSON(label="Run Config")
38
+
39
+ def configure(request: gr.Request):
40
+ project = request.query_params.get("selected_project")
41
+ run = request.query_params.get("selected_run")
42
+ runs = SQLiteStorage.get_runs(project)
43
+ return project, gr.Dropdown(choices=runs, value=run)
44
+
45
+ def update_run_details(project, run):
46
+ config = SQLiteStorage.get_run_config(project, run)
47
+ if not config:
48
+ return gr.Markdown("No run details available"), {}
49
+
50
+ created = config.get("_Created", "Unknown")
51
+ if created != "Unknown":
52
+ created = utils.format_timestamp(created)
53
+
54
+ username = config.get("_Username", "Unknown")
55
+ if username and username != "None" and username != "Unknown":
56
+ username = f"[{username}](https://huggingface.co/{username})"
57
+
58
+ details_md = RUN_DETAILS_TEMPLATE.format(
59
+ run_name=run, created=created, username=username
60
+ )
61
+
62
+ config_display = {k: v for k, v in config.items() if not k.startswith("_")}
63
+
64
+ return gr.Markdown(details_md), config_display
65
+
66
+ gr.on(
67
+ [run_detail_page.load],
68
+ fn=configure,
69
+ outputs=[project_dd, run_dd],
70
+ show_progress="hidden",
71
+ queue=False,
72
+ api_name=False,
73
+ ).then(
74
+ fns.update_navbar_value,
75
+ inputs=[project_dd],
76
+ outputs=[navbar],
77
+ show_progress="hidden",
78
+ api_name=False,
79
+ queue=False,
80
+ )
81
+
82
+ gr.on(
83
+ [run_dd.change],
84
+ update_run_details,
85
+ inputs=[project_dd, run_dd],
86
+ outputs=[run_details, run_config],
87
+ show_progress="hidden",
88
+ api_name=False,
89
+ queue=False,
90
+ )
ui/runs.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """The Runs page for the Trackio UI."""
2
+
3
+ import re
4
+
5
+ import gradio as gr
6
+ import pandas as pd
7
+
8
+ try:
9
+ import trackio.utils as utils
10
+ from trackio.sqlite_storage import SQLiteStorage
11
+ from trackio.ui import fns
12
+ except ImportError:
13
+ import utils
14
+ from sqlite_storage import SQLiteStorage
15
+ from ui import fns
16
+
17
+
18
+ def get_runs_data(project):
19
+ """Get the runs data as a pandas DataFrame."""
20
+ configs = SQLiteStorage.get_all_run_configs(project)
21
+ if not configs:
22
+ return pd.DataFrame()
23
+
24
+ df = pd.DataFrame.from_dict(configs, orient="index")
25
+ df = df.fillna("")
26
+ df.index.name = "Name"
27
+ df.reset_index(inplace=True)
28
+
29
+ column_mapping = {"_Username": "Username", "_Created": "Created"}
30
+ df.rename(columns=column_mapping, inplace=True)
31
+
32
+ if "Created" in df.columns:
33
+ df["Created"] = df["Created"].apply(utils.format_timestamp)
34
+
35
+ if "Username" in df.columns:
36
+ df["Username"] = df["Username"].apply(
37
+ lambda x: f"<a href='https://huggingface.co/{x}' style='text-decoration-style: dotted;'>{x}</a>"
38
+ if x and x != "None"
39
+ else x
40
+ )
41
+
42
+ if "Name" in df.columns:
43
+ df["Name"] = df["Name"].apply(
44
+ lambda x: f"<a href='/run?selected_project={project}&selected_run={x}'>{x}</a>"
45
+ if x and x != "None"
46
+ else x
47
+ )
48
+
49
+ df.insert(0, " ", False)
50
+
51
+ columns = list(df.columns)
52
+ if "Username" in columns and "Created" in columns:
53
+ columns.remove("Username")
54
+ columns.remove("Created")
55
+ columns.insert(2, "Username")
56
+ columns.insert(3, "Created")
57
+ df = df[columns]
58
+
59
+ return df
60
+
61
+
62
+ def get_runs_table(project):
63
+ df = get_runs_data(project)
64
+ if df.empty:
65
+ return gr.DataFrame(pd.DataFrame(), visible=False)
66
+
67
+ datatype = ["bool"] + ["markdown"] * (len(df.columns) - 1)
68
+
69
+ return gr.DataFrame(
70
+ df,
71
+ visible=True,
72
+ pinned_columns=2,
73
+ datatype=datatype,
74
+ wrap=True,
75
+ column_widths=["40px", "150px"],
76
+ interactive=True,
77
+ static_columns=list(range(1, len(df.columns))),
78
+ row_count=(len(df), "fixed"),
79
+ col_count=(len(df.columns), "fixed"),
80
+ )
81
+
82
+
83
+ def check_write_access_runs(request: gr.Request, write_token: str) -> bool:
84
+ """Check if the user has write access based on token validation."""
85
+ cookies = request.headers.get("cookie", "")
86
+ if cookies:
87
+ for cookie in cookies.split(";"):
88
+ parts = cookie.strip().split("=")
89
+ if len(parts) == 2 and parts[0] == "trackio_write_token":
90
+ return parts[1] == write_token
91
+ if hasattr(request, "query_params") and request.query_params:
92
+ token = request.query_params.get("write_token")
93
+ return token == write_token
94
+ return False
95
+
96
+
97
+ def update_delete_button(runs_data, request: gr.Request):
98
+ """Update the delete button value and interactivity based on the runs data and user write access."""
99
+ if not check_write_access_runs(request, run_page.write_token):
100
+ return gr.Button("⚠️ Need write access to delete runs", interactive=False)
101
+
102
+ num_selected = 0
103
+ if runs_data is not None and len(runs_data) > 0:
104
+ first_column_values = runs_data.iloc[:, 0].tolist()
105
+ num_selected = sum(1 for x in first_column_values if x)
106
+
107
+ if num_selected:
108
+ return gr.Button(f"Delete {num_selected} selected run(s)", interactive=True)
109
+ else:
110
+ return gr.Button("Select runs to delete", interactive=False)
111
+
112
+
113
+ def delete_selected_runs(runs_data, project, request: gr.Request):
114
+ """Delete the selected runs and refresh the table."""
115
+ if not check_write_access_runs(request, run_page.write_token):
116
+ return runs_data
117
+
118
+ first_column_values = runs_data.iloc[:, 0].tolist()
119
+ for i, selected in enumerate(first_column_values):
120
+ if selected:
121
+ run_name_raw = runs_data.iloc[i, 1]
122
+ match = re.search(r">([^<]+)<", run_name_raw)
123
+ run_name = match.group(1) if match else run_name_raw
124
+ SQLiteStorage.delete_run(project, run_name)
125
+
126
+ updated_data = get_runs_data(project)
127
+ return updated_data
128
+
129
+
130
+ with gr.Blocks() as run_page:
131
+ with gr.Sidebar() as sidebar:
132
+ logo = gr.Markdown(
133
+ f"""
134
+ <img src='/gradio_api/file={utils.TRACKIO_LOGO_DIR}/trackio_logo_type_light_transparent.png' width='80%' class='logo-light'>
135
+ <img src='/gradio_api/file={utils.TRACKIO_LOGO_DIR}/trackio_logo_type_dark_transparent.png' width='80%' class='logo-dark'>
136
+ """
137
+ )
138
+ project_dd = gr.Dropdown(label="Project", allow_custom_value=True)
139
+
140
+ navbar = gr.Navbar(value=[("Metrics", ""), ("Runs", "/runs")], main_page_name=False)
141
+ timer = gr.Timer(value=1)
142
+ with gr.Row():
143
+ with gr.Column():
144
+ pass
145
+ with gr.Column():
146
+ with gr.Row():
147
+ delete_run_btn = gr.Button(
148
+ "⚠️ Need write access to delete runs",
149
+ interactive=False,
150
+ variant="stop",
151
+ size="sm",
152
+ )
153
+ confirm_btn = gr.Button(
154
+ "Confirm delete", variant="stop", size="sm", visible=False
155
+ )
156
+ cancel_btn = gr.Button("Cancel", size="sm", visible=False)
157
+
158
+ runs_table = gr.DataFrame()
159
+
160
+ gr.on(
161
+ [run_page.load],
162
+ fn=fns.get_projects,
163
+ outputs=project_dd,
164
+ show_progress="hidden",
165
+ queue=False,
166
+ api_name=False,
167
+ )
168
+ gr.on(
169
+ [timer.tick],
170
+ fn=lambda: gr.Dropdown(info=fns.get_project_info()),
171
+ outputs=[project_dd],
172
+ show_progress="hidden",
173
+ api_name=False,
174
+ )
175
+ gr.on(
176
+ [project_dd.change],
177
+ fn=get_runs_table,
178
+ inputs=[project_dd],
179
+ outputs=[runs_table],
180
+ show_progress="hidden",
181
+ api_name=False,
182
+ queue=False,
183
+ ).then(
184
+ fns.update_navbar_value,
185
+ inputs=[project_dd],
186
+ outputs=[navbar],
187
+ show_progress="hidden",
188
+ api_name=False,
189
+ queue=False,
190
+ )
191
+
192
+ gr.on(
193
+ [run_page.load, runs_table.change],
194
+ fn=update_delete_button,
195
+ inputs=[runs_table],
196
+ outputs=[delete_run_btn],
197
+ show_progress="hidden",
198
+ api_name=False,
199
+ queue=False,
200
+ )
201
+
202
+ gr.on(
203
+ [delete_run_btn.click],
204
+ fn=lambda: [
205
+ gr.Button(visible=False),
206
+ gr.Button(visible=True),
207
+ gr.Button(visible=True),
208
+ ],
209
+ inputs=None,
210
+ outputs=[delete_run_btn, confirm_btn, cancel_btn],
211
+ show_progress="hidden",
212
+ api_name=False,
213
+ queue=False,
214
+ )
215
+ gr.on(
216
+ [confirm_btn.click, cancel_btn.click],
217
+ fn=lambda: [
218
+ gr.Button(visible=True),
219
+ gr.Button(visible=False),
220
+ gr.Button(visible=False),
221
+ ],
222
+ inputs=None,
223
+ outputs=[delete_run_btn, confirm_btn, cancel_btn],
224
+ show_progress="hidden",
225
+ api_name=False,
226
+ queue=False,
227
+ )
228
+ gr.on(
229
+ [confirm_btn.click],
230
+ fn=delete_selected_runs,
231
+ inputs=[runs_table, project_dd],
232
+ outputs=[runs_table],
233
+ show_progress="hidden",
234
+ api_name=False,
235
+ queue=False,
236
+ )
utils.py ADDED
@@ -0,0 +1,733 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import re
4
+ import time
5
+ from datetime import datetime, timezone
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING
8
+
9
+ import huggingface_hub
10
+ import numpy as np
11
+ import pandas as pd
12
+ from huggingface_hub.constants import HF_HOME
13
+
14
+ if TYPE_CHECKING:
15
+ from trackio.commit_scheduler import CommitScheduler
16
+ from trackio.dummy_commit_scheduler import DummyCommitScheduler
17
+
18
+ RESERVED_KEYS = ["project", "run", "timestamp", "step", "time", "metrics"]
19
+
20
+ TRACKIO_LOGO_DIR = Path(__file__).parent / "assets"
21
+
22
+
23
+ def persistent_storage_enabled() -> bool:
24
+ return (
25
+ os.environ.get("PERSISTANT_STORAGE_ENABLED") == "true"
26
+ ) # typo in the name of the environment variable
27
+
28
+
29
+ def _get_trackio_dir() -> Path:
30
+ if persistent_storage_enabled():
31
+ return Path("/data/trackio")
32
+ return Path(HF_HOME) / "trackio"
33
+
34
+
35
+ TRACKIO_DIR = _get_trackio_dir()
36
+ MEDIA_DIR = TRACKIO_DIR / "media"
37
+
38
+
39
+ def generate_readable_name(used_names: list[str], space_id: str | None = None) -> str:
40
+ """
41
+ Generates a random, readable name like "dainty-sunset-0".
42
+ If space_id is provided, generates username-timestamp format instead.
43
+ """
44
+ if space_id is not None:
45
+ username = huggingface_hub.whoami()["name"]
46
+ timestamp = int(time.time())
47
+ return f"{username}-{timestamp}"
48
+ adjectives = [
49
+ "dainty",
50
+ "brave",
51
+ "calm",
52
+ "eager",
53
+ "fancy",
54
+ "gentle",
55
+ "happy",
56
+ "jolly",
57
+ "kind",
58
+ "lively",
59
+ "merry",
60
+ "nice",
61
+ "proud",
62
+ "quick",
63
+ "hugging",
64
+ "silly",
65
+ "tidy",
66
+ "witty",
67
+ "zealous",
68
+ "bright",
69
+ "shy",
70
+ "bold",
71
+ "clever",
72
+ "daring",
73
+ "elegant",
74
+ "faithful",
75
+ "graceful",
76
+ "honest",
77
+ "inventive",
78
+ "jovial",
79
+ "keen",
80
+ "lucky",
81
+ "modest",
82
+ "noble",
83
+ "optimistic",
84
+ "patient",
85
+ "quirky",
86
+ "resourceful",
87
+ "sincere",
88
+ "thoughtful",
89
+ "upbeat",
90
+ "valiant",
91
+ "warm",
92
+ "youthful",
93
+ "zesty",
94
+ "adventurous",
95
+ "breezy",
96
+ "cheerful",
97
+ "delightful",
98
+ "energetic",
99
+ "fearless",
100
+ "glad",
101
+ "hopeful",
102
+ "imaginative",
103
+ "joyful",
104
+ "kindly",
105
+ "luminous",
106
+ "mysterious",
107
+ "neat",
108
+ "outgoing",
109
+ "playful",
110
+ "radiant",
111
+ "spirited",
112
+ "tranquil",
113
+ "unique",
114
+ "vivid",
115
+ "wise",
116
+ "zany",
117
+ "artful",
118
+ "bubbly",
119
+ "charming",
120
+ "dazzling",
121
+ "earnest",
122
+ "festive",
123
+ "gentlemanly",
124
+ "hearty",
125
+ "intrepid",
126
+ "jubilant",
127
+ "knightly",
128
+ "lively",
129
+ "magnetic",
130
+ "nimble",
131
+ "orderly",
132
+ "peaceful",
133
+ "quick-witted",
134
+ "robust",
135
+ "sturdy",
136
+ "trusty",
137
+ "upstanding",
138
+ "vibrant",
139
+ "whimsical",
140
+ ]
141
+ nouns = [
142
+ "sunset",
143
+ "forest",
144
+ "river",
145
+ "mountain",
146
+ "breeze",
147
+ "meadow",
148
+ "ocean",
149
+ "valley",
150
+ "sky",
151
+ "field",
152
+ "cloud",
153
+ "star",
154
+ "rain",
155
+ "leaf",
156
+ "stone",
157
+ "flower",
158
+ "bird",
159
+ "tree",
160
+ "wave",
161
+ "trail",
162
+ "island",
163
+ "desert",
164
+ "hill",
165
+ "lake",
166
+ "pond",
167
+ "grove",
168
+ "canyon",
169
+ "reef",
170
+ "bay",
171
+ "peak",
172
+ "glade",
173
+ "marsh",
174
+ "cliff",
175
+ "dune",
176
+ "spring",
177
+ "brook",
178
+ "cave",
179
+ "plain",
180
+ "ridge",
181
+ "wood",
182
+ "blossom",
183
+ "petal",
184
+ "root",
185
+ "branch",
186
+ "seed",
187
+ "acorn",
188
+ "pine",
189
+ "willow",
190
+ "cedar",
191
+ "elm",
192
+ "falcon",
193
+ "eagle",
194
+ "sparrow",
195
+ "robin",
196
+ "owl",
197
+ "finch",
198
+ "heron",
199
+ "crane",
200
+ "duck",
201
+ "swan",
202
+ "fox",
203
+ "wolf",
204
+ "bear",
205
+ "deer",
206
+ "moose",
207
+ "otter",
208
+ "beaver",
209
+ "lynx",
210
+ "hare",
211
+ "badger",
212
+ "butterfly",
213
+ "bee",
214
+ "ant",
215
+ "beetle",
216
+ "dragonfly",
217
+ "firefly",
218
+ "ladybug",
219
+ "moth",
220
+ "spider",
221
+ "worm",
222
+ "coral",
223
+ "kelp",
224
+ "shell",
225
+ "pebble",
226
+ "face",
227
+ "boulder",
228
+ "cobble",
229
+ "sand",
230
+ "wavelet",
231
+ "tide",
232
+ "current",
233
+ "mist",
234
+ ]
235
+ number = 0
236
+ name = f"{adjectives[0]}-{nouns[0]}-{number}"
237
+ while name in used_names:
238
+ number += 1
239
+ adjective = adjectives[number % len(adjectives)]
240
+ noun = nouns[number % len(nouns)]
241
+ name = f"{adjective}-{noun}-{number}"
242
+ return name
243
+
244
+
245
+ def is_in_notebook():
246
+ """
247
+ Detect if code is running in a notebook environment (Jupyter, Colab, etc.).
248
+ """
249
+ try:
250
+ from IPython import get_ipython
251
+
252
+ if get_ipython() is not None:
253
+ return get_ipython().__class__.__name__ in [
254
+ "ZMQInteractiveShell", # Jupyter notebook/lab
255
+ "Shell", # IPython terminal
256
+ ] or "google.colab" in str(get_ipython())
257
+ except ImportError:
258
+ pass
259
+ return False
260
+
261
+
262
+ def block_except_in_notebook():
263
+ if is_in_notebook():
264
+ return
265
+ try:
266
+ while True:
267
+ time.sleep(0.1)
268
+ except (KeyboardInterrupt, OSError):
269
+ print("Keyboard interruption in main thread... closing dashboard.")
270
+
271
+
272
+ def simplify_column_names(columns: list[str]) -> dict[str, str]:
273
+ """
274
+ Simplifies column names to first 10 alphanumeric or "/" characters with unique suffixes.
275
+
276
+ Args:
277
+ columns: List of original column names
278
+
279
+ Returns:
280
+ Dictionary mapping original column names to simplified names
281
+ """
282
+ simplified_names = {}
283
+ used_names = set()
284
+
285
+ for col in columns:
286
+ alphanumeric = re.sub(r"[^a-zA-Z0-9/]", "", col)
287
+ base_name = alphanumeric[:10] if alphanumeric else f"col_{len(used_names)}"
288
+
289
+ final_name = base_name
290
+ suffix = 1
291
+ while final_name in used_names:
292
+ final_name = f"{base_name}_{suffix}"
293
+ suffix += 1
294
+
295
+ simplified_names[col] = final_name
296
+ used_names.add(final_name)
297
+
298
+ return simplified_names
299
+
300
+
301
+ def print_dashboard_instructions(project: str) -> None:
302
+ """
303
+ Prints instructions for viewing the Trackio dashboard.
304
+
305
+ Args:
306
+ project: The name of the project to show dashboard for.
307
+ """
308
+ YELLOW = "\033[93m"
309
+ BOLD = "\033[1m"
310
+ RESET = "\033[0m"
311
+
312
+ print("* View dashboard by running in your terminal:")
313
+ print(f'{BOLD}{YELLOW}trackio show --project "{project}"{RESET}')
314
+ print(f'* or by running in Python: trackio.show(project="{project}")')
315
+
316
+
317
+ def preprocess_space_and_dataset_ids(
318
+ space_id: str | None, dataset_id: str | None
319
+ ) -> tuple[str | None, str | None]:
320
+ if space_id is not None and "/" not in space_id:
321
+ username = huggingface_hub.whoami()["name"]
322
+ space_id = f"{username}/{space_id}"
323
+ if dataset_id is not None and "/" not in dataset_id:
324
+ username = huggingface_hub.whoami()["name"]
325
+ dataset_id = f"{username}/{dataset_id}"
326
+ if space_id is not None and dataset_id is None:
327
+ dataset_id = f"{space_id}-dataset"
328
+ return space_id, dataset_id
329
+
330
+
331
+ def fibo():
332
+ """Generator for Fibonacci backoff: 1, 1, 2, 3, 5, 8, ..."""
333
+ a, b = 1, 1
334
+ while True:
335
+ yield a
336
+ a, b = b, a + b
337
+
338
+
339
+ def format_timestamp(timestamp_str):
340
+ """Convert ISO timestamp to human-readable format like '3 minutes ago'."""
341
+ if not timestamp_str or pd.isna(timestamp_str):
342
+ return "Unknown"
343
+
344
+ try:
345
+ created_time = datetime.fromisoformat(timestamp_str.replace("Z", "+00:00"))
346
+ if created_time.tzinfo is None:
347
+ created_time = created_time.replace(tzinfo=timezone.utc)
348
+
349
+ now = datetime.now(timezone.utc)
350
+ diff = now - created_time
351
+
352
+ seconds = int(diff.total_seconds())
353
+ if seconds < 60:
354
+ return "Just now"
355
+ elif seconds < 3600:
356
+ minutes = seconds // 60
357
+ return f"{minutes} minute{'s' if minutes != 1 else ''} ago"
358
+ elif seconds < 86400:
359
+ hours = seconds // 3600
360
+ return f"{hours} hour{'s' if hours != 1 else ''} ago"
361
+ else:
362
+ days = seconds // 86400
363
+ return f"{days} day{'s' if days != 1 else ''} ago"
364
+ except Exception:
365
+ return "Unknown"
366
+
367
+
368
+ COLOR_PALETTE = [
369
+ "#3B82F6",
370
+ "#EF4444",
371
+ "#10B981",
372
+ "#F59E0B",
373
+ "#8B5CF6",
374
+ "#EC4899",
375
+ "#06B6D4",
376
+ "#84CC16",
377
+ "#F97316",
378
+ "#6366F1",
379
+ ]
380
+
381
+
382
+ def get_color_mapping(runs: list[str], smoothing: bool) -> dict[str, str]:
383
+ """Generate color mapping for runs, with transparency for original data when smoothing is enabled."""
384
+ color_map = {}
385
+
386
+ for i, run in enumerate(runs):
387
+ base_color = COLOR_PALETTE[i % len(COLOR_PALETTE)]
388
+
389
+ if smoothing:
390
+ color_map[run] = base_color + "4D"
391
+ color_map[f"{run}_smoothed"] = base_color
392
+ else:
393
+ color_map[run] = base_color
394
+
395
+ return color_map
396
+
397
+
398
+ def downsample(
399
+ df: pd.DataFrame,
400
+ x: str,
401
+ y: str,
402
+ color: str | None,
403
+ x_lim: tuple[float, float] | None = None,
404
+ ) -> pd.DataFrame:
405
+ if df.empty:
406
+ return df
407
+
408
+ columns_to_keep = [x, y]
409
+ if color is not None and color in df.columns:
410
+ columns_to_keep.append(color)
411
+ df = df[columns_to_keep].copy()
412
+
413
+ n_bins = 100
414
+
415
+ if color is not None and color in df.columns:
416
+ groups = df.groupby(color)
417
+ else:
418
+ groups = [(None, df)]
419
+
420
+ downsampled_indices = []
421
+
422
+ for _, group_df in groups:
423
+ if group_df.empty:
424
+ continue
425
+
426
+ group_df = group_df.sort_values(x)
427
+
428
+ if x_lim is not None:
429
+ x_min, x_max = x_lim
430
+ before_point = group_df[group_df[x] < x_min].tail(1)
431
+ after_point = group_df[group_df[x] > x_max].head(1)
432
+ group_df = group_df[(group_df[x] >= x_min) & (group_df[x] <= x_max)]
433
+ else:
434
+ before_point = after_point = None
435
+ x_min = group_df[x].min()
436
+ x_max = group_df[x].max()
437
+
438
+ if before_point is not None and not before_point.empty:
439
+ downsampled_indices.extend(before_point.index.tolist())
440
+ if after_point is not None and not after_point.empty:
441
+ downsampled_indices.extend(after_point.index.tolist())
442
+
443
+ if group_df.empty:
444
+ continue
445
+
446
+ if x_min == x_max:
447
+ min_y_idx = group_df[y].idxmin()
448
+ max_y_idx = group_df[y].idxmax()
449
+ if min_y_idx != max_y_idx:
450
+ downsampled_indices.extend([min_y_idx, max_y_idx])
451
+ else:
452
+ downsampled_indices.append(min_y_idx)
453
+ continue
454
+
455
+ if len(group_df) < 500:
456
+ downsampled_indices.extend(group_df.index.tolist())
457
+ continue
458
+
459
+ bins = np.linspace(x_min, x_max, n_bins + 1)
460
+ group_df["bin"] = pd.cut(
461
+ group_df[x], bins=bins, labels=False, include_lowest=True
462
+ )
463
+
464
+ for bin_idx in group_df["bin"].dropna().unique():
465
+ bin_data = group_df[group_df["bin"] == bin_idx]
466
+ if bin_data.empty:
467
+ continue
468
+
469
+ min_y_idx = bin_data[y].idxmin()
470
+ max_y_idx = bin_data[y].idxmax()
471
+
472
+ downsampled_indices.append(min_y_idx)
473
+ if min_y_idx != max_y_idx:
474
+ downsampled_indices.append(max_y_idx)
475
+
476
+ unique_indices = list(set(downsampled_indices))
477
+
478
+ downsampled_df = df.loc[unique_indices].copy()
479
+
480
+ if color is not None:
481
+ downsampled_df = (
482
+ downsampled_df.groupby(color, sort=False)[downsampled_df.columns]
483
+ .apply(lambda group: group.sort_values(x))
484
+ .reset_index(drop=True)
485
+ )
486
+ else:
487
+ downsampled_df = downsampled_df.sort_values(x).reset_index(drop=True)
488
+
489
+ downsampled_df = downsampled_df.drop(columns=["bin"], errors="ignore")
490
+
491
+ return downsampled_df
492
+
493
+
494
+ def sort_metrics_by_prefix(metrics: list[str]) -> list[str]:
495
+ """
496
+ Sort metrics by grouping prefixes together for dropdown/list display.
497
+ Metrics without prefixes come first, then grouped by prefix.
498
+
499
+ Args:
500
+ metrics: List of metric names
501
+
502
+ Returns:
503
+ List of metric names sorted by prefix
504
+
505
+ Example:
506
+ Input: ["train/loss", "loss", "train/acc", "val/loss"]
507
+ Output: ["loss", "train/acc", "train/loss", "val/loss"]
508
+ """
509
+ groups = group_metrics_by_prefix(metrics)
510
+ result = []
511
+
512
+ if "charts" in groups:
513
+ result.extend(groups["charts"])
514
+
515
+ for group_name in sorted(groups.keys()):
516
+ if group_name != "charts":
517
+ result.extend(groups[group_name])
518
+
519
+ return result
520
+
521
+
522
+ def group_metrics_by_prefix(metrics: list[str]) -> dict[str, list[str]]:
523
+ """
524
+ Group metrics by their prefix. Metrics without prefix go to 'charts' group.
525
+
526
+ Args:
527
+ metrics: List of metric names
528
+
529
+ Returns:
530
+ Dictionary with prefix names as keys and lists of metrics as values
531
+
532
+ Example:
533
+ Input: ["loss", "accuracy", "train/loss", "train/acc", "val/loss"]
534
+ Output: {
535
+ "charts": ["loss", "accuracy"],
536
+ "train": ["train/loss", "train/acc"],
537
+ "val": ["val/loss"]
538
+ }
539
+ """
540
+ no_prefix = []
541
+ with_prefix = []
542
+
543
+ for metric in metrics:
544
+ if "/" in metric:
545
+ with_prefix.append(metric)
546
+ else:
547
+ no_prefix.append(metric)
548
+
549
+ no_prefix.sort()
550
+
551
+ prefix_groups = {}
552
+ for metric in with_prefix:
553
+ prefix = metric.split("/")[0]
554
+ if prefix not in prefix_groups:
555
+ prefix_groups[prefix] = []
556
+ prefix_groups[prefix].append(metric)
557
+
558
+ for prefix in prefix_groups:
559
+ prefix_groups[prefix].sort()
560
+
561
+ groups = {}
562
+ if no_prefix:
563
+ groups["charts"] = no_prefix
564
+
565
+ for prefix in sorted(prefix_groups.keys()):
566
+ groups[prefix] = prefix_groups[prefix]
567
+
568
+ return groups
569
+
570
+
571
+ def group_metrics_with_subprefixes(metrics: list[str]) -> dict:
572
+ """
573
+ Group metrics with simple 2-level nested structure detection.
574
+
575
+ Returns a dictionary where each prefix group can have:
576
+ - direct_metrics: list of metrics at this level (e.g., "train/acc")
577
+ - subgroups: dict of subgroup name -> list of metrics (e.g., "loss" -> ["train/loss/norm", "train/loss/unnorm"])
578
+
579
+ Example:
580
+ Input: ["loss", "train/acc", "train/loss/normalized", "train/loss/unnormalized", "val/loss"]
581
+ Output: {
582
+ "charts": {
583
+ "direct_metrics": ["loss"],
584
+ "subgroups": {}
585
+ },
586
+ "train": {
587
+ "direct_metrics": ["train/acc"],
588
+ "subgroups": {
589
+ "loss": ["train/loss/normalized", "train/loss/unnormalized"]
590
+ }
591
+ },
592
+ "val": {
593
+ "direct_metrics": ["val/loss"],
594
+ "subgroups": {}
595
+ }
596
+ }
597
+ """
598
+ result = {}
599
+
600
+ for metric in metrics:
601
+ if "/" not in metric:
602
+ if "charts" not in result:
603
+ result["charts"] = {"direct_metrics": [], "subgroups": {}}
604
+ result["charts"]["direct_metrics"].append(metric)
605
+ else:
606
+ parts = metric.split("/")
607
+ main_prefix = parts[0]
608
+
609
+ if main_prefix not in result:
610
+ result[main_prefix] = {"direct_metrics": [], "subgroups": {}}
611
+
612
+ if len(parts) == 2:
613
+ result[main_prefix]["direct_metrics"].append(metric)
614
+ else:
615
+ subprefix = parts[1]
616
+ if subprefix not in result[main_prefix]["subgroups"]:
617
+ result[main_prefix]["subgroups"][subprefix] = []
618
+ result[main_prefix]["subgroups"][subprefix].append(metric)
619
+
620
+ for group_data in result.values():
621
+ group_data["direct_metrics"].sort()
622
+ for subgroup_metrics in group_data["subgroups"].values():
623
+ subgroup_metrics.sort()
624
+
625
+ if "charts" in result and not result["charts"]["direct_metrics"]:
626
+ del result["charts"]
627
+
628
+ return result
629
+
630
+
631
+ def get_sync_status(scheduler: "CommitScheduler | DummyCommitScheduler") -> int | None:
632
+ """Get the sync status from the CommitScheduler in an integer number of minutes, or None if not synced yet."""
633
+ if getattr(
634
+ scheduler, "last_push_time", None
635
+ ): # DummyCommitScheduler doesn't have last_push_time
636
+ time_diff = time.time() - scheduler.last_push_time
637
+ return int(time_diff / 60)
638
+ else:
639
+ return None
640
+
641
+
642
+ def generate_embed_code(project: str, metrics: str, selected_runs: list = None) -> str:
643
+ """Generate the embed iframe code based on current settings."""
644
+ space_host = os.environ.get("SPACE_HOST", "")
645
+ if not space_host:
646
+ return ""
647
+
648
+ params = []
649
+
650
+ if project:
651
+ params.append(f"project={project}")
652
+
653
+ if metrics and metrics.strip():
654
+ params.append(f"metrics={metrics}")
655
+
656
+ if selected_runs:
657
+ runs_param = ",".join(selected_runs)
658
+ params.append(f"runs={runs_param}")
659
+
660
+ params.append("sidebar=hidden")
661
+ params.append("navbar=hidden")
662
+
663
+ query_string = "&".join(params)
664
+ embed_url = f"https://{space_host}?{query_string}"
665
+
666
+ return f'<iframe src="{embed_url}" style="width:1600px; height:500px; border:0;"></iframe>'
667
+
668
+
669
+ def serialize_values(metrics):
670
+ """
671
+ Serialize infinity and NaN values in metrics dict to make it JSON-compliant.
672
+ Only handles top-level float values.
673
+
674
+ Converts:
675
+ - float('inf') -> "Infinity"
676
+ - float('-inf') -> "-Infinity"
677
+ - float('nan') -> "NaN"
678
+
679
+ Example:
680
+ {"loss": float('inf'), "accuracy": 0.95} -> {"loss": "Infinity", "accuracy": 0.95}
681
+ """
682
+ if not isinstance(metrics, dict):
683
+ return metrics
684
+
685
+ result = {}
686
+ for key, value in metrics.items():
687
+ if isinstance(value, float):
688
+ if math.isinf(value):
689
+ result[key] = "Infinity" if value > 0 else "-Infinity"
690
+ elif math.isnan(value):
691
+ result[key] = "NaN"
692
+ else:
693
+ result[key] = value
694
+ elif isinstance(value, np.floating):
695
+ float_val = float(value)
696
+ if math.isinf(float_val):
697
+ result[key] = "Infinity" if float_val > 0 else "-Infinity"
698
+ elif math.isnan(float_val):
699
+ result[key] = "NaN"
700
+ else:
701
+ result[key] = float_val
702
+ else:
703
+ result[key] = value
704
+ return result
705
+
706
+
707
+ def deserialize_values(metrics):
708
+ """
709
+ Deserialize infinity and NaN string values back to their numeric forms.
710
+ Only handles top-level string values.
711
+
712
+ Converts:
713
+ - "Infinity" -> float('inf')
714
+ - "-Infinity" -> float('-inf')
715
+ - "NaN" -> float('nan')
716
+
717
+ Example:
718
+ {"loss": "Infinity", "accuracy": 0.95} -> {"loss": float('inf'), "accuracy": 0.95}
719
+ """
720
+ if not isinstance(metrics, dict):
721
+ return metrics
722
+
723
+ result = {}
724
+ for key, value in metrics.items():
725
+ if value == "Infinity":
726
+ result[key] = float("inf")
727
+ elif value == "-Infinity":
728
+ result[key] = float("-inf")
729
+ elif value == "NaN":
730
+ result[key] = float("nan")
731
+ else:
732
+ result[key] = value
733
+ return result
version.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 0.4.0
video_writer.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ import subprocess
3
+ from pathlib import Path
4
+ from typing import Literal
5
+
6
+ import numpy as np
7
+
8
+ VideoCodec = Literal["h264", "vp9", "gif"]
9
+
10
+
11
+ def _check_ffmpeg_installed() -> None:
12
+ """Raise an error if ffmpeg is not available on the system PATH."""
13
+ if shutil.which("ffmpeg") is None:
14
+ raise RuntimeError(
15
+ "ffmpeg is required to write video but was not found on your system. "
16
+ "Please install ffmpeg and ensure it is available on your PATH."
17
+ )
18
+
19
+
20
+ def _check_array_format(video: np.ndarray) -> None:
21
+ """Raise an error if the array is not in the expected format."""
22
+ if not (video.ndim == 4 and video.shape[-1] == 3):
23
+ raise ValueError(
24
+ f"Expected RGB input shaped (F, H, W, 3), got {video.shape}. "
25
+ f"Input has {video.ndim} dimensions, expected 4."
26
+ )
27
+ if video.dtype != np.uint8:
28
+ raise TypeError(
29
+ f"Expected dtype=uint8, got {video.dtype}. "
30
+ "Please convert your video data to uint8 format."
31
+ )
32
+
33
+
34
+ def _check_path(file_path: str | Path) -> None:
35
+ """Raise an error if the parent directory does not exist."""
36
+ file_path = Path(file_path)
37
+ if not file_path.parent.exists():
38
+ try:
39
+ file_path.parent.mkdir(parents=True, exist_ok=True)
40
+ except OSError as e:
41
+ raise ValueError(
42
+ f"Failed to create parent directory {file_path.parent}: {e}"
43
+ )
44
+
45
+
46
+ def write_video(
47
+ file_path: str | Path, video: np.ndarray, fps: float, codec: VideoCodec
48
+ ) -> None:
49
+ """RGB uint8 only, shape (F, H, W, 3)."""
50
+ _check_ffmpeg_installed()
51
+ _check_path(file_path)
52
+
53
+ if codec not in {"h264", "vp9", "gif"}:
54
+ raise ValueError("Unsupported codec. Use h264, vp9, or gif.")
55
+
56
+ arr = np.asarray(video)
57
+ _check_array_format(arr)
58
+
59
+ frames = np.ascontiguousarray(arr)
60
+ _, height, width, _ = frames.shape
61
+ out_path = str(file_path)
62
+
63
+ cmd = [
64
+ "ffmpeg",
65
+ "-y",
66
+ "-f",
67
+ "rawvideo",
68
+ "-s",
69
+ f"{width}x{height}",
70
+ "-pix_fmt",
71
+ "rgb24",
72
+ "-r",
73
+ str(fps),
74
+ "-i",
75
+ "-",
76
+ "-an",
77
+ ]
78
+
79
+ if codec == "gif":
80
+ video_filter = "split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse"
81
+ cmd += [
82
+ "-vf",
83
+ video_filter,
84
+ "-loop",
85
+ "0",
86
+ ]
87
+ elif codec == "h264":
88
+ cmd += [
89
+ "-vcodec",
90
+ "libx264",
91
+ "-pix_fmt",
92
+ "yuv420p",
93
+ "-movflags",
94
+ "+faststart",
95
+ ]
96
+ elif codec == "vp9":
97
+ bpp = 0.08
98
+ bps = int(width * height * fps * bpp)
99
+ if bps >= 1_000_000:
100
+ bitrate = f"{round(bps / 1_000_000)}M"
101
+ elif bps >= 1_000:
102
+ bitrate = f"{round(bps / 1_000)}k"
103
+ else:
104
+ bitrate = str(max(bps, 1))
105
+ cmd += [
106
+ "-vcodec",
107
+ "libvpx-vp9",
108
+ "-b:v",
109
+ bitrate,
110
+ "-pix_fmt",
111
+ "yuv420p",
112
+ ]
113
+ cmd += [out_path]
114
+ proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stderr=subprocess.PIPE)
115
+ try:
116
+ for frame in frames:
117
+ proc.stdin.write(frame.tobytes())
118
+ finally:
119
+ if proc.stdin:
120
+ proc.stdin.close()
121
+ stderr = (
122
+ proc.stderr.read().decode("utf-8", errors="ignore") if proc.stderr else ""
123
+ )
124
+ ret = proc.wait()
125
+ if ret != 0:
126
+ raise RuntimeError(f"ffmpeg failed with code {ret}\n{stderr}")