baseten-admin commited on
Commit
4180bfa
·
verified ·
1 Parent(s): 42fcefc

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. CHANGELOG.md +24 -0
  3. __init__.py +332 -0
  4. __pycache__/__init__.cpython-310.pyc +0 -0
  5. __pycache__/cli.cpython-310.pyc +0 -0
  6. __pycache__/commit_scheduler.cpython-310.pyc +0 -0
  7. __pycache__/context_vars.cpython-310.pyc +0 -0
  8. __pycache__/deploy.cpython-310.pyc +0 -0
  9. __pycache__/dummy_commit_scheduler.cpython-310.pyc +0 -0
  10. __pycache__/file_storage.cpython-310.pyc +0 -0
  11. __pycache__/imports.cpython-310.pyc +0 -0
  12. __pycache__/media.cpython-310.pyc +0 -0
  13. __pycache__/run.cpython-310.pyc +0 -0
  14. __pycache__/sqlite_storage.cpython-310.pyc +0 -0
  15. __pycache__/table.cpython-310.pyc +0 -0
  16. __pycache__/typehints.cpython-310.pyc +0 -0
  17. __pycache__/utils.cpython-310.pyc +0 -0
  18. __pycache__/video_writer.cpython-310.pyc +0 -0
  19. assets/trackio_logo_dark.png +0 -0
  20. assets/trackio_logo_light.png +0 -0
  21. assets/trackio_logo_old.png +3 -0
  22. assets/trackio_logo_type_dark.png +0 -0
  23. assets/trackio_logo_type_dark_transparent.png +0 -0
  24. assets/trackio_logo_type_light.png +0 -0
  25. assets/trackio_logo_type_light_transparent.png +0 -0
  26. cli.py +37 -0
  27. commit_scheduler.py +391 -0
  28. context_vars.py +18 -0
  29. deploy.py +256 -0
  30. dummy_commit_scheduler.py +12 -0
  31. file_storage.py +37 -0
  32. imports.py +302 -0
  33. media.py +286 -0
  34. package.json +6 -0
  35. py.typed +0 -0
  36. run.py +180 -0
  37. sqlite_storage.py +677 -0
  38. table.py +53 -0
  39. typehints.py +18 -0
  40. ui/__init__.py +10 -0
  41. ui/__pycache__/__init__.cpython-310.pyc +0 -0
  42. ui/__pycache__/fns.cpython-310.pyc +0 -0
  43. ui/__pycache__/main.cpython-310.pyc +0 -0
  44. ui/__pycache__/run_detail.cpython-310.pyc +0 -0
  45. ui/__pycache__/runs.cpython-310.pyc +0 -0
  46. ui/fns.py +241 -0
  47. ui/helpers/__pycache__/run_selection.cpython-310.pyc +0 -0
  48. ui/helpers/run_selection.py +46 -0
  49. ui/main.py +1212 -0
  50. ui/run_detail.py +94 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/trackio_logo_old.png filter=lfs diff=lfs merge=lfs -text
