|
|
"""Shared functions for the Trackio UI.""" |
|
|
|
|
|
import os |
|
|
|
|
|
import gradio as gr |
|
|
import huggingface_hub as hf |
|
|
|
|
|
try: |
|
|
import trackio.utils as utils |
|
|
from trackio.sqlite_storage import SQLiteStorage |
|
|
from trackio.ui.helpers.run_selection import RunSelection |
|
|
except ImportError: |
|
|
import utils |
|
|
from sqlite_storage import SQLiteStorage |
|
|
from ui.helpers.run_selection import RunSelection |
|
|
|
|
|
CONFIG_COLUMN_MAPPINGS = { |
|
|
"_Username": "Username", |
|
|
"_Created": "Created", |
|
|
"_Group": "Group", |
|
|
} |
|
|
CONFIG_COLUMN_MAPPINGS_REVERSE = {v: k for k, v in CONFIG_COLUMN_MAPPINGS.items()} |
|
|
|
|
|
|
|
|
HfApi = hf.HfApi() |
|
|
|
|
|
|
|
|
def get_project_info() -> str | None: |
|
|
dataset_id = os.environ.get("TRACKIO_DATASET_ID") |
|
|
space_id = utils.get_space() |
|
|
if utils.persistent_storage_enabled(): |
|
|
return "✨ Persistent Storage is enabled, logs are stored directly in this Space." |
|
|
if dataset_id: |
|
|
sync_status = utils.get_sync_status(SQLiteStorage.get_scheduler()) |
|
|
upgrade_message = f"New changes are synced every 5 min <span class='info-container'><input type='checkbox' class='info-checkbox' id='upgrade-info'><label for='upgrade-info' class='info-icon'>ⓘ</label><span class='info-expandable'> To avoid losing data between syncs, <a href='https://huggingface.co/spaces/{space_id}/settings' class='accent-link'>click here</a> to open this Space's settings and add Persistent Storage. Make sure data is synced prior to enabling.</span></span>" |
|
|
if sync_status is not None: |
|
|
info = f"↻ Backed up {sync_status} min ago to <a href='https://huggingface.co/datasets/{dataset_id}' target='_blank' class='accent-link'>{dataset_id}</a> | {upgrade_message}" |
|
|
else: |
|
|
info = f"↻ Not backed up yet to <a href='https://huggingface.co/datasets/{dataset_id}' target='_blank' class='accent-link'>{dataset_id}</a> | {upgrade_message}" |
|
|
return info |
|
|
return None |
|
|
|
|
|
|
|
|
def get_projects(request: gr.Request): |
|
|
projects = SQLiteStorage.get_projects() |
|
|
if project := request.query_params.get("project"): |
|
|
interactive = False |
|
|
else: |
|
|
interactive = True |
|
|
if selected_project := request.query_params.get("selected_project"): |
|
|
project = selected_project |
|
|
else: |
|
|
project = projects[0] if projects else None |
|
|
|
|
|
return gr.Dropdown( |
|
|
label="Project", |
|
|
choices=projects, |
|
|
value=project, |
|
|
allow_custom_value=True, |
|
|
interactive=interactive, |
|
|
info=get_project_info(), |
|
|
) |
|
|
|
|
|
|
|
|
def update_navbar_value(project_dd, request: gr.Request): |
|
|
write_token = None |
|
|
if hasattr(request, "query_params") and request.query_params: |
|
|
write_token = request.query_params.get("write_token") |
|
|
|
|
|
metrics_url = f"?selected_project={project_dd}" |
|
|
runs_url = f"runs?selected_project={project_dd}" |
|
|
|
|
|
if write_token: |
|
|
metrics_url += f"&write_token={write_token}" |
|
|
runs_url += f"&write_token={write_token}" |
|
|
|
|
|
return gr.Navbar( |
|
|
value=[ |
|
|
("Metrics", metrics_url), |
|
|
("Runs", runs_url), |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
def check_hf_token_has_write_access(hf_token: str | None) -> None: |
|
|
""" |
|
|
Checks to see if the provided hf_token is valid and has write access to the Space |
|
|
that Trackio is running in. If the hf_token is valid or if Trackio is not running |
|
|
on a Space, this function does nothing. Otherwise, it raises a PermissionError. |
|
|
""" |
|
|
if os.getenv("SYSTEM") == "spaces": |
|
|
|
|
|
if hf_token is None: |
|
|
raise PermissionError( |
|
|
"Expected a HF_TOKEN to be provided when logging to a Space" |
|
|
) |
|
|
who = HfApi.whoami(hf_token) |
|
|
owner_name = os.getenv("SPACE_AUTHOR_NAME") |
|
|
repo_name = os.getenv("SPACE_REPO_NAME") |
|
|
|
|
|
|
|
|
orgs = [o["name"] for o in who["orgs"]] |
|
|
if owner_name != who["name"] and owner_name not in orgs: |
|
|
raise PermissionError( |
|
|
"Expected the provided hf_token to be the user owner of the space, or be a member of the org owner of the space" |
|
|
) |
|
|
|
|
|
access_token = who["auth"]["accessToken"] |
|
|
if access_token["role"] == "fineGrained": |
|
|
matched = False |
|
|
for item in access_token["fineGrained"]["scoped"]: |
|
|
if ( |
|
|
item["entity"]["type"] == "space" |
|
|
and item["entity"]["name"] == f"{owner_name}/{repo_name}" |
|
|
and "repo.write" in item["permissions"] |
|
|
): |
|
|
matched = True |
|
|
break |
|
|
if ( |
|
|
( |
|
|
item["entity"]["type"] == "user" |
|
|
or item["entity"]["type"] == "org" |
|
|
) |
|
|
and item["entity"]["name"] == owner_name |
|
|
and "repo.write" in item["permissions"] |
|
|
): |
|
|
matched = True |
|
|
break |
|
|
if not matched: |
|
|
raise PermissionError( |
|
|
"Expected the provided hf_token with fine grained permissions to provide write access to the space" |
|
|
) |
|
|
|
|
|
elif access_token["role"] != "write": |
|
|
raise PermissionError( |
|
|
"Expected the provided hf_token to provide write permissions" |
|
|
) |
|
|
|
|
|
|
|
|
def check_oauth_token_has_write_access(oauth_token: str | None) -> None: |
|
|
""" |
|
|
Checks to see if the oauth token provided via Gradio's OAuth is valid and has write access |
|
|
to the Space that Trackio is running in. If the oauth token is valid or if Trackio is not running |
|
|
on a Space, this function does nothing. Otherwise, it raises a PermissionError. |
|
|
""" |
|
|
if not os.getenv("SYSTEM") == "spaces": |
|
|
return |
|
|
if oauth_token is None: |
|
|
raise PermissionError( |
|
|
"Expected an oauth to be provided when logging to a Space" |
|
|
) |
|
|
who = HfApi.whoami(oauth_token) |
|
|
user_name = who["name"] |
|
|
owner_name = os.getenv("SPACE_AUTHOR_NAME") |
|
|
if user_name == owner_name: |
|
|
return |
|
|
|
|
|
for org in who["orgs"]: |
|
|
if org["name"] == owner_name and org["roleInOrg"] == "write": |
|
|
return |
|
|
raise PermissionError( |
|
|
"Expected the oauth token to be the user owner of the space, or be a member of the org owner of the space" |
|
|
) |
|
|
|
|
|
|
|
|
def get_group_by_fields(project: str): |
|
|
configs = SQLiteStorage.get_all_run_configs(project) if project else {} |
|
|
keys = set() |
|
|
for config in configs.values(): |
|
|
keys.update(config.keys()) |
|
|
keys.discard("_Created") |
|
|
keys = [CONFIG_COLUMN_MAPPINGS.get(key, key) for key in keys] |
|
|
choices = [None] + sorted(keys) |
|
|
return gr.Dropdown( |
|
|
choices=choices, |
|
|
value=None, |
|
|
interactive=True, |
|
|
) |
|
|
|
|
|
|
|
|
def group_runs_by_config( |
|
|
project: str, config_key: str, filter_text: str | None = None |
|
|
) -> dict[str, list[str]]: |
|
|
if not project or not config_key: |
|
|
return {} |
|
|
display_key = config_key |
|
|
config_key = CONFIG_COLUMN_MAPPINGS_REVERSE.get(config_key, config_key) |
|
|
configs = SQLiteStorage.get_all_run_configs(project) |
|
|
groups: dict[str, list[str]] = {} |
|
|
for run_name, config in configs.items(): |
|
|
if filter_text and filter_text not in run_name: |
|
|
continue |
|
|
group_name = config.get(config_key, "None") |
|
|
label = f"{display_key}: {group_name}" |
|
|
groups.setdefault(label, []).append(run_name) |
|
|
for label in groups: |
|
|
groups[label].sort() |
|
|
sorted_groups = dict(sorted(groups.items(), key=lambda kv: kv[0].lower())) |
|
|
return sorted_groups |
|
|
|
|
|
|
|
|
def run_checkbox_update(selection: RunSelection, **kwargs) -> gr.CheckboxGroup: |
|
|
return gr.CheckboxGroup( |
|
|
choices=selection.choices, |
|
|
value=selection.selected, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
def handle_run_checkbox_change( |
|
|
selected_runs: list[str] | None, selection: RunSelection |
|
|
) -> RunSelection: |
|
|
selection.select(selected_runs or []) |
|
|
return selection |
|
|
|
|
|
|
|
|
def handle_group_checkbox_change( |
|
|
group_selected: list[str] | None, |
|
|
selection: RunSelection, |
|
|
group_runs: list[str] | None, |
|
|
): |
|
|
subset, _ = selection.replace_group(group_runs or [], group_selected or []) |
|
|
return ( |
|
|
selection, |
|
|
gr.CheckboxGroup(value=subset), |
|
|
run_checkbox_update(selection), |
|
|
) |
|
|
|
|
|
|
|
|
def handle_group_toggle( |
|
|
select_all: bool, |
|
|
selection: RunSelection, |
|
|
group_runs: list[str] | None, |
|
|
): |
|
|
target = list(group_runs or []) if select_all else [] |
|
|
subset, _ = selection.replace_group(group_runs or [], target) |
|
|
return ( |
|
|
selection, |
|
|
gr.CheckboxGroup(value=subset), |
|
|
run_checkbox_update(selection), |
|
|
) |
|
|
|