Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- __init__.py +168 -0
- __pycache__/__init__.cpython-311.pyc +0 -0
- __pycache__/cli.cpython-311.pyc +0 -0
- __pycache__/commit_scheduler.cpython-311.pyc +0 -0
- __pycache__/context_vars.cpython-311.pyc +0 -0
- __pycache__/deploy.cpython-311.pyc +0 -0
- __pycache__/dummy_commit_scheduler.cpython-311.pyc +0 -0
- __pycache__/imports.cpython-311.pyc +0 -0
- __pycache__/run.cpython-311.pyc +0 -0
- __pycache__/sqlite_storage.cpython-311.pyc +0 -0
- __pycache__/typehints.cpython-311.pyc +0 -0
- __pycache__/ui.cpython-311.pyc +0 -0
- __pycache__/utils.cpython-311.pyc +0 -0
- assets/trackio_logo_dark.png +0 -0
- assets/trackio_logo_light.png +0 -0
- assets/trackio_logo_old.png +3 -0
- assets/trackio_logo_type_dark.png +0 -0
- assets/trackio_logo_type_dark_transparent.png +0 -0
- assets/trackio_logo_type_light.png +0 -0
- assets/trackio_logo_type_light_transparent.png +0 -0
- cli.py +26 -0
- commit_scheduler.py +392 -0
- context_vars.py +15 -0
- deploy.py +170 -0
- dummy_commit_scheduler.py +12 -0
- imports.py +245 -0
- py.typed +0 -0
- run.py +100 -0
- sqlite_storage.py +384 -0
- typehints.py +8 -0
- ui.py +570 -0
- utils.py +410 -0
- version.txt +1 -0
    	
        .gitattributes
    CHANGED
    
    | @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
|  | 
|  | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
| 36 | 
            +
            assets/trackio_logo_old.png filter=lfs diff=lfs merge=lfs -text
         | 
    	
        __init__.py
    ADDED
    
    | @@ -0,0 +1,168 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import warnings
         | 
| 3 | 
            +
            import webbrowser
         | 
| 4 | 
            +
            from pathlib import Path
         | 
| 5 | 
            +
            from typing import Any
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from gradio_client import Client
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from trackio import context_vars, deploy, utils
         | 
| 10 | 
            +
            from trackio.imports import import_csv, import_tf_events
         | 
| 11 | 
            +
            from trackio.run import Run
         | 
| 12 | 
            +
            from trackio.sqlite_storage import SQLiteStorage
         | 
| 13 | 
            +
            from trackio.ui import demo
         | 
| 14 | 
            +
            from trackio.utils import TRACKIO_DIR, TRACKIO_LOGO_DIR
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            __version__ = Path(__file__).parent.joinpath("version.txt").read_text().strip()
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            __all__ = ["init", "log", "finish", "show", "import_csv", "import_tf_events"]
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            config = {}
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            def init(
         | 
| 25 | 
            +
                project: str,
         | 
| 26 | 
            +
                name: str | None = None,
         | 
| 27 | 
            +
                space_id: str | None = None,
         | 
| 28 | 
            +
                dataset_id: str | None = None,
         | 
| 29 | 
            +
                config: dict | None = None,
         | 
| 30 | 
            +
                resume: str = "never",
         | 
| 31 | 
            +
                settings: Any = None,
         | 
| 32 | 
            +
            ) -> Run:
         | 
| 33 | 
            +
                """
         | 
| 34 | 
            +
                Creates a new Trackio project and returns a Run object.
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                Args:
         | 
| 37 | 
            +
                    project: The name of the project (can be an existing project to continue tracking or a new project to start tracking from scratch).
         | 
| 38 | 
            +
                    name: The name of the run (if not provided, a default name will be generated).
         | 
| 39 | 
            +
                    space_id: If provided, the project will be logged to a Hugging Face Space instead of a local directory. Should be a complete Space name like "username/reponame" or "orgname/reponame", or just "reponame" in which case the Space will be created in the currently-logged-in Hugging Face user's namespace. If the Space does not exist, it will be created. If the Space already exists, the project will be logged to it.
         | 
| 40 | 
            +
                    dataset_id: If a space_id is provided, a persistent Hugging Face Dataset will be created and the metrics will be synced to it every 5 minutes. Specify a Dataset with name like "username/datasetname" or "orgname/datasetname", or "datasetname" (uses currently-logged-in Hugging Face user's namespace), or None (uses the same name as the Space but with the "_dataset" suffix). If the Dataset does not exist, it will be created. If the Dataset already exists, the project will be appended to it.
         | 
| 41 | 
            +
                    config: A dictionary of configuration options. Provided for compatibility with wandb.init()
         | 
| 42 | 
            +
                    resume: Controls how to handle resuming a run. Can be one of:
         | 
| 43 | 
            +
                        - "must": Must resume the run with the given name, raises error if run doesn't exist
         | 
| 44 | 
            +
                        - "allow": Resume the run if it exists, otherwise create a new run
         | 
| 45 | 
            +
                        - "never": Never resume a run, always create a new one
         | 
| 46 | 
            +
                    settings: Not used. Provided for compatibility with wandb.init()
         | 
| 47 | 
            +
                """
         | 
| 48 | 
            +
                if settings is not None:
         | 
| 49 | 
            +
                    warnings.warn(
         | 
| 50 | 
            +
                        "* Warning: settings is not used. Provided for compatibility with wandb.init(). Please create an issue at: https://github.com/gradio-app/trackio/issues if you need a specific feature implemented."
         | 
| 51 | 
            +
                    )
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                if space_id is None and dataset_id is not None:
         | 
| 54 | 
            +
                    raise ValueError("Must provide a `space_id` when `dataset_id` is provided.")
         | 
| 55 | 
            +
                space_id, dataset_id = utils.preprocess_space_and_dataset_ids(space_id, dataset_id)
         | 
| 56 | 
            +
                url = context_vars.current_server.get()
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                if url is None:
         | 
| 59 | 
            +
                    if space_id is None:
         | 
| 60 | 
            +
                        _, url, _ = demo.launch(
         | 
| 61 | 
            +
                            show_api=False,
         | 
| 62 | 
            +
                            inline=False,
         | 
| 63 | 
            +
                            quiet=True,
         | 
| 64 | 
            +
                            prevent_thread_lock=True,
         | 
| 65 | 
            +
                            show_error=True,
         | 
| 66 | 
            +
                        )
         | 
| 67 | 
            +
                    else:
         | 
| 68 | 
            +
                        url = space_id
         | 
| 69 | 
            +
                    context_vars.current_server.set(url)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                if (
         | 
| 72 | 
            +
                    context_vars.current_project.get() is None
         | 
| 73 | 
            +
                    or context_vars.current_project.get() != project
         | 
| 74 | 
            +
                ):
         | 
| 75 | 
            +
                    print(f"* Trackio project initialized: {project}")
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    if dataset_id is not None:
         | 
| 78 | 
            +
                        os.environ["TRACKIO_DATASET_ID"] = dataset_id
         | 
| 79 | 
            +
                        print(
         | 
| 80 | 
            +
                            f"* Trackio metrics will be synced to Hugging Face Dataset: {dataset_id}"
         | 
| 81 | 
            +
                        )
         | 
| 82 | 
            +
                    if space_id is None:
         | 
| 83 | 
            +
                        print(f"* Trackio metrics logged to: {TRACKIO_DIR}")
         | 
| 84 | 
            +
                        utils.print_dashboard_instructions(project)
         | 
| 85 | 
            +
                    else:
         | 
| 86 | 
            +
                        deploy.create_space_if_not_exists(space_id, dataset_id)
         | 
| 87 | 
            +
                        print(
         | 
| 88 | 
            +
                            f"* View dashboard by going to: {deploy.SPACE_URL.format(space_id=space_id)}"
         | 
| 89 | 
            +
                        )
         | 
| 90 | 
            +
                context_vars.current_project.set(project)
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                client = None
         | 
| 93 | 
            +
                if not space_id:
         | 
| 94 | 
            +
                    client = Client(url, verbose=False)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                if resume == "must":
         | 
| 97 | 
            +
                    if name is None:
         | 
| 98 | 
            +
                        raise ValueError("Must provide a run name when resume='must'")
         | 
| 99 | 
            +
                    if name not in SQLiteStorage.get_runs(project):
         | 
| 100 | 
            +
                        raise ValueError(f"Run '{name}' does not exist in project '{project}'")
         | 
| 101 | 
            +
                elif resume == "allow":
         | 
| 102 | 
            +
                    if name is not None and name in SQLiteStorage.get_runs(project):
         | 
| 103 | 
            +
                        print(f"* Resuming existing run: {name}")
         | 
| 104 | 
            +
                elif resume == "never":
         | 
| 105 | 
            +
                    if name is not None and name in SQLiteStorage.get_runs(project):
         | 
| 106 | 
            +
                        name = None
         | 
| 107 | 
            +
                else:
         | 
| 108 | 
            +
                    raise ValueError("resume must be one of: 'must', 'allow', or 'never'")
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                run = Run(
         | 
| 111 | 
            +
                    url=url,
         | 
| 112 | 
            +
                    project=project,
         | 
| 113 | 
            +
                    client=client,
         | 
| 114 | 
            +
                    name=name,
         | 
| 115 | 
            +
                    config=config,
         | 
| 116 | 
            +
                )
         | 
| 117 | 
            +
                context_vars.current_run.set(run)
         | 
| 118 | 
            +
                globals()["config"] = run.config
         | 
| 119 | 
            +
                return run
         | 
| 120 | 
            +
             | 
| 121 | 
            +
             | 
| 122 | 
            +
            def log(metrics: dict, step: int | None = None) -> None:
         | 
| 123 | 
            +
                """
         | 
| 124 | 
            +
                Logs metrics to the current run.
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                Args:
         | 
| 127 | 
            +
                    metrics: A dictionary of metrics to log.
         | 
| 128 | 
            +
                    step: The step number. If not provided, the step will be incremented automatically.
         | 
| 129 | 
            +
                """
         | 
| 130 | 
            +
                run = context_vars.current_run.get()
         | 
| 131 | 
            +
                if run is None:
         | 
| 132 | 
            +
                    raise RuntimeError("Call trackio.init() before trackio.log().")
         | 
| 133 | 
            +
                run.log(
         | 
| 134 | 
            +
                    metrics=metrics,
         | 
| 135 | 
            +
                    step=step,
         | 
| 136 | 
            +
                )
         | 
| 137 | 
            +
             | 
| 138 | 
            +
             | 
| 139 | 
            +
            def finish():
         | 
| 140 | 
            +
                """
         | 
| 141 | 
            +
                Finishes the current run.
         | 
| 142 | 
            +
                """
         | 
| 143 | 
            +
                run = context_vars.current_run.get()
         | 
| 144 | 
            +
                if run is None:
         | 
| 145 | 
            +
                    raise RuntimeError("Call trackio.init() before trackio.finish().")
         | 
| 146 | 
            +
                run.finish()
         | 
| 147 | 
            +
             | 
| 148 | 
            +
             | 
| 149 | 
            +
            def show(project: str | None = None):
         | 
| 150 | 
            +
                """
         | 
| 151 | 
            +
                Launches the Trackio dashboard.
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                Args:
         | 
| 154 | 
            +
                    project: The name of the project whose runs to show. If not provided, all projects will be shown and the user can select one.
         | 
| 155 | 
            +
                """
         | 
| 156 | 
            +
                _, url, share_url = demo.launch(
         | 
| 157 | 
            +
                    show_api=False,
         | 
| 158 | 
            +
                    quiet=True,
         | 
| 159 | 
            +
                    inline=False,
         | 
| 160 | 
            +
                    prevent_thread_lock=True,
         | 
| 161 | 
            +
                    favicon_path=TRACKIO_LOGO_DIR / "trackio_logo_light.png",
         | 
| 162 | 
            +
                    allowed_paths=[TRACKIO_LOGO_DIR],
         | 
| 163 | 
            +
                )
         | 
| 164 | 
            +
                base_url = share_url + "/" if share_url else url
         | 
| 165 | 
            +
                dashboard_url = base_url + f"?project={project}" if project else base_url
         | 
| 166 | 
            +
                print(f"* Trackio UI launched at: {dashboard_url}")
         | 
| 167 | 
            +
                webbrowser.open(dashboard_url)
         | 
| 168 | 
            +
                utils.block_except_in_notebook()
         | 
    	
        __pycache__/__init__.cpython-311.pyc
    ADDED
    
    | Binary file (8.72 kB). View file | 
|  | 
    	
        __pycache__/cli.cpython-311.pyc
    ADDED
    
    | Binary file (1.29 kB). View file | 
|  | 
    	
        __pycache__/commit_scheduler.cpython-311.pyc
    ADDED
    
    | Binary file (20.3 kB). View file | 
|  | 
    	
        __pycache__/context_vars.cpython-311.pyc
    ADDED
    
    | Binary file (884 Bytes). View file | 
|  | 
    	
        __pycache__/deploy.cpython-311.pyc
    ADDED
    
    | Binary file (7.48 kB). View file | 
|  | 
    	
        __pycache__/dummy_commit_scheduler.cpython-311.pyc
    ADDED
    
    | Binary file (1.24 kB). View file | 
|  | 
    	
        __pycache__/imports.cpython-311.pyc
    ADDED
    
    | Binary file (12.9 kB). View file | 
|  | 
    	
        __pycache__/run.cpython-311.pyc
    ADDED
    
    | Binary file (5.85 kB). View file | 
|  | 
    	
        __pycache__/sqlite_storage.cpython-311.pyc
    ADDED
    
    | Binary file (22.4 kB). View file | 
|  | 
    	
        __pycache__/typehints.cpython-311.pyc
    ADDED
    
    | Binary file (692 Bytes). View file | 
|  | 
    	
        __pycache__/ui.cpython-311.pyc
    ADDED
    
    | Binary file (24.4 kB). View file | 
|  | 
    	
        __pycache__/utils.cpython-311.pyc
    ADDED
    
    | Binary file (10.9 kB). View file | 
|  | 
    	
        assets/trackio_logo_dark.png
    ADDED
    
    |   | 
    	
        assets/trackio_logo_light.png
    ADDED
    
    |   | 
    	
        assets/trackio_logo_old.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        assets/trackio_logo_type_dark.png
    ADDED
    
    |   | 
    	
        assets/trackio_logo_type_dark_transparent.png
    ADDED
    
    |   | 
    	
        assets/trackio_logo_type_light.png
    ADDED
    
    |   | 
    	
        assets/trackio_logo_type_light_transparent.png
    ADDED
    
    |   | 
    	
        cli.py
    ADDED
    
    | @@ -0,0 +1,26 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import argparse
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from trackio import show
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            def main():
         | 
| 7 | 
            +
                parser = argparse.ArgumentParser(description="Trackio CLI")
         | 
| 8 | 
            +
                subparsers = parser.add_subparsers(dest="command")
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                ui_parser = subparsers.add_parser(
         | 
| 11 | 
            +
                    "show", help="Show the Trackio dashboard UI for a project"
         | 
| 12 | 
            +
                )
         | 
| 13 | 
            +
                ui_parser.add_argument(
         | 
| 14 | 
            +
                    "--project", required=False, help="Project name to show in the dashboard"
         | 
| 15 | 
            +
                )
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                args = parser.parse_args()
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                if args.command == "show":
         | 
| 20 | 
            +
                    show(args.project)
         | 
| 21 | 
            +
                else:
         | 
| 22 | 
            +
                    parser.print_help()
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            if __name__ == "__main__":
         | 
| 26 | 
            +
                main()
         | 
    	
        commit_scheduler.py
    ADDED
    
    | @@ -0,0 +1,392 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Originally copied from https://github.com/huggingface/huggingface_hub/blob/d0a948fc2a32ed6e557042a95ef3e4af97ec4a7c/src/huggingface_hub/_commit_scheduler.py
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import atexit
         | 
| 4 | 
            +
            import logging
         | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
            import time
         | 
| 7 | 
            +
            from concurrent.futures import Future
         | 
| 8 | 
            +
            from dataclasses import dataclass
         | 
| 9 | 
            +
            from io import SEEK_END, SEEK_SET, BytesIO
         | 
| 10 | 
            +
            from pathlib import Path
         | 
| 11 | 
            +
            from threading import Lock, Thread
         | 
| 12 | 
            +
            from typing import Callable, Dict, List, Optional, Union
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from huggingface_hub.hf_api import (
         | 
| 15 | 
            +
                DEFAULT_IGNORE_PATTERNS,
         | 
| 16 | 
            +
                CommitInfo,
         | 
| 17 | 
            +
                CommitOperationAdd,
         | 
| 18 | 
            +
                HfApi,
         | 
| 19 | 
            +
            )
         | 
| 20 | 
            +
            from huggingface_hub.utils import filter_repo_objects
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            @dataclass(frozen=True)
         | 
| 26 | 
            +
            class _FileToUpload:
         | 
| 27 | 
            +
                """Temporary dataclass to store info about files to upload. Not meant to be used directly."""
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                local_path: Path
         | 
| 30 | 
            +
                path_in_repo: str
         | 
| 31 | 
            +
                size_limit: int
         | 
| 32 | 
            +
                last_modified: float
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            class CommitScheduler:
         | 