CHANGELOG.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # trackio
2
+
3
+ ## 0.5.3
4
+
5
+ ### Features
6
+
7
+ - [#300](https://github.com/gradio-app/trackio/pull/300) [`5e4cacf`](https://github.com/gradio-app/trackio/commit/5e4cacf2e7ce527b4ce60de3a5bc05d2c02c77fb) - Adds more environment variables to allow customization of Trackio dashboard. Thanks @abidlabs!
8
+
9
+ ## 0.5.2
10
+
11
+ ### Features
12
+
13
+ - [#293](https://github.com/gradio-app/trackio/pull/293) [`64afc28`](https://github.com/gradio-app/trackio/commit/64afc28d3ea1dfd821472dc6bf0b8ed35a9b74be) - Ensures that the TRACKIO_DIR environment variable is respected. Thanks @abidlabs!
14
+ - [#287](https://github.com/gradio-app/trackio/pull/287) [`cd3e929`](https://github.com/gradio-app/trackio/commit/cd3e9294320949e6b8b829239069a43d5d7ff4c1) - fix(sqlite): unify .sqlite extension, allow export when DBs exist, clean WAL sidecars on import. Thanks @vaibhav-research!
15
+
16
+ ### Fixes
17
+
18
+ - [#291](https://github.com/gradio-app/trackio/pull/291) [`3b5adc3`](https://github.com/gradio-app/trackio/commit/3b5adc3d1f452dbab7a714d235f4974782f93730) - Fix the wheel build. Thanks @pngwn!
19
+
20
+ ## 0.5.1
21
+
22
+ ### Fixes
23
+
24
+ - [#278](https://github.com/gradio-app/trackio/pull/278) [`314c054`](https://github.com/gradio-app/trackio/commit/314c05438007ddfea3383e06fd19143e27468e2d) - Fix row orientation of metrics plots. Thanks @abidlabs!
__init__.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import logging
3
+ import os
4
+ import warnings
5
+ import webbrowser
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ from gradio.blocks import BUILT_IN_THEMES
10
+ from gradio.themes import Default as DefaultTheme
11
+ from gradio.themes import ThemeClass
12
+ from gradio_client import Client
13
+ from huggingface_hub import SpaceStorage
14
+
15
+ from trackio import context_vars, deploy, utils
16
+ from trackio.imports import import_csv, import_tf_events
17
+ from trackio.media import TrackioImage, TrackioVideo
18
+ from trackio.run import Run
19
+ from trackio.sqlite_storage import SQLiteStorage
20
+ from trackio.table import Table
21
+ from trackio.ui.main import demo
22
+ from trackio.utils import TRACKIO_DIR, TRACKIO_LOGO_DIR
23
+
24
+ logging.getLogger("httpx").setLevel(logging.WARNING)
25
+
26
+ warnings.filterwarnings(
27
+ "ignore",
28
+ message="Empty session being created. Install gradio\\[oauth\\]",
29
+ category=UserWarning,
30
+ module="gradio.helpers",
31
+ )
32
+
33
+ __version__ = Path(__file__).parent.joinpath("version.txt").read_text().strip()
34
+
35
+ __all__ = [
36
+ "init",
37
+ "log",
38
+ "finish",
39
+ "show",
40
+ "import_csv",
41
+ "import_tf_events",
42
+ "Image",
43
+ "Video",
44
+ "Table",
45
+ ]
46
+
47
+ Image = TrackioImage
48
+ Video = TrackioVideo
49
+
50
+
51
+ config = {}
52
+
53
+ DEFAULT_THEME = "default"
54
+
55
+
56
+ def init(
57
+ project: str,
58
+ name: str | None = None,
59
+ group: str | None = None,
60
+ space_id: str | None = None,
61
+ space_storage: SpaceStorage | None = None,
62
+ dataset_id: str | None = None,
63
+ config: dict | None = None,
64
+ resume: str = "never",
65
+ settings: Any = None,
66
+ private: bool | None = None,
67
+ embed: bool = True,
68
+ ) -> Run:
69
+ """
70
+ Creates a new Trackio project and returns a [`Run`] object.
71
+
72
+ Args:
73
+ project (`str`):
74
+ The name of the project (can be an existing project to continue tracking or
75
+ a new project to start tracking from scratch).
76
+ name (`str`, *optional*):
77
+ The name of the run (if not provided, a default name will be generated).
78
+ group (`str`, *optional*):
79
+ The name of the group which this run belongs to in order to help organize
80
+ related runs together. You can toggle the entire group's visibilitiy in the
81
+ dashboard.
82
+ space_id (`str`, *optional*):
83
+ If provided, the project will be logged to a Hugging Face Space instead of
84
+ a local directory. Should be a complete Space name like
85
+ `"username/reponame"` or `"orgname/reponame"`, or just `"reponame"` in which
86
+ case the Space will be created in the currently-logged-in Hugging Face
87
+ user's namespace. If the Space does not exist, it will be created. If the
88
+ Space already exists, the project will be logged to it.
89
+ space_storage ([`~huggingface_hub.SpaceStorage`], *optional*):
90
+ Choice of persistent storage tier.
91
+ dataset_id (`str`, *optional*):
92
+ If a `space_id` is provided, a persistent Hugging Face Dataset will be
93
+ created and the metrics will be synced to it every 5 minutes. Specify a
94
+ Dataset with name like `"username/datasetname"` or `"orgname/datasetname"`,
95
+ or `"datasetname"` (uses currently-logged-in Hugging Face user's namespace),
96
+ or `None` (uses the same name as the Space but with the `"_dataset"`
97
+ suffix). If the Dataset does not exist, it will be created. If the Dataset
98
+ already exists, the project will be appended to it.
99
+ config (`dict`, *optional*):
100
+ A dictionary of configuration options. Provided for compatibility with
101
+ `wandb.init()`.
102
+ resume (`str`, *optional*, defaults to `"never"`):
103
+ Controls how to handle resuming a run. Can be one of:
104
+
105
+ - `"must"`: Must resume the run with the given name, raises error if run
106
+ doesn't exist
107
+ - `"allow"`: Resume the run if it exists, otherwise create a new run
108
+ - `"never"`: Never resume a run, always create a new one
109
+ private (`bool`, *optional*):
110
+ Whether to make the Space private. If None (default), the repo will be
111
+ public unless the organization's default is private. This value is ignored
112
+ if the repo already exists.
113
+ settings (`Any`, *optional*):
114
+ Not used. Provided for compatibility with `wandb.init()`.
115
+ embed (`bool`, *optional*, defaults to `True`):
116
+ If running inside a jupyter/Colab notebook, whether the dashboard should
117
+ automatically be embedded in the cell when trackio.init() is called.
118
+
119
+ Returns:
120
+ `Run`: A [`Run`] object that can be used to log metrics and finish the run.
121
+ """
122
+ if settings is not None:
123
+ warnings.warn(
124
+ "* Warning: settings is not used. Provided for compatibility with wandb.init(). Please create an issue at: https://github.com/gradio-app/trackio/issues if you need a specific feature implemented."
125
+ )
126
+
127
+ if space_id is None and dataset_id is not None:
128
+ raise ValueError("Must provide a `space_id` when `dataset_id` is provided.")
129
+ space_id, dataset_id = utils.preprocess_space_and_dataset_ids(space_id, dataset_id)
130
+ url = context_vars.current_server.get()
131
+ share_url = context_vars.current_share_server.get()
132
+
133
+ if url is None:
134
+ if space_id is None:
135
+ _, url, share_url = demo.launch(
136
+ show_api=False,
137
+ inline=False,
138
+ quiet=True,
139
+ prevent_thread_lock=True,
140
+ show_error=True,
141
+ favicon_path=TRACKIO_LOGO_DIR / "trackio_logo_light.png",
142
+ allowed_paths=[TRACKIO_LOGO_DIR],
143
+ )
144
+ else:
145
+ url = space_id
146
+ share_url = None
147
+ context_vars.current_server.set(url)
148
+ context_vars.current_share_server.set(share_url)
149
+ if (
150
+ context_vars.current_project.get() is None
151
+ or context_vars.current_project.get() != project
152
+ ):
153
+ print(f"* Trackio project initialized: {project}")
154
+
155
+ if dataset_id is not None:
156
+ os.environ["TRACKIO_DATASET_ID"] = dataset_id
157
+ print(
158
+ f"* Trackio metrics will be synced to Hugging Face Dataset: {dataset_id}"
159
+ )
160
+ if space_id is None:
161
+ print(f"* Trackio metrics logged to: {TRACKIO_DIR}")
162
+ if utils.is_in_notebook() and embed:
163
+ base_url = share_url + "/" if share_url else url
164
+ full_url = utils.get_full_url(
165
+ base_url, project=project, write_token=demo.write_token
166
+ )
167
+ utils.embed_url_in_notebook(full_url)
168
+ else:
169
+ utils.print_dashboard_instructions(project)
170
+ else:
171
+ deploy.create_space_if_not_exists(
172
+ space_id, space_storage, dataset_id, private
173
+ )
174
+ user_name, space_name = space_id.split("/")
175
+ space_url = deploy.SPACE_HOST_URL.format(
176
+ user_name=user_name, space_name=space_name
177
+ )
178
+ print(f"* View dashboard by going to: {space_url}")
179
+ if utils.is_in_notebook() and embed:
180
+ utils.embed_url_in_notebook(space_url)
181
+ context_vars.current_project.set(project)
182
+
183
+ client = None
184
+ if not space_id:
185
+ client = Client(url, verbose=False)
186
+
187
+ if resume == "must":
188
+ if name is None:
189
+ raise ValueError("Must provide a run name when resume='must'")
190
+ if name not in SQLiteStorage.get_runs(project):
191
+ raise ValueError(f"Run '{name}' does not exist in project '{project}'")
192
+ resumed = True
193
+ elif resume == "allow":
194
+ resumed = name is not None and name in SQLiteStorage.get_runs(project)
195
+ elif resume == "never":
196
+ if name is not None and name in SQLiteStorage.get_runs(project):
197
+ warnings.warn(
198
+ f"* Warning: resume='never' but a run '{name}' already exists in "
199
+ f"project '{project}'. Generating a new name and instead. If you want "
200
+ "to resume this run, call init() with resume='must' or resume='allow'."
201
+ )
202
+ name = None
203
+ resumed = False
204
+ else:
205
+ raise ValueError("resume must be one of: 'must', 'allow', or 'never'")
206
+
207
+ run = Run(
208
+ url=url,
209
+ project=project,
210
+ client=client,
211
+ name=name,
212
+ group=group,
213
+ config=config,
214
+ space_id=space_id,
215
+ )
216
+
217
+ if resumed:
218
+ print(f"* Resumed existing run: {run.name}")
219
+ else:
220
+ print(f"* Created new run: {run.name}")
221
+
222
+ context_vars.current_run.set(run)
223
+ globals()["config"] = run.config
224
+ return run
225
+
226
+
227
+ def log(metrics: dict, step: int | None = None) -> None:
228
+ """
229
+ Logs metrics to the current run.
230
+
231
+ Args:
232
+ metrics (`dict`):
233
+ A dictionary of metrics to log.
234
+ step (`int`, *optional*):
235
+ The step number. If not provided, the step will be incremented
236
+ automatically.
237
+ """
238
+ run = context_vars.current_run.get()
239
+ if run is None:
240
+ raise RuntimeError("Call trackio.init() before trackio.log().")
241
+ run.log(
242
+ metrics=metrics,
243
+ step=step,
244
+ )
245
+
246
+
247
+ def finish():
248
+ """
249
+ Finishes the current run.
250
+ """
251
+ run = context_vars.current_run.get()
252
+ if run is None:
253
+ raise RuntimeError("Call trackio.init() before trackio.finish().")
254
+ run.finish()
255
+
256
+
257
+ def show(
258
+ project: str | None = None,
259
+ theme: str | ThemeClass | None = None,
260
+ mcp_server: bool | None = None,
261
+ ):
262
+ """
263
+ Launches the Trackio dashboard.
264
+
265
+ Args:
266
+ project (`str`, *optional*):
267
+ The name of the project whose runs to show. If not provided, all projects
268
+ will be shown and the user can select one.
269
+ theme (`str` or `ThemeClass`, *optional*):
270
+ A Gradio Theme to use for the dashboard instead of the default Gradio theme,
271
+ can be a built-in theme (e.g. `'soft'`, `'citrus'`), a theme from the Hub
272
+ (e.g. `"gstaff/xkcd"`), or a custom Theme class. If not provided, the
273
+ `TRACKIO_THEME` environment variable will be used, or if that is not set, the
274
+ default Gradio theme will be used.
275
+ mcp_server (`bool`, *optional*):
276
+ If `True`, the Trackio dashboard will be set up as an MCP server and certain
277
+ functions will be added as MCP tools. If `None` (default behavior), then the
278
+ `GRADIO_MCP_SERVER` environment variable will be used to determine if the
279
+ MCP server should be enabled (which is `"True"` on Hugging Face Spaces).
280
+ """
281
+ theme = theme or os.environ.get("TRACKIO_THEME", DEFAULT_THEME)
282
+
283
+ if theme != DEFAULT_THEME:
284
+ # TODO: It's a little hacky to reproduce this theme-setting logic from Gradio Blocks,
285
+ # but in Gradio 6.0, the theme will be set in `launch()` instead, which means that we
286
+ # will be able to remove this code.
287
+ if isinstance(theme, str):
288
+ if theme.lower() in BUILT_IN_THEMES:
289
+ theme = BUILT_IN_THEMES[theme.lower()]
290
+ else:
291
+ try:
292
+ theme = ThemeClass.from_hub(theme)
293
+ except Exception as e:
294
+ warnings.warn(f"Cannot load {theme}. Caught Exception: {str(e)}")
295
+ theme = DefaultTheme()
296
+ if not isinstance(theme, ThemeClass):
297
+ warnings.warn("Theme should be a class loaded from gradio.themes")
298
+ theme = DefaultTheme()
299
+ demo.theme: ThemeClass = theme
300
+ demo.theme_css = theme._get_theme_css()
301
+ demo.stylesheets = theme._stylesheets
302
+ theme_hasher = hashlib.sha256()
303
+ theme_hasher.update(demo.theme_css.encode("utf-8"))
304
+ demo.theme_hash = theme_hasher.hexdigest()
305
+
306
+ _mcp_server = (
307
+ mcp_server
308
+ if mcp_server is not None
309
+ else os.environ.get("GRADIO_MCP_SERVER", "False") == "True"
310
+ )
311
+
312
+ _, url, share_url = demo.launch(
313
+ show_api=_mcp_server,
314
+ quiet=True,
315
+ inline=False,
316
+ prevent_thread_lock=True,
317
+ favicon_path=TRACKIO_LOGO_DIR / "trackio_logo_light.png",
318
+ allowed_paths=[TRACKIO_LOGO_DIR],
319
+ mcp_server=_mcp_server,
320
+ )
321
+
322
+ base_url = share_url + "/" if share_url else url
323
+ full_url = utils.get_full_url(
324
+ base_url, project=project, write_token=demo.write_token
325
+ )
326
+
327
+ if not utils.is_in_notebook():
328
+ print(f"* Trackio UI launched at: {full_url}")
329
+ webbrowser.open(full_url)
330
+ utils.block_main_thread_until_keyboard_interrupt()
331
+ else:
332
+ utils.embed_url_in_notebook(full_url)
__pycache__/__init__.cpython-310.pyc ADDED
Binary file (10.5 kB). View file
 
__pycache__/cli.cpython-310.pyc ADDED
Binary file (1.23 kB). View file
 
__pycache__/commit_scheduler.cpython-310.pyc ADDED
Binary file (14.1 kB). View file
 
__pycache__/context_vars.cpython-310.pyc ADDED
Binary file (563 Bytes). View file
 
__pycache__/deploy.cpython-310.pyc ADDED
Binary file (6.25 kB). View file
 
__pycache__/dummy_commit_scheduler.cpython-310.pyc ADDED
Binary file (948 Bytes). View file
 
__pycache__/file_storage.cpython-310.pyc ADDED
Binary file (1.26 kB). View file
 
__pycache__/imports.cpython-310.pyc ADDED
Binary file (9.35 kB). View file
 
__pycache__/media.cpython-310.pyc ADDED
Binary file (9.65 kB). View file
 
__pycache__/run.cpython-310.pyc ADDED
Binary file (4.96 kB). View file
 
__pycache__/sqlite_storage.cpython-310.pyc ADDED
Binary file (19.4 kB). View file
 
__pycache__/table.cpython-310.pyc ADDED
Binary file (2.05 kB). View file
 
__pycache__/typehints.cpython-310.pyc ADDED
Binary file (789 Bytes). View file
 
__pycache__/utils.cpython-310.pyc ADDED
Binary file (17.7 kB). View file
 
__pycache__/video_writer.cpython-310.pyc ADDED
Binary file (3.21 kB). View file
 
assets/trackio_logo_dark.png ADDED
assets/trackio_logo_light.png ADDED
assets/trackio_logo_old.png ADDED

Git LFS Details

  • SHA256: 3922c4d1e465270ad4d8abb12023f3beed5d9f7f338528a4c0ac21dcf358a1c8
  • Pointer size: 131 Bytes
  • Size of remote file: 487 kB
assets/trackio_logo_type_dark.png ADDED
assets/trackio_logo_type_dark_transparent.png ADDED
assets/trackio_logo_type_light.png ADDED
assets/trackio_logo_type_light_transparent.png ADDED
cli.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from trackio import show
4
+
5
+
6
+ def main():
7
+ parser = argparse.ArgumentParser(description="Trackio CLI")
8
+ subparsers = parser.add_subparsers(dest="command")
9
+
10
+ ui_parser = subparsers.add_parser(
11
+ "show", help="Show the Trackio dashboard UI for a project"
12
+ )
13
+ ui_parser.add_argument(
14
+ "--project", required=False, help="Project name to show in the dashboard"
15
+ )
16
+ ui_parser.add_argument(
17
+ "--theme",
18
+ required=False,
19
+ default="citrus",
20
+ help="A Gradio Theme to use for the dashboard instead of the default, can be a built-in theme (e.g. 'soft', 'citrus'), or a theme from the Hub (e.g. 'gstaff/xkcd').",
21
+ )
22
+ ui_parser.add_argument(
23
+ "--mcp-server",
24
+ action="store_true",
25
+ help="Enable MCP server functionality. The Trackio dashboard will be set up as an MCP server and certain functions will be exposed as MCP tools.",
26
+ )
27
+
28
+ args = parser.parse_args()
29
+
30
+ if args.command == "show":
31
+ show(args.project, args.theme, args.mcp_server)
32
+ else:
33
+ parser.print_help()
34
+
35
+
36
+ if __name__ == "__main__":
37
+ main()
commit_scheduler.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Originally copied from https://github.com/huggingface/huggingface_hub/blob/d0a948fc2a32ed6e557042a95ef3e4af97ec4a7c/src/huggingface_hub/_commit_scheduler.py
2
+
3
+ import atexit
4
+ import logging
5
+ import os
6
+ import time
7
+ from concurrent.futures import Future
8
+ from dataclasses import dataclass
9
+ from io import SEEK_END, SEEK_SET, BytesIO
10
+ from pathlib import Path
11
+ from threading import Lock, Thread
12
+ from typing import Callable, Dict, List, Union
13
+
14
+ from huggingface_hub.hf_api import (
15
+ DEFAULT_IGNORE_PATTERNS,
16
+ CommitInfo,
17
+ CommitOperationAdd,
18
+ HfApi,
19
+ )
20
+ from huggingface_hub.utils import filter_repo_objects
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class _FileToUpload:
27
+ """Temporary dataclass to store info about files to upload. Not meant to be used directly."""
28
+
29
+ local_path: Path
30
+ path_in_repo: str
31
+ size_limit: int
32
+ last_modified: float
33
+
34
+
35
+ class CommitScheduler:
36
+ """
37
+ Scheduler to upload a local folder to the Hub at regular intervals (e.g. push to hub every 5 minutes).
38
+
39
+ The recommended way to use the scheduler is to use it as a context manager. This ensures that the scheduler is
40
+ properly stopped and the last commit is triggered when the script ends. The scheduler can also be stopped manually
41
+ with the `stop` method. Checkout the [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#scheduled-uploads)
42
+ to learn more about how to use it.
43
+
44
+ Args:
45
+ repo_id (`str`):
46
+ The id of the repo to commit to.
47
+ folder_path (`str` or `Path`):
48
+ Path to the local folder to upload regularly.
49
+ every (`int` or `float`, *optional*):
50
+ The number of minutes between each commit. Defaults to 5 minutes.
51
+ path_in_repo (`str`, *optional*):
52
+ Relative path of the directory in the repo, for example: `"checkpoints/"`. Defaults to the root folder
53
+ of the repository.
54
+ repo_type (`str`, *optional*):
55
+ The type of the repo to commit to. Defaults to `model`.
56
+ revision (`str`, *optional*):
57
+ The revision of the repo to commit to. Defaults to `main`.
58
+ private (`bool`, *optional*):
59
+ Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
60
+ token (`str`, *optional*):
61
+ The token to use to commit to the repo. Defaults to the token saved on the machine.
62
+ allow_patterns (`List[str]` or `str`, *optional*):
63
+ If provided, only files matching at least one pattern are uploaded.
64
+ ignore_patterns (`List[str]` or `str`, *optional*):
65
+ If provided, files matching any of the patterns are not uploaded.
66
+ squash_history (`bool`, *optional*):
67
+ Whether to squash the history of the repo after each commit. Defaults to `False`. Squashing commits is
68
+ useful to avoid degraded performances on the repo when it grows too large.
69
+ hf_api (`HfApi`, *optional*):
70
+ The [`HfApi`] client to use to commit to the Hub. Can be set with custom settings (user agent, token,...).
71
+ on_before_commit (`Callable[[], None]`, *optional*):
72
+ If specified, a function that will be called before the CommitScheduler lists files to create a commit.
73
+
74
+ Example:
75
+ ```py
76
+ >>> from pathlib import Path
77
+ >>> from huggingface_hub import CommitScheduler
78
+
79
+ # Scheduler uploads every 10 minutes
80
+ >>> csv_path = Path("watched_folder/data.csv")
81
+ >>> CommitScheduler(repo_id="test_scheduler", repo_type="dataset", folder_path=csv_path.parent, every=10)
82
+
83
+ >>> with csv_path.open("a") as f:
84
+ ... f.write("first line")
85
+
86
+ # Some time later (...)
87
+ >>> with csv_path.open("a") as f:
88
+ ... f.write("second line")
89
+ ```
90
+
91
+ Example using a context manager:
92
+ ```py
93
+ >>> from pathlib import Path
94
+ >>> from huggingface_hub import CommitScheduler
95
+
96
+ >>> with CommitScheduler(repo_id="test_scheduler", repo_type="dataset", folder_path="watched_folder", every=10) as scheduler:
97
+ ... csv_path = Path("watched_folder/data.csv")
98
+ ... with csv_path.open("a") as f:
99
+ ... f.write("first line")
100
+ ... (...)
101
+ ... with csv_path.open("a") as f:
102
+ ... f.write("second line")
103
+
104
+ # Scheduler is now stopped and last commit have been triggered
105
+ ```
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ *,
111
+ repo_id: str,
112
+ folder_path: Union[str, Path],
113
+ every: Union[int, float] = 5,
114
+ path_in_repo: str | None = None,
115
+ repo_type: str | None = None,
116
+ revision: str | None = None,
117
+ private: bool | None = None,
118
+ token: str | None = None,
119
+ allow_patterns: list[str] | str | None = None,
120
+ ignore_patterns: list[str] | str | None = None,
121
+ squash_history: bool = False,
122
+ hf_api: HfApi | None = None,
123
+ on_before_commit: Callable[[], None] | None = None,
124
+ ) -> None:
125
+ self.api = hf_api or HfApi(token=token)
126
+ self.on_before_commit = on_before_commit
127
+
128
+ # Folder
129
+ self.folder_path = Path(folder_path).expanduser().resolve()
130
+ self.path_in_repo = path_in_repo or ""
131
+ self.allow_patterns = allow_patterns
132
+
133
+ if ignore_patterns is None:
134
+ ignore_patterns = []
135
+ elif isinstance(ignore_patterns, str):
136
+ ignore_patterns = [ignore_patterns]
137
+ self.ignore_patterns = ignore_patterns + DEFAULT_IGNORE_PATTERNS
138
+
139
+ if self.folder_path.is_file():
140
+ raise ValueError(
141
+ f"'folder_path' must be a directory, not a file: '{self.folder_path}'."
142
+ )
143
+ self.folder_path.mkdir(parents=True, exist_ok=True)
144
+
145
+ # Repository
146
+ repo_url = self.api.create_repo(
147
+ repo_id=repo_id, private=private, repo_type=repo_type, exist_ok=True
148
+ )
149
+ self.repo_id = repo_url.repo_id
150
+ self.repo_type = repo_type
151
+ self.revision = revision
152
+ self.token = token
153
+
154
+ self.last_uploaded: Dict[Path, float] = {}
155
+ self.last_push_time: float | None = None
156
+
157
+ if not every > 0:
158
+ raise ValueError(f"'every' must be a positive integer, not '{every}'.")
159
+ self.lock = Lock()
160
+ self.every = every
161
+ self.squash_history = squash_history
162
+
163
+ logger.info(
164
+ f"Scheduled job to push '{self.folder_path}' to '{self.repo_id}' every {self.every} minutes."
165
+ )
166
+ self._scheduler_thread = Thread(target=self._run_scheduler, daemon=True)
167
+ self._scheduler_thread.start()
168
+ atexit.register(self._push_to_hub)
169
+
170
+ self.__stopped = False
171
+
172
+ def stop(self) -> None:
173
+ """Stop the scheduler.
174
+
175
+ A stopped scheduler cannot be restarted. Mostly for tests purposes.
176
+ """
177
+ self.__stopped = True
178
+
179
+ def __enter__(self) -> "CommitScheduler":
180
+ return self
181
+
182
+ def __exit__(self, exc_type, exc_value, traceback) -> None:
183
+ # Upload last changes before exiting
184
+ self.trigger().result()
185
+ self.stop()
186
+ return
187
+
188
+ def _run_scheduler(self) -> None:
189
+ """Dumb thread waiting between each scheduled push to Hub."""
190
+ while True:
191
+ self.last_future = self.trigger()
192
+ time.sleep(self.every * 60)
193
+ if self.__stopped:
194
+ break
195
+
196
+ def trigger(self) -> Future:
197
+ """Trigger a `push_to_hub` and return a future.
198
+
199
+ This method is automatically called every `every` minutes. You can also call it manually to trigger a commit
200
+ immediately, without waiting for the next scheduled commit.
201
+ """
202
+ return self.api.run_as_future(self._push_to_hub)
203
+
204
+ def _push_to_hub(self) -> CommitInfo | None:
205
+ if self.__stopped: # If stopped, already scheduled commits are ignored
206
+ return None
207
+
208
+ logger.info("(Background) scheduled commit triggered.")
209
+ try:
210
+ value = self.push_to_hub()
211
+ if self.squash_history:
212
+ logger.info("(Background) squashing repo history.")
213
+ self.api.super_squash_history(
214
+ repo_id=self.repo_id, repo_type=self.repo_type, branch=self.revision
215
+ )
216
+ return value
217
+ except Exception as e:
218
+ logger.error(
219
+ f"Error while pushing to Hub: {e}"
220
+ ) # Depending on the setup, error might be silenced
221
+ raise
222
+
223
+ def push_to_hub(self) -> CommitInfo | None:
224
+ """
225
+ Push folder to the Hub and return the commit info.
226
+
227
+ <Tip warning={true}>
228
+
229
+ This method is not meant to be called directly. It is run in the background by the scheduler, respecting a
230
+ queue mechanism to avoid concurrent commits. Making a direct call to the method might lead to concurrency
231
+ issues.
232
+
233
+ </Tip>
234
+
235
+ The default behavior of `push_to_hub` is to assume an append-only folder. It lists all files in the folder and
236
+ uploads only changed files. If no changes are found, the method returns without committing anything. If you want
237
+ to change this behavior, you can inherit from [`CommitScheduler`] and override this method. This can be useful
238
+ for example to compress data together in a single file before committing. For more details and examples, check
239
+ out our [integration guide](https://huggingface.co/docs/huggingface_hub/main/en/guides/upload#scheduled-uploads).
240
+ """
241
+ # Check files to upload (with lock)
242
+ with self.lock:
243
+ if self.on_before_commit is not None:
244
+ self.on_before_commit()
245
+
246
+ logger.debug("Listing files to upload for scheduled commit.")
247
+
248
+ # List files from folder (taken from `_prepare_upload_folder_additions`)
249
+ relpath_to_abspath = {
250
+ path.relative_to(self.folder_path).as_posix(): path
251
+ for path in sorted(
252
+ self.folder_path.glob("**/*")
253
+ ) # sorted to be deterministic
254
+ if path.is_file()
255
+ }
256
+ prefix = f"{self.path_in_repo.strip('/')}/" if self.path_in_repo else ""
257
+
258
+ # Filter with pattern + filter out unchanged files + retrieve current file size
259
+ files_to_upload: List[_FileToUpload] = []
260
+ for relpath in filter_repo_objects(
261
+ relpath_to_abspath.keys(),
262
+ allow_patterns=self.allow_patterns,
263
+ ignore_patterns=self.ignore_patterns,
264
+ ):
265
+ local_path = relpath_to_abspath[relpath]
266
+ stat = local_path.stat()
267
+ if (
268
+ self.last_uploaded.get(local_path) is None
269
+ or self.last_uploaded[local_path] != stat.st_mtime
270
+ ):
271
+ files_to_upload.append(
272
+ _FileToUpload(
273
+ local_path=local_path,
274
+ path_in_repo=prefix + relpath,
275
+ size_limit=stat.st_size,
276
+ last_modified=stat.st_mtime,
277
+ )
278
+ )
279
+
280
+ # Return if nothing to upload
281
+ if len(files_to_upload) == 0:
282
+ logger.debug("Dropping schedule commit: no changed file to upload.")
283
+ return None
284
+
285
+ # Convert `_FileToUpload` as `CommitOperationAdd` (=> compute file shas + limit to file size)
286
+ logger.debug("Removing unchanged files since previous scheduled commit.")
287
+ add_operations = [
288
+ CommitOperationAdd(
289
+ # TODO: Cap the file to its current size, even if the user append data to it while a scheduled commit is happening
290
+ # (requires an upstream fix for XET-535: `hf_xet` should support `BinaryIO` for upload)
291
+ path_or_fileobj=file_to_upload.local_path,
292
+ path_in_repo=file_to_upload.path_in_repo,
293
+ )
294
+ for file_to_upload in files_to_upload
295
+ ]
296
+
297
+ # Upload files (append mode expected - no need for lock)
298
+ logger.debug("Uploading files for scheduled commit.")
299
+ commit_info = self.api.create_commit(
300
+ repo_id=self.repo_id,
301
+ repo_type=self.repo_type,
302
+ operations=add_operations,
303
+ commit_message="Scheduled Commit",
304
+ revision=self.revision,
305
+ )
306
+
307
+ for file in files_to_upload:
308
+ self.last_uploaded[file.local_path] = file.last_modified
309
+
310
+ self.last_push_time = time.time()
311
+
312
+ return commit_info
313
+
314
+
315
+ class PartialFileIO(BytesIO):
316
+ """A file-like object that reads only the first part of a file.
317
+
318
+ Useful to upload a file to the Hub when the user might still be appending data to it. Only the first part of the
319
+ file is uploaded (i.e. the part that was available when the filesystem was first scanned).
320
+
321
+ In practice, only used internally by the CommitScheduler to regularly push a folder to the Hub with minimal
322
+ disturbance for the user. The object is passed to `CommitOperationAdd`.
323
+
324
+ Only supports `read`, `tell` and `seek` methods.
325
+
326
+ Args:
327
+ file_path (`str` or `Path`):
328
+ Path to the file to read.
329
+ size_limit (`int`):
330
+ The maximum number of bytes to read from the file. If the file is larger than this, only the first part
331
+ will be read (and uploaded).
332
+ """
333
+
334
+ def __init__(self, file_path: Union[str, Path], size_limit: int) -> None:
335
+ self._file_path = Path(file_path)
336
+ self._file = self._file_path.open("rb")
337
+ self._size_limit = min(size_limit, os.fstat(self._file.fileno()).st_size)
338
+
339
+ def __del__(self) -> None:
340
+ self._file.close()
341
+ return super().__del__()
342
+
343
+ def __repr__(self) -> str:
344
+ return (
345
+ f"<PartialFileIO file_path={self._file_path} size_limit={self._size_limit}>"
346
+ )
347
+
348
+ def __len__(self) -> int:
349
+ return self._size_limit
350
+
351
+ def __getattribute__(self, name: str):
352
+ if name.startswith("_") or name in (
353
+ "read",
354
+ "tell",
355
+ "seek",
356
+ ): # only 3 public methods supported
357
+ return super().__getattribute__(name)
358
+ raise NotImplementedError(f"PartialFileIO does not support '{name}'.")
359
+
360
+ def tell(self) -> int:
361
+ """Return the current file position."""
362
+ return self._file.tell()
363
+
364
+ def seek(self, __offset: int, __whence: int = SEEK_SET) -> int:
365
+ """Change the stream position to the given offset.
366
+
367
+ Behavior is the same as a regular file, except that the position is capped to the size limit.
368
+ """
369
+ if __whence == SEEK_END:
370
+ # SEEK_END => set from the truncated end
371
+ __offset = len(self) + __offset
372
+ __whence = SEEK_SET
373
+
374
+ pos = self._file.seek(__offset, __whence)
375
+ if pos > self._size_limit:
376
+ return self._file.seek(self._size_limit)
377
+ return pos
378
+
379
+ def read(self, __size: int | None = -1) -> bytes:
380
+ """Read at most `__size` bytes from the file.
381
+
382
+ Behavior is the same as a regular file, except that it is capped to the size limit.
383
+ """
384
+ current = self._file.tell()
385
+ if __size is None or __size < 0:
386
+ # Read until file limit
387
+ truncated_size = self._size_limit - current
388
+ else:
389
+ # Read until file limit or __size
390
+ truncated_size = min(__size, self._size_limit - current)
391
+ return self._file.read(truncated_size)
context_vars.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextvars
2
+ from typing import TYPE_CHECKING
3
+
4
+ if TYPE_CHECKING:
5
+ from trackio.run import Run
6
+
7
+ current_run: contextvars.ContextVar["Run | None"] = contextvars.ContextVar(
8
+ "current_run", default=None
9
+ )
10
+ current_project: contextvars.ContextVar[str | None] = contextvars.ContextVar(
11
+ "current_project", default=None
12
+ )
13
+ current_server: contextvars.ContextVar[str | None] = contextvars.ContextVar(
14
+ "current_server", default=None
15
+ )
16
+ current_share_server: contextvars.ContextVar[str | None] = contextvars.ContextVar(
17
+ "current_share_server", default=None
18
+ )
deploy.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.metadata
2
+ import io
3
+ import os
4
+ import time
5
+ from importlib.resources import files
6
+ from pathlib import Path
7
+
8
+ import gradio
9
+ import huggingface_hub
10
+ from gradio_client import Client, handle_file
11
+ from httpx import ReadTimeout
12
+ from huggingface_hub.errors import RepositoryNotFoundError
13
+ from requests import HTTPError
14
+
15
+ import trackio
16
+ from trackio.sqlite_storage import SQLiteStorage
17
+
18
+ SPACE_HOST_URL = "https://{user_name}-{space_name}.hf.space/"
19
+ SPACE_URL = "https://huggingface.co/spaces/{space_id}"
20
+
21
+
22
+ def _is_trackio_installed_from_source() -> bool:
23
+ """Check if trackio is installed from source/editable install vs PyPI."""
24
+ try:
25
+ trackio_file = trackio.__file__
26
+ if "site-packages" not in trackio_file:
27
+ return True
28
+
29
+ dist = importlib.metadata.distribution("trackio")
30
+ if dist.files:
31
+ files = list(dist.files)
32
+ has_pth = any(".pth" in str(f) for f in files)
33
+ if has_pth:
34
+ return True
35
+
36
+ return False
37
+ except (
38
+ AttributeError,
39
+ importlib.metadata.PackageNotFoundError,
40
+ importlib.metadata.MetadataError,
41
+ ValueError,
42
+ TypeError,
43
+ ):
44
+ return True
45
+
46
+
47
+ def deploy_as_space(
48
+ space_id: str,
49
+ space_storage: huggingface_hub.SpaceStorage | None = None,
50
+ dataset_id: str | None = None,
51
+ private: bool | None = None,
52
+ ):
53
+ if (
54
+ os.getenv("SYSTEM") == "spaces"
55
+ ): # in case a repo with this function is uploaded to spaces
56
+ return
57
+
58
+ trackio_path = files("trackio")
59
+
60
+ hf_api = huggingface_hub.HfApi()
61
+
62
+ try:
63
+ huggingface_hub.create_repo(
64
+ space_id,
65
+ private=private,
66
+ space_sdk="gradio",
67
+ space_storage=space_storage,
68
+ repo_type="space",
69
+ exist_ok=True,
70
+ )
71
+ except HTTPError as e:
72
+ if e.response.status_code in [401, 403]: # unauthorized or forbidden
73
+ print("Need 'write' access token to create a Spaces repo.")
74
+ huggingface_hub.login(add_to_git_credential=False)
75
+ huggingface_hub.create_repo(
76
+ space_id,
77
+ private=private,
78
+ space_sdk="gradio",
79
+ space_storage=space_storage,
80
+ repo_type="space",
81
+ exist_ok=True,
82
+ )
83
+ else:
84
+ raise ValueError(f"Failed to create Space: {e}")
85
+
86
+ with open(Path(trackio_path, "README.md"), "r") as f:
87
+ readme_content = f.read()
88
+ readme_content = readme_content.replace("{GRADIO_VERSION}", gradio.__version__)
89
+ readme_buffer = io.BytesIO(readme_content.encode("utf-8"))
90
+ hf_api.upload_file(
91
+ path_or_fileobj=readme_buffer,
92
+ path_in_repo="README.md",
93
+ repo_id=space_id,
94
+ repo_type="space",
95
+ )
96
+
97
+ # We can assume pandas, gradio, and huggingface-hub are already installed in a Gradio Space.
98
+ # Make sure necessary dependencies are installed by creating a requirements.txt.
99
+ is_source_install = _is_trackio_installed_from_source()
100
+
101
+ if is_source_install:
102
+ requirements_content = """pyarrow>=21.0"""
103
+ else:
104
+ requirements_content = f"""pyarrow>=21.0
105
+ trackio=={trackio.__version__}"""
106
+
107
+ requirements_buffer = io.BytesIO(requirements_content.encode("utf-8"))
108
+ hf_api.upload_file(
109
+ path_or_fileobj=requirements_buffer,
110
+ path_in_repo="requirements.txt",
111
+ repo_id=space_id,
112
+ repo_type="space",
113
+ )
114
+
115
+ huggingface_hub.utils.disable_progress_bars()
116
+
117
+ if is_source_install:
118
+ hf_api.upload_folder(
119
+ repo_id=space_id,
120
+ repo_type="space",
121
+ folder_path=trackio_path,
122
+ ignore_patterns=["README.md"],
123
+ )
124
+ else:
125
+ app_file_content = """import trackio
126
+ trackio.show()"""
127
+ app_file_buffer = io.BytesIO(app_file_content.encode("utf-8"))
128
+ hf_api.upload_file(
129
+ path_or_fileobj=app_file_buffer,
130
+ path_in_repo="ui/main.py",
131
+ repo_id=space_id,
132
+ repo_type="space",
133
+ )
134
+
135
+ if hf_token := huggingface_hub.utils.get_token():
136
+ huggingface_hub.add_space_secret(space_id, "HF_TOKEN", hf_token)
137
+ if dataset_id is not None:
138
+ huggingface_hub.add_space_variable(space_id, "TRACKIO_DATASET_ID", dataset_id)
139
+
140
+ if logo_light_url := os.environ.get("TRACKIO_LOGO_LIGHT_URL"):
141
+ huggingface_hub.add_space_variable(
142
+ space_id, "TRACKIO_LOGO_LIGHT_URL", logo_light_url
143
+ )
144
+ if logo_dark_url := os.environ.get("TRACKIO_LOGO_DARK_URL"):
145
+ huggingface_hub.add_space_variable(
146
+ space_id, "TRACKIO_LOGO_DARK_URL", logo_dark_url
147
+ )
148
+
149
+ if plot_order := os.environ.get("TRACKIO_PLOT_ORDER"):
150
+ huggingface_hub.add_space_variable(space_id, "TRACKIO_PLOT_ORDER", plot_order)
151
+
152
+ if theme := os.environ.get("TRACKIO_THEME"):
153
+ huggingface_hub.add_space_variable(space_id, "TRACKIO_THEME", theme)
154
+
155
+
156
+ def create_space_if_not_exists(
157
+ space_id: str,
158
+ space_storage: huggingface_hub.SpaceStorage | None = None,
159
+ dataset_id: str | None = None,
160
+ private: bool | None = None,
161
+ ) -> None:
162
+ """
163
+ Creates a new Hugging Face Space if it does not exist. If a dataset_id is provided, it will be added as a space variable.
164
+
165
+ Args:
166
+ space_id: The ID of the Space to create.
167
+ dataset_id: The ID of the Dataset to add to the Space.
168
+ private: Whether to make the Space private. If None (default), the repo will be
169
+ public unless the organization's default is private. This value is ignored if
170
+ the repo already exists.
171
+ """
172
+ if "/" not in space_id:
173
+ raise ValueError(
174
+ f"Invalid space ID: {space_id}. Must be in the format: username/reponame or orgname/reponame."
175
+ )
176
+ if dataset_id is not None and "/" not in dataset_id:
177
+ raise ValueError(
178
+ f"Invalid dataset ID: {dataset_id}. Must be in the format: username/datasetname or orgname/datasetname."
179
+ )
180
+ try:
181
+ huggingface_hub.repo_info(space_id, repo_type="space")
182
+ print(f"* Found existing space: {SPACE_URL.format(space_id=space_id)}")
183
+ if dataset_id is not None:
184
+ huggingface_hub.add_space_variable(
185
+ space_id, "TRACKIO_DATASET_ID", dataset_id
186
+ )
187
+ if logo_light_url := os.environ.get("TRACKIO_LOGO_LIGHT_URL"):
188
+ huggingface_hub.add_space_variable(
189
+ space_id, "TRACKIO_LOGO_LIGHT_URL", logo_light_url
190
+ )
191
+ if logo_dark_url := os.environ.get("TRACKIO_LOGO_DARK_URL"):
192
+ huggingface_hub.add_space_variable(
193
+ space_id, "TRACKIO_LOGO_DARK_URL", logo_dark_url
194
+ )
195
+
196
+ if plot_order := os.environ.get("TRACKIO_PLOT_ORDER"):
197
+ huggingface_hub.add_space_variable(
198
+ space_id, "TRACKIO_PLOT_ORDER", plot_order
199
+ )
200
+
201
+ if theme := os.environ.get("TRACKIO_THEME"):
202
+ huggingface_hub.add_space_variable(space_id, "TRACKIO_THEME", theme)
203
+ return
204
+ except RepositoryNotFoundError:
205
+ pass
206
+ except HTTPError as e:
207
+ if e.response.status_code in [401, 403]: # unauthorized or forbidden
208
+ print("Need 'write' access token to create a Spaces repo.")
209
+ huggingface_hub.login(add_to_git_credential=False)
210
+ huggingface_hub.add_space_variable(
211
+ space_id, "TRACKIO_DATASET_ID", dataset_id
212
+ )
213
+ else:
214
+ raise ValueError(f"Failed to create Space: {e}")
215
+
216
+ print(f"* Creating new space: {SPACE_URL.format(space_id=space_id)}")
217
+ deploy_as_space(space_id, space_storage, dataset_id, private)
218
+
219
+
220
+ def wait_until_space_exists(
221
+ space_id: str,
222
+ ) -> None:
223
+ """
224
+ Blocks the current thread until the space exists.
225
+ May raise a TimeoutError if this takes quite a while.
226
+
227
+ Args:
228
+ space_id: The ID of the Space to wait for.
229
+ """
230
+ delay = 1
231
+ for _ in range(10):
232
+ try:
233
+ Client(space_id, verbose=False)
234
+ return
235
+ except (ReadTimeout, ValueError):
236
+ time.sleep(delay)
237
+ delay = min(delay * 2, 30)
238
+ raise TimeoutError("Waiting for space to exist took longer than expected")
239
+
240
+
241
+ def upload_db_to_space(project: str, space_id: str) -> None:
242
+ """
243
+ Uploads the database of a local Trackio project to a Hugging Face Space.
244
+
245
+ Args:
246
+ project: The name of the project to upload.
247
+ space_id: The ID of the Space to upload to.
248
+ """
249
+ db_path = SQLiteStorage.get_project_db_path(project)
250
+ client = Client(space_id, verbose=False)
251
+ client.predict(
252
+ api_name="/upload_db_to_space",
253
+ project=project,
254
+ uploaded_db=handle_file(db_path),
255
+ hf_token=huggingface_hub.utils.get_token(),
256
+ )
dummy_commit_scheduler.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A dummy object to fit the interface of huggingface_hub's CommitScheduler
2
+ class DummyCommitSchedulerLock:
3
+ def __enter__(self):
4
+ return None
5
+
6
+ def __exit__(self, exception_type, exception_value, exception_traceback):
7
+ pass
8
+
9
+
10
+ class DummyCommitScheduler:
11
+ def __init__(self):
12
+ self.lock = DummyCommitSchedulerLock()
file_storage.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ try: # absolute imports when installed
4
+ from trackio.utils import MEDIA_DIR
5
+ except ImportError: # relative imports for local execution on Spaces
6
+ from utils import MEDIA_DIR
7
+
8
+
9
+ class FileStorage:
10
+ @staticmethod
11
+ def get_project_media_path(
12
+ project: str,
13
+ run: str | None = None,
14
+ step: int | None = None,
15
+ filename: str | None = None,
16
+ ) -> Path:
17
+ if filename is not None and step is None:
18
+ raise ValueError("filename requires step")
19
+ if step is not None and run is None:
20
+ raise ValueError("step requires run")
21
+
22
+ path = MEDIA_DIR / project
23
+ if run:
24
+ path /= run
25
+ if step is not None:
26
+ path /= str(step)
27
+ if filename:
28
+ path /= filename
29
+ return path
30
+
31
+ @staticmethod
32
+ def init_project_media_path(
33
+ project: str, run: str | None = None, step: int | None = None
34
+ ) -> Path:
35
+ path = FileStorage.get_project_media_path(project, run, step)
36
+ path.mkdir(parents=True, exist_ok=True)
37
+ return path
imports.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import pandas as pd
5
+
6
+ from trackio import deploy, utils
7
+ from trackio.sqlite_storage import SQLiteStorage
8
+
9
+
10
+ def import_csv(
11
+ csv_path: str | Path,
12
+ project: str,
13
+ name: str | None = None,
14
+ space_id: str | None = None,
15
+ dataset_id: str | None = None,
16
+ private: bool | None = None,
17
+ ) -> None:
18
+ """
19
+ Imports a CSV file into a Trackio project. The CSV file must contain a `"step"`
20
+ column, may optionally contain a `"timestamp"` column, and any other columns will be
21
+ treated as metrics. It should also include a header row with the column names.
22
+
23
+ TODO: call init() and return a Run object so that the user can continue to log metrics to it.
24
+
25
+ Args:
26
+ csv_path (`str` or `Path`):
27
+ The str or Path to the CSV file to import.
28
+ project (`str`):
29
+ The name of the project to import the CSV file into. Must not be an existing
30
+ project.
31
+ name (`str`, *optional*):
32
+ The name of the Run to import the CSV file into. If not provided, a default
33
+ name will be generated.
34
+ name (`str`, *optional*):
35
+ The name of the run (if not provided, a default name will be generated).
36
+ space_id (`str`, *optional*):
37
+ If provided, the project will be logged to a Hugging Face Space instead of a
38
+ local directory. Should be a complete Space name like `"username/reponame"`
39
+ or `"orgname/reponame"`, or just `"reponame"` in which case the Space will
40
+ be created in the currently-logged-in Hugging Face user's namespace. If the
41
+ Space does not exist, it will be created. If the Space already exists, the
42
+ project will be logged to it.
43
+ dataset_id (`str`, *optional*):
44
+ If provided, a persistent Hugging Face Dataset will be created and the
45
+ metrics will be synced to it every 5 minutes. Should be a complete Dataset
46
+ name like `"username/datasetname"` or `"orgname/datasetname"`, or just
47
+ `"datasetname"` in which case the Dataset will be created in the
48
+ currently-logged-in Hugging Face user's namespace. If the Dataset does not
49
+ exist, it will be created. If the Dataset already exists, the project will
50
+ be appended to it. If not provided, the metrics will be logged to a local
51
+ SQLite database, unless a `space_id` is provided, in which case a Dataset
52
+ will be automatically created with the same name as the Space but with the
53
+ `"_dataset"` suffix.
54
+ private (`bool`, *optional*):
55
+ Whether to make the Space private. If None (default), the repo will be
56
+ public unless the organization's default is private. This value is ignored
57
+ if the repo already exists.
58
+ """
59
+ if SQLiteStorage.get_runs(project):
60
+ raise ValueError(
61
+ f"Project '{project}' already exists. Cannot import CSV into existing project."
62
+ )
63
+
64
+ csv_path = Path(csv_path)
65
+ if not csv_path.exists():
66
+ raise FileNotFoundError(f"CSV file not found: {csv_path}")
67
+
68
+ df = pd.read_csv(csv_path)
69
+ if df.empty:
70
+ raise ValueError("CSV file is empty")
71
+
72
+ column_mapping = utils.simplify_column_names(df.columns.tolist())
73
+ df = df.rename(columns=column_mapping)
74
+
75
+ step_column = None
76
+ for col in df.columns:
77
+ if col.lower() == "step":
78
+ step_column = col
79
+ break
80
+
81
+ if step_column is None:
82
+ raise ValueError("CSV file must contain a 'step' or 'Step' column")
83
+
84
+ if name is None:
85
+ name = csv_path.stem
86
+
87
+ metrics_list = []
88
+ steps = []
89
+ timestamps = []
90
+
91
+ numeric_columns = []
92
+ for column in df.columns:
93
+ if column == step_column:
94
+ continue
95
+ if column == "timestamp":
96
+ continue
97
+
98
+ try:
99
+ pd.to_numeric(df[column], errors="raise")
100
+ numeric_columns.append(column)
101
+ except (ValueError, TypeError):
102
+ continue
103
+
104
+ for _, row in df.iterrows():
105
+ metrics = {}
106
+ for column in numeric_columns:
107
+ value = row[column]
108
+ if bool(pd.notna(value)):
109
+ metrics[column] = float(value)
110
+
111
+ if metrics:
112
+ metrics_list.append(metrics)
113
+ steps.append(int(row[step_column]))
114
+
115
+ if "timestamp" in df.columns and bool(pd.notna(row["timestamp"])):
116
+ timestamps.append(str(row["timestamp"]))
117
+ else:
118
+ timestamps.append("")
119
+
120
+ if metrics_list:
121
+ SQLiteStorage.bulk_log(
122
+ project=project,
123
+ run=name,
124
+ metrics_list=metrics_list,
125
+ steps=steps,
126
+ timestamps=timestamps,
127
+ )
128
+
129
+ print(
130
+ f"* Imported {len(metrics_list)} rows from {csv_path} into project '{project}' as run '{name}'"
131
+ )
132
+ print(f"* Metrics found: {', '.join(metrics_list[0].keys())}")
133
+
134
+ space_id, dataset_id = utils.preprocess_space_and_dataset_ids(space_id, dataset_id)
135
+ if dataset_id is not None:
136
+ os.environ["TRACKIO_DATASET_ID"] = dataset_id
137
+ print(f"* Trackio metrics will be synced to Hugging Face Dataset: {dataset_id}")
138
+
139
+ if space_id is None:
140
+ utils.print_dashboard_instructions(project)
141
+ else:
142
+ deploy.create_space_if_not_exists(
143
+ space_id=space_id, dataset_id=dataset_id, private=private
144
+ )
145
+ deploy.wait_until_space_exists(space_id=space_id)
146
+ deploy.upload_db_to_space(project=project, space_id=space_id)
147
+ print(
148
+ f"* View dashboard by going to: {deploy.SPACE_URL.format(space_id=space_id)}"
149
+ )
150
+
151
+
152
+ def import_tf_events(
153
+ log_dir: str | Path,
154
+ project: str,
155
+ name: str | None = None,
156
+ space_id: str | None = None,
157
+ dataset_id: str | None = None,
158
+ private: bool | None = None,
159
+ ) -> None:
160
+ """
161
+ Imports TensorFlow Events files from a directory into a Trackio project. Each
162
+ subdirectory in the log directory will be imported as a separate run.
163
+
164
+ Args:
165
+ log_dir (`str` or `Path`):
166
+ The str or Path to the directory containing TensorFlow Events files.
167
+ project (`str`):
168
+ The name of the project to import the TensorFlow Events files into. Must not
169
+ be an existing project.
170
+ name (`str`, *optional*):
171
+ The name prefix for runs (if not provided, will use directory names). Each
172
+ subdirectory will create a separate run.
173
+ space_id (`str`, *optional*):
174
+ If provided, the project will be logged to a Hugging Face Space instead of a
175
+ local directory. Should be a complete Space name like `"username/reponame"`
176
+ or `"orgname/reponame"`, or just `"reponame"` in which case the Space will
177
+ be created in the currently-logged-in Hugging Face user's namespace. If the
178
+ Space does not exist, it will be created. If the Space already exists, the
179
+ project will be logged to it.
180
+ dataset_id (`str`, *optional*):
181
+ If provided, a persistent Hugging Face Dataset will be created and the
182
+ metrics will be synced to it every 5 minutes. Should be a complete Dataset
183
+ name like `"username/datasetname"` or `"orgname/datasetname"`, or just
184
+ `"datasetname"` in which case the Dataset will be created in the
185
+ currently-logged-in Hugging Face user's namespace. If the Dataset does not
186
+ exist, it will be created. If the Dataset already exists, the project will
187
+ be appended to it. If not provided, the metrics will be logged to a local
188
+ SQLite database, unless a `space_id` is provided, in which case a Dataset
189
+ will be automatically created with the same name as the Space but with the
190
+ `"_dataset"` suffix.
191
+ private (`bool`, *optional*):
192
+ Whether to make the Space private. If None (default), the repo will be
193
+ public unless the organization's default is private. This value is ignored
194
+ if the repo already exists.
195
+ """
196
+ try:
197
+ from tbparse import SummaryReader
198
+ except ImportError:
199
+ raise ImportError(
200
+ "The `tbparse` package is not installed but is required for `import_tf_events`. Please install trackio with the `tensorboard` extra: `pip install trackio[tensorboard]`."
201
+ )
202
+
203
+ if SQLiteStorage.get_runs(project):
204
+ raise ValueError(
205
+ f"Project '{project}' already exists. Cannot import TF events into existing project."
206
+ )
207
+
208
+ path = Path(log_dir)
209
+ if not path.exists():
210
+ raise FileNotFoundError(f"TF events directory not found: {path}")
211
+
212
+ # Use tbparse to read all tfevents files in the directory structure
213
+ reader = SummaryReader(str(path), extra_columns={"dir_name"})
214
+ df = reader.scalars
215
+
216
+ if df.empty:
217
+ raise ValueError(f"No TensorFlow events data found in {path}")
218
+
219
+ total_imported = 0
220
+ imported_runs = []
221
+
222
+ # Group by dir_name to create separate runs
223
+ for dir_name, group_df in df.groupby("dir_name"):
224
+ try:
225
+ # Determine run name based on directory name
226
+ if dir_name == "":
227
+ run_name = "main" # For files in the root directory
228
+ else:
229
+ run_name = dir_name # Use directory name
230
+
231
+ if name:
232
+ run_name = f"{name}_{run_name}"
233
+
234
+ if group_df.empty:
235
+ print(f"* Skipping directory {dir_name}: no scalar data found")
236
+ continue
237
+
238
+ metrics_list = []
239
+ steps = []
240
+ timestamps = []
241
+
242
+ for _, row in group_df.iterrows():
243
+ # Convert row values to appropriate types
244
+ tag = str(row["tag"])
245
+ value = float(row["value"])
246
+ step = int(row["step"])
247
+
248
+ metrics = {tag: value}
249
+ metrics_list.append(metrics)
250
+ steps.append(step)
251
+
252
+ # Use wall_time if present, else fallback
253
+ if "wall_time" in group_df.columns and not bool(
254
+ pd.isna(row["wall_time"])
255
+ ):
256
+ timestamps.append(str(row["wall_time"]))
257
+ else:
258
+ timestamps.append("")
259
+
260
+ if metrics_list:
261
+ SQLiteStorage.bulk_log(
262
+ project=project,
263
+ run=str(run_name),
264
+ metrics_list=metrics_list,
265
+ steps=steps,
266
+ timestamps=timestamps,
267
+ )
268
+
269
+ total_imported += len(metrics_list)
270
+ imported_runs.append(run_name)
271
+
272
+ print(
273
+ f"* Imported {len(metrics_list)} scalar events from directory '{dir_name}' as run '{run_name}'"
274
+ )
275
+ print(f"* Metrics in this run: {', '.join(set(group_df['tag']))}")
276
+
277
+ except Exception as e:
278
+ print(f"* Error processing directory {dir_name}: {e}")
279
+ continue
280
+
281
+ if not imported_runs:
282
+ raise ValueError("No valid TensorFlow events data could be imported")
283
+
284
+ print(f"* Total imported events: {total_imported}")
285
+ print(f"* Created runs: {', '.join(imported_runs)}")
286
+
287
+ space_id, dataset_id = utils.preprocess_space_and_dataset_ids(space_id, dataset_id)
288
+ if dataset_id is not None:
289
+ os.environ["TRACKIO_DATASET_ID"] = dataset_id
290
+ print(f"* Trackio metrics will be synced to Hugging Face Dataset: {dataset_id}")
291
+
292
+ if space_id is None:
293
+ utils.print_dashboard_instructions(project)
294
+ else:
295
+ deploy.create_space_if_not_exists(
296
+ space_id, dataset_id=dataset_id, private=private
297
+ )
298
+ deploy.wait_until_space_exists(space_id)
299
+ deploy.upload_db_to_space(project, space_id)
300
+ print(
301
+ f"* View dashboard by going to: {deploy.SPACE_URL.format(space_id=space_id)}"
302
+ )
media.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import uuid
4
+ from abc import ABC, abstractmethod
5
+ from pathlib import Path
6
+ from typing import Literal
7
+
8
+ import numpy as np
9
+ from PIL import Image as PILImage
10
+
11
+ try: # absolute imports when installed
12
+ from trackio.file_storage import FileStorage
13
+ from trackio.utils import MEDIA_DIR
14
+ from trackio.video_writer import write_video
15
+ except ImportError: # relative imports for local execution on Spaces
16
+ from file_storage import FileStorage
17
+ from utils import MEDIA_DIR
18
+ from video_writer import write_video
19
+
20
+
21
+ class TrackioMedia(ABC):
22
+ """
23
+ Abstract base class for Trackio media objects
24
+ Provides shared functionality for file handling and serialization.
25
+ """
26
+
27
+ TYPE: str
28
+
29
+ def __init_subclass__(cls, **kwargs):
30
+ """Ensure subclasses define the TYPE attribute."""
31
+ super().__init_subclass__(**kwargs)
32
+ if not hasattr(cls, "TYPE") or cls.TYPE is None:
33
+ raise TypeError(f"Class {cls.__name__} must define TYPE attribute")
34
+
35
+ def __init__(self, value, caption: str | None = None):
36
+ self.caption = caption
37
+ self._value = value
38
+ self._file_path: Path | None = None
39
+
40
+ # Validate file existence for string/Path inputs
41
+ if isinstance(self._value, str | Path):
42
+ if not os.path.isfile(self._value):
43
+ raise ValueError(f"File not found: {self._value}")
44
+
45
+ def _file_extension(self) -> str:
46
+ if self._file_path:
47
+ return self._file_path.suffix[1:].lower()
48
+ if isinstance(self._value, str | Path):
49
+ path = Path(self._value)
50
+ return path.suffix[1:].lower()
51
+ if hasattr(self, "_format") and self._format:
52
+ return self._format
53
+ return "unknown"
54
+
55
+ def _get_relative_file_path(self) -> Path | None:
56
+ return self._file_path
57
+
58
+ def _get_absolute_file_path(self) -> Path | None:
59
+ if self._file_path:
60
+ return MEDIA_DIR / self._file_path
61
+ return None
62
+
63
+ def _save(self, project: str, run: str, step: int = 0):
64
+ if self._file_path:
65
+ return
66
+
67
+ media_dir = FileStorage.init_project_media_path(project, run, step)
68
+ filename = f"{uuid.uuid4()}.{self._file_extension()}"
69
+ file_path = media_dir / filename
70
+
71
+ # Delegate to subclass-specific save logic
72
+ self._save_media(file_path)
73
+
74
+ self._file_path = file_path.relative_to(MEDIA_DIR)
75
+
76
+ @abstractmethod
77
+ def _save_media(self, file_path: Path):
78
+ """
79
+ Performs the actual media saving logic.
80
+ """
81
+ pass
82
+
83
+ def _to_dict(self) -> dict:
84
+ if not self._file_path:
85
+ raise ValueError("Media must be saved to file before serialization")
86
+ return {
87
+ "_type": self.TYPE,
88
+ "file_path": str(self._get_relative_file_path()),
89
+ "caption": self.caption,
90
+ }
91
+
92
+
93
+ TrackioImageSourceType = str | Path | np.ndarray | PILImage.Image
94
+
95
+
96
+ class TrackioImage(TrackioMedia):
97
+ """
98
+ Initializes an Image object.
99
+
100
+ Example:
101
+ ```python
102
+ import trackio
103
+ import numpy as np
104
+ from PIL import Image
105
+
106
+ # Create an image from numpy array
107
+ image_data = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
108
+ image = trackio.Image(image_data, caption="Random image")
109
+ trackio.log({"my_image": image})
110
+
111
+ # Create an image from PIL Image
112
+ pil_image = Image.new('RGB', (100, 100), color='red')
113
+ image = trackio.Image(pil_image, caption="Red square")
114
+ trackio.log({"red_image": image})
115
+
116
+ # Create an image from file path
117
+ image = trackio.Image("path/to/image.jpg", caption="Photo from file")
118
+ trackio.log({"file_image": image})
119
+ ```
120
+
121
+ Args:
122
+ value (`str`, `Path`, `numpy.ndarray`, or `PIL.Image`, *optional*):
123
+ A path to an image, a PIL Image, or a numpy array of shape (height, width, channels).
124
+ caption (`str`, *optional*):
125
+ A string caption for the image.
126
+ """
127
+
128
+ TYPE = "trackio.image"
129
+
130
+ def __init__(self, value: TrackioImageSourceType, caption: str | None = None):
131
+ super().__init__(value, caption)
132
+ self._format: str | None = None
133
+
134
+ if (
135
+ isinstance(self._value, np.ndarray | PILImage.Image)
136
+ and self._format is None
137
+ ):
138
+ self._format = "png"
139
+
140
+ def _as_pil(self) -> PILImage.Image | None:
141
+ try:
142
+ if isinstance(self._value, np.ndarray):
143
+ arr = np.asarray(self._value).astype("uint8")
144
+ return PILImage.fromarray(arr).convert("RGBA")
145
+ if isinstance(self._value, PILImage.Image):
146
+ return self._value.convert("RGBA")
147
+ except Exception as e:
148
+ raise ValueError(f"Failed to process image data: {self._value}") from e
149
+ return None
150
+
151
+ def _save_media(self, file_path: Path):
152
+ if pil := self._as_pil():
153
+ pil.save(file_path, format=self._format)
154
+ elif isinstance(self._value, str | Path):
155
+ if os.path.isfile(self._value):
156
+ shutil.copy(self._value, file_path)
157
+ else:
158
+ raise ValueError(f"File not found: {self._value}")
159
+
160
+
161
+ TrackioVideoSourceType = str | Path | np.ndarray
162
+ TrackioVideoFormatType = Literal["gif", "mp4", "webm"]
163
+
164
+
165
+ class TrackioVideo(TrackioMedia):
166
+ """
167
+ Initializes a Video object.
168
+
169
+ Example:
170
+ ```python
171
+ import trackio
172
+ import numpy as np
173
+
174
+ # Create a simple video from numpy array
175
+ frames = np.random.randint(0, 255, (10, 3, 64, 64), dtype=np.uint8)
176
+ video = trackio.Video(frames, caption="Random video", fps=30)
177
+
178
+ # Create a batch of videos
179
+ batch_frames = np.random.randint(0, 255, (3, 10, 3, 64, 64), dtype=np.uint8)
180
+ batch_video = trackio.Video(batch_frames, caption="Batch of videos", fps=15)
181
+
182
+ # Create video from file path
183
+ video = trackio.Video("path/to/video.mp4", caption="Video from file")
184
+ ```
185
+
186
+ Args:
187
+ value (`str`, `Path`, or `numpy.ndarray`, *optional*):
188
+ A path to a video file, or a numpy array.
189
+ The array should be of type `np.uint8` with RGB values in the range `[0, 255]`.
190
+ It is expected to have shape of either (frames, channels, height, width) or (batch, frames, channels, height, width).
191
+ For the latter, the videos will be tiled into a grid.
192
+ caption (`str`, *optional*):
193
+ A string caption for the video.
194
+ fps (`int`, *optional*):
195
+ Frames per second for the video. Only used when value is an ndarray. Default is `24`.
196
+ format (`Literal["gif", "mp4", "webm"]`, *optional*):
197
+ Video format ("gif", "mp4", or "webm"). Only used when value is an ndarray. Default is "gif".
198
+ """
199
+
200
+ TYPE = "trackio.video"
201
+
202
+ def __init__(
203
+ self,
204
+ value: TrackioVideoSourceType,
205
+ caption: str | None = None,
206
+ fps: int | None = None,
207
+ format: TrackioVideoFormatType | None = None,
208
+ ):
209
+ super().__init__(value, caption)
210
+ if isinstance(value, np.ndarray):
211
+ if format is None:
212
+ format = "gif"
213
+ if fps is None:
214
+ fps = 24
215
+ self._fps = fps
216
+ self._format = format
217
+
218
+ @property
219
+ def _codec(self) -> str:
220
+ match self._format:
221
+ case "gif":
222
+ return "gif"
223
+ case "mp4":
224
+ return "h264"
225
+ case "webm":
226
+ return "vp9"
227
+ case _:
228
+ raise ValueError(f"Unsupported format: {self._format}")
229
+
230
+ def _save_media(self, file_path: Path):
231
+ if isinstance(self._value, np.ndarray):
232
+ video = TrackioVideo._process_ndarray(self._value)
233
+ write_video(file_path, video, fps=self._fps, codec=self._codec)
234
+ elif isinstance(self._value, str | Path):
235
+ if os.path.isfile(self._value):
236
+ shutil.copy(self._value, file_path)
237
+ else:
238
+ raise ValueError(f"File not found: {self._value}")
239
+
240
+ @staticmethod
241
+ def _process_ndarray(value: np.ndarray) -> np.ndarray:
242
+ # Verify value is either 4D (single video) or 5D array (batched videos).
243
+ # Expected format: (frames, channels, height, width) or (batch, frames, channels, height, width)
244
+ if value.ndim < 4:
245
+ raise ValueError(
246
+ "Video requires at least 4 dimensions (frames, channels, height, width)"
247
+ )
248
+ if value.ndim > 5:
249
+ raise ValueError(
250
+ "Videos can have at most 5 dimensions (batch, frames, channels, height, width)"
251
+ )
252
+ if value.ndim == 4:
253
+ # Reshape to 5D with single batch: (1, frames, channels, height, width)
254
+ value = value[np.newaxis, ...]
255
+
256
+ value = TrackioVideo._tile_batched_videos(value)
257
+ return value
258
+
259
+ @staticmethod
260
+ def _tile_batched_videos(video: np.ndarray) -> np.ndarray:
261
+ """
262
+ Tiles a batch of videos into a grid of videos.
263
+
264
+ Input format: (batch, frames, channels, height, width) - original FCHW format
265
+ Output format: (frames, total_height, total_width, channels)
266
+ """
267
+ batch_size, frames, channels, height, width = video.shape
268
+
269
+ next_pow2 = 1 << (batch_size - 1).bit_length()
270
+ if batch_size != next_pow2:
271
+ pad_len = next_pow2 - batch_size
272
+ pad_shape = (pad_len, frames, channels, height, width)
273
+ padding = np.zeros(pad_shape, dtype=video.dtype)
274
+ video = np.concatenate((video, padding), axis=0)
275
+ batch_size = next_pow2
276
+
277
+ n_rows = 1 << ((batch_size.bit_length() - 1) // 2)
278
+ n_cols = batch_size // n_rows
279
+
280
+ # Reshape to grid layout: (n_rows, n_cols, frames, channels, height, width)
281
+ video = video.reshape(n_rows, n_cols, frames, channels, height, width)
282
+
283
+ # Rearrange dimensions to (frames, total_height, total_width, channels)
284
+ video = video.transpose(2, 0, 4, 1, 5, 3)
285
+ video = video.reshape(frames, n_rows * height, n_cols * width, channels)
286
+ return video
package.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "name": "trackio",
3
+ "version": "0.5.3",
4
+ "description": "",
5
+ "python": "true"
6
+ }
py.typed ADDED
File without changes
run.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import time
3
+ from datetime import datetime, timezone
4
+
5
+ import huggingface_hub
6
+ from gradio_client import Client, handle_file
7
+
8
+ from trackio import utils
9
+ from trackio.media import TrackioMedia
10
+ from trackio.sqlite_storage import SQLiteStorage
11
+ from trackio.table import Table
12
+ from trackio.typehints import LogEntry, UploadEntry
13
+
14
+ BATCH_SEND_INTERVAL = 0.5
15
+
16
+
17
+ class Run:
18
+ def __init__(
19
+ self,
20
+ url: str,
21
+ project: str,
22
+ client: Client | None,
23
+ name: str | None = None,
24
+ group: str | None = None,
25
+ config: dict | None = None,
26
+ space_id: str | None = None,
27
+ ):
28
+ self.url = url
29
+ self.project = project
30
+ self._client_lock = threading.Lock()
31
+ self._client_thread = None
32
+ self._client = client
33
+ self._space_id = space_id
34
+ self.name = name or utils.generate_readable_name(
35
+ SQLiteStorage.get_runs(project), space_id
36
+ )
37
+ self.group = group
38
+ self.config = utils.to_json_safe(config or {})
39
+
40
+ if isinstance(self.config, dict):
41
+ for key in self.config:
42
+ if key.startswith("_"):
43
+ raise ValueError(
44
+ f"Config key '{key}' is reserved (keys starting with '_' are reserved for internal use)"
45
+ )
46
+
47
+ self.config["_Username"] = self._get_username()
48
+ self.config["_Created"] = datetime.now(timezone.utc).isoformat()
49
+ self.config["_Group"] = self.group
50
+
51
+ self._queued_logs: list[LogEntry] = []
52
+ self._queued_uploads: list[UploadEntry] = []
53
+ self._stop_flag = threading.Event()
54
+ self._config_logged = False
55
+
56
+ self._client_thread = threading.Thread(target=self._init_client_background)
57
+ self._client_thread.daemon = True
58
+ self._client_thread.start()
59
+
60
+ def _get_username(self) -> str | None:
61
+ """Get the current HuggingFace username if logged in, otherwise None."""
62
+ try:
63
+ who = huggingface_hub.whoami()
64
+ return who["name"] if who else None
65
+ except Exception:
66
+ return None
67
+
68
+ def _batch_sender(self):
69
+ """Send batched logs every BATCH_SEND_INTERVAL."""
70
+ while not self._stop_flag.is_set() or len(self._queued_logs) > 0:
71
+ # If the stop flag has been set, then just quickly send all
72
+ # the logs and exit.
73
+ if not self._stop_flag.is_set():
74
+ time.sleep(BATCH_SEND_INTERVAL)
75
+
76
+ with self._client_lock:
77
+ if self._client is None:
78
+ return
79
+ if self._queued_logs:
80
+ logs_to_send = self._queued_logs.copy()
81
+ self._queued_logs.clear()
82
+ self._client.predict(
83
+ api_name="/bulk_log",
84
+ logs=logs_to_send,
85
+ hf_token=huggingface_hub.utils.get_token(),
86
+ )
87
+ if self._queued_uploads:
88
+ uploads_to_send = self._queued_uploads.copy()
89
+ self._queued_uploads.clear()
90
+ self._client.predict(
91
+ api_name="/bulk_upload_media",
92
+ uploads=uploads_to_send,
93
+ hf_token=huggingface_hub.utils.get_token(),
94
+ )
95
+
96
+ def _init_client_background(self):
97
+ if self._client is None:
98
+ fib = utils.fibo()
99
+ for sleep_coefficient in fib:
100
+ try:
101
+ client = Client(self.url, verbose=False)
102
+
103
+ with self._client_lock:
104
+ self._client = client
105
+ break
106
+ except Exception:
107
+ pass
108
+ if sleep_coefficient is not None:
109
+ time.sleep(0.1 * sleep_coefficient)
110
+
111
+ self._batch_sender()
112
+
113
+ def _process_media(self, metrics, step: int | None) -> dict:
114
+ """
115
+ Serialize media in metrics and upload to space if needed.
116
+ """
117
+ serializable_metrics = {}
118
+ if not step:
119
+ step = 0
120
+ for key, value in metrics.items():
121
+ if isinstance(value, TrackioMedia):
122
+ value._save(self.project, self.name, step)
123
+ serializable_metrics[key] = value._to_dict()
124
+ if self._space_id:
125
+ # Upload local media when deploying to space
126
+ upload_entry: UploadEntry = {
127
+ "project": self.project,
128
+ "run": self.name,
129
+ "step": step,
130
+ "uploaded_file": handle_file(value._get_absolute_file_path()),
131
+ }
132
+ with self._client_lock:
133
+ self._queued_uploads.append(upload_entry)
134
+ else:
135
+ serializable_metrics[key] = value
136
+ return serializable_metrics
137
+
138
+ @staticmethod
139
+ def _replace_tables(metrics):
140
+ for k, v in metrics.items():
141
+ if isinstance(v, Table):
142
+ metrics[k] = v._to_dict()
143
+
144
+ def log(self, metrics: dict, step: int | None = None):
145
+ for k in metrics.keys():
146
+ if k in utils.RESERVED_KEYS or k.startswith("__"):
147
+ raise ValueError(
148
+ f"Please do not use this reserved key as a metric: {k}"
149
+ )
150
+ Run._replace_tables(metrics)
151
+
152
+ metrics = self._process_media(metrics, step)
153
+ metrics = utils.serialize_values(metrics)
154
+
155
+ config_to_log = None
156
+ if not self._config_logged and self.config:
157
+ config_to_log = utils.to_json_safe(self.config)
158
+ self._config_logged = True
159
+
160
+ log_entry: LogEntry = {
161
+ "project": self.project,
162
+ "run": self.name,
163
+ "metrics": metrics,
164
+ "step": step,
165
+ "config": config_to_log,
166
+ }
167
+
168
+ with self._client_lock:
169
+ self._queued_logs.append(log_entry)
170
+
171
+ def finish(self):
172
+ """Cleanup when run is finished."""
173
+ self._stop_flag.set()
174
+
175
+ # Wait for the batch sender to finish before joining the client thread.
176
+ time.sleep(2 * BATCH_SEND_INTERVAL)
177
+
178
+ if self._client_thread is not None:
179
+ print("* Run finished. Uploading logs to Trackio (please wait...)")
180
+ self._client_thread.join()
sqlite_storage.py ADDED
@@ -0,0 +1,677 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import platform
3
+ import sqlite3
4
+ import time
5
+ from datetime import datetime
6
+ from pathlib import Path
7
+ from threading import Lock
8
+
9
+ try:
10
+ import fcntl
11
+ except ImportError: # fcntl is not available on Windows
12
+ fcntl = None
13
+
14
+ import huggingface_hub as hf
15
+ import orjson
16
+ import pandas as pd
17
+
18
+ try: # absolute imports when installed from PyPI
19
+ from trackio.commit_scheduler import CommitScheduler
20
+ from trackio.dummy_commit_scheduler import DummyCommitScheduler
21
+ from trackio.utils import (
22
+ TRACKIO_DIR,
23
+ deserialize_values,
24
+ serialize_values,
25
+ )
26
+ except ImportError: # relative imports when installed from source on Spaces
27
+ from commit_scheduler import CommitScheduler
28
+ from dummy_commit_scheduler import DummyCommitScheduler
29
+ from utils import TRACKIO_DIR, deserialize_values, serialize_values
30
+
31
+ DB_EXT = ".db"
32
+
33
+
34
+ class ProcessLock:
35
+ """A file-based lock that works across processes. Is a no-op on Windows."""
36
+
37
+ def __init__(self, lockfile_path: Path):
38
+ self.lockfile_path = lockfile_path
39
+ self.lockfile = None
40
+ self.is_windows = platform.system() == "Windows"
41
+
42
+ def __enter__(self):
43
+ """Acquire the lock with retry logic."""
44
+ if self.is_windows:
45
+ return self
46
+ self.lockfile_path.parent.mkdir(parents=True, exist_ok=True)
47
+ self.lockfile = open(self.lockfile_path, "w")
48
+
49
+ max_retries = 100
50
+ for attempt in range(max_retries):
51
+ try:
52
+ fcntl.flock(self.lockfile.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
53
+ return self
54
+ except IOError:
55
+ if attempt < max_retries - 1:
56
+ time.sleep(0.1)
57
+ else:
58
+ raise IOError("Could not acquire database lock after 10 seconds")
59
+
60
+ def __exit__(self, exc_type, exc_val, exc_tb):
61
+ """Release the lock."""
62
+ if self.is_windows:
63
+ return
64
+
65
+ if self.lockfile:
66
+ fcntl.flock(self.lockfile.fileno(), fcntl.LOCK_UN)
67
+ self.lockfile.close()
68
+
69
+
70
+ class SQLiteStorage:
71
+ _dataset_import_attempted = False
72
+ _current_scheduler: CommitScheduler | DummyCommitScheduler | None = None
73
+ _scheduler_lock = Lock()
74
+
75
+ @staticmethod
76
+ def _get_connection(db_path: Path) -> sqlite3.Connection:
77
+ conn = sqlite3.connect(str(db_path), timeout=30.0)
78
+ # Keep WAL for concurrency + performance on many small writes
79
+ conn.execute("PRAGMA journal_mode = WAL")
80
+ # ---- Minimal perf tweaks for many tiny transactions ----
81
+ # NORMAL = fsync at critical points only (safer than OFF, much faster than FULL)
82
+ conn.execute("PRAGMA synchronous = NORMAL")
83
+ # Keep temp data in memory to avoid disk hits during small writes
84
+ conn.execute("PRAGMA temp_store = MEMORY")
85
+ # Give SQLite a bit more room for cache (negative = KB, engine-managed)
86
+ conn.execute("PRAGMA cache_size = -20000")
87
+ # --------------------------------------------------------
88
+ conn.row_factory = sqlite3.Row
89
+ return conn
90
+
91
+ @staticmethod
92
+ def _get_process_lock(project: str) -> ProcessLock:
93
+ lockfile_path = TRACKIO_DIR / f"{project}.lock"
94
+ return ProcessLock(lockfile_path)
95
+
96
+ @staticmethod
97
+ def get_project_db_filename(project: str) -> str:
98
+ """Get the database filename for a specific project."""
99
+ safe_project_name = "".join(
100
+ c for c in project if c.isalnum() or c in ("-", "_")
101
+ ).rstrip()
102
+ if not safe_project_name:
103
+ safe_project_name = "default"
104
+ return f"{safe_project_name}{DB_EXT}"
105
+
106
+ @staticmethod
107
+ def get_project_db_path(project: str) -> Path:
108
+ """Get the database path for a specific project."""
109
+ filename = SQLiteStorage.get_project_db_filename(project)
110
+ return TRACKIO_DIR / filename
111
+
112
+ @staticmethod
113
+ def init_db(project: str) -> Path:
114
+ """
115
+ Initialize the SQLite database with required tables.
116
+ Returns the database path.
117
+ """
118
+ db_path = SQLiteStorage.get_project_db_path(project)
119
+ db_path.parent.mkdir(parents=True, exist_ok=True)
120
+ with SQLiteStorage._get_process_lock(project):
121
+ with sqlite3.connect(str(db_path), timeout=30.0) as conn:
122
+ conn.execute("PRAGMA journal_mode = WAL")
123
+ conn.execute("PRAGMA synchronous = NORMAL")
124
+ conn.execute("PRAGMA temp_store = MEMORY")
125
+ conn.execute("PRAGMA cache_size = -20000")
126
+ cursor = conn.cursor()
127
+ cursor.execute(
128
+ """
129
+ CREATE TABLE IF NOT EXISTS metrics (
130
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
131
+ timestamp TEXT NOT NULL,
132
+ run_name TEXT NOT NULL,
133
+ step INTEGER NOT NULL,
134
+ metrics TEXT NOT NULL
135
+ )
136
+ """
137
+ )
138
+ cursor.execute(
139
+ """
140
+ CREATE TABLE IF NOT EXISTS configs (
141
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
142
+ run_name TEXT NOT NULL,
143
+ config TEXT NOT NULL,
144
+ created_at TEXT NOT NULL,
145
+ UNIQUE(run_name)
146
+ )
147
+ """
148
+ )
149
+ cursor.execute(
150
+ """
151
+ CREATE INDEX IF NOT EXISTS idx_metrics_run_step
152
+ ON metrics(run_name, step)
153
+ """
154
+ )
155
+ cursor.execute(
156
+ """
157
+ CREATE INDEX IF NOT EXISTS idx_configs_run_name
158
+ ON configs(run_name)
159
+ """
160
+ )
161
+ cursor.execute(
162
+ """
163
+ CREATE INDEX IF NOT EXISTS idx_metrics_run_timestamp
164
+ ON metrics(run_name, timestamp)
165
+ """
166
+ )
167
+ conn.commit()
168
+ return db_path
169
+
170
+ @staticmethod
171
+ def export_to_parquet():
172
+ """
173
+ Exports all projects' DB files as Parquet under the same path but with extension ".parquet".
174
+ """
175
+ # don't attempt to export (potentially wrong/blank) data before importing for the first time
176
+ if not SQLiteStorage._dataset_import_attempted:
177
+ return
178
+ if not TRACKIO_DIR.exists():
179
+ return
180
+
181
+ all_paths = os.listdir(TRACKIO_DIR)
182
+ db_names = [f for f in all_paths if f.endswith(DB_EXT)]
183
+ for db_name in db_names:
184
+ db_path = TRACKIO_DIR / db_name
185
+ parquet_path = db_path.with_suffix(".parquet")
186
+ if (not parquet_path.exists()) or (
187
+ db_path.stat().st_mtime > parquet_path.stat().st_mtime
188
+ ):
189
+ with sqlite3.connect(str(db_path)) as conn:
190
+ df = pd.read_sql("SELECT * FROM metrics", conn)
191
+ # break out the single JSON metrics column into individual columns
192
+ metrics = df["metrics"].copy()
193
+ metrics = pd.DataFrame(
194
+ metrics.apply(
195
+ lambda x: deserialize_values(orjson.loads(x))
196
+ ).values.tolist(),
197
+ index=df.index,
198
+ )
199
+ del df["metrics"]
200
+ for col in metrics.columns:
201
+ df[col] = metrics[col]
202
+
203
+ df.to_parquet(parquet_path)
204
+
205
+ @staticmethod
206
+ def _cleanup_wal_sidecars(db_path: Path) -> None:
207
+ """Remove leftover -wal/-shm files for a DB basename (prevents disk I/O errors)."""
208
+ for suffix in ("-wal", "-shm"):
209
+ sidecar = Path(str(db_path) + suffix)
210
+ try:
211
+ if sidecar.exists():
212
+ sidecar.unlink()
213
+ except Exception:
214
+ pass
215
+
216
+ @staticmethod
217
+ def import_from_parquet():
218
+ """
219
+ Imports to all DB files that have matching files under the same path but with extension ".parquet".
220
+ """
221
+ if not TRACKIO_DIR.exists():
222
+ return
223
+
224
+ all_paths = os.listdir(TRACKIO_DIR)
225
+ parquet_names = [f for f in all_paths if f.endswith(".parquet")]
226
+ for pq_name in parquet_names:
227
+ parquet_path = TRACKIO_DIR / pq_name
228
+ db_path = parquet_path.with_suffix(DB_EXT)
229
+
230
+ SQLiteStorage._cleanup_wal_sidecars(db_path)
231
+
232
+ df = pd.read_parquet(parquet_path)
233
+ # fix up df to have a single JSON metrics column
234
+ if "metrics" not in df.columns:
235
+ # separate other columns from metrics
236
+ metrics = df.copy()
237
+ other_cols = ["id", "timestamp", "run_name", "step"]
238
+ df = df[other_cols]
239
+ for col in other_cols:
240
+ del metrics[col]
241
+ # combine them all into a single metrics col
242
+ metrics = orjson.loads(metrics.to_json(orient="records"))
243
+ df["metrics"] = [orjson.dumps(serialize_values(row)) for row in metrics]
244
+
245
+ with sqlite3.connect(str(db_path), timeout=30.0) as conn:
246
+ df.to_sql("metrics", conn, if_exists="replace", index=False)
247
+ conn.commit()
248
+
249
+ @staticmethod
250
+ def get_scheduler():
251
+ """
252
+ Get the scheduler for the database based on the environment variables.
253
+ This applies to both local and Spaces.
254
+ """
255
+ with SQLiteStorage._scheduler_lock:
256
+ if SQLiteStorage._current_scheduler is not None:
257
+ return SQLiteStorage._current_scheduler
258
+ hf_token = os.environ.get("HF_TOKEN")
259
+ dataset_id = os.environ.get("TRACKIO_DATASET_ID")
260
+ space_repo_name = os.environ.get("SPACE_REPO_NAME")
261
+ if dataset_id is None or space_repo_name is None:
262
+ scheduler = DummyCommitScheduler()
263
+ else:
264
+ scheduler = CommitScheduler(
265
+ repo_id=dataset_id,
266
+ repo_type="dataset",
267
+ folder_path=TRACKIO_DIR,
268
+ private=True,
269
+ allow_patterns=["*.parquet", "media/**/*"],
270
+ squash_history=True,
271
+ token=hf_token,
272
+ on_before_commit=SQLiteStorage.export_to_parquet,
273
+ )
274
+ SQLiteStorage._current_scheduler = scheduler
275
+ return scheduler
276
+
277
+ @staticmethod
278
+ def log(project: str, run: str, metrics: dict, step: int | None = None):
279
+ """
280
+ Safely log metrics to the database. Before logging, this method will ensure the database exists
281
+ and is set up with the correct tables. It also uses a cross-process lock to prevent
282
+ database locking errors when multiple processes access the same database.
283
+
284
+ This method is not used in the latest versions of Trackio (replaced by bulk_log) but
285
+ is kept for backwards compatibility for users who are connecting to a newer version of
286
+ a Trackio Spaces dashboard with an older version of Trackio installed locally.
287
+ """
288
+ db_path = SQLiteStorage.init_db(project)
289
+ with SQLiteStorage._get_process_lock(project):
290
+ with SQLiteStorage._get_connection(db_path) as conn:
291
+ cursor = conn.cursor()
292
+ cursor.execute(
293
+ """
294
+ SELECT MAX(step)
295
+ FROM metrics
296
+ WHERE run_name = ?
297
+ """,
298
+ (run,),
299
+ )
300
+ last_step = cursor.fetchone()[0]
301
+ current_step = (
302
+ 0
303
+ if step is None and last_step is None
304
+ else (step if step is not None else last_step + 1)
305
+ )
306
+ current_timestamp = datetime.now().isoformat()
307
+ cursor.execute(
308
+ """
309
+ INSERT INTO metrics
310
+ (timestamp, run_name, step, metrics)
311
+ VALUES (?, ?, ?, ?)
312
+ """,
313
+ (
314
+ current_timestamp,
315
+ run,
316
+ current_step,
317
+ orjson.dumps(serialize_values(metrics)),
318
+ ),
319
+ )
320
+ conn.commit()
321
+
322
+ @staticmethod
323
+ def bulk_log(
324
+ project: str,
325
+ run: str,
326
+ metrics_list: list[dict],
327
+ steps: list[int] | None = None,
328
+ timestamps: list[str] | None = None,
329
+ config: dict | None = None,
330
+ ):
331
+ """
332
+ Safely log bulk metrics to the database. Before logging, this method will ensure the database exists
333
+ and is set up with the correct tables. It also uses a cross-process lock to prevent
334
+ database locking errors when multiple processes access the same database.
335
+ """
336
+ if not metrics_list:
337
+ return
338
+
339
+ if timestamps is None:
340
+ timestamps = [datetime.now().isoformat()] * len(metrics_list)
341
+
342
+ db_path = SQLiteStorage.init_db(project)
343
+ with SQLiteStorage._get_process_lock(project):
344
+ with SQLiteStorage._get_connection(db_path) as conn:
345
+ cursor = conn.cursor()
346
+
347
+ if steps is None:
348
+ steps = list(range(len(metrics_list)))
349
+ elif any(s is None for s in steps):
350
+ cursor.execute(
351
+ "SELECT MAX(step) FROM metrics WHERE run_name = ?", (run,)
352
+ )
353
+ last_step = cursor.fetchone()[0]
354
+ current_step = 0 if last_step is None else last_step + 1
355
+ processed_steps = []
356
+ for step in steps:
357
+ if step is None:
358
+ processed_steps.append(current_step)
359
+ current_step += 1
360
+ else:
361
+ processed_steps.append(step)
362
+ steps = processed_steps
363
+
364
+ if len(metrics_list) != len(steps) or len(metrics_list) != len(
365
+ timestamps
366
+ ):
367
+ raise ValueError(
368
+ "metrics_list, steps, and timestamps must have the same length"
369
+ )
370
+
371
+ data = []
372
+ for i, metrics in enumerate(metrics_list):
373
+ data.append(
374
+ (
375
+ timestamps[i],
376
+ run,
377
+ steps[i],
378
+ orjson.dumps(serialize_values(metrics)),
379
+ )
380
+ )
381
+
382
+ cursor.executemany(
383
+ """
384
+ INSERT INTO metrics
385
+ (timestamp, run_name, step, metrics)
386
+ VALUES (?, ?, ?, ?)
387
+ """,
388
+ data,
389
+ )
390
+
391
+ if config:
392
+ current_timestamp = datetime.now().isoformat()
393
+ cursor.execute(
394
+ """
395
+ INSERT OR REPLACE INTO configs
396
+ (run_name, config, created_at)
397
+ VALUES (?, ?, ?)
398
+ """,
399
+ (
400
+ run,
401
+ orjson.dumps(serialize_values(config)),
402
+ current_timestamp,
403
+ ),
404
+ )
405
+
406
+ conn.commit()
407
+
408
+ @staticmethod
409
+ def get_logs(project: str, run: str) -> list[dict]:
410
+ """Retrieve logs for a specific run. Logs include the step count (int) and the timestamp (datetime object)."""
411
+ db_path = SQLiteStorage.get_project_db_path(project)
412
+ if not db_path.exists():
413
+ return []
414
+
415
+ with SQLiteStorage._get_connection(db_path) as conn:
416
+ cursor = conn.cursor()
417
+ cursor.execute(
418
+ """
419
+ SELECT timestamp, step, metrics
420
+ FROM metrics
421
+ WHERE run_name = ?
422
+ ORDER BY timestamp
423
+ """,
424
+ (run,),
425
+ )
426
+
427
+ rows = cursor.fetchall()
428
+ results = []
429
+ for row in rows:
430
+ metrics = orjson.loads(row["metrics"])
431
+ metrics = deserialize_values(metrics)
432
+ metrics["timestamp"] = row["timestamp"]
433
+ metrics["step"] = row["step"]
434
+ results.append(metrics)
435
+ return results
436
+
437
+ @staticmethod
438
+ def load_from_dataset():
439
+ dataset_id = os.environ.get("TRACKIO_DATASET_ID")
440
+ space_repo_name = os.environ.get("SPACE_REPO_NAME")
441
+ if dataset_id is not None and space_repo_name is not None:
442
+ hfapi = hf.HfApi()
443
+ updated = False
444
+ if not TRACKIO_DIR.exists():
445
+ TRACKIO_DIR.mkdir(parents=True, exist_ok=True)
446
+ with SQLiteStorage.get_scheduler().lock:
447
+ try:
448
+ files = hfapi.list_repo_files(dataset_id, repo_type="dataset")
449
+ for file in files:
450
+ # Download parquet and media assets
451
+ if not (file.endswith(".parquet") or file.startswith("media/")):
452
+ continue
453
+ if (TRACKIO_DIR / file).exists():
454
+ continue
455
+ hf.hf_hub_download(
456
+ dataset_id, file, repo_type="dataset", local_dir=TRACKIO_DIR
457
+ )
458
+ updated = True
459
+ except hf.errors.EntryNotFoundError:
460
+ pass
461
+ except hf.errors.RepositoryNotFoundError:
462
+ pass
463
+ if updated:
464
+ SQLiteStorage.import_from_parquet()
465
+ SQLiteStorage._dataset_import_attempted = True
466
+
467
+ @staticmethod
468
+ def get_projects() -> list[str]:
469
+ """
470
+ Get list of all projects by scanning the database files in the trackio directory.
471
+ """
472
+ if not SQLiteStorage._dataset_import_attempted:
473
+ SQLiteStorage.load_from_dataset()
474
+
475
+ projects: set[str] = set()
476
+ if not TRACKIO_DIR.exists():
477
+ return []
478
+
479
+ for db_file in TRACKIO_DIR.glob(f"*{DB_EXT}"):
480
+ project_name = db_file.stem
481
+ projects.add(project_name)
482
+ return sorted(projects)
483
+
484
+ @staticmethod
485
+ def get_runs(project: str) -> list[str]:
486
+ """Get list of all runs for a project."""
487
+ db_path = SQLiteStorage.get_project_db_path(project)
488
+ if not db_path.exists():
489
+ return []
490
+
491
+ with SQLiteStorage._get_connection(db_path) as conn:
492
+ cursor = conn.cursor()
493
+ cursor.execute(
494
+ "SELECT DISTINCT run_name FROM metrics",
495
+ )
496
+ return [row[0] for row in cursor.fetchall()]
497
+
498
+ @staticmethod
499
+ def get_max_steps_for_runs(project: str) -> dict[str, int]:
500
+ """Get the maximum step for each run in a project."""
501
+ db_path = SQLiteStorage.get_project_db_path(project)
502
+ if not db_path.exists():
503
+ return {}
504
+
505
+ with SQLiteStorage._get_connection(db_path) as conn:
506
+ cursor = conn.cursor()
507
+ cursor.execute(
508
+ """
509
+ SELECT run_name, MAX(step) as max_step
510
+ FROM metrics
511
+ GROUP BY run_name
512
+ """
513
+ )
514
+
515
+ results = {}
516
+ for row in cursor.fetchall():
517
+ results[row["run_name"]] = row["max_step"]
518
+
519
+ return results
520
+
521
+ @staticmethod
522
+ def store_config(project: str, run: str, config: dict) -> None:
523
+ """Store configuration for a run."""
524
+ db_path = SQLiteStorage.init_db(project)
525
+
526
+ with SQLiteStorage._get_process_lock(project):
527
+ with SQLiteStorage._get_connection(db_path) as conn:
528
+ cursor = conn.cursor()
529
+ current_timestamp = datetime.now().isoformat()
530
+
531
+ cursor.execute(
532
+ """
533
+ INSERT OR REPLACE INTO configs
534
+ (run_name, config, created_at)
535
+ VALUES (?, ?, ?)
536
+ """,
537
+ (run, orjson.dumps(serialize_values(config)), current_timestamp),
538
+ )
539
+ conn.commit()
540
+
541
+ @staticmethod
542
+ def get_run_config(project: str, run: str) -> dict | None:
543
+ """Get configuration for a specific run."""
544
+ db_path = SQLiteStorage.get_project_db_path(project)
545
+ if not db_path.exists():
546
+ return None
547
+
548
+ with SQLiteStorage._get_connection(db_path) as conn:
549
+ cursor = conn.cursor()
550
+ try:
551
+ cursor.execute(
552
+ """
553
+ SELECT config FROM configs WHERE run_name = ?
554
+ """,
555
+ (run,),
556
+ )
557
+
558
+ row = cursor.fetchone()
559
+ if row:
560
+ config = orjson.loads(row["config"])
561
+ return deserialize_values(config)
562
+ return None
563
+ except sqlite3.OperationalError as e:
564
+ if "no such table: configs" in str(e):
565
+ return None
566
+ raise
567
+
568
+ @staticmethod
569
+ def delete_run(project: str, run: str) -> bool:
570
+ """Delete a run from the database (both metrics and config)."""
571
+ db_path = SQLiteStorage.get_project_db_path(project)
572
+ if not db_path.exists():
573
+ return False
574
+
575
+ with SQLiteStorage._get_process_lock(project):
576
+ with SQLiteStorage._get_connection(db_path) as conn:
577
+ cursor = conn.cursor()
578
+ try:
579
+ cursor.execute("DELETE FROM metrics WHERE run_name = ?", (run,))
580
+ cursor.execute("DELETE FROM configs WHERE run_name = ?", (run,))
581
+ conn.commit()
582
+ return True
583
+ except sqlite3.Error:
584
+ return False
585
+
586
+ @staticmethod
587
+ def get_all_run_configs(project: str) -> dict[str, dict]:
588
+ """Get configurations for all runs in a project."""
589
+ db_path = SQLiteStorage.get_project_db_path(project)
590
+ if not db_path.exists():
591
+ return {}
592
+
593
+ with SQLiteStorage._get_connection(db_path) as conn:
594
+ cursor = conn.cursor()
595
+ try:
596
+ cursor.execute(
597
+ """
598
+ SELECT run_name, config FROM configs
599
+ """
600
+ )
601
+
602
+ results = {}
603
+ for row in cursor.fetchall():
604
+ config = orjson.loads(row["config"])
605
+ results[row["run_name"]] = deserialize_values(config)
606
+ return results
607
+ except sqlite3.OperationalError as e:
608
+ if "no such table: configs" in str(e):
609
+ return {}
610
+ raise
611
+
612
+ @staticmethod
613
+ def get_metric_values(project: str, run: str, metric_name: str) -> list[dict]:
614
+ """Get all values for a specific metric in a project/run."""
615
+ db_path = SQLiteStorage.get_project_db_path(project)
616
+ if not db_path.exists():
617
+ return []
618
+
619
+ with SQLiteStorage._get_connection(db_path) as conn:
620
+ cursor = conn.cursor()
621
+ cursor.execute(
622
+ """
623
+ SELECT timestamp, step, metrics
624
+ FROM metrics
625
+ WHERE run_name = ?
626
+ ORDER BY timestamp
627
+ """,
628
+ (run,),
629
+ )
630
+
631
+ rows = cursor.fetchall()
632
+ results = []
633
+ for row in rows:
634
+ metrics = orjson.loads(row["metrics"])
635
+ metrics = deserialize_values(metrics)
636
+ if metric_name in metrics:
637
+ results.append(
638
+ {
639
+ "timestamp": row["timestamp"],
640
+ "step": row["step"],
641
+ "value": metrics[metric_name],
642
+ }
643
+ )
644
+ return results
645
+
646
+ @staticmethod
647
+ def get_all_metrics_for_run(project: str, run: str) -> list[str]:
648
+ """Get all metric names for a specific project/run."""
649
+ db_path = SQLiteStorage.get_project_db_path(project)
650
+ if not db_path.exists():
651
+ return []
652
+
653
+ with SQLiteStorage._get_connection(db_path) as conn:
654
+ cursor = conn.cursor()
655
+ cursor.execute(
656
+ """
657
+ SELECT metrics
658
+ FROM metrics
659
+ WHERE run_name = ?
660
+ ORDER BY timestamp
661
+ """,
662
+ (run,),
663
+ )
664
+
665
+ rows = cursor.fetchall()
666
+ all_metrics = set()
667
+ for row in rows:
668
+ metrics = orjson.loads(row["metrics"])
669
+ metrics = deserialize_values(metrics)
670
+ for key in metrics.keys():
671
+ if key not in ["timestamp", "step"]:
672
+ all_metrics.add(key)
673
+ return sorted(list(all_metrics))
674
+
675
+ def finish(self):
676
+ """Cleanup when run is finished."""
677
+ pass
table.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Literal
2
+
3
+ from pandas import DataFrame
4
+
5
+
6
+ class Table:
7
+ """
8
+ Initializes a Table object.
9
+
10
+ Args:
11
+ columns (`list[str]`, *optional*):
12
+ Names of the columns in the table. Optional if `data` is provided. Not
13
+ expected if `dataframe` is provided. Currently ignored.
14
+ data (`list[list[Any]]`, *optional*):
15
+ 2D row-oriented array of values.
16
+ dataframe (`pandas.`DataFrame``, *optional*):
17
+ DataFrame object used to create the table. When set, `data` and `columns`
18
+ arguments are ignored.
19
+ rows (`list[list[any]]`, *optional*):
20
+ Currently ignored.
21
+ optional (`bool` or `list[bool]`, *optional*, defaults to `True`):
22
+ Currently ignored.
23
+ allow_mixed_types (`bool`, *optional*, defaults to `False`):
24
+ Currently ignored.
25
+ log_mode: (`Literal["IMMUTABLE", "MUTABLE", "INCREMENTAL"]` or `None`, *optional*, defaults to `"IMMUTABLE"`):
26
+ Currently ignored.
27
+ """
28
+
29
+ TYPE = "trackio.table"
30
+
31
+ def __init__(
32
+ self,
33
+ columns: list[str] | None = None,
34
+ data: list[list[Any]] | None = None,
35
+ dataframe: DataFrame | None = None,
36
+ rows: list[list[Any]] | None = None,
37
+ optional: bool | list[bool] = True,
38
+ allow_mixed_types: bool = False,
39
+ log_mode: Literal["IMMUTABLE", "MUTABLE", "INCREMENTAL"] | None = "IMMUTABLE",
40
+ ):
41
+ # TODO: implement support for columns, dtype, optional, allow_mixed_types, and log_mode.
42
+ # for now (like `rows`) they are included for API compat but don't do anything.
43
+
44
+ if dataframe is None:
45
+ self.data = data
46
+ else:
47
+ self.data = dataframe.to_dict(orient="records")
48
+
49
+ def _to_dict(self):
50
+ return {
51
+ "_type": self.TYPE,
52
+ "_value": self.data,
53
+ }
typehints.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, TypedDict
2
+
3
+ from gradio import FileData
4
+
5
+
6
+ class LogEntry(TypedDict):
7
+ project: str
8
+ run: str
9
+ metrics: dict[str, Any]
10
+ step: int | None
11
+ config: dict[str, Any] | None
12
+
13
+
14
+ class UploadEntry(TypedDict):
15
+ project: str
16
+ run: str
17
+ step: int | None
18
+ uploaded_file: FileData
ui/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from trackio.ui.main import demo
3
+ from trackio.ui.run_detail import run_detail_page
4
+ from trackio.ui.runs import run_page
5
+ except ImportError:
6
+ from ui.main import demo
7
+ from ui.run_detail import run_detail_page
8
+ from ui.runs import run_page
9
+
10
+ __all__ = ["demo", "run_page", "run_detail_page"]
ui/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (462 Bytes). View file
 
ui/__pycache__/fns.cpython-310.pyc ADDED
Binary file (7.59 kB). View file
 
ui/__pycache__/main.cpython-310.pyc ADDED
Binary file (26.7 kB). View file
 
ui/__pycache__/run_detail.cpython-310.pyc ADDED
Binary file (2.6 kB). View file
 
ui/__pycache__/runs.cpython-310.pyc ADDED
Binary file (6.98 kB). View file
 
ui/fns.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared functions for the Trackio UI."""
2
+
3
+ import os
4
+
5
+ import gradio as gr
6
+ import huggingface_hub as hf
7
+
8
+ try:
9
+ import trackio.utils as utils
10
+ from trackio.sqlite_storage import SQLiteStorage
11
+ from trackio.ui.helpers.run_selection import RunSelection
12
+ except ImportError:
13
+ import utils
14
+ from sqlite_storage import SQLiteStorage
15
+ from ui.helpers.run_selection import RunSelection
16
+
17
+ CONFIG_COLUMN_MAPPINGS = {
18
+ "_Username": "Username",
19
+ "_Created": "Created",
20
+ "_Group": "Group",
21
+ }
22
+ CONFIG_COLUMN_MAPPINGS_REVERSE = {v: k for k, v in CONFIG_COLUMN_MAPPINGS.items()}
23
+
24
+
25
+ HfApi = hf.HfApi()
26
+
27
+
28
+ def get_project_info() -> str | None:
29
+ dataset_id = os.environ.get("TRACKIO_DATASET_ID")
30
+ space_id = utils.get_space()
31
+ if utils.persistent_storage_enabled():
32
+ return "&#10024; Persistent Storage is enabled, logs are stored directly in this Space."
33
+ if dataset_id:
34
+ sync_status = utils.get_sync_status(SQLiteStorage.get_scheduler())
35
+ upgrade_message = f"New changes are synced every 5 min <span class='info-container'><input type='checkbox' class='info-checkbox' id='upgrade-info'><label for='upgrade-info' class='info-icon'>&#9432;</label><span class='info-expandable'> To avoid losing data between syncs, <a href='https://huggingface.co/spaces/{space_id}/settings' class='accent-link'>click here</a> to open this Space's settings and add Persistent Storage. Make sure data is synced prior to enabling.</span></span>"
36
+ if sync_status is not None:
37
+ info = f"&#x21bb; Backed up {sync_status} min ago to <a href='https://huggingface.co/datasets/{dataset_id}' target='_blank' class='accent-link'>{dataset_id}</a> | {upgrade_message}"
38
+ else:
39
+ info = f"&#x21bb; Not backed up yet to <a href='https://huggingface.co/datasets/{dataset_id}' target='_blank' class='accent-link'>{dataset_id}</a> | {upgrade_message}"
40
+ return info
41
+ return None
42
+
43
+
44
+ def get_projects(request: gr.Request):
45
+ projects = SQLiteStorage.get_projects()
46
+ if project := request.query_params.get("project"):
47
+ interactive = False
48
+ else:
49
+ interactive = True
50
+ if selected_project := request.query_params.get("selected_project"):
51
+ project = selected_project
52
+ else:
53
+ project = projects[0] if projects else None
54
+
55
+ return gr.Dropdown(
56
+ label="Project",
57
+ choices=projects,
58
+ value=project,
59
+ allow_custom_value=True,
60
+ interactive=interactive,
61
+ info=get_project_info(),
62
+ )
63
+
64
+
65
+ def update_navbar_value(project_dd, request: gr.Request):
66
+ write_token = None
67
+ if hasattr(request, "query_params") and request.query_params:
68
+ write_token = request.query_params.get("write_token")
69
+
70
+ metrics_url = f"?selected_project={project_dd}"
71
+ runs_url = f"runs?selected_project={project_dd}"
72
+
73
+ if write_token:
74
+ metrics_url += f"&write_token={write_token}"
75
+ runs_url += f"&write_token={write_token}"
76
+
77
+ return gr.Navbar(
78
+ value=[
79
+ ("Metrics", metrics_url),
80
+ ("Runs", runs_url),
81
+ ]
82
+ )
83
+
84
+
85
+ def check_hf_token_has_write_access(hf_token: str | None) -> None:
86
+ """
87
+ Checks to see if the provided hf_token is valid and has write access to the Space
88
+ that Trackio is running in. If the hf_token is valid or if Trackio is not running
89
+ on a Space, this function does nothing. Otherwise, it raises a PermissionError.
90
+ """
91
+ if os.getenv("SYSTEM") == "spaces": # if we are running in Spaces
92
+ # check auth token passed in
93
+ if hf_token is None:
94
+ raise PermissionError(
95
+ "Expected a HF_TOKEN to be provided when logging to a Space"
96
+ )
97
+ who = HfApi.whoami(hf_token)
98
+ owner_name = os.getenv("SPACE_AUTHOR_NAME")
99
+ repo_name = os.getenv("SPACE_REPO_NAME")
100
+ # make sure the token user is either the author of the space,
101
+ # or is a member of an org that is the author.
102
+ orgs = [o["name"] for o in who["orgs"]]
103
+ if owner_name != who["name"] and owner_name not in orgs:
104
+ raise PermissionError(
105
+ "Expected the provided hf_token to be the user owner of the space, or be a member of the org owner of the space"
106
+ )
107
+ # reject fine-grained tokens without specific repo access
108
+ access_token = who["auth"]["accessToken"]
109
+ if access_token["role"] == "fineGrained":
110
+ matched = False
111
+ for item in access_token["fineGrained"]["scoped"]:
112
+ if (
113
+ item["entity"]["type"] == "space"
114
+ and item["entity"]["name"] == f"{owner_name}/{repo_name}"
115
+ and "repo.write" in item["permissions"]
116
+ ):
117
+ matched = True
118
+ break
119
+ if (
120
+ (
121
+ item["entity"]["type"] == "user"
122
+ or item["entity"]["type"] == "org"
123
+ )
124
+ and item["entity"]["name"] == owner_name
125
+ and "repo.write" in item["permissions"]
126
+ ):
127
+ matched = True
128
+ break
129
+ if not matched:
130
+ raise PermissionError(
131
+ "Expected the provided hf_token with fine grained permissions to provide write access to the space"
132
+ )
133
+ # reject read-only tokens
134
+ elif access_token["role"] != "write":
135
+ raise PermissionError(
136
+ "Expected the provided hf_token to provide write permissions"
137
+ )
138
+
139
+
140
+ def check_oauth_token_has_write_access(oauth_token: str | None) -> None:
141
+ """
142
+ Checks to see if the oauth token provided via Gradio's OAuth is valid and has write access
143
+ to the Space that Trackio is running in. If the oauth token is valid or if Trackio is not running
144
+ on a Space, this function does nothing. Otherwise, it raises a PermissionError.
145
+ """
146
+ if not os.getenv("SYSTEM") == "spaces":
147
+ return
148
+ if oauth_token is None:
149
+ raise PermissionError(
150
+ "Expected an oauth to be provided when logging to a Space"
151
+ )
152
+ who = HfApi.whoami(oauth_token)
153
+ user_name = who["name"]
154
+ owner_name = os.getenv("SPACE_AUTHOR_NAME")
155
+ if user_name == owner_name:
156
+ return
157
+ # check if user is a member of an org that owns the space with write permissions
158
+ for org in who["orgs"]:
159
+ if org["name"] == owner_name and org["roleInOrg"] == "write":
160
+ return
161
+ raise PermissionError(
162
+ "Expected the oauth token to be the user owner of the space, or be a member of the org owner of the space"
163
+ )
164
+
165
+
166
+ def get_group_by_fields(project: str):
167
+ configs = SQLiteStorage.get_all_run_configs(project) if project else {}
168
+ keys = set()
169
+ for config in configs.values():
170
+ keys.update(config.keys())
171
+ keys.discard("_Created")
172
+ keys = [CONFIG_COLUMN_MAPPINGS.get(key, key) for key in keys]
173
+ choices = [None] + sorted(keys)
174
+ return gr.Dropdown(
175
+ choices=choices,
176
+ value=None,
177
+ interactive=True,
178
+ )
179
+
180
+
181
+ def group_runs_by_config(
182
+ project: str, config_key: str, filter_text: str | None = None
183
+ ) -> dict[str, list[str]]:
184
+ if not project or not config_key:
185
+ return {}
186
+ display_key = config_key
187
+ config_key = CONFIG_COLUMN_MAPPINGS_REVERSE.get(config_key, config_key)
188
+ configs = SQLiteStorage.get_all_run_configs(project)
189
+ groups: dict[str, list[str]] = {}
190
+ for run_name, config in configs.items():
191
+ if filter_text and filter_text not in run_name:
192
+ continue
193
+ group_name = config.get(config_key, "None")
194
+ label = f"{display_key}: {group_name}"
195
+ groups.setdefault(label, []).append(run_name)
196
+ for label in groups:
197
+ groups[label].sort()
198
+ sorted_groups = dict(sorted(groups.items(), key=lambda kv: kv[0].lower()))
199
+ return sorted_groups
200
+
201
+
202
+ def run_checkbox_update(selection: RunSelection, **kwargs) -> gr.CheckboxGroup:
203
+ return gr.CheckboxGroup(
204
+ choices=selection.choices,
205
+ value=selection.selected,
206
+ **kwargs,
207
+ )
208
+
209
+
210
+ def handle_run_checkbox_change(
211
+ selected_runs: list[str] | None, selection: RunSelection
212
+ ) -> RunSelection:
213
+ selection.select(selected_runs or [])
214
+ return selection
215
+
216
+
217
+ def handle_group_checkbox_change(
218
+ group_selected: list[str] | None,
219
+ selection: RunSelection,
220
+ group_runs: list[str] | None,
221
+ ):
222
+ subset, _ = selection.replace_group(group_runs or [], group_selected or [])
223
+ return (
224
+ selection,
225
+ gr.CheckboxGroup(value=subset),
226
+ run_checkbox_update(selection),
227
+ )
228
+
229
+
230
+ def handle_group_toggle(
231
+ select_all: bool,
232
+ selection: RunSelection,
233
+ group_runs: list[str] | None,
234
+ ):
235
+ target = list(group_runs or []) if select_all else []
236
+ subset, _ = selection.replace_group(group_runs or [], target)
237
+ return (
238
+ selection,
239
+ gr.CheckboxGroup(value=subset),
240
+ run_checkbox_update(selection),
241
+ )
ui/helpers/__pycache__/run_selection.cpython-310.pyc ADDED
Binary file (2.04 kB). View file
 
ui/helpers/run_selection.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+ try:
4
+ import trackio.utils as utils
5
+ except ImportError:
6
+ import utils
7
+
8
+
9
+ @dataclass
10
+ class RunSelection:
11
+ choices: list[str] = field(default_factory=list)
12
+ selected: list[str] = field(default_factory=list)
13
+ locked: bool = False
14
+
15
+ def update_choices(
16
+ self, runs: list[str], preferred: list[str] | None = None
17
+ ) -> bool:
18
+ if self.choices == runs:
19
+ return False
20
+ new_choices = set(runs) - set(self.choices)
21
+ self.choices = list(runs)
22
+ if self.locked:
23
+ base = set(self.selected) | new_choices
24
+ elif preferred:
25
+ base = set(preferred)
26
+ else:
27
+ base = set(runs)
28
+ self.selected = [run for run in self.choices if run in base]
29
+ return True
30
+
31
+ def select(self, runs: list[str]) -> list[str]:
32
+ choice_set = set(self.choices)
33
+ self.selected = [run for run in runs if run in choice_set]
34
+ self.locked = True
35
+ return self.selected
36
+
37
+ def replace_group(
38
+ self, group_runs: list[str], new_subset: list[str] | None
39
+ ) -> tuple[list[str], list[str]]:
40
+ new_subset = utils.ordered_subset(group_runs, new_subset)
41
+ selection_set = set(self.selected)
42
+ selection_set.difference_update(group_runs)
43
+ selection_set.update(new_subset)
44
+ self.selected = [run for run in self.choices if run in selection_set]
45
+ self.locked = True
46
+ return new_subset, self.selected
ui/main.py ADDED
@@ -0,0 +1,1212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """The main page for the Trackio UI."""
2
+
3
+ import os
4
+ import re
5
+ import secrets
6
+ import shutil
7
+ from dataclasses import dataclass
8
+ from typing import Any
9
+
10
+ import gradio as gr
11
+ import numpy as np
12
+ import pandas as pd
13
+
14
+ try:
15
+ import trackio.utils as utils
16
+ from trackio.file_storage import FileStorage
17
+ from trackio.media import TrackioImage, TrackioVideo
18
+ from trackio.sqlite_storage import SQLiteStorage
19
+ from trackio.table import Table
20
+ from trackio.typehints import LogEntry, UploadEntry
21
+ from trackio.ui import fns
22
+ from trackio.ui.helpers.run_selection import RunSelection
23
+ from trackio.ui.run_detail import run_detail_page
24
+ from trackio.ui.runs import run_page
25
+ except ImportError:
26
+ import utils
27
+ from file_storage import FileStorage
28
+ from media import TrackioImage, TrackioVideo
29
+ from sqlite_storage import SQLiteStorage
30
+ from table import Table
31
+ from typehints import LogEntry, UploadEntry
32
+ from ui import fns
33
+ from ui.helpers.run_selection import RunSelection
34
+ from ui.run_detail import run_detail_page
35
+ from ui.runs import run_page
36
+
37
+
38
+ INSTRUCTIONS_SPACES = """
39
+ ## Start logging with Trackio 🤗
40
+
41
+ To start logging to this Trackio dashboard, first make sure you have the Trackio library installed. You can do this by running:
42
+
43
+ ```bash
44
+ pip install trackio
45
+ ```
46
+
47
+ Then, start logging to this Trackio dashboard by passing in the `space_id` to `trackio.init()`:
48
+
49
+ ```python
50
+ import trackio
51
+ trackio.init(project="my-project", space_id="{}")
52
+ ```
53
+
54
+ Then call `trackio.log()` to log metrics.
55
+
56
+ ```python
57
+ for i in range(10):
58
+ trackio.log({{"loss": 1/(i+1)}})
59
+ ```
60
+
61
+ Finally, call `trackio.finish()` to finish the run.
62
+
63
+ ```python
64
+ trackio.finish()
65
+ ```
66
+ """
67
+
68
+ INSTRUCTIONS_LOCAL = """
69
+ ## Start logging with Trackio 🤗
70
+
71
+ You can create a new project by calling `trackio.init()`:
72
+
73
+ ```python
74
+ import trackio
75
+ trackio.init(project="my-project")
76
+ ```
77
+
78
+ Then call `trackio.log()` to log metrics.
79
+
80
+ ```python
81
+ for i in range(10):
82
+ trackio.log({"loss": 1/(i+1)})
83
+ ```
84
+
85
+ Finally, call `trackio.finish()` to finish the run.
86
+
87
+ ```python
88
+ trackio.finish()
89
+ ```
90
+
91
+ Read the [Trackio documentation](https://huggingface.co/docs/trackio/en/index) for more examples.
92
+ """
93
+
94
+
95
+ def get_runs(project) -> list[str]:
96
+ if not project:
97
+ return []
98
+ return SQLiteStorage.get_runs(project)
99
+
100
+
101
+ def get_available_metrics(project: str, runs: list[str]) -> list[str]:
102
+ """Get all available metrics across all runs for x-axis selection."""
103
+ if not project or not runs:
104
+ return ["step", "time"]
105
+
106
+ all_metrics = set()
107
+ for run in runs:
108
+ metrics = SQLiteStorage.get_logs(project, run)
109
+ if metrics:
110
+ df = pd.DataFrame(metrics)
111
+ numeric_cols = df.select_dtypes(include="number").columns
112
+ numeric_cols = [c for c in numeric_cols if c not in utils.RESERVED_KEYS]
113
+ all_metrics.update(numeric_cols)
114
+
115
+ all_metrics.add("step")
116
+ all_metrics.add("time")
117
+
118
+ sorted_metrics = utils.sort_metrics_by_prefix(list(all_metrics))
119
+
120
+ result = ["step", "time"]
121
+ for metric in sorted_metrics:
122
+ if metric not in result:
123
+ result.append(metric)
124
+
125
+ return result
126
+
127
+
128
+ @dataclass
129
+ class MediaData:
130
+ caption: str | None
131
+ file_path: str
132
+
133
+
134
+ def extract_media(logs: list[dict]) -> dict[str, list[MediaData]]:
135
+ media_by_key: dict[str, list[MediaData]] = {}
136
+ logs = sorted(logs, key=lambda x: x.get("step", 0))
137
+ for log in logs:
138
+ for key, value in log.items():
139
+ if isinstance(value, dict):
140
+ type = value.get("_type")
141
+ if type == TrackioImage.TYPE or type == TrackioVideo.TYPE:
142
+ if key not in media_by_key:
143
+ media_by_key[key] = []
144
+ try:
145
+ media_data = MediaData(
146
+ file_path=utils.MEDIA_DIR / value.get("file_path"),
147
+ caption=value.get("caption"),
148
+ )
149
+ media_by_key[key].append(media_data)
150
+ except Exception as e:
151
+ print(f"Media currently unavailable: {key}: {e}")
152
+ return media_by_key
153
+
154
+
155
+ def load_run_data(
156
+ project: str | None,
157
+ run: str | None,
158
+ smoothing_granularity: int,
159
+ x_axis: str,
160
+ log_scale: bool = False,
161
+ ) -> tuple[pd.DataFrame, dict]:
162
+ if not project or not run:
163
+ return None, None
164
+
165
+ logs = SQLiteStorage.get_logs(project, run)
166
+ if not logs:
167
+ return None, None
168
+
169
+ media = extract_media(logs)
170
+ df = pd.DataFrame(logs)
171
+
172
+ if "step" not in df.columns:
173
+ df["step"] = range(len(df))
174
+
175
+ if x_axis == "time" and "timestamp" in df.columns:
176
+ df["timestamp"] = pd.to_datetime(df["timestamp"])
177
+ first_timestamp = df["timestamp"].min()
178
+ df["time"] = (df["timestamp"] - first_timestamp).dt.total_seconds()
179
+ x_column = "time"
180
+ elif x_axis == "step":
181
+ x_column = "step"
182
+ else:
183
+ x_column = x_axis
184
+
185
+ if log_scale and x_column in df.columns:
186
+ x_vals = df[x_column]
187
+ if (x_vals <= 0).any():
188
+ df[x_column] = np.log10(np.maximum(x_vals, 0) + 1)
189
+ else:
190
+ df[x_column] = np.log10(x_vals)
191
+
192
+ if smoothing_granularity > 0:
193
+ numeric_cols = df.select_dtypes(include="number").columns
194
+ numeric_cols = [c for c in numeric_cols if c not in utils.RESERVED_KEYS]
195
+
196
+ df_original = df.copy()
197
+ df_original["run"] = run
198
+ df_original["data_type"] = "original"
199
+
200
+ df_smoothed = df.copy()
201
+ window_size = max(3, min(smoothing_granularity, len(df)))
202
+ df_smoothed[numeric_cols] = (
203
+ df_smoothed[numeric_cols]
204
+ .rolling(window=window_size, center=True, min_periods=1)
205
+ .mean()
206
+ )
207
+ df_smoothed["run"] = f"{run}_smoothed"
208
+ df_smoothed["data_type"] = "smoothed"
209
+
210
+ combined_df = pd.concat([df_original, df_smoothed], ignore_index=True)
211
+ combined_df["x_axis"] = x_column
212
+ return combined_df, media
213
+ else:
214
+ df["run"] = run
215
+ df["data_type"] = "original"
216
+ df["x_axis"] = x_column
217
+ return df, media
218
+
219
+
220
+ def refresh_runs(
221
+ project: str | None,
222
+ filter_text: str | None,
223
+ selection: RunSelection,
224
+ selected_runs_from_url: list[str] | None = None,
225
+ ):
226
+ if project is None:
227
+ runs: list[str] = []
228
+ else:
229
+ runs = get_runs(project)
230
+ if filter_text:
231
+ runs = [r for r in runs if filter_text in r]
232
+
233
+ preferred = None
234
+ if selected_runs_from_url:
235
+ preferred = [r for r in runs if r in selected_runs_from_url]
236
+
237
+ did_change = selection.update_choices(runs, preferred)
238
+ return (
239
+ fns.run_checkbox_update(selection) if did_change else gr.CheckboxGroup(),
240
+ gr.Textbox(label=f"Runs ({len(runs)})"),
241
+ selection,
242
+ )
243
+
244
+
245
+ def generate_embed(project: str, metrics: str, selection: RunSelection) -> str:
246
+ return utils.generate_embed_code(project, metrics, selection.selected)
247
+
248
+
249
+ def update_x_axis_choices(project, selection):
250
+ """Update x-axis dropdown choices based on available metrics."""
251
+ runs = selection.selected
252
+ available_metrics = get_available_metrics(project, runs)
253
+ return gr.Dropdown(
254
+ label="X-axis",
255
+ choices=available_metrics,
256
+ value="step",
257
+ )
258
+
259
+
260
+ def toggle_timer(cb_value):
261
+ if cb_value:
262
+ return gr.Timer(active=True)
263
+ else:
264
+ return gr.Timer(active=False)
265
+
266
+
267
+ def upload_db_to_space(
268
+ project: str, uploaded_db: gr.FileData, hf_token: str | None
269
+ ) -> None:
270
+ """
271
+ Uploads the database of a local Trackio project to a Hugging Face Space.
272
+ """
273
+ fns.check_hf_token_has_write_access(hf_token)
274
+ db_project_path = SQLiteStorage.get_project_db_path(project)
275
+ if os.path.exists(db_project_path):
276
+ raise gr.Error(
277
+ f"Trackio database file already exists for project {project}, cannot overwrite."
278
+ )
279
+ os.makedirs(os.path.dirname(db_project_path), exist_ok=True)
280
+ shutil.copy(uploaded_db["path"], db_project_path)
281
+
282
+
283
+ def bulk_upload_media(uploads: list[UploadEntry], hf_token: str | None) -> None:
284
+ """
285
+ Uploads media files to a Trackio dashboard. Each entry in the list is a tuple of the project, run, and media file to be uploaded.
286
+ """
287
+ fns.check_hf_token_has_write_access(hf_token)
288
+ for upload in uploads:
289
+ media_path = FileStorage.init_project_media_path(
290
+ upload["project"], upload["run"], upload["step"]
291
+ )
292
+ shutil.copy(upload["uploaded_file"]["path"], media_path)
293
+
294
+
295
+ def log(
296
+ project: str,
297
+ run: str,
298
+ metrics: dict[str, Any],
299
+ step: int | None,
300
+ hf_token: str | None,
301
+ ) -> None:
302
+ """
303
+ Note: this method is not used in the latest versions of Trackio (replaced by bulk_log) but
304
+ is kept for backwards compatibility for users who are connecting to a newer version of
305
+ a Trackio Spaces dashboard with an older version of Trackio installed locally.
306
+ """
307
+ fns.check_hf_token_has_write_access(hf_token)
308
+ SQLiteStorage.log(project=project, run=run, metrics=metrics, step=step)
309
+
310
+
311
+ def bulk_log(
312
+ logs: list[LogEntry],
313
+ hf_token: str | None,
314
+ ) -> None:
315
+ """
316
+ Logs a list of metrics to a Trackio dashboard. Each entry in the list is a dictionary of the project, run, a dictionary of metrics, and optionally, a step and config.
317
+ """
318
+ fns.check_hf_token_has_write_access(hf_token)
319
+
320
+ logs_by_run = {}
321
+ for log_entry in logs:
322
+ key = (log_entry["project"], log_entry["run"])
323
+ if key not in logs_by_run:
324
+ logs_by_run[key] = {"metrics": [], "steps": [], "config": None}
325
+ logs_by_run[key]["metrics"].append(log_entry["metrics"])
326
+ logs_by_run[key]["steps"].append(log_entry.get("step"))
327
+ if log_entry.get("config") and logs_by_run[key]["config"] is None:
328
+ logs_by_run[key]["config"] = log_entry["config"]
329
+
330
+ for (project, run), data in logs_by_run.items():
331
+ SQLiteStorage.bulk_log(
332
+ project=project,
333
+ run=run,
334
+ metrics_list=data["metrics"],
335
+ steps=data["steps"],
336
+ config=data["config"],
337
+ )
338
+
339
+
340
+ def get_metric_values(
341
+ project: str,
342
+ run: str,
343
+ metric_name: str,
344
+ ) -> list[dict]:
345
+ """
346
+ Get all values for a specific metric in a project/run.
347
+ Returns a list of dictionaries with timestamp, step, and value.
348
+ """
349
+ return SQLiteStorage.get_metric_values(project, run, metric_name)
350
+
351
+
352
+ def get_runs_for_project(
353
+ project: str,
354
+ ) -> list[str]:
355
+ """
356
+ Get all runs for a given project.
357
+ Returns a list of run names.
358
+ """
359
+ return SQLiteStorage.get_runs(project)
360
+
361
+
362
+ def get_metrics_for_run(
363
+ project: str,
364
+ run: str,
365
+ ) -> list[str]:
366
+ """
367
+ Get all metrics for a given project and run.
368
+ Returns a list of metric names.
369
+ """
370
+ return SQLiteStorage.get_all_metrics_for_run(project, run)
371
+
372
+
373
+ def filter_metrics_by_regex(metrics: list[str], filter_pattern: str) -> list[str]:
374
+ """
375
+ Filter metrics using regex pattern.
376
+
377
+ Args:
378
+ metrics: List of metric names to filter
379
+ filter_pattern: Regex pattern to match against metric names
380
+
381
+ Returns:
382
+ List of metric names that match the pattern
383
+ """
384
+ if not filter_pattern.strip():
385
+ return metrics
386
+
387
+ try:
388
+ pattern = re.compile(filter_pattern, re.IGNORECASE)
389
+ return [metric for metric in metrics if pattern.search(metric)]
390
+ except re.error:
391
+ return [
392
+ metric for metric in metrics if filter_pattern.lower() in metric.lower()
393
+ ]
394
+
395
+
396
+ def get_all_projects() -> list[str]:
397
+ """
398
+ Get all project names.
399
+ Returns a list of project names.
400
+ """
401
+ return SQLiteStorage.get_projects()
402
+
403
+
404
+ def get_project_summary(project: str) -> dict:
405
+ """
406
+ Get a summary of a project including number of runs and recent activity.
407
+
408
+ Args:
409
+ project: Project name
410
+
411
+ Returns:
412
+ Dictionary with project summary information
413
+ """
414
+ runs = SQLiteStorage.get_runs(project)
415
+ if not runs:
416
+ return {"project": project, "num_runs": 0, "runs": [], "last_activity": None}
417
+
418
+ last_steps = SQLiteStorage.get_max_steps_for_runs(project)
419
+
420
+ return {
421
+ "project": project,
422
+ "num_runs": len(runs),
423
+ "runs": runs,
424
+ "last_activity": max(last_steps.values()) if last_steps else None,
425
+ }
426
+
427
+
428
+ def get_run_summary(project: str, run: str) -> dict:
429
+ """
430
+ Get a summary of a specific run including metrics and configuration.
431
+
432
+ Args:
433
+ project: Project name
434
+ run: Run name
435
+
436
+ Returns:
437
+ Dictionary with run summary information
438
+ """
439
+ logs = SQLiteStorage.get_logs(project, run)
440
+ metrics = SQLiteStorage.get_all_metrics_for_run(project, run)
441
+
442
+ if not logs:
443
+ return {
444
+ "project": project,
445
+ "run": run,
446
+ "num_logs": 0,
447
+ "metrics": [],
448
+ "config": None,
449
+ "last_step": None,
450
+ }
451
+
452
+ df = pd.DataFrame(logs)
453
+ config = logs[0].get("config") if logs else None
454
+ last_step = df["step"].max() if "step" in df.columns else len(logs) - 1
455
+
456
+ return {
457
+ "project": project,
458
+ "run": run,
459
+ "num_logs": len(logs),
460
+ "metrics": metrics,
461
+ "config": config,
462
+ "last_step": last_step,
463
+ }
464
+
465
+
466
+ def configure(request: gr.Request):
467
+ sidebar_param = request.query_params.get("sidebar")
468
+ match sidebar_param:
469
+ case "collapsed":
470
+ sidebar = gr.Sidebar(open=False, visible=True)
471
+ case "hidden":
472
+ sidebar = gr.Sidebar(open=False, visible=False)
473
+ case _:
474
+ sidebar = gr.Sidebar(open=True, visible=True)
475
+
476
+ metrics_param = request.query_params.get("metrics", "")
477
+ runs_param = request.query_params.get("runs", "")
478
+ selected_runs = runs_param.split(",") if runs_param else []
479
+ navbar_param = request.query_params.get("navbar")
480
+ match navbar_param:
481
+ case "hidden":
482
+ navbar = gr.Navbar(visible=False)
483
+ case _:
484
+ navbar = gr.Navbar(visible=True)
485
+
486
+ return [], sidebar, metrics_param, selected_runs, navbar
487
+
488
+
489
+ def create_media_section(media_by_run: dict[str, dict[str, list[MediaData]]]):
490
+ with gr.Accordion(label="media"):
491
+ with gr.Group(elem_classes=("media-group")):
492
+ for run, media_by_key in media_by_run.items():
493
+ with gr.Tab(label=run, elem_classes=("media-tab")):
494
+ for key, media_item in media_by_key.items():
495
+ gr.Gallery(
496
+ [(item.file_path, item.caption) for item in media_item],
497
+ label=key,
498
+ columns=6,
499
+ elem_classes=("media-gallery"),
500
+ )
501
+
502
+
503
+ css = """
504
+ #run-cb .wrap { gap: 2px; }
505
+ #run-cb .wrap label {
506
+ line-height: 1;
507
+ padding: 6px;
508
+ }
509
+ .logo-light { display: block; }
510
+ .logo-dark { display: none; }
511
+ .dark .logo-light { display: none; }
512
+ .dark .logo-dark { display: block; }
513
+ .dark .caption-label { color: white; }
514
+
515
+ .info-container {
516
+ position: relative;
517
+ display: inline;
518
+ }
519
+ .info-checkbox {
520
+ position: absolute;
521
+ opacity: 0;
522
+ pointer-events: none;
523
+ }
524
+ .info-icon {
525
+ border-bottom: 1px dotted;
526
+ cursor: pointer;
527
+ user-select: none;
528
+ color: var(--color-accent);
529
+ }
530
+ .info-expandable {
531
+ display: none;
532
+ opacity: 0;
533
+ transition: opacity 0.2s ease-in-out;
534
+ }
535
+ .info-checkbox:checked ~ .info-expandable {
536
+ display: inline;
537
+ opacity: 1;
538
+ }
539
+ .info-icon:hover { opacity: 0.8; }
540
+ .accent-link { font-weight: bold; }
541
+
542
+ .media-gallery .fixed-height { min-height: 275px; }
543
+ .media-group, .media-group > div { background: none; }
544
+ .media-group .tabs { padding: 0.5em; }
545
+ .media-tab { max-height: 500px; overflow-y: scroll; }
546
+ """
547
+
548
+ javascript = """
549
+ <script>
550
+ function setCookie(name, value, days) {
551
+ var expires = "";
552
+ if (days) {
553
+ var date = new Date();
554
+ date.setTime(date.getTime() + (days * 24 * 60 * 60 * 1000));
555
+ expires = "; expires=" + date.toUTCString();
556
+ }
557
+ document.cookie = name + "=" + (value || "") + expires + "; path=/; SameSite=Lax";
558
+ }
559
+
560
+ function getCookie(name) {
561
+ var nameEQ = name + "=";
562
+ var ca = document.cookie.split(';');
563
+ for(var i=0;i < ca.length;i++) {
564
+ var c = ca[i];
565
+ while (c.charAt(0)==' ') c = c.substring(1,c.length);
566
+ if (c.indexOf(nameEQ) == 0) return c.substring(nameEQ.length,c.length);
567
+ }
568
+ return null;
569
+ }
570
+
571
+ (function() {
572
+ const urlParams = new URLSearchParams(window.location.search);
573
+ const writeToken = urlParams.get('write_token');
574
+
575
+ if (writeToken) {
576
+ setCookie('trackio_write_token', writeToken, 7);
577
+
578
+ // Only remove write_token from URL if not in iframe
579
+ // In iframes, keep it in URL as cookies may be blocked
580
+ const inIframe = window.self !== window.top;
581
+ if (!inIframe) {
582
+ urlParams.delete('write_token');
583
+ const newUrl = window.location.pathname +
584
+ (urlParams.toString() ? '?' + urlParams.toString() : '') +
585
+ window.location.hash;
586
+ window.history.replaceState({}, document.title, newUrl);
587
+ }
588
+ }
589
+ })();
590
+ </script>
591
+ """
592
+
593
+
594
+ gr.set_static_paths(paths=[utils.MEDIA_DIR])
595
+
596
+ with gr.Blocks(title="Trackio Dashboard", css=css, head=javascript) as demo:
597
+ with gr.Sidebar(open=False) as sidebar:
598
+ logo_urls = utils.get_logo_urls()
599
+ logo = gr.Markdown(
600
+ f"""
601
+ <img src='{logo_urls["light"]}' width='80%' class='logo-light'>
602
+ <img src='{logo_urls["dark"]}' width='80%' class='logo-dark'>
603
+ """
604
+ )
605
+ project_dd = gr.Dropdown(label="Project", allow_custom_value=True)
606
+
607
+ embed_code = gr.Code(
608
+ label="Embed this view",
609
+ max_lines=2,
610
+ lines=2,
611
+ language="html",
612
+ visible=bool(os.environ.get("SPACE_HOST")),
613
+ )
614
+ with gr.Group():
615
+ run_tb = gr.Textbox(label="Runs", placeholder="Type to filter...")
616
+ run_group_by_dd = gr.Dropdown(label="Group by...", choices=[], value=None)
617
+ grouped_runs_panel = gr.Group(visible=False)
618
+ run_cb = gr.CheckboxGroup(
619
+ label="Runs",
620
+ choices=[],
621
+ interactive=True,
622
+ elem_id="run-cb",
623
+ show_select_all=True,
624
+ )
625
+
626
+ gr.HTML("<hr>")
627
+ realtime_cb = gr.Checkbox(label="Refresh metrics realtime", value=True)
628
+ smoothing_slider = gr.Slider(
629
+ label="Smoothing Factor",
630
+ minimum=0,
631
+ maximum=20,
632
+ value=10,
633
+ step=1,
634
+ info="0 = no smoothing",
635
+ )
636
+ x_axis_dd = gr.Dropdown(
637
+ label="X-axis",
638
+ choices=["step", "time"],
639
+ value="step",
640
+ )
641
+ log_scale_cb = gr.Checkbox(label="Log scale X-axis", value=False)
642
+ metric_filter_tb = gr.Textbox(
643
+ label="Metric Filter (regex)",
644
+ placeholder="e.g., loss|ndcg@10|gpu",
645
+ value="",
646
+ info="Filter metrics using regex patterns. Leave empty to show all metrics.",
647
+ )
648
+
649
+ navbar = gr.Navbar(value=[("Metrics", ""), ("Runs", "/runs")], main_page_name=False)
650
+ timer = gr.Timer(value=1)
651
+ metrics_subset = gr.State([])
652
+ selected_runs_from_url = gr.State([])
653
+ run_selection_state = gr.State(RunSelection())
654
+
655
+ gr.on(
656
+ [demo.load],
657
+ fn=configure,
658
+ outputs=[
659
+ metrics_subset,
660
+ sidebar,
661
+ metric_filter_tb,
662
+ selected_runs_from_url,
663
+ navbar,
664
+ ],
665
+ queue=False,
666
+ api_name=False,
667
+ )
668
+ gr.on(
669
+ [demo.load],
670
+ fn=fns.get_projects,
671
+ outputs=project_dd,
672
+ show_progress="hidden",
673
+ queue=False,
674
+ api_name=False,
675
+ )
676
+ gr.on(
677
+ [timer.tick],
678
+ fn=refresh_runs,
679
+ inputs=[project_dd, run_tb, run_selection_state, selected_runs_from_url],
680
+ outputs=[run_cb, run_tb, run_selection_state],
681
+ show_progress="hidden",
682
+ api_name=False,
683
+ )
684
+ gr.on(
685
+ [timer.tick],
686
+ fn=lambda: gr.Dropdown(info=fns.get_project_info()),
687
+ outputs=[project_dd],
688
+ show_progress="hidden",
689
+ api_name=False,
690
+ )
691
+ gr.on(
692
+ [demo.load, project_dd.change],
693
+ fn=refresh_runs,
694
+ inputs=[project_dd, run_tb, run_selection_state, selected_runs_from_url],
695
+ outputs=[run_cb, run_tb, run_selection_state],
696
+ show_progress="hidden",
697
+ queue=False,
698
+ api_name=False,
699
+ ).then(
700
+ fn=update_x_axis_choices,
701
+ inputs=[project_dd, run_selection_state],
702
+ outputs=x_axis_dd,
703
+ show_progress="hidden",
704
+ queue=False,
705
+ api_name=False,
706
+ ).then(
707
+ fn=generate_embed,
708
+ inputs=[project_dd, metric_filter_tb, run_selection_state],
709
+ outputs=[embed_code],
710
+ show_progress="hidden",
711
+ api_name=False,
712
+ queue=False,
713
+ ).then(
714
+ fns.update_navbar_value,
715
+ inputs=[project_dd],
716
+ outputs=[navbar],
717
+ show_progress="hidden",
718
+ api_name=False,
719
+ queue=False,
720
+ ).then(
721
+ fn=fns.get_group_by_fields,
722
+ inputs=[project_dd],
723
+ outputs=[run_group_by_dd],
724
+ show_progress="hidden",
725
+ api_name=False,
726
+ queue=False,
727
+ )
728
+
729
+ gr.on(
730
+ [run_cb.input],
731
+ fn=update_x_axis_choices,
732
+ inputs=[project_dd, run_selection_state],
733
+ outputs=x_axis_dd,
734
+ show_progress="hidden",
735
+ queue=False,
736
+ api_name=False,
737
+ )
738
+ gr.on(
739
+ [metric_filter_tb.change, run_cb.change],
740
+ fn=generate_embed,
741
+ inputs=[project_dd, metric_filter_tb, run_selection_state],
742
+ outputs=embed_code,
743
+ show_progress="hidden",
744
+ api_name=False,
745
+ queue=False,
746
+ )
747
+
748
+ def toggle_group_view(group_by_dd):
749
+ return (
750
+ gr.CheckboxGroup(visible=not bool(group_by_dd)),
751
+ gr.Group(visible=bool(group_by_dd)),
752
+ )
753
+
754
+ gr.on(
755
+ [run_group_by_dd.change],
756
+ fn=toggle_group_view,
757
+ inputs=[run_group_by_dd],
758
+ outputs=[run_cb, grouped_runs_panel],
759
+ show_progress="hidden",
760
+ api_name=False,
761
+ queue=False,
762
+ )
763
+
764
+ realtime_cb.change(
765
+ fn=toggle_timer,
766
+ inputs=realtime_cb,
767
+ outputs=timer,
768
+ api_name=False,
769
+ queue=False,
770
+ )
771
+ run_cb.input(
772
+ fn=fns.handle_run_checkbox_change,
773
+ inputs=[run_cb, run_selection_state],
774
+ outputs=run_selection_state,
775
+ api_name=False,
776
+ queue=False,
777
+ ).then(
778
+ fn=generate_embed,
779
+ inputs=[project_dd, metric_filter_tb, run_selection_state],
780
+ outputs=embed_code,
781
+ show_progress="hidden",
782
+ api_name=False,
783
+ queue=False,
784
+ )
785
+ run_tb.input(
786
+ fn=refresh_runs,
787
+ inputs=[project_dd, run_tb, run_selection_state],
788
+ outputs=[run_cb, run_tb, run_selection_state],
789
+ api_name=False,
790
+ queue=False,
791
+ show_progress="hidden",
792
+ )
793
+
794
+ gr.api(
795
+ fn=upload_db_to_space,
796
+ api_name="upload_db_to_space",
797
+ )
798
+ gr.api(
799
+ fn=bulk_upload_media,
800
+ api_name="bulk_upload_media",
801
+ )
802
+ gr.api(
803
+ fn=log,
804
+ api_name="log",
805
+ )
806
+ gr.api(
807
+ fn=bulk_log,
808
+ api_name="bulk_log",
809
+ )
810
+ gr.api(
811
+ fn=get_metric_values,
812
+ api_name="get_metric_values",
813
+ )
814
+ gr.api(
815
+ fn=get_runs_for_project,
816
+ api_name="get_runs_for_project",
817
+ )
818
+ gr.api(
819
+ fn=get_metrics_for_run,
820
+ api_name="get_metrics_for_run",
821
+ )
822
+ gr.api(
823
+ fn=get_all_projects,
824
+ api_name="get_all_projects",
825
+ )
826
+ gr.api(
827
+ fn=get_project_summary,
828
+ api_name="get_project_summary",
829
+ )
830
+ gr.api(
831
+ fn=get_run_summary,
832
+ api_name="get_run_summary",
833
+ )
834
+
835
+ x_lim = gr.State(None)
836
+ last_steps = gr.State({})
837
+
838
+ def update_x_lim(select_data: gr.SelectData):
839
+ return select_data.index
840
+
841
+ def update_last_steps(project):
842
+ """Check the last step for each run to detect when new data is available."""
843
+ if not project:
844
+ return {}
845
+ return SQLiteStorage.get_max_steps_for_runs(project)
846
+
847
+ timer.tick(
848
+ fn=update_last_steps,
849
+ inputs=[project_dd],
850
+ outputs=last_steps,
851
+ show_progress="hidden",
852
+ api_name=False,
853
+ )
854
+
855
+ @gr.render(
856
+ triggers=[
857
+ demo.load,
858
+ run_cb.change,
859
+ last_steps.change,
860
+ smoothing_slider.change,
861
+ x_lim.change,
862
+ x_axis_dd.change,
863
+ log_scale_cb.change,
864
+ metric_filter_tb.change,
865
+ ],
866
+ inputs=[
867
+ project_dd,
868
+ run_cb,
869
+ smoothing_slider,
870
+ metrics_subset,
871
+ x_lim,
872
+ x_axis_dd,
873
+ log_scale_cb,
874
+ metric_filter_tb,
875
+ ],
876
+ show_progress="hidden",
877
+ queue=False,
878
+ )
879
+ def update_dashboard(
880
+ project,
881
+ runs,
882
+ smoothing_granularity,
883
+ metrics_subset,
884
+ x_lim_value,
885
+ x_axis,
886
+ log_scale,
887
+ metric_filter,
888
+ ):
889
+ dfs = []
890
+ images_by_run = {}
891
+ original_runs = runs.copy()
892
+
893
+ for run in runs:
894
+ df, images_by_key = load_run_data(
895
+ project, run, smoothing_granularity, x_axis, log_scale
896
+ )
897
+ if df is not None:
898
+ dfs.append(df)
899
+ images_by_run[run] = images_by_key
900
+
901
+ if dfs:
902
+ if smoothing_granularity > 0:
903
+ original_dfs = []
904
+ smoothed_dfs = []
905
+ for df in dfs:
906
+ original_data = df[df["data_type"] == "original"]
907
+ smoothed_data = df[df["data_type"] == "smoothed"]
908
+ if not original_data.empty:
909
+ original_dfs.append(original_data)
910
+ if not smoothed_data.empty:
911
+ smoothed_dfs.append(smoothed_data)
912
+
913
+ all_dfs = original_dfs + smoothed_dfs
914
+ master_df = (
915
+ pd.concat(all_dfs, ignore_index=True) if all_dfs else pd.DataFrame()
916
+ )
917
+
918
+ else:
919
+ master_df = pd.concat(dfs, ignore_index=True)
920
+ else:
921
+ master_df = pd.DataFrame()
922
+
923
+ if master_df.empty:
924
+ if not SQLiteStorage.get_projects():
925
+ if space_id := utils.get_space():
926
+ gr.Markdown(INSTRUCTIONS_SPACES.format(space_id))
927
+ else:
928
+ gr.Markdown(INSTRUCTIONS_LOCAL)
929
+ else:
930
+ gr.Markdown("*Waiting for runs to appear...*")
931
+ return
932
+
933
+ x_column = "step"
934
+ if dfs and not dfs[0].empty and "x_axis" in dfs[0].columns:
935
+ x_column = dfs[0]["x_axis"].iloc[0]
936
+
937
+ numeric_cols = master_df.select_dtypes(include="number").columns
938
+ numeric_cols = [c for c in numeric_cols if c not in utils.RESERVED_KEYS]
939
+ if x_column and x_column in numeric_cols:
940
+ numeric_cols.remove(x_column)
941
+
942
+ if metrics_subset:
943
+ numeric_cols = [c for c in numeric_cols if c in metrics_subset]
944
+
945
+ if metric_filter and metric_filter.strip():
946
+ numeric_cols = filter_metrics_by_regex(list(numeric_cols), metric_filter)
947
+
948
+ ordered_groups, nested_metric_groups = utils.order_metrics_by_plot_preference(
949
+ list(numeric_cols)
950
+ )
951
+ color_map = utils.get_color_mapping(original_runs, smoothing_granularity > 0)
952
+
953
+ metric_idx = 0
954
+ for group_name in ordered_groups:
955
+ group_data = nested_metric_groups[group_name]
956
+
957
+ total_plot_count = sum(
958
+ 1
959
+ for m in group_data["direct_metrics"]
960
+ if not master_df.dropna(subset=[m]).empty
961
+ ) + sum(
962
+ sum(1 for m in metrics if not master_df.dropna(subset=[m]).empty)
963
+ for metrics in group_data["subgroups"].values()
964
+ )
965
+ group_label = (
966
+ f"{group_name} ({total_plot_count})"
967
+ if total_plot_count > 0
968
+ else group_name
969
+ )
970
+
971
+ with gr.Accordion(
972
+ label=group_label,
973
+ open=True,
974
+ key=f"accordion-{group_name}",
975
+ preserved_by_key=["value", "open"],
976
+ ):
977
+ if group_data["direct_metrics"]:
978
+ with gr.Draggable(
979
+ key=f"row-{group_name}-direct", orientation="row"
980
+ ):
981
+ for metric_name in group_data["direct_metrics"]:
982
+ metric_df = master_df.dropna(subset=[metric_name])
983
+ color = "run" if "run" in metric_df.columns else None
984
+ if not metric_df.empty:
985
+ plot = gr.LinePlot(
986
+ utils.downsample(
987
+ metric_df,
988
+ x_column,
989
+ metric_name,
990
+ color,
991
+ x_lim_value,
992
+ ),
993
+ x=x_column,
994
+ y=metric_name,
995
+ y_title=metric_name.split("/")[-1],
996
+ color=color,
997
+ color_map=color_map,
998
+ title=metric_name,
999
+ key=f"plot-{metric_idx}",
1000
+ preserved_by_key=None,
1001
+ x_lim=x_lim_value,
1002
+ show_fullscreen_button=True,
1003
+ min_width=400,
1004
+ show_export_button=True,
1005
+ )
1006
+ plot.select(
1007
+ update_x_lim,
1008
+ outputs=x_lim,
1009
+ key=f"select-{metric_idx}",
1010
+ )
1011
+ plot.double_click(
1012
+ lambda: None,
1013
+ outputs=x_lim,
1014
+ key=f"double-{metric_idx}",
1015
+ )
1016
+ metric_idx += 1
1017
+
1018
+ if group_data["subgroups"]:
1019
+ for subgroup_name in sorted(group_data["subgroups"].keys()):
1020
+ subgroup_metrics = group_data["subgroups"][subgroup_name]
1021
+
1022
+ subgroup_plot_count = sum(
1023
+ 1
1024
+ for m in subgroup_metrics
1025
+ if not master_df.dropna(subset=[m]).empty
1026
+ )
1027
+ subgroup_label = (
1028
+ f"{subgroup_name} ({subgroup_plot_count})"
1029
+ if subgroup_plot_count > 0
1030
+ else subgroup_name
1031
+ )
1032
+
1033
+ with gr.Accordion(
1034
+ label=subgroup_label,
1035
+ open=True,
1036
+ key=f"accordion-{group_name}-{subgroup_name}",
1037
+ preserved_by_key=["value", "open"],
1038
+ ):
1039
+ with gr.Draggable(
1040
+ key=f"row-{group_name}-{subgroup_name}",
1041
+ orientation="row",
1042
+ ):
1043
+ for metric_name in subgroup_metrics:
1044
+ metric_df = master_df.dropna(subset=[metric_name])
1045
+ color = (
1046
+ "run" if "run" in metric_df.columns else None
1047
+ )
1048
+ if not metric_df.empty:
1049
+ plot = gr.LinePlot(
1050
+ utils.downsample(
1051
+ metric_df,
1052
+ x_column,
1053
+ metric_name,
1054
+ color,
1055
+ x_lim_value,
1056
+ ),
1057
+ x=x_column,
1058
+ y=metric_name,
1059
+ y_title=metric_name.split("/")[-1],
1060
+ color=color,
1061
+ color_map=color_map,
1062
+ title=metric_name,
1063
+ key=f"plot-{metric_idx}",
1064
+ preserved_by_key=None,
1065
+ x_lim=x_lim_value,
1066
+ show_fullscreen_button=True,
1067
+ min_width=400,
1068
+ show_export_button=True,
1069
+ )
1070
+ plot.select(
1071
+ update_x_lim,
1072
+ outputs=x_lim,
1073
+ key=f"select-{metric_idx}",
1074
+ )
1075
+ plot.double_click(
1076
+ lambda: None,
1077
+ outputs=x_lim,
1078
+ key=f"double-{metric_idx}",
1079
+ )
1080
+ metric_idx += 1
1081
+ if images_by_run and any(any(images) for images in images_by_run.values()):
1082
+ create_media_section(images_by_run)
1083
+
1084
+ table_cols = master_df.select_dtypes(include="object").columns
1085
+ table_cols = [c for c in table_cols if c not in utils.RESERVED_KEYS]
1086
+ if metrics_subset:
1087
+ table_cols = [c for c in table_cols if c in metrics_subset]
1088
+ if metric_filter and metric_filter.strip():
1089
+ table_cols = filter_metrics_by_regex(list(table_cols), metric_filter)
1090
+
1091
+ actual_table_count = sum(
1092
+ 1
1093
+ for metric_name in table_cols
1094
+ if not (metric_df := master_df.dropna(subset=[metric_name])).empty
1095
+ and isinstance(value := metric_df[metric_name].iloc[-1], dict)
1096
+ and value.get("_type") == Table.TYPE
1097
+ )
1098
+
1099
+ if actual_table_count > 0:
1100
+ with gr.Accordion(f"tables ({actual_table_count})", open=True):
1101
+ with gr.Row(key="row"):
1102
+ for metric_idx, metric_name in enumerate(table_cols):
1103
+ metric_df = master_df.dropna(subset=[metric_name])
1104
+ if not metric_df.empty:
1105
+ value = metric_df[metric_name].iloc[-1]
1106
+ if (
1107
+ isinstance(value, dict)
1108
+ and "_type" in value
1109
+ and value["_type"] == Table.TYPE
1110
+ ):
1111
+ try:
1112
+ df = pd.DataFrame(value["_value"])
1113
+ gr.DataFrame(
1114
+ df,
1115
+ label=f"{metric_name} (latest)",
1116
+ key=f"table-{metric_idx}",
1117
+ wrap=True,
1118
+ )
1119
+ except Exception as e:
1120
+ gr.Warning(
1121
+ f"Column {metric_name} failed to render as a table: {e}"
1122
+ )
1123
+
1124
+ with grouped_runs_panel:
1125
+
1126
+ @gr.render(
1127
+ triggers=[
1128
+ demo.load,
1129
+ project_dd.change,
1130
+ run_group_by_dd.change,
1131
+ run_tb.input,
1132
+ run_selection_state.change,
1133
+ ],
1134
+ inputs=[project_dd, run_group_by_dd, run_tb, run_selection_state],
1135
+ show_progress="hidden",
1136
+ queue=False,
1137
+ )
1138
+ def render_grouped_runs(project, group_key, filter_text, selection):
1139
+ if not group_key:
1140
+ return
1141
+ selection = selection or RunSelection()
1142
+ groups = fns.group_runs_by_config(project, group_key, filter_text)
1143
+
1144
+ for label, runs in groups.items():
1145
+ ordered_current = utils.ordered_subset(runs, selection.selected)
1146
+
1147
+ with gr.Group():
1148
+ show_group_cb = gr.Checkbox(
1149
+ label="Show/Hide",
1150
+ value=bool(ordered_current),
1151
+ key=f"show-cb-{group_key}-{label}",
1152
+ preserved_by_key=["value"],
1153
+ )
1154
+
1155
+ with gr.Accordion(
1156
+ f"{label} ({len(runs)})",
1157
+ open=False,
1158
+ key=f"accordion-{group_key}-{label}",
1159
+ preserved_by_key=["open"],
1160
+ ):
1161
+ group_cb = gr.CheckboxGroup(
1162
+ choices=runs,
1163
+ value=ordered_current,
1164
+ show_label=False,
1165
+ key=f"group-cb-{group_key}-{label}",
1166
+ )
1167
+
1168
+ gr.on(
1169
+ [group_cb.change],
1170
+ fn=fns.handle_group_checkbox_change,
1171
+ inputs=[
1172
+ group_cb,
1173
+ run_selection_state,
1174
+ gr.State(runs),
1175
+ ],
1176
+ outputs=[
1177
+ run_selection_state,
1178
+ group_cb,
1179
+ run_cb,
1180
+ ],
1181
+ show_progress="hidden",
1182
+ api_name=False,
1183
+ queue=False,
1184
+ )
1185
+
1186
+ gr.on(
1187
+ [show_group_cb.change],
1188
+ fn=fns.handle_group_toggle,
1189
+ inputs=[
1190
+ show_group_cb,
1191
+ run_selection_state,
1192
+ gr.State(runs),
1193
+ ],
1194
+ outputs=[run_selection_state, group_cb, run_cb],
1195
+ show_progress="hidden",
1196
+ api_name=False,
1197
+ queue=False,
1198
+ )
1199
+
1200
+
1201
+ with demo.route("Runs", show_in_navbar=False):
1202
+ run_page.render()
1203
+ with demo.route("Run", show_in_navbar=False):
1204
+ run_detail_page.render()
1205
+
1206
+ write_token = secrets.token_urlsafe(32)
1207
+ demo.write_token = write_token
1208
+ run_page.write_token = write_token
1209
+ run_detail_page.write_token = write_token
1210
+
1211
+ if __name__ == "__main__":
1212
+ demo.launch(allowed_paths=[utils.TRACKIO_LOGO_DIR], show_api=False, show_error=True)
ui/run_detail.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """The Runs page for the Trackio UI."""
2
+
3
+ import gradio as gr
4
+
5
+ try:
6
+ import trackio.utils as utils
7
+ from trackio.sqlite_storage import SQLiteStorage
8
+ from trackio.ui import fns
9
+ except ImportError:
10
+ import utils
11
+ from sqlite_storage import SQLiteStorage
12
+ from ui import fns
13
+
14
+ RUN_DETAILS_TEMPLATE = """
15
+ ## Run Details
16
+ * **Run Name:** `{run_name}`
17
+ * **Group:** `{group}`
18
+ * **Created:** {created} by {username}
19
+ """
20
+
21
+ with gr.Blocks() as run_detail_page:
22
+ with gr.Sidebar() as sidebar:
23
+ logo_urls = utils.get_logo_urls()
24
+ logo = gr.Markdown(
25
+ f"""
26
+ <img src='{logo_urls["light"]}' width='80%' class='logo-light'>
27
+ <img src='{logo_urls["dark"]}' width='80%' class='logo-dark'>
28
+ """
29
+ )
30
+ project_dd = gr.Dropdown(
31
+ label="Project", allow_custom_value=True, interactive=False
32
+ )
33
+ run_dd = gr.Dropdown(label="Run")
34
+
35
+ navbar = gr.Navbar(value=[("Metrics", ""), ("Runs", "/runs")], main_page_name=False)
36
+
37
+ run_details = gr.Markdown(RUN_DETAILS_TEMPLATE)
38
+
39
+ run_config = gr.JSON(label="Run Config")
40
+
41
+ def configure(request: gr.Request):
42
+ project = request.query_params.get("selected_project")
43
+ run = request.query_params.get("selected_run")
44
+ runs = SQLiteStorage.get_runs(project)
45
+ return project, gr.Dropdown(choices=runs, value=run)
46
+
47
+ def update_run_details(project, run):
48
+ config = SQLiteStorage.get_run_config(project, run)
49
+ if not config:
50
+ return gr.Markdown("No run details available"), {}
51
+
52
+ group = config.get("_Group", "None")
53
+
54
+ created = config.get("_Created", "Unknown")
55
+ if created != "Unknown":
56
+ created = utils.format_timestamp(created)
57
+
58
+ username = config.get("_Username", "Unknown")
59
+ if username and username != "None" and username != "Unknown":
60
+ username = f"[{username}](https://huggingface.co/{username})"
61
+
62
+ details_md = RUN_DETAILS_TEMPLATE.format(
63
+ run_name=run, group=group, created=created, username=username
64
+ )
65
+
66
+ config_display = {k: v for k, v in config.items() if not k.startswith("_")}
67
+
68
+ return gr.Markdown(details_md), config_display
69
+
70
+ gr.on(
71
+ [run_detail_page.load],
72
+ fn=configure,
73
+ outputs=[project_dd, run_dd],
74
+ show_progress="hidden",
75
+ queue=False,
76
+ api_name=False,
77
+ ).then(
78
+ fns.update_navbar_value,
79
+ inputs=[project_dd],
80
+ outputs=[navbar],
81
+ show_progress="hidden",
82
+ api_name=False,
83
+ queue=False,
84
+ )
85
+
86
+ gr.on(
87
+ [run_dd.change],
88
+ update_run_details,
89
+ inputs=[project_dd, run_dd],
90
+ outputs=[run_details, run_config],
91
+ show_progress="hidden",
92
+ api_name=False,
93
+ queue=False,
94
+ )