| 36 | 
            +
                """
         | 
| 37 | 
            +
                Scheduler to upload a local folder to the Hub at regular intervals (e.g. push to hub every 5 minutes).
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                The recommended way to use the scheduler is to use it as a context manager. This ensures that the scheduler is
         | 
| 40 | 
            +
                properly stopped and the last commit is triggered when the script ends. The scheduler can also be stopped manually
         | 
| 41 | 
            +
                with the `stop` method. Checkout the [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#scheduled-uploads)
         | 
| 42 | 
            +
                to learn more about how to use it.
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                Args:
         | 
| 45 | 
            +
                    repo_id (`str`):
         | 
| 46 | 
            +
                        The id of the repo to commit to.
         | 
| 47 | 
            +
                    folder_path (`str` or `Path`):
         | 
| 48 | 
            +
                        Path to the local folder to upload regularly.
         | 
| 49 | 
            +
                    every (`int` or `float`, *optional*):
         | 
| 50 | 
            +
                        The number of minutes between each commit. Defaults to 5 minutes.
         | 
| 51 | 
            +
                    path_in_repo (`str`, *optional*):
         | 
| 52 | 
            +
                        Relative path of the directory in the repo, for example: `"checkpoints/"`. Defaults to the root folder
         | 
| 53 | 
            +
                        of the repository.
         | 
| 54 | 
            +
                    repo_type (`str`, *optional*):
         | 
| 55 | 
            +
                        The type of the repo to commit to. Defaults to `model`.
         | 
| 56 | 
            +
                    revision (`str`, *optional*):
         | 
| 57 | 
            +
                        The revision of the repo to commit to. Defaults to `main`.
         | 
| 58 | 
            +
                    private (`bool`, *optional*):
         | 
| 59 | 
            +
                        Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
         | 
| 60 | 
            +
                    token (`str`, *optional*):
         | 
| 61 | 
            +
                        The token to use to commit to the repo. Defaults to the token saved on the machine.
         | 
| 62 | 
            +
                    allow_patterns (`List[str]` or `str`, *optional*):
         | 
| 63 | 
            +
                        If provided, only files matching at least one pattern are uploaded.
         | 
| 64 | 
            +
                    ignore_patterns (`List[str]` or `str`, *optional*):
         | 
| 65 | 
            +
                        If provided, files matching any of the patterns are not uploaded.
         | 
| 66 | 
            +
                    squash_history (`bool`, *optional*):
         | 
| 67 | 
            +
                        Whether to squash the history of the repo after each commit. Defaults to `False`. Squashing commits is
         | 
| 68 | 
            +
                        useful to avoid degraded performances on the repo when it grows too large.
         | 
| 69 | 
            +
                    hf_api (`HfApi`, *optional*):
         | 
| 70 | 
            +
                        The [`HfApi`] client to use to commit to the Hub. Can be set with custom settings (user agent, token,...).
         | 
| 71 | 
            +
                    on_before_commit (`Callable[[], None]`, *optional*):
         | 
| 72 | 
            +
                        If specified, a function that will be called before the CommitScheduler lists files to create a commit.
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                Example:
         | 
| 75 | 
            +
                ```py
         | 
| 76 | 
            +
                >>> from pathlib import Path
         | 
| 77 | 
            +
                >>> from huggingface_hub import CommitScheduler
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                # Scheduler uploads every 10 minutes
         | 
| 80 | 
            +
                >>> csv_path = Path("watched_folder/data.csv")
         | 
| 81 | 
            +
                >>> CommitScheduler(repo_id="test_scheduler", repo_type="dataset", folder_path=csv_path.parent, every=10)
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                >>> with csv_path.open("a") as f:
         | 
| 84 | 
            +
                ...     f.write("first line")
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                # Some time later (...)
         | 
| 87 | 
            +
                >>> with csv_path.open("a") as f:
         | 
| 88 | 
            +
                ...     f.write("second line")
         | 
| 89 | 
            +
                ```
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                Example using a context manager:
         | 
| 92 | 
            +
                ```py
         | 
| 93 | 
            +
                >>> from pathlib import Path
         | 
| 94 | 
            +
                >>> from huggingface_hub import CommitScheduler
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                >>> with CommitScheduler(repo_id="test_scheduler", repo_type="dataset", folder_path="watched_folder", every=10) as scheduler:
         | 
| 97 | 
            +
                ...     csv_path = Path("watched_folder/data.csv")
         | 
| 98 | 
            +
                ...     with csv_path.open("a") as f:
         | 
| 99 | 
            +
                ...         f.write("first line")
         | 
| 100 | 
            +
                ...     (...)
         | 
| 101 | 
            +
                ...     with csv_path.open("a") as f:
         | 
| 102 | 
            +
                ...         f.write("second line")
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                # Scheduler is now stopped and last commit have been triggered
         | 
| 105 | 
            +
                ```
         | 
| 106 | 
            +
                """
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                def __init__(
         | 
| 109 | 
            +
                    self,
         | 
| 110 | 
            +
                    *,
         | 
| 111 | 
            +
                    repo_id: str,
         | 
| 112 | 
            +
                    folder_path: Union[str, Path],
         | 
| 113 | 
            +
                    every: Union[int, float] = 5,
         | 
| 114 | 
            +
                    path_in_repo: Optional[str] = None,
         | 
| 115 | 
            +
                    repo_type: Optional[str] = None,
         | 
| 116 | 
            +
                    revision: Optional[str] = None,
         | 
| 117 | 
            +
                    private: Optional[bool] = None,
         | 
| 118 | 
            +
                    token: Optional[str] = None,
         | 
| 119 | 
            +
                    allow_patterns: Optional[Union[List[str], str]] = None,
         | 
| 120 | 
            +
                    ignore_patterns: Optional[Union[List[str], str]] = None,
         | 
| 121 | 
            +
                    squash_history: bool = False,
         | 
| 122 | 
            +
                    hf_api: Optional["HfApi"] = None,
         | 
| 123 | 
            +
                    on_before_commit: Optional[Callable[[], None]] = None,
         | 
| 124 | 
            +
                ) -> None:
         | 
| 125 | 
            +
                    self.api = hf_api or HfApi(token=token)
         | 
| 126 | 
            +
                    self.on_before_commit = on_before_commit
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    # Folder
         | 
| 129 | 
            +
                    self.folder_path = Path(folder_path).expanduser().resolve()
         | 
| 130 | 
            +
                    self.path_in_repo = path_in_repo or ""
         | 
| 131 | 
            +
                    self.allow_patterns = allow_patterns
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    if ignore_patterns is None:
         | 
| 134 | 
            +
                        ignore_patterns = []
         | 
| 135 | 
            +
                    elif isinstance(ignore_patterns, str):
         | 
| 136 | 
            +
                        ignore_patterns = [ignore_patterns]
         | 
| 137 | 
            +
                    self.ignore_patterns = ignore_patterns + DEFAULT_IGNORE_PATTERNS
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    if self.folder_path.is_file():
         | 
| 140 | 
            +
                        raise ValueError(
         | 
| 141 | 
            +
                            f"'folder_path' must be a directory, not a file: '{self.folder_path}'."
         | 
| 142 | 
            +
                        )
         | 
| 143 | 
            +
                    self.folder_path.mkdir(parents=True, exist_ok=True)
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    # Repository
         | 
| 146 | 
            +
                    repo_url = self.api.create_repo(
         | 
| 147 | 
            +
                        repo_id=repo_id, private=private, repo_type=repo_type, exist_ok=True
         | 
| 148 | 
            +
                    )
         | 
| 149 | 
            +
                    self.repo_id = repo_url.repo_id
         | 
| 150 | 
            +
                    self.repo_type = repo_type
         | 
| 151 | 
            +
                    self.revision = revision
         | 
| 152 | 
            +
                    self.token = token
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    # Keep track of already uploaded files
         | 
| 155 | 
            +
                    self.last_uploaded: Dict[
         | 
| 156 | 
            +
                        Path, float
         | 
| 157 | 
            +
                    ] = {}  # key is local path, value is timestamp
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    # Scheduler
         | 
| 160 | 
            +
                    if not every > 0:
         | 
| 161 | 
            +
                        raise ValueError(f"'every' must be a positive integer, not '{every}'.")
         | 
| 162 | 
            +
                    self.lock = Lock()
         | 
| 163 | 
            +
                    self.every = every
         | 
| 164 | 
            +
                    self.squash_history = squash_history
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                    logger.info(
         | 
| 167 | 
            +
                        f"Scheduled job to push '{self.folder_path}' to '{self.repo_id}' every {self.every} minutes."
         | 
| 168 | 
            +
                    )
         | 
| 169 | 
            +
                    self._scheduler_thread = Thread(target=self._run_scheduler, daemon=True)
         | 
| 170 | 
            +
                    self._scheduler_thread.start()
         | 
| 171 | 
            +
                    atexit.register(self._push_to_hub)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                    self.__stopped = False
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                def stop(self) -> None:
         | 
| 176 | 
            +
                    """Stop the scheduler.
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                    A stopped scheduler cannot be restarted. Mostly for tests purposes.
         | 
| 179 | 
            +
                    """
         | 
| 180 | 
            +
                    self.__stopped = True
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                def __enter__(self) -> "CommitScheduler":
         | 
| 183 | 
            +
                    return self
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                def __exit__(self, exc_type, exc_value, traceback) -> None:
         | 
| 186 | 
            +
                    # Upload last changes before exiting
         | 
| 187 | 
            +
                    self.trigger().result()
         | 
| 188 | 
            +
                    self.stop()
         | 
| 189 | 
            +
                    return
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                def _run_scheduler(self) -> None:
         | 
| 192 | 
            +
                    """Dumb thread waiting between each scheduled push to Hub."""
         | 
| 193 | 
            +
                    while True:
         | 
| 194 | 
            +
                        self.last_future = self.trigger()
         | 
| 195 | 
            +
                        time.sleep(self.every * 60)
         | 
| 196 | 
            +
                        if self.__stopped:
         | 
| 197 | 
            +
                            break
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                def trigger(self) -> Future:
         | 
| 200 | 
            +
                    """Trigger a `push_to_hub` and return a future.
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    This method is automatically called every `every` minutes. You can also call it manually to trigger a commit
         | 
| 203 | 
            +
                    immediately, without waiting for the next scheduled commit.
         | 
| 204 | 
            +
                    """
         | 
| 205 | 
            +
                    return self.api.run_as_future(self._push_to_hub)
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                def _push_to_hub(self) -> Optional[CommitInfo]:
         | 
| 208 | 
            +
                    if self.__stopped:  # If stopped, already scheduled commits are ignored
         | 
| 209 | 
            +
                        return None
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                    logger.info("(Background) scheduled commit triggered.")
         | 
| 212 | 
            +
                    try:
         | 
| 213 | 
            +
                        value = self.push_to_hub()
         | 
| 214 | 
            +
                        if self.squash_history:
         | 
| 215 | 
            +
                            logger.info("(Background) squashing repo history.")
         | 
| 216 | 
            +
                            self.api.super_squash_history(
         | 
| 217 | 
            +
                                repo_id=self.repo_id, repo_type=self.repo_type, branch=self.revision
         | 
| 218 | 
            +
                            )
         | 
| 219 | 
            +
                        return value
         | 
| 220 | 
            +
                    except Exception as e:
         | 
| 221 | 
            +
                        logger.error(
         | 
| 222 | 
            +
                            f"Error while pushing to Hub: {e}"
         | 
| 223 | 
            +
                        )  # Depending on the setup, error might be silenced
         | 
| 224 | 
            +
                        raise
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                def push_to_hub(self) -> Optional[CommitInfo]:
         | 
| 227 | 
            +
                    """
         | 
| 228 | 
            +
                    Push folder to the Hub and return the commit info.
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                    <Tip warning={true}>
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                    This method is not meant to be called directly. It is run in the background by the scheduler, respecting a
         | 
| 233 | 
            +
                    queue mechanism to avoid concurrent commits. Making a direct call to the method might lead to concurrency
         | 
| 234 | 
            +
                    issues.
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                    </Tip>
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                    The default behavior of `push_to_hub` is to assume an append-only folder. It lists all files in the folder and
         | 
| 239 | 
            +
                    uploads only changed files. If no changes are found, the method returns without committing anything. If you want
         | 
| 240 | 
            +
                    to change this behavior, you can inherit from [`CommitScheduler`] and override this method. This can be useful
         | 
| 241 | 
            +
                    for example to compress data together in a single file before committing. For more details and examples, check
         | 
| 242 | 
            +
                    out our [integration guide](https://huggingface.co/docs/huggingface_hub/main/en/guides/upload#scheduled-uploads).
         | 
| 243 | 
            +
                    """
         | 
| 244 | 
            +
                    # Check files to upload (with lock)
         | 
| 245 | 
            +
                    with self.lock:
         | 
| 246 | 
            +
                        if self.on_before_commit is not None:
         | 
| 247 | 
            +
                            self.on_before_commit()
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                        logger.debug("Listing files to upload for scheduled commit.")
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                        # List files from folder (taken from `_prepare_upload_folder_additions`)
         | 
| 252 | 
            +
                        relpath_to_abspath = {
         | 
| 253 | 
            +
                            path.relative_to(self.folder_path).as_posix(): path
         | 
| 254 | 
            +
                            for path in sorted(
         | 
| 255 | 
            +
                                self.folder_path.glob("**/*")
         | 
| 256 | 
            +
                            )  # sorted to be deterministic
         | 
| 257 | 
            +
                            if path.is_file()
         | 
| 258 | 
            +
                        }
         | 
| 259 | 
            +
                        prefix = f"{self.path_in_repo.strip('/')}/" if self.path_in_repo else ""
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                        # Filter with pattern + filter out unchanged files + retrieve current file size
         | 
| 262 | 
            +
                        files_to_upload: List[_FileToUpload] = []
         | 
| 263 | 
            +
                        for relpath in filter_repo_objects(
         | 
| 264 | 
            +
                            relpath_to_abspath.keys(),
         | 
| 265 | 
            +
                            allow_patterns=self.allow_patterns,
         | 
| 266 | 
            +
                            ignore_patterns=self.ignore_patterns,
         | 
| 267 | 
            +
                        ):
         | 
| 268 | 
            +
                            local_path = relpath_to_abspath[relpath]
         | 
| 269 | 
            +
                            stat = local_path.stat()
         | 
| 270 | 
            +
                            if (
         | 
| 271 | 
            +
                                self.last_uploaded.get(local_path) is None
         | 
| 272 | 
            +
                                or self.last_uploaded[local_path] != stat.st_mtime
         | 
| 273 | 
            +
                            ):
         | 
| 274 | 
            +
                                files_to_upload.append(
         | 
| 275 | 
            +
                                    _FileToUpload(
         | 
| 276 | 
            +
                                        local_path=local_path,
         | 
| 277 | 
            +
                                        path_in_repo=prefix + relpath,
         | 
| 278 | 
            +
                                        size_limit=stat.st_size,
         | 
| 279 | 
            +
                                        last_modified=stat.st_mtime,
         | 
| 280 | 
            +
                                    )
         | 
| 281 | 
            +
                                )
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                    # Return if nothing to upload
         | 
| 284 | 
            +
                    if len(files_to_upload) == 0:
         | 
| 285 | 
            +
                        logger.debug("Dropping schedule commit: no changed file to upload.")
         | 
| 286 | 
            +
                        return None
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                    # Convert `_FileToUpload` as `CommitOperationAdd` (=> compute file shas + limit to file size)
         | 
| 289 | 
            +
                    logger.debug("Removing unchanged files since previous scheduled commit.")
         | 
| 290 | 
            +
                    add_operations = [
         | 
| 291 | 
            +
                        CommitOperationAdd(
         | 
| 292 | 
            +
                            # TODO: Cap the file to its current size, even if the user append data to it while a scheduled commit is happening
         | 
| 293 | 
            +
                            # (requires an upstream fix for XET-535: `hf_xet` should support `BinaryIO` for upload)
         | 
| 294 | 
            +
                            path_or_fileobj=file_to_upload.local_path,
         | 
| 295 | 
            +
                            path_in_repo=file_to_upload.path_in_repo,
         | 
| 296 | 
            +
                        )
         | 
| 297 | 
            +
                        for file_to_upload in files_to_upload
         | 
| 298 | 
            +
                    ]
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                    # Upload files (append mode expected - no need for lock)
         | 
| 301 | 
            +
                    logger.debug("Uploading files for scheduled commit.")
         | 
| 302 | 
            +
                    commit_info = self.api.create_commit(
         | 
| 303 | 
            +
                        repo_id=self.repo_id,
         | 
| 304 | 
            +
                        repo_type=self.repo_type,
         | 
| 305 | 
            +
                        operations=add_operations,
         | 
| 306 | 
            +
                        commit_message="Scheduled Commit",
         | 
| 307 | 
            +
                        revision=self.revision,
         | 
| 308 | 
            +
                    )
         | 
| 309 | 
            +
             | 
| 310 | 
            +
                    # Successful commit: keep track of the latest "last_modified" for each file
         | 
| 311 | 
            +
                    for file in files_to_upload:
         | 
| 312 | 
            +
                        self.last_uploaded[file.local_path] = file.last_modified
         | 
| 313 | 
            +
                    return commit_info
         | 
| 314 | 
            +
             | 
| 315 | 
            +
             | 
| 316 | 
            +
            class PartialFileIO(BytesIO):
         | 
| 317 | 
            +
                """A file-like object that reads only the first part of a file.
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                Useful to upload a file to the Hub when the user might still be appending data to it. Only the first part of the
         | 
| 320 | 
            +
                file is uploaded (i.e. the part that was available when the filesystem was first scanned).
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                In practice, only used internally by the CommitScheduler to regularly push a folder to the Hub with minimal
         | 
| 323 | 
            +
                disturbance for the user. The object is passed to `CommitOperationAdd`.
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                Only supports `read`, `tell` and `seek` methods.
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                Args:
         | 
| 328 | 
            +
                    file_path (`str` or `Path`):
         | 
| 329 | 
            +
                        Path to the file to read.
         | 
| 330 | 
            +
                    size_limit (`int`):
         | 
| 331 | 
            +
                        The maximum number of bytes to read from the file. If the file is larger than this, only the first part
         | 
| 332 | 
            +
                        will be read (and uploaded).
         | 
| 333 | 
            +
                """
         | 
| 334 | 
            +
             | 
| 335 | 
            +
                def __init__(self, file_path: Union[str, Path], size_limit: int) -> None:
         | 
| 336 | 
            +
                    self._file_path = Path(file_path)
         | 
| 337 | 
            +
                    self._file = self._file_path.open("rb")
         | 
| 338 | 
            +
                    self._size_limit = min(size_limit, os.fstat(self._file.fileno()).st_size)
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                def __del__(self) -> None:
         | 
| 341 | 
            +
                    self._file.close()
         | 
| 342 | 
            +
                    return super().__del__()
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                def __repr__(self) -> str:
         | 
| 345 | 
            +
                    return (
         | 
| 346 | 
            +
                        f"<PartialFileIO file_path={self._file_path} size_limit={self._size_limit}>"
         | 
| 347 | 
            +
                    )
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                def __len__(self) -> int:
         | 
| 350 | 
            +
                    return self._size_limit
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                def __getattribute__(self, name: str):
         | 
| 353 | 
            +
                    if name.startswith("_") or name in (
         | 
| 354 | 
            +
                        "read",
         | 
| 355 | 
            +
                        "tell",
         | 
| 356 | 
            +
                        "seek",
         | 
| 357 | 
            +
                    ):  # only 3 public methods supported
         | 
| 358 | 
            +
                        return super().__getattribute__(name)
         | 
| 359 | 
            +
                    raise NotImplementedError(f"PartialFileIO does not support '{name}'.")
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                def tell(self) -> int:
         | 
| 362 | 
            +
                    """Return the current file position."""
         | 
| 363 | 
            +
                    return self._file.tell()
         | 
| 364 | 
            +
             | 
| 365 | 
            +
                def seek(self, __offset: int, __whence: int = SEEK_SET) -> int:
         | 
| 366 | 
            +
                    """Change the stream position to the given offset.
         | 
| 367 | 
            +
             | 
| 368 | 
            +
                    Behavior is the same as a regular file, except that the position is capped to the size limit.
         | 
| 369 | 
            +
                    """
         | 
| 370 | 
            +
                    if __whence == SEEK_END:
         | 
| 371 | 
            +
                        # SEEK_END => set from the truncated end
         | 
| 372 | 
            +
                        __offset = len(self) + __offset
         | 
| 373 | 
            +
                        __whence = SEEK_SET
         | 
| 374 | 
            +
             | 
| 375 | 
            +
                    pos = self._file.seek(__offset, __whence)
         | 
| 376 | 
            +
                    if pos > self._size_limit:
         | 
| 377 | 
            +
                        return self._file.seek(self._size_limit)
         | 
| 378 | 
            +
                    return pos
         | 
| 379 | 
            +
             | 
| 380 | 
            +
                def read(self, __size: Optional[int] = -1) -> bytes:
         | 
| 381 | 
            +
                    """Read at most `__size` bytes from the file.
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                    Behavior is the same as a regular file, except that it is capped to the size limit.
         | 
| 384 | 
            +
                    """
         | 
| 385 | 
            +
                    current = self._file.tell()
         | 
| 386 | 
            +
                    if __size is None or __size < 0:
         | 
| 387 | 
            +
                        # Read until file limit
         | 
| 388 | 
            +
                        truncated_size = self._size_limit - current
         | 
| 389 | 
            +
                    else:
         | 
| 390 | 
            +
                        # Read until file limit or __size
         | 
| 391 | 
            +
                        truncated_size = min(__size, self._size_limit - current)
         | 
| 392 | 
            +
                    return self._file.read(truncated_size)
         | 
    	
        context_vars.py
    ADDED
    
    | @@ -0,0 +1,15 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import contextvars
         | 
| 2 | 
            +
            from typing import TYPE_CHECKING
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            if TYPE_CHECKING:
         | 
| 5 | 
            +
                from trackio.run import Run
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            current_run: contextvars.ContextVar["Run | None"] = contextvars.ContextVar(
         | 
| 8 | 
            +
                "current_run", default=None
         | 
| 9 | 
            +
            )
         | 
| 10 | 
            +
            current_project: contextvars.ContextVar[str | None] = contextvars.ContextVar(
         | 
| 11 | 
            +
                "current_project", default=None
         | 
| 12 | 
            +
            )
         | 
| 13 | 
            +
            current_server: contextvars.ContextVar[str | None] = contextvars.ContextVar(
         | 
| 14 | 
            +
                "current_server", default=None
         | 
| 15 | 
            +
            )
         | 
    	
        deploy.py
    ADDED
    
    | @@ -0,0 +1,170 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import io
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import time
         | 
| 4 | 
            +
            from importlib.resources import files
         | 
| 5 | 
            +
            from pathlib import Path
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import gradio
         | 
| 8 | 
            +
            import huggingface_hub
         | 
| 9 | 
            +
            from gradio_client import Client, handle_file
         | 
| 10 | 
            +
            from httpx import ReadTimeout
         | 
| 11 | 
            +
            from huggingface_hub.errors import RepositoryNotFoundError
         | 
| 12 | 
            +
            from requests import HTTPError
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from trackio.sqlite_storage import SQLiteStorage
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            SPACE_URL = "https://huggingface.co/spaces/{space_id}"
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            def deploy_as_space(
         | 
| 20 | 
            +
                space_id: str,
         | 
| 21 | 
            +
                dataset_id: str | None = None,
         | 
| 22 | 
            +
            ):
         | 
| 23 | 
            +
                if (
         | 
| 24 | 
            +
                    os.getenv("SYSTEM") == "spaces"
         | 
| 25 | 
            +
                ):  # in case a repo with this function is uploaded to spaces
         | 
| 26 | 
            +
                    return
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                trackio_path = files("trackio")
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                hf_api = huggingface_hub.HfApi()
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                try:
         | 
| 33 | 
            +
                    huggingface_hub.create_repo(
         | 
| 34 | 
            +
                        space_id,
         | 
| 35 | 
            +
                        space_sdk="gradio",
         | 
| 36 | 
            +
                        repo_type="space",
         | 
| 37 | 
            +
                        exist_ok=True,
         | 
| 38 | 
            +
                    )
         | 
| 39 | 
            +
                except HTTPError as e:
         | 
| 40 | 
            +
                    if e.response.status_code in [401, 403]:  # unauthorized or forbidden
         | 
| 41 | 
            +
                        print("Need 'write' access token to create a Spaces repo.")
         | 
| 42 | 
            +
                        huggingface_hub.login(add_to_git_credential=False)
         | 
| 43 | 
            +
                        huggingface_hub.create_repo(
         | 
| 44 | 
            +
                            space_id,
         | 
| 45 | 
            +
                            space_sdk="gradio",
         | 
| 46 | 
            +
                            repo_type="space",
         | 
| 47 | 
            +
                            exist_ok=True,
         | 
| 48 | 
            +
                        )
         | 
| 49 | 
            +
                    else:
         | 
| 50 | 
            +
                        raise ValueError(f"Failed to create Space: {e}")
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                with open(Path(trackio_path, "README.md"), "r") as f:
         | 
| 53 | 
            +
                    readme_content = f.read()
         | 
| 54 | 
            +
                    readme_content = readme_content.replace("{GRADIO_VERSION}", gradio.__version__)
         | 
| 55 | 
            +
                    readme_buffer = io.BytesIO(readme_content.encode("utf-8"))
         | 
| 56 | 
            +
                    hf_api.upload_file(
         | 
| 57 | 
            +
                        path_or_fileobj=readme_buffer,
         | 
| 58 | 
            +
                        path_in_repo="README.md",
         | 
| 59 | 
            +
                        repo_id=space_id,
         | 
| 60 | 
            +
                        repo_type="space",
         | 
| 61 | 
            +
                    )
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                # We can assume pandas, gradio, and huggingface-hub are already installed in a Gradio Space.
         | 
| 64 | 
            +
                # Make sure necessary dependencies are installed by creating a requirements.txt.
         | 
| 65 | 
            +
                requirements_content = """
         | 
| 66 | 
            +
            pyarrow>=21.0
         | 
| 67 | 
            +
                """
         | 
| 68 | 
            +
                requirements_buffer = io.BytesIO(requirements_content.encode("utf-8"))
         | 
| 69 | 
            +
                hf_api.upload_file(
         | 
| 70 | 
            +
                    path_or_fileobj=requirements_buffer,
         | 
| 71 | 
            +
                    path_in_repo="requirements.txt",
         | 
| 72 | 
            +
                    repo_id=space_id,
         | 
| 73 | 
            +
                    repo_type="space",
         | 
| 74 | 
            +
                )
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                huggingface_hub.utils.disable_progress_bars()
         | 
| 77 | 
            +
                hf_api.upload_folder(
         | 
| 78 | 
            +
                    repo_id=space_id,
         | 
| 79 | 
            +
                    repo_type="space",
         | 
| 80 | 
            +
                    folder_path=trackio_path,
         | 
| 81 | 
            +
                    ignore_patterns=["README.md"],
         | 
| 82 | 
            +
                )
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                hf_token = huggingface_hub.utils.get_token()
         | 
| 85 | 
            +
                if hf_token is not None:
         | 
| 86 | 
            +
                    huggingface_hub.add_space_secret(space_id, "HF_TOKEN", hf_token)
         | 
| 87 | 
            +
                if dataset_id is not None:
         | 
| 88 | 
            +
                    huggingface_hub.add_space_variable(space_id, "TRACKIO_DATASET_ID", dataset_id)
         | 
| 89 | 
            +
             | 
| 90 | 
            +
             | 
| 91 | 
            +
            def create_space_if_not_exists(
         | 
| 92 | 
            +
                space_id: str,
         | 
| 93 | 
            +
                dataset_id: str | None = None,
         | 
| 94 | 
            +
            ) -> None:
         | 
| 95 | 
            +
                """
         | 
| 96 | 
            +
                Creates a new Hugging Face Space if it does not exist. If a dataset_id is provided, it will be added as a space variable.
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                Args:
         | 
| 99 | 
            +
                    space_id: The ID of the Space to create.
         | 
| 100 | 
            +
                    dataset_id: The ID of the Dataset to add to the Space.
         | 
| 101 | 
            +
                """
         | 
| 102 | 
            +
                if "/" not in space_id:
         | 
| 103 | 
            +
                    raise ValueError(
         | 
| 104 | 
            +
                        f"Invalid space ID: {space_id}. Must be in the format: username/reponame or orgname/reponame."
         | 
| 105 | 
            +
                    )
         | 
| 106 | 
            +
                if dataset_id is not None and "/" not in dataset_id:
         | 
| 107 | 
            +
                    raise ValueError(
         | 
| 108 | 
            +
                        f"Invalid dataset ID: {dataset_id}. Must be in the format: username/datasetname or orgname/datasetname."
         | 
| 109 | 
            +
                    )
         | 
| 110 | 
            +
                try:
         | 
| 111 | 
            +
                    huggingface_hub.repo_info(space_id, repo_type="space")
         | 
| 112 | 
            +
                    print(f"* Found existing space: {SPACE_URL.format(space_id=space_id)}")
         | 
| 113 | 
            +
                    if dataset_id is not None:
         | 
| 114 | 
            +
                        huggingface_hub.add_space_variable(
         | 
| 115 | 
            +
                            space_id, "TRACKIO_DATASET_ID", dataset_id
         | 
| 116 | 
            +
                        )
         | 
| 117 | 
            +
                    return
         | 
| 118 | 
            +
                except RepositoryNotFoundError:
         | 
| 119 | 
            +
                    pass
         | 
| 120 | 
            +
                except HTTPError as e:
         | 
| 121 | 
            +
                    if e.response.status_code in [401, 403]:  # unauthorized or forbidden
         | 
| 122 | 
            +
                        print("Need 'write' access token to create a Spaces repo.")
         | 
| 123 | 
            +
                        huggingface_hub.login(add_to_git_credential=False)
         | 
| 124 | 
            +
                        huggingface_hub.add_space_variable(
         | 
| 125 | 
            +
                            space_id, "TRACKIO_DATASET_ID", dataset_id
         | 
| 126 | 
            +
                        )
         | 
| 127 | 
            +
                    else:
         | 
| 128 | 
            +
                        raise ValueError(f"Failed to create Space: {e}")
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                print(f"* Creating new space: {SPACE_URL.format(space_id=space_id)}")
         | 
| 131 | 
            +
                deploy_as_space(space_id, dataset_id)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
             | 
| 134 | 
            +
            def wait_until_space_exists(
         | 
| 135 | 
            +
                space_id: str,
         | 
| 136 | 
            +
            ) -> None:
         | 
| 137 | 
            +
                """
         | 
| 138 | 
            +
                Blocks the current thread until the space exists.
         | 
| 139 | 
            +
                May raise a TimeoutError if this takes quite a while.
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                Args:
         | 
| 142 | 
            +
                    space_id: The ID of the Space to wait for.
         | 
| 143 | 
            +
                """
         | 
| 144 | 
            +
                delay = 1
         | 
| 145 | 
            +
                for _ in range(10):
         | 
| 146 | 
            +
                    try:
         | 
| 147 | 
            +
                        Client(space_id, verbose=False)
         | 
| 148 | 
            +
                        return
         | 
| 149 | 
            +
                    except (ReadTimeout, ValueError):
         | 
| 150 | 
            +
                        time.sleep(delay)
         | 
| 151 | 
            +
                        delay = min(delay * 2, 30)
         | 
| 152 | 
            +
                raise TimeoutError("Waiting for space to exist took longer than expected")
         | 
| 153 | 
            +
             | 
| 154 | 
            +
             | 
| 155 | 
            +
            def upload_db_to_space(project: str, space_id: str) -> None:
         | 
| 156 | 
            +
                """
         | 
| 157 | 
            +
                Uploads the database of a local Trackio project to a Hugging Face Space.
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                Args:
         | 
| 160 | 
            +
                    project: The name of the project to upload.
         | 
| 161 | 
            +
                    space_id: The ID of the Space to upload to.
         | 
| 162 | 
            +
                """
         | 
| 163 | 
            +
                db_path = SQLiteStorage.get_project_db_path(project)
         | 
| 164 | 
            +
                client = Client(space_id, verbose=False)
         | 
| 165 | 
            +
                client.predict(
         | 
| 166 | 
            +
                    api_name="/upload_db_to_space",
         | 
| 167 | 
            +
                    project=project,
         | 
| 168 | 
            +
                    uploaded_db=handle_file(db_path),
         | 
| 169 | 
            +
                    hf_token=huggingface_hub.utils.get_token(),
         | 
| 170 | 
            +
                )
         | 
    	
        dummy_commit_scheduler.py
    ADDED
    
    | @@ -0,0 +1,12 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # A dummy object to fit the interface of huggingface_hub's CommitScheduler
         | 
| 2 | 
            +
            class DummyCommitSchedulerLock:
         | 
| 3 | 
            +
                def __enter__(self):
         | 
| 4 | 
            +
                    return None
         | 
| 5 | 
            +
             | 
| 6 | 
            +
                def __exit__(self, exception_type, exception_value, exception_traceback):
         | 
| 7 | 
            +
                    pass
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            class DummyCommitScheduler:
         | 
| 11 | 
            +
                def __init__(self):
         | 
| 12 | 
            +
                    self.lock = DummyCommitSchedulerLock()
         | 
    	
        imports.py
    ADDED
    
    | @@ -0,0 +1,245 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            from pathlib import Path
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import pandas as pd
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from trackio import deploy, utils
         | 
| 7 | 
            +
            from trackio.sqlite_storage import SQLiteStorage
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            def import_csv(
         | 
| 11 | 
            +
                csv_path: str | Path,
         | 
| 12 | 
            +
                project: str,
         | 
| 13 | 
            +
                name: str | None = None,
         | 
| 14 | 
            +
                space_id: str | None = None,
         | 
| 15 | 
            +
                dataset_id: str | None = None,
         | 
| 16 | 
            +
            ) -> None:
         | 
| 17 | 
            +
                """
         | 
| 18 | 
            +
                Imports a CSV file into a Trackio project. The CSV file must contain a "step" column, may optionally
         | 
| 19 | 
            +
                contain a "timestamp" column, and any other columns will be treated as metrics. It should also include
         | 
| 20 | 
            +
                a header row with the column names.
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                TODO: call init() and return a Run object so that the user can continue to log metrics to it.
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                Args:
         | 
| 25 | 
            +
                    csv_path: The str or Path to the CSV file to import.
         | 
| 26 | 
            +
                    project: The name of the project to import the CSV file into. Must not be an existing project.
         | 
| 27 | 
            +
                    name: The name of the Run to import the CSV file into. If not provided, a default name will be generated.
         | 
| 28 | 
            +
                    name: The name of the run (if not provided, a default name will be generated).
         | 
| 29 | 
            +
                    space_id: If provided, the project will be logged to a Hugging Face Space instead of a local directory. Should be a complete Space name like "username/reponame" or "orgname/reponame", or just "reponame" in which case the Space will be created in the currently-logged-in Hugging Face user's namespace. If the Space does not exist, it will be created. If the Space already exists, the project will be logged to it.
         | 
| 30 | 
            +
                    dataset_id: If provided, a persistent Hugging Face Dataset will be created and the metrics will be synced to it every 5 minutes. Should be a complete Dataset name like "username/datasetname" or "orgname/datasetname", or just "datasetname" in which case the Dataset will be created in the currently-logged-in Hugging Face user's namespace. If the Dataset does not exist, it will be created. If the Dataset already exists, the project will be appended to it. If not provided, the metrics will be logged to a local SQLite database, unless a `space_id` is provided, in which case a Dataset will be automatically created with the same name as the Space but with the "_dataset" suffix.
         | 
| 31 | 
            +
                """
         | 
| 32 | 
            +
                if SQLiteStorage.get_runs(project):
         | 
| 33 | 
            +
                    raise ValueError(
         | 
| 34 | 
            +
                        f"Project '{project}' already exists. Cannot import CSV into existing project."
         | 
| 35 | 
            +
                    )
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                csv_path = Path(csv_path)
         | 
| 38 | 
            +
                if not csv_path.exists():
         | 
| 39 | 
            +
                    raise FileNotFoundError(f"CSV file not found: {csv_path}")
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                df = pd.read_csv(csv_path)
         | 
| 42 | 
            +
                if df.empty:
         | 
| 43 | 
            +
                    raise ValueError("CSV file is empty")
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                column_mapping = utils.simplify_column_names(df.columns.tolist())
         | 
| 46 | 
            +
                df = df.rename(columns=column_mapping)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                step_column = None
         | 
| 49 | 
            +
                for col in df.columns:
         | 
| 50 | 
            +
                    if col.lower() == "step":
         | 
| 51 | 
            +
                        step_column = col
         | 
| 52 | 
            +
                        break
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                if step_column is None:
         | 
| 55 | 
            +
                    raise ValueError("CSV file must contain a 'step' or 'Step' column")
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                if name is None:
         | 
| 58 | 
            +
                    name = csv_path.stem
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                metrics_list = []
         | 
| 61 | 
            +
                steps = []
         | 
| 62 | 
            +
                timestamps = []
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                numeric_columns = []
         | 
| 65 | 
            +
                for column in df.columns:
         | 
| 66 | 
            +
                    if column == step_column:
         | 
| 67 | 
            +
                        continue
         | 
| 68 | 
            +
                    if column == "timestamp":
         | 
| 69 | 
            +
                        continue
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    try:
         | 
| 72 | 
            +
                        pd.to_numeric(df[column], errors="raise")
         | 
| 73 | 
            +
                        numeric_columns.append(column)
         | 
| 74 | 
            +
                    except (ValueError, TypeError):
         | 
| 75 | 
            +
                        continue
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                for _, row in df.iterrows():
         | 
| 78 | 
            +
                    metrics = {}
         | 
| 79 | 
            +
                    for column in numeric_columns:
         | 
| 80 | 
            +
                        value = row[column]
         | 
| 81 | 
            +
                        if bool(pd.notna(value)):
         | 
| 82 | 
            +
                            metrics[column] = float(value)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    if metrics:
         | 
| 85 | 
            +
                        metrics_list.append(metrics)
         | 
| 86 | 
            +
                        steps.append(int(row[step_column]))
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                        if "timestamp" in df.columns and bool(pd.notna(row["timestamp"])):
         | 
| 89 | 
            +
                            timestamps.append(str(row["timestamp"]))
         | 
| 90 | 
            +
                        else:
         | 
| 91 | 
            +
                            timestamps.append("")
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                if metrics_list:
         | 
| 94 | 
            +
                    SQLiteStorage.bulk_log(
         | 
| 95 | 
            +
                        project=project,
         | 
| 96 | 
            +
                        run=name,
         | 
| 97 | 
            +
                        metrics_list=metrics_list,
         | 
| 98 | 
            +
                        steps=steps,
         | 
| 99 | 
            +
                        timestamps=timestamps,
         | 
| 100 | 
            +
                    )
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                print(
         | 
| 103 | 
            +
                    f"* Imported {len(metrics_list)} rows from {csv_path} into project '{project}' as run '{name}'"
         | 
| 104 | 
            +
                )
         | 
| 105 | 
            +
                print(f"* Metrics found: {', '.join(metrics_list[0].keys())}")
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                space_id, dataset_id = utils.preprocess_space_and_dataset_ids(space_id, dataset_id)
         | 
| 108 | 
            +
                if dataset_id is not None:
         | 
| 109 | 
            +
                    os.environ["TRACKIO_DATASET_ID"] = dataset_id
         | 
| 110 | 
            +
                    print(f"* Trackio metrics will be synced to Hugging Face Dataset: {dataset_id}")
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                if space_id is None:
         | 
| 113 | 
            +
                    utils.print_dashboard_instructions(project)
         | 
| 114 | 
            +
                else:
         | 
| 115 | 
            +
                    deploy.create_space_if_not_exists(space_id, dataset_id)
         | 
| 116 | 
            +
                    deploy.wait_until_space_exists(space_id)
         | 
| 117 | 
            +
                    deploy.upload_db_to_space(project, space_id)
         | 
| 118 | 
            +
                    print(
         | 
| 119 | 
            +
                        f"* View dashboard by going to: {deploy.SPACE_URL.format(space_id=space_id)}"
         | 
| 120 | 
            +
                    )
         | 
| 121 | 
            +
             | 
| 122 | 
            +
             | 
| 123 | 
            +
            def import_tf_events(
         | 
| 124 | 
            +
                log_dir: str | Path,
         | 
| 125 | 
            +
                project: str,
         | 
| 126 | 
            +
                name: str | None = None,
         | 
| 127 | 
            +
                space_id: str | None = None,
         | 
| 128 | 
            +
                dataset_id: str | None = None,
         | 
| 129 | 
            +
            ) -> None:
         | 
| 130 | 
            +
                """
         | 
| 131 | 
            +
                Imports TensorFlow Events files from a directory into a Trackio project.
         | 
| 132 | 
            +
                Each subdirectory in the log directory will be imported as a separate run.
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                Args:
         | 
| 135 | 
            +
                    log_dir: The str or Path to the directory containing TensorFlow Events files.
         | 
| 136 | 
            +
                    project: The name of the project to import the TensorFlow Events files into. Must not be an existing project.
         | 
| 137 | 
            +
                    name: The name prefix for runs (if not provided, will use directory names). Each subdirectory will create a separate run.
         | 
| 138 | 
            +
                    space_id: If provided, the project will be logged to a Hugging Face Space instead of a local directory. Should be a complete Space name like username/reponame" ororgname/reponame", or just "reponame" in which case the Space will be created in the currently-logged-in Hugging Face user's namespace. If the Space does not exist, it will be created. If the Space already exists, the project will be logged to it.
         | 
| 139 | 
            +
                    dataset_id: If provided, a persistent Hugging Face Dataset will be created and the metrics will be synced to it every 5 minutes. Should be a complete Dataset name likeusername/datasetname" or "orgname/datasetname", or just "datasetname" in which case the Dataset will be created in the currently-logged-in Hugging Face user's namespace. If the Dataset does not exist, it will be created. If the Dataset already exists, the project will be appended to it. If not provided, the metrics will be logged to a local SQLite database, unless a `space_id` is provided, in which case a Dataset will be automatically created with the same name as the Space but with the_dataset suffix.
         | 
| 140 | 
            +
                """
         | 
| 141 | 
            +
                try:
         | 
| 142 | 
            +
                    from tbparse import SummaryReader
         | 
| 143 | 
            +
                except ImportError:
         | 
| 144 | 
            +
                    raise ImportError(
         | 
| 145 | 
            +
                        "The `tbparse` package is not installed but is required for `import_tf_events`. Please install trackio with the `tensorboard` extra: `pip install trackio[tensorboard]`."
         | 
| 146 | 
            +
                    )
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                if SQLiteStorage.get_runs(project):
         | 
| 149 | 
            +
                    raise ValueError(
         | 
| 150 | 
            +
                        f"Project '{project}' already exists. Cannot import TF events into existing project."
         | 
| 151 | 
            +
                    )
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                path = Path(log_dir)
         | 
| 154 | 
            +
                if not path.exists():
         | 
| 155 | 
            +
                    raise FileNotFoundError(f"TF events directory not found: {path}")
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                # Use tbparse to read all tfevents files in the directory structure
         | 
| 158 | 
            +
                reader = SummaryReader(str(path), extra_columns={"dir_name"})
         | 
| 159 | 
            +
                df = reader.scalars
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                if df.empty:
         | 
| 162 | 
            +
                    raise ValueError(f"No TensorFlow events data found in {path}")
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                total_imported = 0
         | 
| 165 | 
            +
                imported_runs = []
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                # Group by dir_name to create separate runs
         | 
| 168 | 
            +
                for dir_name, group_df in df.groupby("dir_name"):
         | 
| 169 | 
            +
                    try:
         | 
| 170 | 
            +
                        # Determine run name based on directory name
         | 
| 171 | 
            +
                        if dir_name == "":
         | 
| 172 | 
            +
                            run_name = "main"  # For files in the root directory
         | 
| 173 | 
            +
                        else:
         | 
| 174 | 
            +
                            run_name = dir_name  # Use directory name
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                        if name:
         | 
| 177 | 
            +
                            run_name = f"{name}_{run_name}"
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                        if group_df.empty:
         | 
| 180 | 
            +
                            print(f"* Skipping directory {dir_name}: no scalar data found")
         | 
| 181 | 
            +
                            continue
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                        metrics_list = []
         | 
| 184 | 
            +
                        steps = []
         | 
| 185 | 
            +
                        timestamps = []
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                        for _, row in group_df.iterrows():
         | 
| 188 | 
            +
                            # Convert row values to appropriate types
         | 
| 189 | 
            +
                            tag = str(row["tag"])
         | 
| 190 | 
            +
                            value = float(row["value"])
         | 
| 191 | 
            +
                            step = int(row["step"])
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                            metrics = {tag: value}
         | 
| 194 | 
            +
                            metrics_list.append(metrics)
         | 
| 195 | 
            +
                            steps.append(step)
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                            # Use wall_time if present, else fallback
         | 
| 198 | 
            +
                            if "wall_time" in group_df.columns and not bool(
         | 
| 199 | 
            +
                                pd.isna(row["wall_time"])
         | 
| 200 | 
            +
                            ):
         | 
| 201 | 
            +
                                timestamps.append(str(row["wall_time"]))
         | 
| 202 | 
            +
                            else:
         | 
| 203 | 
            +
                                timestamps.append("")
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                        if metrics_list:
         | 
| 206 | 
            +
                            SQLiteStorage.bulk_log(
         | 
| 207 | 
            +
                                project=project,
         | 
| 208 | 
            +
                                run=str(run_name),
         | 
| 209 | 
            +
                                metrics_list=metrics_list,
         | 
| 210 | 
            +
                                steps=steps,
         | 
| 211 | 
            +
                                timestamps=timestamps,
         | 
| 212 | 
            +
                            )
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                            total_imported += len(metrics_list)
         | 
| 215 | 
            +
                            imported_runs.append(run_name)
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                            print(
         | 
| 218 | 
            +
                                f"* Imported {len(metrics_list)} scalar events from directory '{dir_name}' as run '{run_name}'"
         | 
| 219 | 
            +
                            )
         | 
| 220 | 
            +
                            print(f"* Metrics in this run: {', '.join(set(group_df['tag']))}")
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                    except Exception as e:
         | 
| 223 | 
            +
                        print(f"* Error processing directory {dir_name}: {e}")
         | 
| 224 | 
            +
                        continue
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                if not imported_runs:
         | 
| 227 | 
            +
                    raise ValueError("No valid TensorFlow events data could be imported")
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                print(f"* Total imported events: {total_imported}")
         | 
| 230 | 
            +
                print(f"* Created runs: {', '.join(imported_runs)}")
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                space_id, dataset_id = utils.preprocess_space_and_dataset_ids(space_id, dataset_id)
         | 
| 233 | 
            +
                if dataset_id is not None:
         | 
| 234 | 
            +
                    os.environ["TRACKIO_DATASET_ID"] = dataset_id
         | 
| 235 | 
            +
                    print(f"* Trackio metrics will be synced to Hugging Face Dataset: {dataset_id}")
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                if space_id is None:
         | 
| 238 | 
            +
                    utils.print_dashboard_instructions(project)
         | 
| 239 | 
            +
                else:
         | 
| 240 | 
            +
                    deploy.create_space_if_not_exists(space_id, dataset_id)
         | 
| 241 | 
            +
                    deploy.wait_until_space_exists(space_id)
         | 
| 242 | 
            +
                    deploy.upload_db_to_space(project, space_id)
         | 
| 243 | 
            +
                    print(
         | 
| 244 | 
            +
                        f"* View dashboard by going to: {deploy.SPACE_URL.format(space_id=space_id)}"
         | 
| 245 | 
            +
                    )
         | 
    	
        py.typed
    ADDED
    
    | 
            File without changes
         | 
    	
        run.py
    ADDED
    
    | @@ -0,0 +1,100 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import threading
         | 
| 2 | 
            +
            import time
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import huggingface_hub
         | 
| 5 | 
            +
            from gradio_client import Client
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from trackio.sqlite_storage import SQLiteStorage
         | 
| 8 | 
            +
            from trackio.typehints import LogEntry
         | 
| 9 | 
            +
            from trackio.utils import RESERVED_KEYS, fibo, generate_readable_name
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class Run:
         | 
| 13 | 
            +
                def __init__(
         | 
| 14 | 
            +
                    self,
         | 
| 15 | 
            +
                    url: str,
         | 
| 16 | 
            +
                    project: str,
         | 
| 17 | 
            +
                    client: Client | None,
         | 
| 18 | 
            +
                    name: str | None = None,
         | 
| 19 | 
            +
                    config: dict | None = None,
         | 
| 20 | 
            +
                ):
         | 
| 21 | 
            +
                    self.url = url
         | 
| 22 | 
            +
                    self.project = project
         | 
| 23 | 
            +
                    self._client_lock = threading.Lock()
         | 
| 24 | 
            +
                    self._client_thread = None
         | 
| 25 | 
            +
                    self._client = client
         | 
| 26 | 
            +
                    self.name = name or generate_readable_name(SQLiteStorage.get_runs(project))
         | 
| 27 | 
            +
                    self.config = config or {}
         | 
| 28 | 
            +
                    self._queued_logs: list[LogEntry] = []
         | 
| 29 | 
            +
                    self._stop_flag = threading.Event()
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                    self._client_thread = threading.Thread(target=self._init_client_background)
         | 
| 32 | 
            +
                    self._client_thread.daemon = True
         | 
| 33 | 
            +
                    self._client_thread.start()
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                def _batch_sender(self):
         | 
| 36 | 
            +
                    """Send batched logs every 500ms."""
         | 
| 37 | 
            +
                    while not self._stop_flag.is_set():
         | 
| 38 | 
            +
                        time.sleep(0.5)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                        with self._client_lock:
         | 
| 41 | 
            +
                            if self._queued_logs and self._client is not None:
         | 
| 42 | 
            +
                                logs_to_send = self._queued_logs.copy()
         | 
| 43 | 
            +
                                self._queued_logs.clear()
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                                self._client.predict(
         | 
| 46 | 
            +
                                    api_name="/bulk_log",
         | 
| 47 | 
            +
                                    logs=logs_to_send,
         | 
| 48 | 
            +
                                    hf_token=huggingface_hub.utils.get_token(),
         | 
| 49 | 
            +
                                )
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                def _init_client_background(self):
         | 
| 52 | 
            +
                    if self._client is None:
         | 
| 53 | 
            +
                        fib = fibo()
         | 
| 54 | 
            +
                        for sleep_coefficient in fib:
         | 
| 55 | 
            +
                            try:
         | 
| 56 | 
            +
                                client = Client(self.url, verbose=False)
         | 
| 57 | 
            +
                                with self._client_lock:
         | 
| 58 | 
            +
                                    self._client = client
         | 
| 59 | 
            +
                                break
         | 
| 60 | 
            +
                            except Exception:
         | 
| 61 | 
            +
                                pass
         | 
| 62 | 
            +
                            if sleep_coefficient is not None:
         | 
| 63 | 
            +
                                time.sleep(0.1 * sleep_coefficient)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    self._batch_sender()
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                def log(self, metrics: dict, step: int | None = None):
         | 
| 68 | 
            +
                    for k in metrics.keys():
         | 
| 69 | 
            +
                        if k in RESERVED_KEYS or k.startswith("__"):
         | 
| 70 | 
            +
                            raise ValueError(
         | 
| 71 | 
            +
                                f"Please do not use this reserved key as a metric: {k}"
         | 
| 72 | 
            +
                            )
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    log_entry: LogEntry = {
         | 
| 75 | 
            +
                        "project": self.project,
         | 
| 76 | 
            +
                        "run": self.name,
         | 
| 77 | 
            +
                        "metrics": metrics,
         | 
| 78 | 
            +
                        "step": step,
         | 
| 79 | 
            +
                    }
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    with self._client_lock:
         | 
| 82 | 
            +
                        self._queued_logs.append(log_entry)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                def finish(self):
         | 
| 85 | 
            +
                    """Cleanup when run is finished."""
         | 
| 86 | 
            +
                    self._stop_flag.set()
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    with self._client_lock:
         | 
| 89 | 
            +
                        if self._queued_logs and self._client is not None:
         | 
| 90 | 
            +
                            logs_to_send = self._queued_logs.copy()
         | 
| 91 | 
            +
                            self._queued_logs.clear()
         | 
| 92 | 
            +
                            self._client.predict(
         | 
| 93 | 
            +
                                api_name="/bulk_log",
         | 
| 94 | 
            +
                                logs=logs_to_send,
         | 
| 95 | 
            +
                                hf_token=huggingface_hub.utils.get_token(),
         | 
| 96 | 
            +
                            )
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    if self._client_thread is not None:
         | 
| 99 | 
            +
                        print(f"* Uploading logs to Trackio Space: {self.url} (please wait...)")
         | 
| 100 | 
            +
                        self._client_thread.join(timeout=30)
         | 
    	
        sqlite_storage.py
    ADDED
    
    | @@ -0,0 +1,384 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import json
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import sqlite3
         | 
| 4 | 
            +
            from datetime import datetime
         | 
| 5 | 
            +
            from pathlib import Path
         | 
| 6 | 
            +
            from threading import Lock
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import huggingface_hub as hf
         | 
| 9 | 
            +
            import pandas as pd
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            try:  # absolute imports when installed
         | 
| 12 | 
            +
                from trackio.commit_scheduler import CommitScheduler
         | 
| 13 | 
            +
                from trackio.dummy_commit_scheduler import DummyCommitScheduler
         | 
| 14 | 
            +
                from trackio.utils import TRACKIO_DIR
         | 
| 15 | 
            +
            except Exception:  # relative imports for local execution on Spaces
         | 
| 16 | 
            +
                from commit_scheduler import CommitScheduler
         | 
| 17 | 
            +
                from dummy_commit_scheduler import DummyCommitScheduler
         | 
| 18 | 
            +
                from utils import TRACKIO_DIR
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            class SQLiteStorage:
         | 
| 22 | 
            +
                _dataset_import_attempted = False
         | 
| 23 | 
            +
                _current_scheduler: CommitScheduler | DummyCommitScheduler | None = None
         | 
| 24 | 
            +
                _scheduler_lock = Lock()
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                @staticmethod
         | 
| 27 | 
            +
                def _get_connection(db_path: Path) -> sqlite3.Connection:
         | 
| 28 | 
            +
                    conn = sqlite3.connect(str(db_path))
         | 
| 29 | 
            +
                    conn.row_factory = sqlite3.Row
         | 
| 30 | 
            +
                    return conn
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                @staticmethod
         | 
| 33 | 
            +
                def get_project_db_filename(project: str) -> Path:
         | 
| 34 | 
            +
                    """Get the database filename for a specific project."""
         | 
| 35 | 
            +
                    safe_project_name = "".join(
         | 
| 36 | 
            +
                        c for c in project if c.isalnum() or c in ("-", "_")
         | 
| 37 | 
            +
                    ).rstrip()
         | 
| 38 | 
            +
                    if not safe_project_name:
         | 
| 39 | 
            +
                        safe_project_name = "default"
         | 
| 40 | 
            +
                    return f"{safe_project_name}.db"
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                @staticmethod
         | 
| 43 | 
            +
                def get_project_db_path(project: str) -> Path:
         | 
| 44 | 
            +
                    """Get the database path for a specific project."""
         | 
| 45 | 
            +
                    filename = SQLiteStorage.get_project_db_filename(project)
         | 
| 46 | 
            +
                    return TRACKIO_DIR / filename
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                @staticmethod
         | 
| 49 | 
            +
                def init_db(project: str) -> Path:
         | 
| 50 | 
            +
                    """
         | 
| 51 | 
            +
                    Initialize the SQLite database with required tables.
         | 
| 52 | 
            +
                    If there is a dataset ID provided, copies from that dataset instead.
         | 
| 53 | 
            +
                    Returns the database path.
         | 
| 54 | 
            +
                    """
         | 
| 55 | 
            +
                    db_path = SQLiteStorage.get_project_db_path(project)
         | 
| 56 | 
            +
                    db_path.parent.mkdir(parents=True, exist_ok=True)
         | 
| 57 | 
            +
                    with SQLiteStorage.get_scheduler().lock:
         | 
| 58 | 
            +
                        with sqlite3.connect(db_path) as conn:
         | 
| 59 | 
            +
                            cursor = conn.cursor()
         | 
| 60 | 
            +
                            cursor.execute("""
         | 
| 61 | 
            +
                                CREATE TABLE IF NOT EXISTS metrics (
         | 
| 62 | 
            +
                                    id INTEGER PRIMARY KEY AUTOINCREMENT,
         | 
| 63 | 
            +
                                    timestamp TEXT NOT NULL,
         | 
| 64 | 
            +
                                    run_name TEXT NOT NULL,
         | 
| 65 | 
            +
                                    step INTEGER NOT NULL,
         | 
| 66 | 
            +
                                    metrics TEXT NOT NULL
         | 
| 67 | 
            +
                                )
         | 
| 68 | 
            +
                            """)
         | 
| 69 | 
            +
                            cursor.execute(
         | 
| 70 | 
            +
                                """
         | 
| 71 | 
            +
                                CREATE INDEX IF NOT EXISTS idx_metrics_run_step
         | 
| 72 | 
            +
                                ON metrics(run_name, step)
         | 
| 73 | 
            +
                                """
         | 
| 74 | 
            +
                            )
         | 
| 75 | 
            +
                            conn.commit()
         | 
| 76 | 
            +
                    return db_path
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                @staticmethod
         | 
| 79 | 
            +
                def export_to_parquet():
         | 
| 80 | 
            +
                    """
         | 
| 81 | 
            +
                    Exports all projects' DB files as Parquet under the same path but with extension ".parquet".
         | 
| 82 | 
            +
                    """
         | 
| 83 | 
            +
                    # don't attempt to export (potentially wrong/blank) data before importing for the first time
         | 
| 84 | 
            +
                    if not SQLiteStorage._dataset_import_attempted:
         | 
| 85 | 
            +
                        return
         | 
| 86 | 
            +
                    all_paths = os.listdir(TRACKIO_DIR)
         | 
| 87 | 
            +
                    db_paths = [f for f in all_paths if f.endswith(".db")]
         | 
| 88 | 
            +
                    for db_path in db_paths:
         | 
| 89 | 
            +
                        db_path = TRACKIO_DIR / db_path
         | 
| 90 | 
            +
                        parquet_path = db_path.with_suffix(".parquet")
         | 
| 91 | 
            +
                        if (not parquet_path.exists()) or (
         | 
| 92 | 
            +
                            db_path.stat().st_mtime > parquet_path.stat().st_mtime
         | 
| 93 | 
            +
                        ):
         | 
| 94 | 
            +
                            with sqlite3.connect(db_path) as conn:
         | 
| 95 | 
            +
                                df = pd.read_sql("SELECT * from metrics", conn)
         | 
| 96 | 
            +
                            # break out the single JSON metrics column into individual columns
         | 
| 97 | 
            +
                            metrics = df["metrics"].copy()
         | 
| 98 | 
            +
                            metrics = pd.DataFrame(
         | 
| 99 | 
            +
                                metrics.apply(json.loads).values.tolist(), index=df.index
         | 
| 100 | 
            +
                            )
         | 
| 101 | 
            +
                            del df["metrics"]
         | 
| 102 | 
            +
                            for col in metrics.columns:
         | 
| 103 | 
            +
                                df[col] = metrics[col]
         | 
| 104 | 
            +
                            df.to_parquet(parquet_path)
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                @staticmethod
         | 
| 107 | 
            +
                def import_from_parquet():
         | 
| 108 | 
            +
                    """
         | 
| 109 | 
            +
                    Imports to all DB files that have matching files under the same path but with extension ".parquet".
         | 
| 110 | 
            +
                    """
         | 
| 111 | 
            +
                    all_paths = os.listdir(TRACKIO_DIR)
         | 
| 112 | 
            +
                    parquet_paths = [f for f in all_paths if f.endswith(".parquet")]
         | 
| 113 | 
            +
                    for parquet_path in parquet_paths:
         | 
| 114 | 
            +
                        parquet_path = TRACKIO_DIR / parquet_path
         | 
| 115 | 
            +
                        db_path = parquet_path.with_suffix(".db")
         | 
| 116 | 
            +
                        df = pd.read_parquet(parquet_path)
         | 
| 117 | 
            +
                        with sqlite3.connect(db_path) as conn:
         | 
| 118 | 
            +
                            # fix up df to have a single JSON metrics column
         | 
| 119 | 
            +
                            if "metrics" not in df.columns:
         | 
| 120 | 
            +
                                # separate other columns from metrics
         | 
| 121 | 
            +
                                metrics = df.copy()
         | 
| 122 | 
            +
                                other_cols = ["id", "timestamp", "run_name", "step"]
         | 
| 123 | 
            +
                                df = df[other_cols]
         | 
| 124 | 
            +
                                for col in other_cols:
         | 
| 125 | 
            +
                                    del metrics[col]
         | 
| 126 | 
            +
                                # combine them all into a single metrics col
         | 
| 127 | 
            +
                                metrics = json.loads(metrics.to_json(orient="records"))
         | 
| 128 | 
            +
                                df["metrics"] = [json.dumps(row) for row in metrics]
         | 
| 129 | 
            +
                            df.to_sql("metrics", conn, if_exists="replace", index=False)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                @staticmethod
         | 
| 132 | 
            +
                def get_scheduler():
         | 
| 133 | 
            +
                    """
         | 
| 134 | 
            +
                    Get the scheduler for the database based on the environment variables.
         | 
| 135 | 
            +
                    This applies to both local and Spaces.
         | 
| 136 | 
            +
                    """
         | 
| 137 | 
            +
                    with SQLiteStorage._scheduler_lock:
         | 
| 138 | 
            +
                        if SQLiteStorage._current_scheduler is not None:
         | 
| 139 | 
            +
                            return SQLiteStorage._current_scheduler
         | 
| 140 | 
            +
                        hf_token = os.environ.get("HF_TOKEN")
         | 
| 141 | 
            +
                        dataset_id = os.environ.get("TRACKIO_DATASET_ID")
         | 
| 142 | 
            +
                        space_repo_name = os.environ.get("SPACE_REPO_NAME")
         | 
| 143 | 
            +
                        if dataset_id is None or space_repo_name is None:
         | 
| 144 | 
            +
                            scheduler = DummyCommitScheduler()
         | 
| 145 | 
            +
                        else:
         | 
| 146 | 
            +
                            scheduler = CommitScheduler(
         | 
| 147 | 
            +
                                repo_id=dataset_id,
         | 
| 148 | 
            +
                                repo_type="dataset",
         | 
| 149 | 
            +
                                folder_path=TRACKIO_DIR,
         | 
| 150 | 
            +
                                private=True,
         | 
| 151 | 
            +
                                allow_patterns="*.parquet",
         | 
| 152 | 
            +
                                squash_history=True,
         | 
| 153 | 
            +
                                token=hf_token,
         | 
| 154 | 
            +
                                on_before_commit=SQLiteStorage.export_to_parquet,
         | 
| 155 | 
            +
                            )
         | 
| 156 | 
            +
                        SQLiteStorage._current_scheduler = scheduler
         | 
| 157 | 
            +
                        return scheduler
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                @staticmethod
         | 
| 160 | 
            +
                def log(project: str, run: str, metrics: dict, step: int | None = None):
         | 
| 161 | 
            +
                    """
         | 
| 162 | 
            +
                    Safely log metrics to the database. Before logging, this method will ensure the database exists
         | 
| 163 | 
            +
                    and is set up with the correct tables. It also uses the scheduler to lock the database so
         | 
| 164 | 
            +
                    that there is no race condition when logging / syncing to the Hugging Face Dataset.
         | 
| 165 | 
            +
                    """
         | 
| 166 | 
            +
                    db_path = SQLiteStorage.init_db(project)
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    with SQLiteStorage.get_scheduler().lock:
         | 
| 169 | 
            +
                        with SQLiteStorage._get_connection(db_path) as conn:
         | 
| 170 | 
            +
                            cursor = conn.cursor()
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                            cursor.execute(
         | 
| 173 | 
            +
                                """
         | 
| 174 | 
            +
                                SELECT MAX(step) 
         | 
| 175 | 
            +
                                FROM metrics 
         | 
| 176 | 
            +
                                WHERE run_name = ?
         | 
| 177 | 
            +
                                """,
         | 
| 178 | 
            +
                                (run,),
         | 
| 179 | 
            +
                            )
         | 
| 180 | 
            +
                            last_step = cursor.fetchone()[0]
         | 
| 181 | 
            +
                            if step is None:
         | 
| 182 | 
            +
                                current_step = 0 if last_step is None else last_step + 1
         | 
| 183 | 
            +
                            else:
         | 
| 184 | 
            +
                                current_step = step
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                            current_timestamp = datetime.now().isoformat()
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                            cursor.execute(
         | 
| 189 | 
            +
                                """
         | 
| 190 | 
            +
                                INSERT INTO metrics
         | 
| 191 | 
            +
                                (timestamp, run_name, step, metrics)
         | 
| 192 | 
            +
                                VALUES (?, ?, ?, ?)
         | 
| 193 | 
            +
                                """,
         | 
| 194 | 
            +
                                (
         | 
| 195 | 
            +
                                    current_timestamp,
         | 
| 196 | 
            +
                                    run,
         | 
| 197 | 
            +
                                    current_step,
         | 
| 198 | 
            +
                                    json.dumps(metrics),
         | 
| 199 | 
            +
                                ),
         | 
| 200 | 
            +
                            )
         | 
| 201 | 
            +
                            conn.commit()
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                @staticmethod
         | 
| 204 | 
            +
                def bulk_log(
         | 
| 205 | 
            +
                    project: str,
         | 
| 206 | 
            +
                    run: str,
         | 
| 207 | 
            +
                    metrics_list: list[dict],
         | 
| 208 | 
            +
                    steps: list[int] | None = None,
         | 
| 209 | 
            +
                    timestamps: list[str] | None = None,
         | 
| 210 | 
            +
                ):
         | 
| 211 | 
            +
                    """Bulk log metrics to the database with specified steps and timestamps."""
         | 
| 212 | 
            +
                    if not metrics_list:
         | 
| 213 | 
            +
                        return
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                    if timestamps is None:
         | 
| 216 | 
            +
                        timestamps = [datetime.now().isoformat()] * len(metrics_list)
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                    db_path = SQLiteStorage.init_db(project)
         | 
| 219 | 
            +
                    with SQLiteStorage.get_scheduler().lock:
         | 
| 220 | 
            +
                        with SQLiteStorage._get_connection(db_path) as conn:
         | 
| 221 | 
            +
                            cursor = conn.cursor()
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                            if steps is None:
         | 
| 224 | 
            +
                                steps = list(range(len(metrics_list)))
         | 
| 225 | 
            +
                            elif any(s is None for s in steps):
         | 
| 226 | 
            +
                                cursor.execute(
         | 
| 227 | 
            +
                                    "SELECT MAX(step) FROM metrics WHERE run_name = ?", (run,)
         | 
| 228 | 
            +
                                )
         | 
| 229 | 
            +
                                last_step = cursor.fetchone()[0]
         | 
| 230 | 
            +
                                current_step = 0 if last_step is None else last_step + 1
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                                processed_steps = []
         | 
| 233 | 
            +
                                for step in steps:
         | 
| 234 | 
            +
                                    if step is None:
         | 
| 235 | 
            +
                                        processed_steps.append(current_step)
         | 
| 236 | 
            +
                                        current_step += 1
         | 
| 237 | 
            +
                                    else:
         | 
| 238 | 
            +
                                        processed_steps.append(step)
         | 
| 239 | 
            +
                                steps = processed_steps
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                            if len(metrics_list) != len(steps) or len(metrics_list) != len(
         | 
| 242 | 
            +
                                timestamps
         | 
| 243 | 
            +
                            ):
         | 
| 244 | 
            +
                                raise ValueError(
         | 
| 245 | 
            +
                                    "metrics_list, steps, and timestamps must have the same length"
         | 
| 246 | 
            +
                                )
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                            data = []
         | 
| 249 | 
            +
                            for i, metrics in enumerate(metrics_list):
         | 
| 250 | 
            +
                                data.append(
         | 
| 251 | 
            +
                                    (
         | 
| 252 | 
            +
                                        timestamps[i],
         | 
| 253 | 
            +
                                        run,
         | 
| 254 | 
            +
                                        steps[i],
         | 
| 255 | 
            +
                                        json.dumps(metrics),
         | 
| 256 | 
            +
                                    )
         | 
| 257 | 
            +
                                )
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                            cursor.executemany(
         | 
| 260 | 
            +
                                """
         | 
| 261 | 
            +
                                INSERT INTO metrics
         | 
| 262 | 
            +
                                (timestamp, run_name, step, metrics)
         | 
| 263 | 
            +
                                VALUES (?, ?, ?, ?)
         | 
| 264 | 
            +
                                """,
         | 
| 265 | 
            +
                                data,
         | 
| 266 | 
            +
                            )
         | 
| 267 | 
            +
                            conn.commit()
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                @staticmethod
         | 
| 270 | 
            +
                def get_metrics(project: str, run: str) -> list[dict]:
         | 
| 271 | 
            +
                    """Retrieve metrics for a specific run. The metrics also include the step count (int) and the timestamp (datetime object)."""
         | 
| 272 | 
            +
                    db_path = SQLiteStorage.get_project_db_path(project)
         | 
| 273 | 
            +
                    if not db_path.exists():
         | 
| 274 | 
            +
                        return []
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                    with SQLiteStorage._get_connection(db_path) as conn:
         | 
| 277 | 
            +
                        cursor = conn.cursor()
         | 
| 278 | 
            +
                        cursor.execute(
         | 
| 279 | 
            +
                            """
         | 
| 280 | 
            +
                            SELECT timestamp, step, metrics
         | 
| 281 | 
            +
                            FROM metrics
         | 
| 282 | 
            +
                            WHERE run_name = ?
         | 
| 283 | 
            +
                            ORDER BY timestamp
         | 
| 284 | 
            +
                            """,
         | 
| 285 | 
            +
                            (run,),
         | 
| 286 | 
            +
                        )
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                        rows = cursor.fetchall()
         | 
| 289 | 
            +
                        results = []
         | 
| 290 | 
            +
                        for row in rows:
         | 
| 291 | 
            +
                            metrics = json.loads(row["metrics"])
         | 
| 292 | 
            +
                            metrics["timestamp"] = row["timestamp"]
         | 
| 293 | 
            +
                            metrics["step"] = row["step"]
         | 
| 294 | 
            +
                            results.append(metrics)
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                        return results
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                @staticmethod
         | 
| 299 | 
            +
                def load_from_dataset():
         | 
| 300 | 
            +
                    dataset_id = os.environ.get("TRACKIO_DATASET_ID")
         | 
| 301 | 
            +
                    space_repo_name = os.environ.get("SPACE_REPO_NAME")
         | 
| 302 | 
            +
                    if dataset_id is not None and space_repo_name is not None:
         | 
| 303 | 
            +
                        hfapi = hf.HfApi()
         | 
| 304 | 
            +
                        updated = False
         | 
| 305 | 
            +
                        if not TRACKIO_DIR.exists():
         | 
| 306 | 
            +
                            TRACKIO_DIR.mkdir(parents=True, exist_ok=True)
         | 
| 307 | 
            +
                        with SQLiteStorage.get_scheduler().lock:
         | 
| 308 | 
            +
                            try:
         | 
| 309 | 
            +
                                files = hfapi.list_repo_files(dataset_id, repo_type="dataset")
         | 
| 310 | 
            +
                                for file in files:
         | 
| 311 | 
            +
                                    if not file.endswith(".parquet"):
         | 
| 312 | 
            +
                                        continue
         | 
| 313 | 
            +
                                    hf.hf_hub_download(
         | 
| 314 | 
            +
                                        dataset_id, file, repo_type="dataset", local_dir=TRACKIO_DIR
         | 
| 315 | 
            +
                                    )
         | 
| 316 | 
            +
                                    updated = True
         | 
| 317 | 
            +
                            except hf.errors.EntryNotFoundError:
         | 
| 318 | 
            +
                                pass
         | 
| 319 | 
            +
                            except hf.errors.RepositoryNotFoundError:
         | 
| 320 | 
            +
                                pass
         | 
| 321 | 
            +
                            if updated:
         | 
| 322 | 
            +
                                SQLiteStorage.import_from_parquet()
         | 
| 323 | 
            +
                    SQLiteStorage._dataset_import_attempted = True
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                @staticmethod
         | 
| 326 | 
            +
                def get_projects() -> list[str]:
         | 
| 327 | 
            +
                    """
         | 
| 328 | 
            +
                    Get list of all projects by scanning the database files in the trackio directory.
         | 
| 329 | 
            +
                    """
         | 
| 330 | 
            +
                    if not SQLiteStorage._dataset_import_attempted:
         | 
| 331 | 
            +
                        SQLiteStorage.load_from_dataset()
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                    projects: set[str] = set()
         | 
| 334 | 
            +
                    if not TRACKIO_DIR.exists():
         | 
| 335 | 
            +
                        return []
         | 
| 336 | 
            +
             | 
| 337 | 
            +
                    for db_file in TRACKIO_DIR.glob("*.db"):
         | 
| 338 | 
            +
                        project_name = db_file.stem
         | 
| 339 | 
            +
                        projects.add(project_name)
         | 
| 340 | 
            +
                    return sorted(projects)
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                @staticmethod
         | 
| 343 | 
            +
                def get_runs(project: str) -> list[str]:
         | 
| 344 | 
            +
                    """Get list of all runs for a project."""
         | 
| 345 | 
            +
                    db_path = SQLiteStorage.get_project_db_path(project)
         | 
| 346 | 
            +
                    if not db_path.exists():
         | 
| 347 | 
            +
                        return []
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                    with SQLiteStorage._get_connection(db_path) as conn:
         | 
| 350 | 
            +
                        cursor = conn.cursor()
         | 
| 351 | 
            +
                        cursor.execute(
         | 
| 352 | 
            +
                            "SELECT DISTINCT run_name FROM metrics",
         | 
| 353 | 
            +
                        )
         | 
| 354 | 
            +
                        return [row[0] for row in cursor.fetchall()]
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                @staticmethod
         | 
| 357 | 
            +
                def get_max_steps_for_runs(project: str, runs: list[str]) -> dict[str, int]:
         | 
| 358 | 
            +
                    """Efficiently get the maximum step for multiple runs in a single query."""
         | 
| 359 | 
            +
                    db_path = SQLiteStorage.get_project_db_path(project)
         | 
| 360 | 
            +
                    if not db_path.exists():
         | 
| 361 | 
            +
                        return {run: 0 for run in runs}
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                    with SQLiteStorage._get_connection(db_path) as conn:
         | 
| 364 | 
            +
                        cursor = conn.cursor()
         | 
| 365 | 
            +
                        placeholders = ",".join("?" * len(runs))
         | 
| 366 | 
            +
                        cursor.execute(
         | 
| 367 | 
            +
                            f"""
         | 
| 368 | 
            +
                            SELECT run_name, MAX(step) as max_step
         | 
| 369 | 
            +
                            FROM metrics
         | 
| 370 | 
            +
                            WHERE run_name IN ({placeholders})
         | 
| 371 | 
            +
                            GROUP BY run_name
         | 
| 372 | 
            +
                            """,
         | 
| 373 | 
            +
                            runs,
         | 
| 374 | 
            +
                        )
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                        results = {run: 0 for run in runs}  # Default to 0 for runs with no data
         | 
| 377 | 
            +
                        for row in cursor.fetchall():
         | 
| 378 | 
            +
                            results[row["run_name"]] = row["max_step"]
         | 
| 379 | 
            +
             | 
| 380 | 
            +
                        return results
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                def finish(self):
         | 
| 383 | 
            +
                    """Cleanup when run is finished."""
         | 
| 384 | 
            +
                    pass
         | 
    	
        typehints.py
    ADDED
    
    | @@ -0,0 +1,8 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Any, TypedDict
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            class LogEntry(TypedDict):
         | 
| 5 | 
            +
                project: str
         | 
| 6 | 
            +
                run: str
         | 
| 7 | 
            +
                metrics: dict[str, Any]
         | 
| 8 | 
            +
                step: int | None
         | 
    	
        ui.py
    ADDED
    
    | @@ -0,0 +1,570 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import re
         | 
| 3 | 
            +
            import shutil
         | 
| 4 | 
            +
            from typing import Any
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import gradio as gr
         | 
| 7 | 
            +
            import huggingface_hub as hf
         | 
| 8 | 
            +
            import numpy as np
         | 
| 9 | 
            +
            import pandas as pd
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            HfApi = hf.HfApi()
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            try:
         | 
| 14 | 
            +
                from trackio.sqlite_storage import SQLiteStorage
         | 
| 15 | 
            +
                from trackio.typehints import LogEntry
         | 
| 16 | 
            +
                from trackio.utils import (
         | 
| 17 | 
            +
                    RESERVED_KEYS,
         | 
| 18 | 
            +
                    TRACKIO_LOGO_DIR,
         | 
| 19 | 
            +
                    downsample,
         | 
| 20 | 
            +
                    get_color_mapping,
         | 
| 21 | 
            +
                )
         | 
| 22 | 
            +
            except:  # noqa: E722
         | 
| 23 | 
            +
                from sqlite_storage import SQLiteStorage
         | 
| 24 | 
            +
                from typehints import LogEntry
         | 
| 25 | 
            +
                from utils import RESERVED_KEYS, TRACKIO_LOGO_DIR, downsample, get_color_mapping
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            def get_projects(request: gr.Request):
         | 
| 29 | 
            +
                dataset_id = os.environ.get("TRACKIO_DATASET_ID")
         | 
| 30 | 
            +
                projects = SQLiteStorage.get_projects()
         | 
| 31 | 
            +
                if project := request.query_params.get("project"):
         | 
| 32 | 
            +
                    interactive = False
         | 
| 33 | 
            +
                else:
         | 
| 34 | 
            +
                    interactive = True
         | 
| 35 | 
            +
                    project = projects[0] if projects else None
         | 
| 36 | 
            +
                return gr.Dropdown(
         | 
| 37 | 
            +
                    label="Project",
         | 
| 38 | 
            +
                    choices=projects,
         | 
| 39 | 
            +
                    value=project,
         | 
| 40 | 
            +
                    allow_custom_value=True,
         | 
| 41 | 
            +
                    interactive=interactive,
         | 
| 42 | 
            +
                    info=f"↻ Synced to <a href='https://huggingface.co/datasets/{dataset_id}' target='_blank'>{dataset_id}</a> every 5 min"
         | 
| 43 | 
            +
                    if dataset_id
         | 
| 44 | 
            +
                    else None,
         | 
| 45 | 
            +
                )
         | 
| 46 | 
            +
             | 
| 47 | 
            +
             | 
| 48 | 
            +
            def get_runs(project) -> list[str]:
         | 
| 49 | 
            +
                if not project:
         | 
| 50 | 
            +
                    return []
         | 
| 51 | 
            +
                return SQLiteStorage.get_runs(project)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            def get_available_metrics(project: str, runs: list[str]) -> list[str]:
         | 
| 55 | 
            +
                """Get all available metrics across all runs for x-axis selection."""
         | 
| 56 | 
            +
                if not project or not runs:
         | 
| 57 | 
            +
                    return ["step", "time"]
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                all_metrics = set()
         | 
| 60 | 
            +
                for run in runs:
         | 
| 61 | 
            +
                    metrics = SQLiteStorage.get_metrics(project, run)
         | 
| 62 | 
            +
                    if metrics:
         | 
| 63 | 
            +
                        df = pd.DataFrame(metrics)
         | 
| 64 | 
            +
                        numeric_cols = df.select_dtypes(include="number").columns
         | 
| 65 | 
            +
                        numeric_cols = [c for c in numeric_cols if c not in RESERVED_KEYS]
         | 
| 66 | 
            +
                        all_metrics.update(numeric_cols)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                # Always include step and time as options
         | 
| 69 | 
            +
                all_metrics.add("step")
         | 
| 70 | 
            +
                all_metrics.add("time")
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                # Sort metrics by prefix
         | 
| 73 | 
            +
                sorted_metrics = sort_metrics_by_prefix(list(all_metrics))
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                # Put step and time at the beginning
         | 
| 76 | 
            +
                result = ["step", "time"]
         | 
| 77 | 
            +
                for metric in sorted_metrics:
         | 
| 78 | 
            +
                    if metric not in result:
         | 
| 79 | 
            +
                        result.append(metric)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                return result
         | 
| 82 | 
            +
             | 
| 83 | 
            +
             | 
| 84 | 
            +
            def load_run_data(
         | 
| 85 | 
            +
                project: str | None,
         | 
| 86 | 
            +
                run: str | None,
         | 
| 87 | 
            +
                smoothing: bool,
         | 
| 88 | 
            +
                x_axis: str,
         | 
| 89 | 
            +
                log_scale: bool = False,
         | 
| 90 | 
            +
            ):
         | 
| 91 | 
            +
                if not project or not run:
         | 
| 92 | 
            +
                    return None
         | 
| 93 | 
            +
                metrics = SQLiteStorage.get_metrics(project, run)
         | 
| 94 | 
            +
                if not metrics:
         | 
| 95 | 
            +
                    return None
         | 
| 96 | 
            +
                df = pd.DataFrame(metrics)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                if "step" not in df.columns:
         | 
| 99 | 
            +
                    df["step"] = range(len(df))
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                if x_axis == "time" and "timestamp" in df.columns:
         | 
| 102 | 
            +
                    df["timestamp"] = pd.to_datetime(df["timestamp"])
         | 
| 103 | 
            +
                    first_timestamp = df["timestamp"].min()
         | 
| 104 | 
            +
                    df["time"] = (df["timestamp"] - first_timestamp).dt.total_seconds()
         | 
| 105 | 
            +
                    x_column = "time"
         | 
| 106 | 
            +
                elif x_axis == "step":
         | 
| 107 | 
            +
                    x_column = "step"
         | 
| 108 | 
            +
                else:
         | 
| 109 | 
            +
                    x_column = x_axis
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                if log_scale and x_column in df.columns:
         | 
| 112 | 
            +
                    x_vals = df[x_column]
         | 
| 113 | 
            +
                    if (x_vals <= 0).any():
         | 
| 114 | 
            +
                        df[x_column] = np.log10(np.maximum(x_vals, 0) + 1)
         | 
| 115 | 
            +
                    else:
         | 
| 116 | 
            +
                        df[x_column] = np.log10(x_vals)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                if smoothing:
         | 
| 119 | 
            +
                    numeric_cols = df.select_dtypes(include="number").columns
         | 
| 120 | 
            +
                    numeric_cols = [c for c in numeric_cols if c not in RESERVED_KEYS]
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    df_original = df.copy()
         | 
| 123 | 
            +
                    df_original["run"] = f"{run}_original"
         | 
| 124 | 
            +
                    df_original["data_type"] = "original"
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                    df_smoothed = df.copy()
         | 
| 127 | 
            +
                    window_size = max(3, min(10, len(df) // 10))  # Adaptive window size
         | 
| 128 | 
            +
                    df_smoothed[numeric_cols] = (
         | 
| 129 | 
            +
                        df_smoothed[numeric_cols]
         | 
| 130 | 
            +
                        .rolling(window=window_size, center=True, min_periods=1)
         | 
| 131 | 
            +
                        .mean()
         | 
| 132 | 
            +
                    )
         | 
| 133 | 
            +
                    df_smoothed["run"] = f"{run}_smoothed"
         | 
| 134 | 
            +
                    df_smoothed["data_type"] = "smoothed"
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    combined_df = pd.concat([df_original, df_smoothed], ignore_index=True)
         | 
| 137 | 
            +
                    combined_df["x_axis"] = x_column
         | 
| 138 | 
            +
                    return combined_df
         | 
| 139 | 
            +
                else:
         | 
| 140 | 
            +
                    df["run"] = run
         | 
| 141 | 
            +
                    df["data_type"] = "original"
         | 
| 142 | 
            +
                    df["x_axis"] = x_column
         | 
| 143 | 
            +
                    return df
         | 
| 144 | 
            +
             | 
| 145 | 
            +
             | 
| 146 | 
            +
            def update_runs(project, filter_text, user_interacted_with_runs=False):
         | 
| 147 | 
            +
                if project is None:
         | 
| 148 | 
            +
                    runs = []
         | 
| 149 | 
            +
                    num_runs = 0
         | 
| 150 | 
            +
                else:
         | 
| 151 | 
            +
                    runs = get_runs(project)
         | 
| 152 | 
            +
                    num_runs = len(runs)
         | 
| 153 | 
            +
                    if filter_text:
         | 
| 154 | 
            +
                        runs = [r for r in runs if filter_text in r]
         | 
| 155 | 
            +
                if not user_interacted_with_runs:
         | 
| 156 | 
            +
                    return gr.CheckboxGroup(choices=runs, value=runs), gr.Textbox(
         | 
| 157 | 
            +
                        label=f"Runs ({num_runs})"
         | 
| 158 | 
            +
                    )
         | 
| 159 | 
            +
                else:
         | 
| 160 | 
            +
                    return gr.CheckboxGroup(choices=runs), gr.Textbox(label=f"Runs ({num_runs})")
         | 
| 161 | 
            +
             | 
| 162 | 
            +
             | 
| 163 | 
            +
            def filter_runs(project, filter_text):
         | 
| 164 | 
            +
                runs = get_runs(project)
         | 
| 165 | 
            +
                runs = [r for r in runs if filter_text in r]
         | 
| 166 | 
            +
                return gr.CheckboxGroup(choices=runs, value=runs)
         | 
| 167 | 
            +
             | 
| 168 | 
            +
             | 
| 169 | 
            +
            def update_x_axis_choices(project, runs):
         | 
| 170 | 
            +
                """Update x-axis dropdown choices based on available metrics."""
         | 
| 171 | 
            +
                available_metrics = get_available_metrics(project, runs)
         | 
| 172 | 
            +
                return gr.Dropdown(
         | 
| 173 | 
            +
                    label="X-axis",
         | 
| 174 | 
            +
                    choices=available_metrics,
         | 
| 175 | 
            +
                    value="step",
         | 
| 176 | 
            +
                )
         | 
| 177 | 
            +
             | 
| 178 | 
            +
             | 
| 179 | 
            +
            def toggle_timer(cb_value):
         | 
| 180 | 
            +
                if cb_value:
         | 
| 181 | 
            +
                    return gr.Timer(active=True)
         | 
| 182 | 
            +
                else:
         | 
| 183 | 
            +
                    return gr.Timer(active=False)
         | 
| 184 | 
            +
             | 
| 185 | 
            +
             | 
| 186 | 
            +
            def check_auth(hf_token: str | None) -> None:
         | 
| 187 | 
            +
                if os.getenv("SYSTEM") == "spaces":  # if we are running in Spaces
         | 
| 188 | 
            +
                    # check auth token passed in
         | 
| 189 | 
            +
                    if hf_token is None:
         | 
| 190 | 
            +
                        raise PermissionError(
         | 
| 191 | 
            +
                            "Expected a HF_TOKEN to be provided when logging to a Space"
         | 
| 192 | 
            +
                        )
         | 
| 193 | 
            +
                    who = HfApi.whoami(hf_token)
         | 
| 194 | 
            +
                    access_token = who["auth"]["accessToken"]
         | 
| 195 | 
            +
                    owner_name = os.getenv("SPACE_AUTHOR_NAME")
         | 
| 196 | 
            +
                    repo_name = os.getenv("SPACE_REPO_NAME")
         | 
| 197 | 
            +
                    # make sure the token user is either the author of the space,
         | 
| 198 | 
            +
                    # or is a member of an org that is the author.
         | 
| 199 | 
            +
                    orgs = [o["name"] for o in who["orgs"]]
         | 
| 200 | 
            +
                    if owner_name != who["name"] and owner_name not in orgs:
         | 
| 201 | 
            +
                        raise PermissionError(
         | 
| 202 | 
            +
                            "Expected the provided hf_token to be the user owner of the space, or be a member of the org owner of the space"
         | 
| 203 | 
            +
                        )
         | 
| 204 | 
            +
                    # reject fine-grained tokens without specific repo access
         | 
| 205 | 
            +
                    if access_token["role"] == "fineGrained":
         | 
| 206 | 
            +
                        matched = False
         | 
| 207 | 
            +
                        for item in access_token["fineGrained"]["scoped"]:
         | 
| 208 | 
            +
                            if (
         | 
| 209 | 
            +
                                item["entity"]["type"] == "space"
         | 
| 210 | 
            +
                                and item["entity"]["name"] == f"{owner_name}/{repo_name}"
         | 
| 211 | 
            +
                                and "repo.write" in item["permissions"]
         | 
| 212 | 
            +
                            ):
         | 
| 213 | 
            +
                                matched = True
         | 
| 214 | 
            +
                                break
         | 
| 215 | 
            +
                            if (
         | 
| 216 | 
            +
                                (
         | 
| 217 | 
            +
                                    item["entity"]["type"] == "user"
         | 
| 218 | 
            +
                                    or item["entity"]["type"] == "org"
         | 
| 219 | 
            +
                                )
         | 
| 220 | 
            +
                                and item["entity"]["name"] == owner_name
         | 
| 221 | 
            +
                                and "repo.write" in item["permissions"]
         | 
| 222 | 
            +
                            ):
         | 
| 223 | 
            +
                                matched = True
         | 
| 224 | 
            +
                                break
         | 
| 225 | 
            +
                        if not matched:
         | 
| 226 | 
            +
                            raise PermissionError(
         | 
| 227 | 
            +
                                "Expected the provided hf_token with fine grained permissions to provide write access to the space"
         | 
| 228 | 
            +
                            )
         | 
| 229 | 
            +
                    # reject read-only tokens
         | 
| 230 | 
            +
                    elif access_token["role"] != "write":
         | 
| 231 | 
            +
                        raise PermissionError(
         | 
| 232 | 
            +
                            "Expected the provided hf_token to provide write permissions"
         | 
| 233 | 
            +
                        )
         | 
| 234 | 
            +
             | 
| 235 | 
            +
             | 
| 236 | 
            +
            def upload_db_to_space(
         | 
| 237 | 
            +
                project: str, uploaded_db: gr.FileData, hf_token: str | None
         | 
| 238 | 
            +
            ) -> None:
         | 
| 239 | 
            +
                check_auth(hf_token)
         | 
| 240 | 
            +
                db_project_path = SQLiteStorage.get_project_db_path(project)
         | 
| 241 | 
            +
                if os.path.exists(db_project_path):
         | 
| 242 | 
            +
                    raise gr.Error(
         | 
| 243 | 
            +
                        f"Trackio database file already exists for project {project}, cannot overwrite."
         | 
| 244 | 
            +
                    )
         | 
| 245 | 
            +
                os.makedirs(os.path.dirname(db_project_path), exist_ok=True)
         | 
| 246 | 
            +
                shutil.copy(uploaded_db["path"], db_project_path)
         | 
| 247 | 
            +
             | 
| 248 | 
            +
             | 
| 249 | 
            +
            def log(
         | 
| 250 | 
            +
                project: str,
         | 
| 251 | 
            +
                run: str,
         | 
| 252 | 
            +
                metrics: dict[str, Any],
         | 
| 253 | 
            +
                step: int | None,
         | 
| 254 | 
            +
                hf_token: str | None,
         | 
| 255 | 
            +
            ) -> None:
         | 
| 256 | 
            +
                check_auth(hf_token)
         | 
| 257 | 
            +
                SQLiteStorage.log(project=project, run=run, metrics=metrics, step=step)
         | 
| 258 | 
            +
             | 
| 259 | 
            +
             | 
| 260 | 
            +
            def bulk_log(
         | 
| 261 | 
            +
                logs: list[LogEntry],
         | 
| 262 | 
            +
                hf_token: str | None,
         | 
| 263 | 
            +
            ) -> None:
         | 
| 264 | 
            +
                check_auth(hf_token)
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                logs_by_run = {}
         | 
| 267 | 
            +
                for log_entry in logs:
         | 
| 268 | 
            +
                    key = (log_entry["project"], log_entry["run"])
         | 
| 269 | 
            +
                    if key not in logs_by_run:
         | 
| 270 | 
            +
                        logs_by_run[key] = {"metrics": [], "steps": []}
         | 
| 271 | 
            +
                    logs_by_run[key]["metrics"].append(log_entry["metrics"])
         | 
| 272 | 
            +
                    logs_by_run[key]["steps"].append(log_entry.get("step"))
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                for (project, run), data in logs_by_run.items():
         | 
| 275 | 
            +
                    SQLiteStorage.bulk_log(
         | 
| 276 | 
            +
                        project=project,
         | 
| 277 | 
            +
                        run=run,
         | 
| 278 | 
            +
                        metrics_list=data["metrics"],
         | 
| 279 | 
            +
                        steps=data["steps"],
         | 
| 280 | 
            +
                    )
         | 
| 281 | 
            +
             | 
| 282 | 
            +
             | 
| 283 | 
            +
            def filter_metrics_by_regex(metrics: list[str], filter_pattern: str) -> list[str]:
         | 
| 284 | 
            +
                """
         | 
| 285 | 
            +
                Filter metrics using regex pattern.
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                Args:
         | 
| 288 | 
            +
                    metrics: List of metric names to filter
         | 
| 289 | 
            +
                    filter_pattern: Regex pattern to match against metric names
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                Returns:
         | 
| 292 | 
            +
                    List of metric names that match the pattern
         | 
| 293 | 
            +
                """
         | 
| 294 | 
            +
                if not filter_pattern.strip():
         | 
| 295 | 
            +
                    return metrics
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                try:
         | 
| 298 | 
            +
                    pattern = re.compile(filter_pattern, re.IGNORECASE)
         | 
| 299 | 
            +
                    return [metric for metric in metrics if pattern.search(metric)]
         | 
| 300 | 
            +
                except re.error:
         | 
| 301 | 
            +
                    return [
         | 
| 302 | 
            +
                        metric for metric in metrics if filter_pattern.lower() in metric.lower()
         | 
| 303 | 
            +
                    ]
         | 
| 304 | 
            +
             | 
| 305 | 
            +
             | 
| 306 | 
            +
            def sort_metrics_by_prefix(metrics: list[str]) -> list[str]:
         | 
| 307 | 
            +
                """
         | 
| 308 | 
            +
                Sort metrics by grouping prefixes together.
         | 
| 309 | 
            +
                Metrics without prefixes come first, then grouped by prefix.
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                Example:
         | 
| 312 | 
            +
                Input: ["train/loss", "loss", "train/acc", "val/loss"]
         | 
| 313 | 
            +
                Output: ["loss", "train/acc", "train/loss", "val/loss"]
         | 
| 314 | 
            +
                """
         | 
| 315 | 
            +
                no_prefix = []
         | 
| 316 | 
            +
                with_prefix = []
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                for metric in metrics:
         | 
| 319 | 
            +
                    if "/" in metric:
         | 
| 320 | 
            +
                        with_prefix.append(metric)
         | 
| 321 | 
            +
                    else:
         | 
| 322 | 
            +
                        no_prefix.append(metric)
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                no_prefix.sort()
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                prefix_groups = {}
         | 
| 327 | 
            +
                for metric in with_prefix:
         | 
| 328 | 
            +
                    prefix = metric.split("/")[0]
         | 
| 329 | 
            +
                    if prefix not in prefix_groups:
         | 
| 330 | 
            +
                        prefix_groups[prefix] = []
         | 
| 331 | 
            +
                    prefix_groups[prefix].append(metric)
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                sorted_with_prefix = []
         | 
| 334 | 
            +
                for prefix in sorted(prefix_groups.keys()):
         | 
| 335 | 
            +
                    sorted_with_prefix.extend(sorted(prefix_groups[prefix]))
         | 
| 336 | 
            +
             | 
| 337 | 
            +
                return no_prefix + sorted_with_prefix
         | 
| 338 | 
            +
             | 
| 339 | 
            +
             | 
| 340 | 
            +
            def configure(request: gr.Request):
         | 
| 341 | 
            +
                sidebar_param = request.query_params.get("sidebar")
         | 
| 342 | 
            +
                match sidebar_param:
         | 
| 343 | 
            +
                    case "collapsed":
         | 
| 344 | 
            +
                        sidebar = gr.Sidebar(open=False, visible=True)
         | 
| 345 | 
            +
                    case "hidden":
         | 
| 346 | 
            +
                        sidebar = gr.Sidebar(open=False, visible=False)
         | 
| 347 | 
            +
                    case _:
         | 
| 348 | 
            +
                        sidebar = gr.Sidebar(open=True, visible=True)
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                if metrics := request.query_params.get("metrics"):
         | 
| 351 | 
            +
                    return metrics.split(","), sidebar
         | 
| 352 | 
            +
                else:
         | 
| 353 | 
            +
                    return [], sidebar
         | 
| 354 | 
            +
             | 
| 355 | 
            +
             | 
| 356 | 
            +
            css = """
         | 
| 357 | 
            +
            #run-cb .wrap { gap: 2px; }
         | 
| 358 | 
            +
            #run-cb .wrap label {
         | 
| 359 | 
            +
                line-height: 1;
         | 
| 360 | 
            +
                padding: 6px;
         | 
| 361 | 
            +
            }
         | 
| 362 | 
            +
            .logo-light { display: block; } 
         | 
| 363 | 
            +
            .logo-dark { display: none; }
         | 
| 364 | 
            +
            .dark .logo-light { display: none; }
         | 
| 365 | 
            +
            .dark .logo-dark { display: block; }
         | 
| 366 | 
            +
            """
         | 
| 367 | 
            +
             | 
| 368 | 
            +
            with gr.Blocks(theme="citrus", title="Trackio Dashboard", css=css) as demo:
         | 
| 369 | 
            +
                with gr.Sidebar(open=False) as sidebar:
         | 
| 370 | 
            +
                    logo = gr.Markdown(
         | 
| 371 | 
            +
                        f"""
         | 
| 372 | 
            +
                            <img src='/gradio_api/file={TRACKIO_LOGO_DIR}/trackio_logo_type_light_transparent.png' width='80%' class='logo-light'>
         | 
| 373 | 
            +
                            <img src='/gradio_api/file={TRACKIO_LOGO_DIR}/trackio_logo_type_dark_transparent.png' width='80%' class='logo-dark'>            
         | 
| 374 | 
            +
                        """
         | 
| 375 | 
            +
                    )
         | 
| 376 | 
            +
                    project_dd = gr.Dropdown(label="Project", allow_custom_value=True)
         | 
| 377 | 
            +
                    run_tb = gr.Textbox(label="Runs", placeholder="Type to filter...")
         | 
| 378 | 
            +
                    run_cb = gr.CheckboxGroup(
         | 
| 379 | 
            +
                        label="Runs", choices=[], interactive=True, elem_id="run-cb"
         | 
| 380 | 
            +
                    )
         | 
| 381 | 
            +
                    gr.HTML("<hr>")
         | 
| 382 | 
            +
                    realtime_cb = gr.Checkbox(label="Refresh metrics realtime", value=True)
         | 
| 383 | 
            +
                    smoothing_cb = gr.Checkbox(label="Smooth metrics", value=True)
         | 
| 384 | 
            +
                    x_axis_dd = gr.Dropdown(
         | 
| 385 | 
            +
                        label="X-axis",
         | 
| 386 | 
            +
                        choices=["step", "time"],
         | 
| 387 | 
            +
                        value="step",
         | 
| 388 | 
            +
                    )
         | 
| 389 | 
            +
                    log_scale_cb = gr.Checkbox(label="Log scale X-axis", value=False)
         | 
| 390 | 
            +
                    metric_filter_tb = gr.Textbox(
         | 
| 391 | 
            +
                        label="Metric Filter (regex)",
         | 
| 392 | 
            +
                        placeholder="e.g., loss|ndcg@10|gpu",
         | 
| 393 | 
            +
                        value="",
         | 
| 394 | 
            +
                        info="Filter metrics using regex patterns. Leave empty to show all metrics.",
         | 
| 395 | 
            +
                    )
         | 
| 396 | 
            +
             | 
| 397 | 
            +
                timer = gr.Timer(value=1)
         | 
| 398 | 
            +
                metrics_subset = gr.State([])
         | 
| 399 | 
            +
                user_interacted_with_run_cb = gr.State(False)
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                gr.on([demo.load], fn=configure, outputs=[metrics_subset, sidebar])
         | 
| 402 | 
            +
                gr.on(
         | 
| 403 | 
            +
                    [demo.load],
         | 
| 404 | 
            +
                    fn=get_projects,
         | 
| 405 | 
            +
                    outputs=project_dd,
         | 
| 406 | 
            +
                    show_progress="hidden",
         | 
| 407 | 
            +
                )
         | 
| 408 | 
            +
                gr.on(
         | 
| 409 | 
            +
                    [timer.tick],
         | 
| 410 | 
            +
                    fn=update_runs,
         | 
| 411 | 
            +
                    inputs=[project_dd, run_tb, user_interacted_with_run_cb],
         | 
| 412 | 
            +
                    outputs=[run_cb, run_tb],
         | 
| 413 | 
            +
                    show_progress="hidden",
         | 
| 414 | 
            +
                )
         | 
| 415 | 
            +
                gr.on(
         | 
| 416 | 
            +
                    [demo.load, project_dd.change],
         | 
| 417 | 
            +
                    fn=update_runs,
         | 
| 418 | 
            +
                    inputs=[project_dd, run_tb],
         | 
| 419 | 
            +
                    outputs=[run_cb, run_tb],
         | 
| 420 | 
            +
                    show_progress="hidden",
         | 
| 421 | 
            +
                )
         | 
| 422 | 
            +
                gr.on(
         | 
| 423 | 
            +
                    [demo.load, project_dd.change, run_cb.change],
         | 
| 424 | 
            +
                    fn=update_x_axis_choices,
         | 
| 425 | 
            +
                    inputs=[project_dd, run_cb],
         | 
| 426 | 
            +
                    outputs=x_axis_dd,
         | 
| 427 | 
            +
                    show_progress="hidden",
         | 
| 428 | 
            +
                )
         | 
| 429 | 
            +
             | 
| 430 | 
            +
                realtime_cb.change(
         | 
| 431 | 
            +
                    fn=toggle_timer,
         | 
| 432 | 
            +
                    inputs=realtime_cb,
         | 
| 433 | 
            +
                    outputs=timer,
         | 
| 434 | 
            +
                    api_name="toggle_timer",
         | 
| 435 | 
            +
                )
         | 
| 436 | 
            +
                run_cb.input(
         | 
| 437 | 
            +
                    fn=lambda: True,
         | 
| 438 | 
            +
                    outputs=user_interacted_with_run_cb,
         | 
| 439 | 
            +
                )
         | 
| 440 | 
            +
                run_tb.input(
         | 
| 441 | 
            +
                    fn=filter_runs,
         | 
| 442 | 
            +
                    inputs=[project_dd, run_tb],
         | 
| 443 | 
            +
                    outputs=run_cb,
         | 
| 444 | 
            +
                )
         | 
| 445 | 
            +
             | 
| 446 | 
            +
                gr.api(
         | 
| 447 | 
            +
                    fn=upload_db_to_space,
         | 
| 448 | 
            +
                    api_name="upload_db_to_space",
         | 
| 449 | 
            +
                )
         | 
| 450 | 
            +
                gr.api(
         | 
| 451 | 
            +
                    fn=log,
         | 
| 452 | 
            +
                    api_name="log",
         | 
| 453 | 
            +
                )
         | 
| 454 | 
            +
                gr.api(
         | 
| 455 | 
            +
                    fn=bulk_log,
         | 
| 456 | 
            +
                    api_name="bulk_log",
         | 
| 457 | 
            +
                )
         | 
| 458 | 
            +
             | 
| 459 | 
            +
                x_lim = gr.State(None)
         | 
| 460 | 
            +
                last_steps = gr.State({})
         | 
| 461 | 
            +
             | 
| 462 | 
            +
                def update_x_lim(select_data: gr.SelectData):
         | 
| 463 | 
            +
                    return select_data.index
         | 
| 464 | 
            +
             | 
| 465 | 
            +
                def update_last_steps(project, runs):
         | 
| 466 | 
            +
                    """Update the last step from all runs to detect when new data is available."""
         | 
| 467 | 
            +
                    if not project or not runs:
         | 
| 468 | 
            +
                        return {}
         | 
| 469 | 
            +
             | 
| 470 | 
            +
                    return SQLiteStorage.get_max_steps_for_runs(project, runs)
         | 
| 471 | 
            +
             | 
| 472 | 
            +
                timer.tick(
         | 
| 473 | 
            +
                    fn=update_last_steps,
         | 
| 474 | 
            +
                    inputs=[project_dd, run_cb],
         | 
| 475 | 
            +
                    outputs=last_steps,
         | 
| 476 | 
            +
                    show_progress="hidden",
         | 
| 477 | 
            +
                )
         | 
| 478 | 
            +
             | 
| 479 | 
            +
                @gr.render(
         | 
| 480 | 
            +
                    triggers=[
         | 
| 481 | 
            +
                        demo.load,
         | 
| 482 | 
            +
                        run_cb.change,
         | 
| 483 | 
            +
                        last_steps.change,
         | 
| 484 | 
            +
                        smoothing_cb.change,
         | 
| 485 | 
            +
                        x_lim.change,
         | 
| 486 | 
            +
                        x_axis_dd.change,
         | 
| 487 | 
            +
                        log_scale_cb.change,
         | 
| 488 | 
            +
                        metric_filter_tb.change,
         | 
| 489 | 
            +
                    ],
         | 
| 490 | 
            +
                    inputs=[
         | 
| 491 | 
            +
                        project_dd,
         | 
| 492 | 
            +
                        run_cb,
         | 
| 493 | 
            +
                        smoothing_cb,
         | 
| 494 | 
            +
                        metrics_subset,
         | 
| 495 | 
            +
                        x_lim,
         | 
| 496 | 
            +
                        x_axis_dd,
         | 
| 497 | 
            +
                        log_scale_cb,
         | 
| 498 | 
            +
                        metric_filter_tb,
         | 
| 499 | 
            +
                    ],
         | 
| 500 | 
            +
                    show_progress="hidden",
         | 
| 501 | 
            +
                )
         | 
| 502 | 
            +
                def update_dashboard(
         | 
| 503 | 
            +
                    project,
         | 
| 504 | 
            +
                    runs,
         | 
| 505 | 
            +
                    smoothing,
         | 
| 506 | 
            +
                    metrics_subset,
         | 
| 507 | 
            +
                    x_lim_value,
         | 
| 508 | 
            +
                    x_axis,
         | 
| 509 | 
            +
                    log_scale,
         | 
| 510 | 
            +
                    metric_filter,
         | 
| 511 | 
            +
                ):
         | 
| 512 | 
            +
                    dfs = []
         | 
| 513 | 
            +
                    original_runs = runs.copy()
         | 
| 514 | 
            +
             | 
| 515 | 
            +
                    for run in runs:
         | 
| 516 | 
            +
                        df = load_run_data(project, run, smoothing, x_axis, log_scale)
         | 
| 517 | 
            +
                        if df is not None:
         | 
| 518 | 
            +
                            dfs.append(df)
         | 
| 519 | 
            +
             | 
| 520 | 
            +
                    if dfs:
         | 
| 521 | 
            +
                        master_df = pd.concat(dfs, ignore_index=True)
         | 
| 522 | 
            +
                    else:
         | 
| 523 | 
            +
                        master_df = pd.DataFrame()
         | 
| 524 | 
            +
             | 
| 525 | 
            +
                    if master_df.empty:
         | 
| 526 | 
            +
                        return
         | 
| 527 | 
            +
             | 
| 528 | 
            +
                    x_column = "step"
         | 
| 529 | 
            +
                    if dfs and not dfs[0].empty and "x_axis" in dfs[0].columns:
         | 
| 530 | 
            +
                        x_column = dfs[0]["x_axis"].iloc[0]
         | 
| 531 | 
            +
             | 
| 532 | 
            +
                    numeric_cols = master_df.select_dtypes(include="number").columns
         | 
| 533 | 
            +
                    numeric_cols = [c for c in numeric_cols if c not in RESERVED_KEYS]
         | 
| 534 | 
            +
                    if metrics_subset:
         | 
| 535 | 
            +
                        numeric_cols = [c for c in numeric_cols if c in metrics_subset]
         | 
| 536 | 
            +
             | 
| 537 | 
            +
                    if metric_filter and metric_filter.strip():
         | 
| 538 | 
            +
                        numeric_cols = filter_metrics_by_regex(list(numeric_cols), metric_filter)
         | 
| 539 | 
            +
             | 
| 540 | 
            +
                    numeric_cols = sort_metrics_by_prefix(list(numeric_cols))
         | 
| 541 | 
            +
                    color_map = get_color_mapping(original_runs, smoothing)
         | 
| 542 | 
            +
             | 
| 543 | 
            +
                    with gr.Row(key="row"):
         | 
| 544 | 
            +
                        for metric_idx, metric_name in enumerate(numeric_cols):
         | 
| 545 | 
            +
                            metric_df = master_df.dropna(subset=[metric_name])
         | 
| 546 | 
            +
                            color = "run" if "run" in metric_df.columns else None
         | 
| 547 | 
            +
                            if not metric_df.empty:
         | 
| 548 | 
            +
                                plot = gr.LinePlot(
         | 
| 549 | 
            +
                                    downsample(
         | 
| 550 | 
            +
                                        metric_df, x_column, metric_name, color, x_lim_value
         | 
| 551 | 
            +
                                    ),
         | 
| 552 | 
            +
                                    x=x_column,
         | 
| 553 | 
            +
                                    y=metric_name,
         | 
| 554 | 
            +
                                    color=color,
         | 
| 555 | 
            +
                                    color_map=color_map,
         | 
| 556 | 
            +
                                    title=metric_name,
         | 
| 557 | 
            +
                                    key=f"plot-{metric_idx}",
         | 
| 558 | 
            +
                                    preserved_by_key=None,
         | 
| 559 | 
            +
                                    x_lim=x_lim_value,
         | 
| 560 | 
            +
                                    show_fullscreen_button=True,
         | 
| 561 | 
            +
                                    min_width=400,
         | 
| 562 | 
            +
                                )
         | 
| 563 | 
            +
                            plot.select(update_x_lim, outputs=x_lim, key=f"select-{metric_idx}")
         | 
| 564 | 
            +
                            plot.double_click(
         | 
| 565 | 
            +
                                lambda: None, outputs=x_lim, key=f"double-{metric_idx}"
         | 
| 566 | 
            +
                            )
         | 
| 567 | 
            +
             | 
| 568 | 
            +
             | 
| 569 | 
            +
            if __name__ == "__main__":
         | 
| 570 | 
            +
                demo.launch(allowed_paths=[TRACKIO_LOGO_DIR], show_api=False, show_error=True)
         | 
    	
        utils.py
    ADDED
    
    | @@ -0,0 +1,410 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import re
         | 
| 2 | 
            +
            import sys
         | 
| 3 | 
            +
            import time
         | 
| 4 | 
            +
            from pathlib import Path
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import huggingface_hub
         | 
| 7 | 
            +
            import numpy as np
         | 
| 8 | 
            +
            import pandas as pd
         | 
| 9 | 
            +
            from huggingface_hub.constants import HF_HOME
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            RESERVED_KEYS = ["project", "run", "timestamp", "step", "time", "metrics"]
         | 
| 12 | 
            +
            TRACKIO_DIR = Path(HF_HOME) / "trackio"
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            TRACKIO_LOGO_DIR = Path(__file__).parent / "assets"
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            def generate_readable_name(used_names: list[str]) -> str:
         | 
| 18 | 
            +
                """
         | 
| 19 | 
            +
                Generates a random, readable name like "dainty-sunset-0"
         | 
| 20 | 
            +
                """
         | 
| 21 | 
            +
                adjectives = [
         | 
| 22 | 
            +
                    "dainty",
         | 
| 23 | 
            +
                    "brave",
         | 
| 24 | 
            +
                    "calm",
         | 
| 25 | 
            +
                    "eager",
         | 
| 26 | 
            +
                    "fancy",
         | 
| 27 | 
            +
                    "gentle",
         | 
| 28 | 
            +
                    "happy",
         | 
| 29 | 
            +
                    "jolly",
         | 
| 30 | 
            +
                    "kind",
         | 
| 31 | 
            +
                    "lively",
         | 
| 32 | 
            +
                    "merry",
         | 
| 33 | 
            +
                    "nice",
         | 
| 34 | 
            +
                    "proud",
         | 
| 35 | 
            +
                    "quick",
         | 
| 36 | 
            +
                    "hugging",
         | 
| 37 | 
            +
                    "silly",
         | 
| 38 | 
            +
                    "tidy",
         | 
| 39 | 
            +
                    "witty",
         | 
| 40 | 
            +
                    "zealous",
         | 
| 41 | 
            +
                    "bright",
         | 
| 42 | 
            +
                    "shy",
         | 
| 43 | 
            +
                    "bold",
         | 
| 44 | 
            +
                    "clever",
         | 
| 45 | 
            +
                    "daring",
         | 
| 46 | 
            +
                    "elegant",
         | 
| 47 | 
            +
                    "faithful",
         | 
| 48 | 
            +
                    "graceful",
         | 
| 49 | 
            +
                    "honest",
         | 
| 50 | 
            +
                    "inventive",
         | 
| 51 | 
            +
                    "jovial",
         | 
| 52 | 
            +
                    "keen",
         | 
| 53 | 
            +
                    "lucky",
         | 
| 54 | 
            +
                    "modest",
         | 
| 55 | 
            +
                    "noble",
         | 
| 56 | 
            +
                    "optimistic",
         | 
| 57 | 
            +
                    "patient",
         | 
| 58 | 
            +
                    "quirky",
         | 
| 59 | 
            +
                    "resourceful",
         | 
| 60 | 
            +
                    "sincere",
         | 
| 61 | 
            +
                    "thoughtful",
         | 
| 62 | 
            +
                    "upbeat",
         | 
| 63 | 
            +
                    "valiant",
         | 
| 64 | 
            +
                    "warm",
         | 
| 65 | 
            +
                    "youthful",
         | 
| 66 | 
            +
                    "zesty",
         | 
| 67 | 
            +
                    "adventurous",
         | 
| 68 | 
            +
                    "breezy",
         | 
| 69 | 
            +
                    "cheerful",
         | 
| 70 | 
            +
                    "delightful",
         | 
| 71 | 
            +
                    "energetic",
         | 
| 72 | 
            +
                    "fearless",
         | 
| 73 | 
            +
                    "glad",
         | 
| 74 | 
            +
                    "hopeful",
         | 
| 75 | 
            +
                    "imaginative",
         | 
| 76 | 
            +
                    "joyful",
         | 
| 77 | 
            +
                    "kindly",
         | 
| 78 | 
            +
                    "luminous",
         | 
| 79 | 
            +
                    "mysterious",
         | 
| 80 | 
            +
                    "neat",
         | 
| 81 | 
            +
                    "outgoing",
         | 
| 82 | 
            +
                    "playful",
         | 
| 83 | 
            +
                    "radiant",
         | 
| 84 | 
            +
                    "spirited",
         | 
| 85 | 
            +
                    "tranquil",
         | 
| 86 | 
            +
                    "unique",
         | 
| 87 | 
            +
                    "vivid",
         | 
| 88 | 
            +
                    "wise",
         | 
| 89 | 
            +
                    "zany",
         | 
| 90 | 
            +
                    "artful",
         | 
| 91 | 
            +
                    "bubbly",
         | 
| 92 | 
            +
                    "charming",
         | 
| 93 | 
            +
                    "dazzling",
         | 
| 94 | 
            +
                    "earnest",
         | 
| 95 | 
            +
                    "festive",
         | 
| 96 | 
            +
                    "gentlemanly",
         | 
| 97 | 
            +
                    "hearty",
         | 
| 98 | 
            +
                    "intrepid",
         | 
| 99 | 
            +
                    "jubilant",
         | 
| 100 | 
            +
                    "knightly",
         | 
| 101 | 
            +
                    "lively",
         | 
| 102 | 
            +
                    "magnetic",
         | 
| 103 | 
            +
                    "nimble",
         | 
| 104 | 
            +
                    "orderly",
         | 
| 105 | 
            +
                    "peaceful",
         | 
| 106 | 
            +
                    "quick-witted",
         | 
| 107 | 
            +
                    "robust",
         | 
| 108 | 
            +
                    "sturdy",
         | 
| 109 | 
            +
                    "trusty",
         | 
| 110 | 
            +
                    "upstanding",
         | 
| 111 | 
            +
                    "vibrant",
         | 
| 112 | 
            +
                    "whimsical",
         | 
| 113 | 
            +
                ]
         | 
| 114 | 
            +
                nouns = [
         | 
| 115 | 
            +
                    "sunset",
         | 
| 116 | 
            +
                    "forest",
         | 
| 117 | 
            +
                    "river",
         | 
| 118 | 
            +
                    "mountain",
         | 
| 119 | 
            +
                    "breeze",
         | 
| 120 | 
            +
                    "meadow",
         | 
| 121 | 
            +
                    "ocean",
         | 
| 122 | 
            +
                    "valley",
         | 
| 123 | 
            +
                    "sky",
         | 
| 124 | 
            +
                    "field",
         | 
| 125 | 
            +
                    "cloud",
         | 
| 126 | 
            +
                    "star",
         | 
| 127 | 
            +
                    "rain",
         | 
| 128 | 
            +
                    "leaf",
         | 
| 129 | 
            +
                    "stone",
         | 
| 130 | 
            +
                    "flower",
         | 
| 131 | 
            +
                    "bird",
         | 
| 132 | 
            +
                    "tree",
         | 
| 133 | 
            +
                    "wave",
         | 
| 134 | 
            +
                    "trail",
         | 
| 135 | 
            +
                    "island",
         | 
| 136 | 
            +
                    "desert",
         | 
| 137 | 
            +
                    "hill",
         | 
| 138 | 
            +
                    "lake",
         | 
| 139 | 
            +
                    "pond",
         | 
| 140 | 
            +
                    "grove",
         | 
| 141 | 
            +
                    "canyon",
         | 
| 142 | 
            +
                    "reef",
         | 
| 143 | 
            +
                    "bay",
         | 
| 144 | 
            +
                    "peak",
         | 
| 145 | 
            +
                    "glade",
         | 
| 146 | 
            +
                    "marsh",
         | 
| 147 | 
            +
                    "cliff",
         | 
| 148 | 
            +
                    "dune",
         | 
| 149 | 
            +
                    "spring",
         | 
| 150 | 
            +
                    "brook",
         | 
| 151 | 
            +
                    "cave",
         | 
| 152 | 
            +
                    "plain",
         | 
| 153 | 
            +
                    "ridge",
         | 
| 154 | 
            +
                    "wood",
         | 
| 155 | 
            +
                    "blossom",
         | 
| 156 | 
            +
                    "petal",
         | 
| 157 | 
            +
                    "root",
         | 
| 158 | 
            +
                    "branch",
         | 
| 159 | 
            +
                    "seed",
         | 
| 160 | 
            +
                    "acorn",
         | 
| 161 | 
            +
                    "pine",
         | 
| 162 | 
            +
                    "willow",
         | 
| 163 | 
            +
                    "cedar",
         | 
| 164 | 
            +
                    "elm",
         | 
| 165 | 
            +
                    "falcon",
         | 
| 166 | 
            +
                    "eagle",
         | 
| 167 | 
            +
                    "sparrow",
         | 
| 168 | 
            +
                    "robin",
         | 
| 169 | 
            +
                    "owl",
         | 
| 170 | 
            +
                    "finch",
         | 
| 171 | 
            +
                    "heron",
         | 
| 172 | 
            +
                    "crane",
         | 
| 173 | 
            +
                    "duck",
         | 
| 174 | 
            +
                    "swan",
         | 
| 175 | 
            +
                    "fox",
         | 
| 176 | 
            +
                    "wolf",
         | 
| 177 | 
            +
                    "bear",
         | 
| 178 | 
            +
                    "deer",
         | 
| 179 | 
            +
                    "moose",
         | 
| 180 | 
            +
                    "otter",
         | 
| 181 | 
            +
                    "beaver",
         | 
| 182 | 
            +
                    "lynx",
         | 
| 183 | 
            +
                    "hare",
         | 
| 184 | 
            +
                    "badger",
         | 
| 185 | 
            +
                    "butterfly",
         | 
| 186 | 
            +
                    "bee",
         | 
| 187 | 
            +
                    "ant",
         | 
| 188 | 
            +
                    "beetle",
         | 
| 189 | 
            +
                    "dragonfly",
         | 
| 190 | 
            +
                    "firefly",
         | 
| 191 | 
            +
                    "ladybug",
         | 
| 192 | 
            +
                    "moth",
         | 
| 193 | 
            +
                    "spider",
         | 
| 194 | 
            +
                    "worm",
         | 
| 195 | 
            +
                    "coral",
         | 
| 196 | 
            +
                    "kelp",
         | 
| 197 | 
            +
                    "shell",
         | 
| 198 | 
            +
                    "pebble",
         | 
| 199 | 
            +
                    "face",
         | 
| 200 | 
            +
                    "boulder",
         | 
| 201 | 
            +
                    "cobble",
         | 
| 202 | 
            +
                    "sand",
         | 
| 203 | 
            +
                    "wavelet",
         | 
| 204 | 
            +
                    "tide",
         | 
| 205 | 
            +
                    "current",
         | 
| 206 | 
            +
                    "mist",
         | 
| 207 | 
            +
                ]
         | 
| 208 | 
            +
                number = 0
         | 
| 209 | 
            +
                name = f"{adjectives[0]}-{nouns[0]}-{number}"
         | 
| 210 | 
            +
                while name in used_names:
         | 
| 211 | 
            +
                    number += 1
         | 
| 212 | 
            +
                    adjective = adjectives[number % len(adjectives)]
         | 
| 213 | 
            +
                    noun = nouns[number % len(nouns)]
         | 
| 214 | 
            +
                    name = f"{adjective}-{noun}-{number}"
         | 
| 215 | 
            +
                return name
         | 
| 216 | 
            +
             | 
| 217 | 
            +
             | 
| 218 | 
            +
            def block_except_in_notebook():
         | 
| 219 | 
            +
                in_notebook = bool(getattr(sys, "ps1", sys.flags.interactive))
         | 
| 220 | 
            +
                if in_notebook:
         | 
| 221 | 
            +
                    return
         | 
| 222 | 
            +
                try:
         | 
| 223 | 
            +
                    while True:
         | 
| 224 | 
            +
                        time.sleep(0.1)
         | 
| 225 | 
            +
                except (KeyboardInterrupt, OSError):
         | 
| 226 | 
            +
                    print("Keyboard interruption in main thread... closing dashboard.")
         | 
| 227 | 
            +
             | 
| 228 | 
            +
             | 
| 229 | 
            +
            def simplify_column_names(columns: list[str]) -> dict[str, str]:
         | 
| 230 | 
            +
                """
         | 
| 231 | 
            +
                Simplifies column names to first 10 alphanumeric or "/" characters with unique suffixes.
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                Args:
         | 
| 234 | 
            +
                    columns: List of original column names
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                Returns:
         | 
| 237 | 
            +
                    Dictionary mapping original column names to simplified names
         | 
| 238 | 
            +
                """
         | 
| 239 | 
            +
                simplified_names = {}
         | 
| 240 | 
            +
                used_names = set()
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                for col in columns:
         | 
| 243 | 
            +
                    alphanumeric = re.sub(r"[^a-zA-Z0-9/]", "", col)
         | 
| 244 | 
            +
                    base_name = alphanumeric[:10] if alphanumeric else f"col_{len(used_names)}"
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                    final_name = base_name
         | 
| 247 | 
            +
                    suffix = 1
         | 
| 248 | 
            +
                    while final_name in used_names:
         | 
| 249 | 
            +
                        final_name = f"{base_name}_{suffix}"
         | 
| 250 | 
            +
                        suffix += 1
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                    simplified_names[col] = final_name
         | 
| 253 | 
            +
                    used_names.add(final_name)
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                return simplified_names
         | 
| 256 | 
            +
             | 
| 257 | 
            +
             | 
| 258 | 
            +
            def print_dashboard_instructions(project: str) -> None:
         | 
| 259 | 
            +
                """
         | 
| 260 | 
            +
                Prints instructions for viewing the Trackio dashboard.
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                Args:
         | 
| 263 | 
            +
                    project: The name of the project to show dashboard for.
         | 
| 264 | 
            +
                """
         | 
| 265 | 
            +
                YELLOW = "\033[93m"
         | 
| 266 | 
            +
                BOLD = "\033[1m"
         | 
| 267 | 
            +
                RESET = "\033[0m"
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                print("* View dashboard by running in your terminal:")
         | 
| 270 | 
            +
                print(f'{BOLD}{YELLOW}trackio show --project "{project}"{RESET}')
         | 
| 271 | 
            +
                print(f'* or by running in Python: trackio.show(project="{project}")')
         | 
| 272 | 
            +
             | 
| 273 | 
            +
             | 
| 274 | 
            +
            def preprocess_space_and_dataset_ids(
         | 
| 275 | 
            +
                space_id: str | None, dataset_id: str | None
         | 
| 276 | 
            +
            ) -> tuple[str | None, str | None]:
         | 
| 277 | 
            +
                if space_id is not None and "/" not in space_id:
         | 
| 278 | 
            +
                    username = huggingface_hub.whoami()["name"]
         | 
| 279 | 
            +
                    space_id = f"{username}/{space_id}"
         | 
| 280 | 
            +
                if dataset_id is not None and "/" not in dataset_id:
         | 
| 281 | 
            +
                    username = huggingface_hub.whoami()["name"]
         | 
| 282 | 
            +
                    dataset_id = f"{username}/{dataset_id}"
         | 
| 283 | 
            +
                if space_id is not None and dataset_id is None:
         | 
| 284 | 
            +
                    dataset_id = f"{space_id}_dataset"
         | 
| 285 | 
            +
                return space_id, dataset_id
         | 
| 286 | 
            +
             | 
| 287 | 
            +
             | 
| 288 | 
            +
            def fibo():
         | 
| 289 | 
            +
                """Generator for Fibonacci backoff: 1, 1, 2, 3, 5, 8, ..."""
         | 
| 290 | 
            +
                a, b = 1, 1
         | 
| 291 | 
            +
                while True:
         | 
| 292 | 
            +
                    yield a
         | 
| 293 | 
            +
                    a, b = b, a + b
         | 
| 294 | 
            +
             | 
| 295 | 
            +
             | 
| 296 | 
            +
            COLOR_PALETTE = [
         | 
| 297 | 
            +
                "#3B82F6",
         | 
| 298 | 
            +
                "#EF4444",
         | 
| 299 | 
            +
                "#10B981",
         | 
| 300 | 
            +
                "#F59E0B",
         | 
| 301 | 
            +
                "#8B5CF6",
         | 
| 302 | 
            +
                "#EC4899",
         | 
| 303 | 
            +
                "#06B6D4",
         | 
| 304 | 
            +
                "#84CC16",
         | 
| 305 | 
            +
                "#F97316",
         | 
| 306 | 
            +
                "#6366F1",
         | 
| 307 | 
            +
            ]
         | 
| 308 | 
            +
             | 
| 309 | 
            +
             | 
| 310 | 
            +
            def get_color_mapping(runs: list[str], smoothing: bool) -> dict[str, str]:
         | 
| 311 | 
            +
                """Generate color mapping for runs, with transparency for original data when smoothing is enabled."""
         | 
| 312 | 
            +
                color_map = {}
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                for i, run in enumerate(runs):
         | 
| 315 | 
            +
                    base_color = COLOR_PALETTE[i % len(COLOR_PALETTE)]
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                    if smoothing:
         | 
| 318 | 
            +
                        color_map[f"{run}_smoothed"] = base_color
         | 
| 319 | 
            +
                        color_map[f"{run}_original"] = base_color + "4D"
         | 
| 320 | 
            +
                    else:
         | 
| 321 | 
            +
                        color_map[run] = base_color
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                return color_map
         | 
| 324 | 
            +
             | 
| 325 | 
            +
             | 
| 326 | 
            +
            def downsample(
         | 
| 327 | 
            +
                df: pd.DataFrame,
         | 
| 328 | 
            +
                x: str,
         | 
| 329 | 
            +
                y: str,
         | 
| 330 | 
            +
                color: str | None,
         | 
| 331 | 
            +
                x_lim: tuple[float, float] | None = None,
         | 
| 332 | 
            +
            ) -> pd.DataFrame:
         | 
| 333 | 
            +
                if df.empty:
         | 
| 334 | 
            +
                    return df
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                columns_to_keep = [x, y]
         | 
| 337 | 
            +
                if color is not None and color in df.columns:
         | 
| 338 | 
            +
                    columns_to_keep.append(color)
         | 
| 339 | 
            +
                df = df[columns_to_keep].copy()
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                n_bins = 100
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                if color is not None and color in df.columns:
         | 
| 344 | 
            +
                    groups = df.groupby(color)
         | 
| 345 | 
            +
                else:
         | 
| 346 | 
            +
                    groups = [(None, df)]
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                downsampled_indices = []
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                for _, group_df in groups:
         | 
| 351 | 
            +
                    if group_df.empty:
         | 
| 352 | 
            +
                        continue
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                    group_df = group_df.sort_values(x)
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                    if x_lim is not None:
         | 
| 357 | 
            +
                        x_min, x_max = x_lim
         | 
| 358 | 
            +
                        before_point = group_df[group_df[x] < x_min].tail(1)
         | 
| 359 | 
            +
                        after_point = group_df[group_df[x] > x_max].head(1)
         | 
| 360 | 
            +
                        group_df = group_df[(group_df[x] >= x_min) & (group_df[x] <= x_max)]
         | 
| 361 | 
            +
                    else:
         | 
| 362 | 
            +
                        before_point = after_point = None
         | 
| 363 | 
            +
                        x_min = group_df[x].min()
         | 
| 364 | 
            +
                        x_max = group_df[x].max()
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                    if before_point is not None and not before_point.empty:
         | 
| 367 | 
            +
                        downsampled_indices.extend(before_point.index.tolist())
         | 
| 368 | 
            +
                    if after_point is not None and not after_point.empty:
         | 
| 369 | 
            +
                        downsampled_indices.extend(after_point.index.tolist())
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                    if group_df.empty:
         | 
| 372 | 
            +
                        continue
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                    if x_min == x_max:
         | 
| 375 | 
            +
                        min_y_idx = group_df[y].idxmin()
         | 
| 376 | 
            +
                        max_y_idx = group_df[y].idxmax()
         | 
| 377 | 
            +
                        if min_y_idx != max_y_idx:
         | 
| 378 | 
            +
                            downsampled_indices.extend([min_y_idx, max_y_idx])
         | 
| 379 | 
            +
                        else:
         | 
| 380 | 
            +
                            downsampled_indices.append(min_y_idx)
         | 
| 381 | 
            +
                        continue
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                    if len(group_df) < 500:
         | 
| 384 | 
            +
                        downsampled_indices.extend(group_df.index.tolist())
         | 
| 385 | 
            +
                        continue
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                    bins = np.linspace(x_min, x_max, n_bins + 1)
         | 
| 388 | 
            +
                    group_df["bin"] = pd.cut(
         | 
| 389 | 
            +
                        group_df[x], bins=bins, labels=False, include_lowest=True
         | 
| 390 | 
            +
                    )
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                    for bin_idx in group_df["bin"].dropna().unique():
         | 
| 393 | 
            +
                        bin_data = group_df[group_df["bin"] == bin_idx]
         | 
| 394 | 
            +
                        if bin_data.empty:
         | 
| 395 | 
            +
                            continue
         | 
| 396 | 
            +
             | 
| 397 | 
            +
                        min_y_idx = bin_data[y].idxmin()
         | 
| 398 | 
            +
                        max_y_idx = bin_data[y].idxmax()
         | 
| 399 | 
            +
             | 
| 400 | 
            +
                        downsampled_indices.append(min_y_idx)
         | 
| 401 | 
            +
                        if min_y_idx != max_y_idx:
         | 
| 402 | 
            +
                            downsampled_indices.append(max_y_idx)
         | 
| 403 | 
            +
             | 
| 404 | 
            +
                unique_indices = list(set(downsampled_indices))
         | 
| 405 | 
            +
             | 
| 406 | 
            +
                downsampled_df = df.loc[unique_indices].copy()
         | 
| 407 | 
            +
                downsampled_df = downsampled_df.sort_values(x).reset_index(drop=True)
         | 
| 408 | 
            +
                downsampled_df = downsampled_df.drop(columns=["bin"], errors="ignore")
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                return downsampled_df
         | 
    	
        version.txt
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            0.2.5
         | 
