|
import glob |
|
import logging |
|
import os |
|
import shutil |
|
import time |
|
import zipfile |
|
from pathlib import Path |
|
from typing import Callable, List, Optional, Set |
|
|
|
import accelerate |
|
import einops |
|
import huggingface_hub |
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
import transformers |
|
import yaml |
|
from h2o_wave import Q, data, ui |
|
from sqlitedict import SqliteDict |
|
|
|
from llm_studio.app_utils.config import default_cfg |
|
from llm_studio.app_utils.hugging_face_utils import ( |
|
get_model_card, |
|
publish_model_to_hugging_face, |
|
) |
|
from llm_studio.app_utils.sections.chat import chat_tab, load_cfg_model_tokenizer |
|
from llm_studio.app_utils.sections.common import clean_dashboard |
|
from llm_studio.app_utils.utils import ( |
|
add_model_type, |
|
flatten_dict, |
|
get_cfg_list_items, |
|
get_data_dir, |
|
get_download_link, |
|
get_experiment_status, |
|
get_experiments, |
|
get_model_types, |
|
get_problem_categories, |
|
get_problem_types, |
|
get_ui_elements, |
|
get_unique_name, |
|
hf_repo_friendly_name, |
|
parse_ui_elements, |
|
remove_model_type, |
|
set_env, |
|
start_experiment, |
|
) |
|
from llm_studio.app_utils.wave_utils import busy_dialog, ui_table_from_df, wave_theme |
|
from llm_studio.python_configs.cfg_checks import check_config_for_errors |
|
from llm_studio.src.datasets.text_utils import get_tokenizer |
|
from llm_studio.src.tooltips import tooltips |
|
from llm_studio.src.utils.config_utils import ( |
|
NON_GENERATION_PROBLEM_TYPES, |
|
load_config_py, |
|
load_config_yaml, |
|
save_config_yaml, |
|
) |
|
from llm_studio.src.utils.exceptions import LLMResourceException |
|
from llm_studio.src.utils.export_utils import ( |
|
check_available_space, |
|
get_artifact_path_path, |
|
get_logs_path, |
|
get_model_path, |
|
get_predictions_path, |
|
save_logs, |
|
save_prediction_outputs, |
|
) |
|
from llm_studio.src.utils.logging_utils import write_flag |
|
from llm_studio.src.utils.modeling_utils import unwrap_model |
|
from llm_studio.src.utils.plot_utils import PLOT_ENCODINGS |
|
from llm_studio.src.utils.utils import add_file_to_zip, kill_child_processes |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
async def experiment_start(q: Q) -> None: |
|
"""Display experiment start cards.""" |
|
|
|
await clean_dashboard(q, mode="experiment_start", exclude=["experiment/start"]) |
|
q.client["nav/active"] = "experiment/start" |
|
|
|
show_update_warnings = True |
|
is_create_experiment = False |
|
|
|
if ( |
|
q.args.__wave_submission_name__ == "experiment/start" |
|
or q.args.__wave_submission_name__ == "experiment/start_experiment" |
|
or q.args.__wave_submission_name__ == "dataset/newexperiment" |
|
or q.args.__wave_submission_name__ == "dataset/newexperiment/from_current" |
|
or q.args.__wave_submission_name__ == "experiment/list/new" |
|
): |
|
q.client["experiment/start/cfg_experiment_prev"] = None |
|
q.client["experiment/start/cfg_file_prev"] = None |
|
q.client["experiment/start/prev_dataset"] = None |
|
q.client["experiment/start/cfg_sub"] = None |
|
show_update_warnings = False |
|
is_create_experiment = True |
|
|
|
|
|
df_datasets = q.client.app_db.get_datasets_df() |
|
|
|
df_datasets = df_datasets.loc[df_datasets["train_rows"].notna()] |
|
if ( |
|
not q.client["experiment/start/dataset"] |
|
or q.client["experiment/start/dataset"] not in df_datasets.id.astype(str).values |
|
): |
|
if len(df_datasets) >= 1: |
|
q.client["experiment/start/dataset"] = str(df_datasets["id"].iloc[-1]) |
|
else: |
|
q.client["experiment/start/dataset"] = "1" |
|
|
|
warning_message = "Experiment settings might be updated after changing {}" |
|
|
|
items = [ |
|
ui.separator(name="general_expander", label="General settings"), |
|
ui.dropdown( |
|
name="experiment/start/dataset", |
|
label="Dataset", |
|
required=True, |
|
value=q.client["experiment/start/dataset"], |
|
choices=[ |
|
ui.choice(str(row["id"]), str(row["name"])) |
|
for _, row in df_datasets.iterrows() |
|
], |
|
trigger=True, |
|
tooltip=tooltips["experiments_dataset"], |
|
), |
|
] |
|
|
|
if ( |
|
show_update_warnings |
|
and q.client["experiment/start/dataset_prev"] |
|
!= q.client["experiment/start/dataset"] |
|
): |
|
items += [ |
|
ui.message_bar(type="warning", text=warning_message.format("Dataset")) |
|
] |
|
show_update_warnings = False |
|
|
|
if ( |
|
q.client["experiment/start/cfg_file"] is None |
|
or q.client["experiment/start/dataset_prev"] |
|
!= q.client["experiment/start/dataset"] |
|
) and q.client["experiment/start/cfg_category"] != "experiment": |
|
dataset = q.client.app_db.get_dataset(q.client["experiment/start/dataset"]) |
|
if dataset is not None: |
|
problem_type = dataset.config_file.replace(dataset.path + "/", "").replace( |
|
".yaml", "" |
|
) |
|
else: |
|
problem_type = default_cfg.cfg_file |
|
q.client["experiment/start/cfg_file"] = problem_type |
|
q.client["experiment/start/cfg_category"] = problem_type.split("_")[0] |
|
|
|
if q.client["experiment/start/cfg_category"] == "experiment": |
|
q.client["experiment/start/cfg_file"] = "experiment" |
|
|
|
|
|
df_experiments = get_experiments(q, mode="train") |
|
|
|
|
|
choices_problem_categories = [ |
|
ui.choice(name, label) for name, label in get_problem_categories() |
|
] |
|
|
|
if len(df_experiments["id"]) > 0: |
|
choices_problem_categories += [ui.choice("experiment", "From Experiment")] |
|
|
|
|
|
if ( |
|
q.client["experiment/start/cfg_category"] |
|
not in q.client["experiment/start/cfg_file"] |
|
): |
|
if q.client["experiment/start/cfg_category"] != "experiment": |
|
q.client["experiment/start/cfg_file"] = get_problem_types( |
|
category=q.client["experiment/start/cfg_category"] |
|
)[0][0] |
|
|
|
|
|
choices_problem_types = [ |
|
ui.choice(name, label) |
|
for name, label in get_problem_types( |
|
category=q.client["experiment/start/cfg_category"] |
|
) |
|
] |
|
|
|
|
|
q.client["experiment/start/cfg_file"] = remove_model_type( |
|
q.client["experiment/start/cfg_file"] |
|
) |
|
|
|
if len(df_experiments["id"]) > 0: |
|
if q.client["experiment/start/cfg_experiment"] is None: |
|
q.client["experiment/start/cfg_experiment"] = str( |
|
df_experiments["id"].iloc[0] |
|
) |
|
|
|
if ( |
|
q.client["experiment/start/cfg_experiment_pretrained"] is None |
|
or is_create_experiment |
|
): |
|
q.client["experiment/start/cfg_experiment_pretrained"] = False |
|
|
|
if q.client["experiment/start/cfg_category"] != "experiment": |
|
items += [ |
|
ui.dropdown( |
|
name="experiment/start/cfg_file", |
|
label="Problem Type", |
|
required=True, |
|
choices=choices_problem_types, |
|
value=q.client["experiment/start/cfg_file"], |
|
trigger=True, |
|
tooltip=tooltips["experiments_problem_type"], |
|
) |
|
] |
|
|
|
model_types = get_model_types(q.client["experiment/start/cfg_file"]) |
|
if len(model_types) > 0: |
|
choices = [ui.choice(name, label) for name, label in model_types] |
|
if q.client["experiment/start/cfg_sub"] in [None, ""]: |
|
q.client["experiment/start/cfg_sub"] = model_types[0][0] |
|
items += [ |
|
ui.dropdown( |
|
name="experiment/start/cfg_sub", |
|
label="Model Type", |
|
required=True, |
|
choices=choices, |
|
value=q.client["experiment/start/cfg_sub"], |
|
trigger=True, |
|
) |
|
] |
|
else: |
|
q.client["experiment/start/cfg_sub"] = "" |
|
|
|
|
|
q.client["experiment/start/cfg_file"] = add_model_type( |
|
q.client["experiment/start/cfg_file"], q.client["experiment/start/cfg_sub"] |
|
) |
|
|
|
if ( |
|
show_update_warnings |
|
and q.client["experiment/start/cfg_file_prev"] |
|
!= q.client["experiment/start/cfg_file"] |
|
and q.client["experiment/start/cfg_category"] != "experiment" |
|
): |
|
items += [ |
|
ui.message_bar(type="warning", text=warning_message.format("Problem Type")) |
|
] |
|
show_update_warnings = False |
|
|
|
if q.client["experiment/start/cfg_category"] == "experiment": |
|
items += [ |
|
ui.dropdown( |
|
name="experiment/start/cfg_experiment", |
|
label="Experiment", |
|
required=True, |
|
choices=[ |
|
ui.choice(str(row.id), row["name"]) |
|
for _, row in df_experiments.iterrows() |
|
], |
|
value=q.client["experiment/start/cfg_experiment"], |
|
trigger=True, |
|
) |
|
] |
|
|
|
if ( |
|
show_update_warnings |
|
and q.client["experiment/start/cfg_experiment_prev"] |
|
!= q.client["experiment/start/cfg_experiment"] |
|
): |
|
items += [ |
|
ui.message_bar( |
|
type="warning", text=warning_message.format("previous Experiment") |
|
) |
|
] |
|
|
|
|
|
if ( |
|
df_experiments.loc[ |
|
df_experiments.id == int(q.client["experiment/start/cfg_experiment"]), |
|
"status", |
|
].values[0] |
|
== "finished" |
|
): |
|
items += [ |
|
ui.toggle( |
|
name="experiment/start/cfg_experiment_pretrained", |
|
label="Use previous experiment weights", |
|
value=q.client["experiment/start/cfg_experiment_pretrained"], |
|
trigger=True, |
|
) |
|
] |
|
|
|
|
|
if q.client["experiment/start/cfg_category"] != "experiment": |
|
items += [ |
|
ui.toggle( |
|
name="experiment/start/from_yaml", |
|
label="Import config from YAML", |
|
value=False, |
|
trigger=True, |
|
tooltip=tooltips["experiments_import_config_from_yaml"], |
|
) |
|
] |
|
|
|
if q.args["experiment/start/from_yaml"]: |
|
items += [ |
|
ui.file_upload( |
|
name="experiment/upload_yaml", |
|
label="Upload!", |
|
multiple=False, |
|
file_extensions=["yaml"], |
|
) |
|
] |
|
|
|
if q.args["experiment/upload_yaml"] is not None: |
|
|
|
q.client["experiment/start/cfg_file_prev"] = None |
|
await config_import_uploaded_file(q) |
|
|
|
logger.info( |
|
f"PREV {q.client['experiment/start/cfg_file_prev']} " |
|
f"{q.client['experiment/start/cfg_file']} " |
|
f"{q.client['experiment/start/dataset_prev']} " |
|
f"{q.client['experiment/start/dataset']} " |
|
f"{q.client['experiment/start/cfg_experiment_prev']} " |
|
f"{q.client['experiment/start/cfg_experiment']} " |
|
) |
|
|
|
|
|
q.client["experiment/start/cfg_mode/mode"] = "train" |
|
|
|
if q.client["experiment/start/cfg_category"] == "experiment": |
|
logger.info("Starting from experiment") |
|
|
|
|
|
q.client["experiment/start/cfg_file_prev"] = None |
|
|
|
q.client["experiment/start/experiment"] = q.client.app_db.get_experiment( |
|
q.client["experiment/start/cfg_experiment"] |
|
) |
|
|
|
parent_path = os.path.dirname(q.client["experiment/start/experiment"].path) |
|
parent_exp_name = parent_path.split("/")[-1] |
|
parent_experiment = f"{parent_exp_name}" |
|
|
|
old_config = load_config_yaml(f"{parent_path}/cfg.yaml") |
|
old_config._parent_experiment = parent_experiment |
|
|
|
q.client["experiment/start/cfg"] = old_config |
|
|
|
|
|
if q.client["experiment/start/cfg_experiment_pretrained"]: |
|
prev_weights = os.path.join( |
|
q.client["experiment/start/experiment"].path, |
|
"checkpoint.pth", |
|
) |
|
if os.path.exists(prev_weights): |
|
q.client["experiment/start/cfg"].architecture.pretrained_weights = ( |
|
prev_weights |
|
) |
|
q.client["experiment/start/cfg"].architecture._visibility[ |
|
"pretrained_weights" |
|
] = -1 |
|
|
|
experiments_df = q.client.app_db.get_experiments_df() |
|
output_dir = os.path.abspath( |
|
os.path.join(q.client["experiment/start/cfg"].output_directory, "..") |
|
) |
|
q.client["experiment/start/cfg"].experiment_name = get_unique_name( |
|
q.client["experiment/start/cfg"].experiment_name, |
|
experiments_df["name"].values, |
|
lambda x: os.path.exists(os.path.join(output_dir, x)), |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
q.client["experiment/start/cfg_experiment_prev"] |
|
!= q.client["experiment/start/cfg_experiment"] |
|
): |
|
q.client["experiment/start/cfg_mode/from_dataset"] = False |
|
q.client["experiment/start/cfg_mode/from_cfg"] = True |
|
q.client["experiment/start/cfg_mode/from_dataset_args"] = False |
|
|
|
q.client["experiment/start/dataset"] = str( |
|
q.client["experiment/start/experiment"].dataset |
|
) |
|
|
|
items[1].dropdown.value = q.client["experiment/start/dataset"] |
|
|
|
elif ( |
|
q.client["experiment/start/dataset_prev"] |
|
!= q.client["experiment/start/dataset"] |
|
): |
|
q.client["experiment/start/cfg_mode/from_dataset"] = True |
|
q.client["experiment/start/cfg_mode/from_cfg"] = True |
|
q.client["experiment/start/cfg_mode/from_dataset_args"] = False |
|
|
|
else: |
|
q.client["experiment/start/cfg_mode/from_dataset"] = False |
|
q.client["experiment/start/cfg_mode/from_cfg"] = False |
|
q.client["experiment/start/cfg_mode/from_dataset_args"] = True |
|
|
|
q.client["experiment/start/cfg_mode/from_default"] = False |
|
q.client["experiment/start/cfg_experiment_prev"] = q.client[ |
|
"experiment/start/cfg_experiment" |
|
] |
|
|
|
else: |
|
logger.info("Starting from CFG") |
|
|
|
|
|
q.client["experiment/start/cfg_experiment_prev"] = None |
|
|
|
|
|
if ( |
|
q.client["experiment/start/cfg_file_prev"] |
|
!= q.client["experiment/start/cfg_file"] |
|
) or ( |
|
q.client["experiment/start/dataset_prev"] |
|
!= q.client["experiment/start/dataset"] |
|
): |
|
q.client["experiment/start/cfg_mode/from_dataset"] = True |
|
q.client["experiment/start/cfg_mode/from_cfg"] = True |
|
q.client["experiment/start/cfg_mode/from_default"] = True |
|
q.client["experiment/start/cfg_mode/from_dataset_args"] = False |
|
|
|
else: |
|
q.client["experiment/start/cfg_mode/from_dataset"] = False |
|
q.client["experiment/start/cfg_mode/from_cfg"] = False |
|
q.client["experiment/start/cfg_mode/from_default"] = False |
|
q.client["experiment/start/cfg_mode/from_dataset_args"] = True |
|
|
|
q.client["experiment/start/cfg_file_prev"] = q.client[ |
|
"experiment/start/cfg_file" |
|
] |
|
|
|
config_path = ( |
|
f"llm_studio/python_configs/{q.client['experiment/start/cfg_file']}" |
|
) |
|
|
|
q.client["experiment/start/cfg"] = load_config_py( |
|
config_path=config_path, config_name="ConfigProblemBase" |
|
) |
|
|
|
q.client["experiment/start/dataset_prev"] = q.client["experiment/start/dataset"] |
|
logger.info(f"From dataset {q.client['experiment/start/cfg_mode/from_dataset']}") |
|
logger.info(f"From cfg {q.client['experiment/start/cfg_mode/from_cfg']}") |
|
logger.info(f"From default {q.client['experiment/start/cfg_mode/from_default']}") |
|
logger.info(f"Config file: {q.client['experiment/start/cfg_file']}") |
|
|
|
option_items = get_ui_elements(cfg=q.client["experiment/start/cfg"], q=q) |
|
items.extend(option_items) |
|
|
|
if q.client["experiment/start/cfg_mode/from_cfg"]: |
|
q.page["experiment/start"] = ui.form_card(box="content", items=items) |
|
else: |
|
q.page["experiment/start"].items = items |
|
|
|
q.client.delete_cards.add("experiment/start") |
|
|
|
q.page["experiment/start/footer"] = ui.form_card( |
|
box="footer", |
|
items=[ |
|
ui.inline( |
|
items=[ |
|
ui.button( |
|
name="experiment/start/run", |
|
label="Run experiment", |
|
primary=True, |
|
) |
|
], |
|
justify="start", |
|
) |
|
], |
|
) |
|
q.client.delete_cards.add("experiment/start/footer") |
|
|
|
|
|
async def experiment_run(q: Q, pre: str = "experiment/start"): |
|
"""Start an experiment. |
|
|
|
Args: |
|
q: Q |
|
pre: prefix for client key |
|
""" |
|
|
|
from llm_studio.app_utils.sections.project import list_current_experiments |
|
|
|
logger.info("Starting experiment") |
|
logger.info(f"{pre}/cfg_file") |
|
logger.info(f"CFG: {q.client[f'{pre}/cfg_file']}") |
|
|
|
if q.client[f"{pre}/cfg_category"] == "experiment": |
|
q.client[f"{pre}/cfg_file"] = q.client[f"{pre}/experiment"].config_file |
|
|
|
cfg = q.client[f"{pre}/cfg"] |
|
cfg = parse_ui_elements(cfg=cfg, q=q, pre=f"{pre}/cfg/") |
|
cfg.experiment_name = cfg.experiment_name.replace("/", "-") |
|
|
|
errors = check_config_for_errors(cfg) |
|
if errors["title"] and not q.args["experiment/start/error/proceed"]: |
|
title = ( |
|
errors["title"][0] |
|
if len(errors["title"]) == 1 |
|
else "The following configuration mismatches were found:" |
|
) |
|
error_text = [ui.text(message) for message in errors["message"]] |
|
q.page["meta"].dialog = ui.dialog( |
|
title=title, |
|
name="experiment/start/error/dialog", |
|
items=error_text |
|
+ [ |
|
ui.buttons( |
|
[ |
|
ui.button( |
|
name="experiment/start/error/ok", label="Ok", primary=True |
|
), |
|
ui.button( |
|
name="experiment/start/error/proceed", |
|
label="I want to proceed anyhow", |
|
primary=False, |
|
), |
|
] |
|
) |
|
], |
|
closable=True, |
|
) |
|
q.client["keep_meta"] = True |
|
else: |
|
start_experiment(cfg=cfg, q=q, pre=pre) |
|
await list_current_experiments(q) |
|
|
|
|
|
def get_experiment_table( |
|
q, df_viz, predictions, height="calc(100vh - 245px)", actions=None |
|
): |
|
col_remove = [ |
|
"id", |
|
"path", |
|
"mode", |
|
"seed", |
|
"process_id", |
|
"gpu_list", |
|
"loss", |
|
"eta", |
|
"epoch", |
|
"config_file", |
|
] |
|
if predictions: |
|
col_remove += ["epoch", "val metric"] |
|
|
|
for col in col_remove: |
|
if col in df_viz: |
|
del df_viz[col] |
|
|
|
|
|
|
|
|
|
|
|
if actions == "experiment" and q.client["experiment/list/mode"] == "train": |
|
actions_dict = { |
|
"experiment/list/new": "New experiment", |
|
"experiment/list/rename": "Rename experiment", |
|
"experiment/list/stop/table": "Stop experiment", |
|
"experiment/list/delete/table/dialog": "Delete experiment", |
|
} |
|
else: |
|
actions_dict = {} |
|
|
|
min_widths = { |
|
"name": "350", |
|
"dataset": "150", |
|
|
|
"metric": "75", |
|
"val metric": "102", |
|
"progress": "85", |
|
"status": "90", |
|
"info": "115", |
|
"actions": "5" if predictions else "5", |
|
} |
|
|
|
if predictions: |
|
for k, v in min_widths.items(): |
|
min_widths[k] = str(int(np.ceil(int(v) * 1.05))) |
|
|
|
return ui_table_from_df( |
|
q=q, |
|
df=df_viz, |
|
name="experiment/list/table", |
|
sortables=["val metric"], |
|
filterables=["name", "dataset", "problem type", "metric", "status"], |
|
searchables=["name", "dataset"], |
|
numerics=["val metric"], |
|
tags=["status"], |
|
progresses=["progress"], |
|
min_widths=min_widths, |
|
link_col="name", |
|
height=height, |
|
actions=actions_dict, |
|
) |
|
|
|
|
|
async def experiment_list( |
|
q: Q, |
|
reset: bool = True, |
|
allowed_statuses: Optional[List[str]] = None, |
|
actions: bool = True, |
|
) -> None: |
|
"""List all experiments.""" |
|
|
|
if q.client["experiment/list/mode"] is None: |
|
q.client["experiment/list/mode"] = "train" |
|
|
|
if q.client["experiment/list/mode"] == "train": |
|
q.client["nav/active"] = "experiment/list" |
|
else: |
|
q.client["nav/active"] = "experiment/list_predictions" |
|
|
|
if reset: |
|
await clean_dashboard(q, mode="full") |
|
|
|
q.client["experiment/list/df_experiments"] = get_experiments( |
|
q, |
|
mode=q.client["experiment/list/mode"], |
|
status=allowed_statuses, |
|
) |
|
|
|
df_viz = q.client["experiment/list/df_experiments"].copy() |
|
|
|
table = get_experiment_table( |
|
q, |
|
df_viz, |
|
q.client["experiment/list/mode"] == "predict", |
|
actions="experiment" if actions else None, |
|
) |
|
|
|
message_bar = get_experiment_list_message_bar(q) |
|
|
|
items = [table, message_bar] |
|
|
|
q.page["experiment/list"] = ui.form_card(box="content", items=items) |
|
q.client.delete_cards.add("experiment/list") |
|
|
|
buttons = [ |
|
ui.button(name="experiment/list/refresh", label="Refresh", primary=True), |
|
ui.button( |
|
name="experiment/list/compare", |
|
label="Compare experiments", |
|
primary=False, |
|
), |
|
ui.button(name="experiment/list/stop", label="Stop experiments", primary=False), |
|
ui.button( |
|
name="experiment/list/delete", label="Delete experiments", primary=False |
|
), |
|
] |
|
|
|
q.page["dataset/display/footer"] = ui.form_card( |
|
box="footer", items=[ui.inline(items=buttons, justify="start")] |
|
) |
|
q.client.delete_cards.add("dataset/display/footer") |
|
|
|
|
|
def get_table_and_message_item_indices(q): |
|
table_item_idx, message_item_idx = 0, 1 |
|
return table_item_idx, message_item_idx |
|
|
|
|
|
async def experiment_compare(q: Q, selected_rows: list): |
|
if q.client["experiment/compare/tab"] is None: |
|
q.client["experiment/compare/tab"] = "experiment/compare/charts" |
|
if q.args["experiment/compare/charts"] is not None: |
|
q.client["experiment/compare/tab"] = "experiment/compare/charts" |
|
if q.args["experiment/compare/config"] is not None: |
|
q.client["experiment/compare/tab"] = "experiment/compare/config" |
|
|
|
experiment_ids = [ |
|
q.client["experiment/list/df_experiments"]["id"].iloc[int(idx)] |
|
for idx in selected_rows |
|
] |
|
|
|
await clean_dashboard(q, mode=q.client["experiment/compare/tab"]) |
|
tabs = [ |
|
ui.tab(name="experiment/compare/charts", label="Charts"), |
|
ui.tab(name="experiment/compare/config", label="Config"), |
|
] |
|
q.page["experiment/compare/tab"] = ui.tab_card( |
|
box="nav2", link=True, items=tabs, value=q.client["experiment/compare/tab"] |
|
) |
|
q.client.delete_cards.add("experiment/compare/tab") |
|
|
|
if q.client["experiment/compare/tab"] == "experiment/compare/charts": |
|
charts = [] |
|
experiment_names = [] |
|
|
|
for experiment_id in experiment_ids: |
|
experiment = q.client.app_db.get_experiment(experiment_id) |
|
experiment_path = experiment.path |
|
charts.append(load_charts(experiment_path)) |
|
current_name = f" {experiment.name}" |
|
experiment_names.append(current_name) |
|
|
|
await charts_tab(q, charts, experiment_names) |
|
|
|
elif q.client["experiment/compare/tab"] == "experiment/compare/config": |
|
if q.client["experiment/compare/diff_toggle"] is None: |
|
q.client["experiment/compare/diff_toggle"] = False |
|
|
|
settings = pd.DataFrame() |
|
for experiment_id in experiment_ids: |
|
experiment = q.client.app_db.get_experiment(experiment_id) |
|
experiment_path = experiment.path |
|
experiment_cfg = load_config_yaml(os.path.join(experiment_path, "cfg.yaml")) |
|
items = get_cfg_list_items(experiment_cfg) |
|
act_df = pd.Series({item.label: item.value for item in items}) |
|
settings[experiment.name] = act_df |
|
|
|
settings.index.name = "setting" |
|
|
|
if q.client["experiment/compare/diff_toggle"]: |
|
val_counts = settings.T.nunique() |
|
drop_idx = val_counts[val_counts == 1].index.values |
|
settings = settings.drop(drop_idx) |
|
|
|
items = [ |
|
ui.toggle( |
|
name="experiment/compare/diff_toggle", |
|
label="Show differences only", |
|
value=q.client["experiment/compare/diff_toggle"], |
|
trigger=True, |
|
), |
|
ui_table_from_df( |
|
q=q, |
|
df=settings.reset_index(), |
|
name="experiment/compare/summary/table", |
|
link_col="setting", |
|
height="calc(100vh - 315px)", |
|
), |
|
] |
|
|
|
q.page["experiment/compare/config"] = ui.form_card(box="first", items=items) |
|
q.client.delete_cards.add("experiment/compare/config") |
|
|
|
buttons = [ |
|
ui.button(name="experiment/compare", label="Refresh", primary=True), |
|
ui.button(name="experiment/list/current", label="Back", primary=False), |
|
] |
|
q.page["experiment/compare/footer"] = ui.form_card( |
|
box="footer", items=[ui.inline(items=buttons, justify="start")] |
|
) |
|
q.client.delete_cards.add("experiment/compare/footer") |
|
|
|
|
|
async def experiment_rename_form(q: Q, error: str = "") -> None: |
|
experiment = q.client.app_db.get_experiment(q.client["experiment/rename/id"]) |
|
|
|
experiment_name = experiment.name |
|
items = [ |
|
ui.textbox( |
|
name="experiment/rename/name", |
|
label=f"New name for {experiment_name}", |
|
value=experiment_name, |
|
required=True, |
|
) |
|
] |
|
|
|
if error: |
|
items.append(ui.message_bar(type="error", text=error)) |
|
|
|
q.page["experiment/list"].items = items |
|
|
|
buttons = [ |
|
ui.button(name="experiment/rename/action", label="Rename", primary=True), |
|
ui.button(name="experiment/list/current", label="Abort", primary=False), |
|
] |
|
q.page["dataset/display/footer"] = ui.form_card( |
|
box="footer", items=[ui.inline(items=buttons, justify="start")] |
|
) |
|
q.client.delete_cards.add("dataset/display/footer") |
|
|
|
|
|
async def experiment_rename_ui_workflow(q: Q): |
|
selected_row = q.args["experiment/list/rename"] |
|
rename_id = q.client["experiment/list/df_experiments"]["id"].iloc[int(selected_row)] |
|
q.client["experiment/rename/id"] = rename_id |
|
await experiment_rename_form(q) |
|
|
|
|
|
async def experiment_rename_action(q, experiment, new_name): |
|
"""Rename experiment with `current_id` id in DB to `new_name`""" |
|
|
|
old_name = experiment.name |
|
old_path = experiment.path |
|
new_path = old_path.replace(old_name, new_name) |
|
|
|
if old_path != new_path: |
|
old_exp_path = f"{old_path}" |
|
exp_path = f"{new_path}" |
|
logger.info(f"Renaming {old_exp_path} to {exp_path}") |
|
shutil.move(os.path.abspath(old_exp_path), os.path.abspath(exp_path)) |
|
|
|
|
|
with SqliteDict(os.path.join(new_path, "charts.db")) as charts: |
|
for k1 in PLOT_ENCODINGS: |
|
if k1 == "df": |
|
|
|
df = charts[k1].copy() |
|
for k2, v2 in df.items(): |
|
logger.info( |
|
f"Renaming charts {v2} to {v2.replace(old_name, new_name)}" |
|
) |
|
df[k2] = v2.replace(old_name, new_name) |
|
charts[k1] = df |
|
charts.commit() |
|
|
|
for config_file in ["cfg.yaml"]: |
|
config_path = os.path.join(exp_path, config_file) |
|
if os.path.exists(config_path): |
|
experiment_cfg = load_config_yaml(config_path) |
|
experiment_cfg.experiment_name = new_name |
|
experiment_cfg.output_directory = new_path |
|
save_config_yaml(config_path, experiment_cfg) |
|
|
|
rename_files = ["preds"] |
|
for file in rename_files: |
|
old_file = get_artifact_path_path(old_name, exp_path, file) |
|
new_file = get_artifact_path_path(new_name, exp_path, file) |
|
if os.path.exists(old_file): |
|
logger.info(f"Renaming {old_file} to {new_file}") |
|
shutil.move(os.path.abspath(old_file), os.path.abspath(new_file)) |
|
|
|
delete_files = ["logs"] |
|
for file in delete_files: |
|
file = get_artifact_path_path(old_name, exp_path, file) |
|
if os.path.exists(file): |
|
logger.info(f"Deleting {file}") |
|
os.remove(file) |
|
|
|
q.client.app_db.rename_experiment(experiment.id, new_name, new_path) |
|
|
|
|
|
async def experiment_delete(q: Q, experiment_ids: List[int]) -> None: |
|
"""Delete selected experiments. |
|
|
|
Args: |
|
q: Q |
|
experiment_ids: list of experiment ids to delete |
|
""" |
|
|
|
for experiment_id in experiment_ids: |
|
experiment = q.client.app_db.get_experiment(experiment_id) |
|
q.client.app_db.delete_experiment(experiment.id) |
|
shutil.rmtree(f"{experiment.path}") |
|
|
|
|
|
async def experiment_stop(q: Q, experiment_ids: List[int]) -> None: |
|
"""Stop selected experiments. |
|
|
|
Args: |
|
q: Q |
|
experiment_ids: list of experiment ids to stop |
|
""" |
|
|
|
for experiment_id in experiment_ids: |
|
experiment = q.client.app_db.get_experiment(experiment_id) |
|
|
|
try: |
|
ret = kill_child_processes(int(experiment.process_id)) |
|
if ret: |
|
flag_path = os.path.join(experiment.path, "flags.json") |
|
write_flag(flag_path, "status", "stopped") |
|
except Exception as e: |
|
logger.error(f"Error while stopping the experiment: {e}") |
|
pass |
|
|
|
|
|
def load_charts(experiment_path): |
|
try: |
|
with SqliteDict(os.path.join(experiment_path, "charts.db")) as charts: |
|
charts = dict(charts) |
|
except Exception: |
|
charts = {} |
|
logger.warning("Too early, wait for the charts to appear") |
|
|
|
return charts |
|
|
|
|
|
async def experiment_display(q: Q) -> None: |
|
"""Display a selected experiment.""" |
|
|
|
experiment_id = q.client["experiment/list/df_experiments"]["id"].iloc[ |
|
q.client["experiment/display/id"] |
|
] |
|
q.client["experiment/display/experiment_id"] = experiment_id |
|
experiment = q.client.app_db.get_experiment(experiment_id) |
|
q.client["experiment/display/experiment"] = experiment |
|
|
|
q.client["experiment/display/experiment_path"] = experiment.path |
|
|
|
status, _ = get_experiment_status(experiment.path) |
|
|
|
charts = load_charts(q.client["experiment/display/experiment_path"]) |
|
q.client["experiment/display/charts"] = charts |
|
|
|
if experiment.mode == "train": |
|
if q.client["experiment/display/tab"] is None: |
|
q.client["experiment/display/tab"] = "experiment/display/charts" |
|
else: |
|
if q.client["experiment/display/tab"] is None: |
|
q.client["experiment/display/tab"] = "experiment/display/summary" |
|
|
|
if q.args["experiment/display/charts"] is not None: |
|
q.client["experiment/display/tab"] = "experiment/display/charts" |
|
if q.args["experiment/display/summary"] is not None: |
|
q.client["experiment/display/tab"] = "experiment/display/summary" |
|
if q.args["experiment/display/train_data_insights"] is not None: |
|
q.client["experiment/display/tab"] = "experiment/display/train_data_insights" |
|
if q.args["experiment/display/validation_prediction_insights"] is not None: |
|
q.client["experiment/display/tab"] = ( |
|
"experiment/display/validation_prediction_insights" |
|
) |
|
if q.args["experiment/display/config"] is not None: |
|
q.client["experiment/display/tab"] = "experiment/display/config" |
|
if q.args["experiment/display/deployment"] is not None: |
|
q.client["experiment/display/tab"] = "experiment/display/deployment" |
|
if q.args["experiment/display/logs"] is not None: |
|
q.client["experiment/display/tab"] = "experiment/display/logs" |
|
if q.args["experiment/display/chat"] is not None: |
|
q.client["experiment/display/tab"] = "experiment/display/chat" |
|
|
|
await clean_dashboard(q, mode=q.client["experiment/display/tab"]) |
|
|
|
tabs = [ |
|
ui.tab(name="experiment/display/charts", label="Charts"), |
|
ui.tab(name="experiment/display/summary", label="Summary"), |
|
] |
|
|
|
has_train_data_insights = any( |
|
[ |
|
charts.get(plot_encoding, dict()).get("train_data") is not None |
|
for plot_encoding in PLOT_ENCODINGS |
|
] |
|
) |
|
if has_train_data_insights: |
|
tabs += [ |
|
ui.tab( |
|
name="experiment/display/train_data_insights", |
|
label="Train Data Insights", |
|
) |
|
] |
|
has_validation_prediction_insights = any( |
|
[ |
|
charts.get(plot_encoding, dict()).get("validation_predictions") is not None |
|
for plot_encoding in PLOT_ENCODINGS |
|
] |
|
) |
|
if has_validation_prediction_insights: |
|
tabs += [ |
|
ui.tab( |
|
name="experiment/display/validation_prediction_insights", |
|
label="Validation Prediction Insights", |
|
) |
|
] |
|
|
|
tabs += [ |
|
ui.tab(name="experiment/display/logs", label="Logs"), |
|
ui.tab(name="experiment/display/config", label="Config"), |
|
] |
|
|
|
if status == "finished": |
|
tabs += [ui.tab(name="experiment/display/chat", label="Chat")] |
|
|
|
q.page["experiment/display/tab"] = ui.tab_card( |
|
box="nav2", link=True, items=tabs, value=q.client["experiment/display/tab"] |
|
) |
|
q.client.delete_cards.add("experiment/display/tab") |
|
|
|
if q.client["experiment/display/tab"] == "experiment/display/charts": |
|
await charts_tab(q, [charts], [""]) |
|
elif q.client["experiment/display/tab"] in [ |
|
"experiment/display/train_data_insights", |
|
"experiment/display/validation_prediction_insights", |
|
]: |
|
await insights_tab(charts, q) |
|
elif q.client["experiment/display/tab"] in ["experiment/display/summary"]: |
|
await summary_tab(experiment_id, q) |
|
elif q.client["experiment/display/tab"] in ["experiment/display/config"]: |
|
await configs_tab(q) |
|
elif q.client["experiment/display/tab"] in ["experiment/display/logs"]: |
|
await logs_tab(q) |
|
elif q.client["experiment/display/tab"] in ["experiment/display/chat"]: |
|
await chat_tab(q) |
|
|
|
await q.page.save() |
|
|
|
buttons = [ |
|
ui.button(name="experiment/display/refresh", label="Refresh", primary=True) |
|
] |
|
|
|
buttons += [ |
|
ui.button( |
|
name="experiment/display/download_logs", |
|
label="Download logs/config", |
|
primary=False, |
|
) |
|
] |
|
|
|
if status == "finished": |
|
buttons += [ |
|
ui.button( |
|
name="experiment/display/download_predictions", |
|
label="Download predictions", |
|
primary=False, |
|
disabled=False, |
|
tooltip=None, |
|
), |
|
ui.button( |
|
name="experiment/display/download_model", |
|
label="Download model", |
|
primary=False, |
|
disabled=False, |
|
tooltip=None, |
|
), |
|
ui.button( |
|
name="experiment/display/push_to_huggingface", |
|
label="Push checkpoint to huggingface", |
|
primary=False, |
|
disabled=False, |
|
tooltip=None, |
|
), |
|
] |
|
|
|
buttons += [ui.button(name="experiment/list/current", label="Back", primary=False)] |
|
|
|
q.page["experiment/display/footer"] = ui.form_card( |
|
box="footer", |
|
items=[ |
|
ui.inline(items=buttons, justify="start"), |
|
], |
|
) |
|
q.client.delete_cards.add("experiment/display/footer") |
|
|
|
|
|
async def insights_tab(charts, q): |
|
if q.client["experiment/display/tab"] == "experiment/display/train_data_insights": |
|
key = "train_data" |
|
elif ( |
|
q.client["experiment/display/tab"] |
|
== "experiment/display/validation_prediction_insights" |
|
): |
|
key = "validation_predictions" |
|
for k1 in PLOT_ENCODINGS: |
|
if k1 not in charts: |
|
continue |
|
for k2, v2 in charts[k1].items(): |
|
if k2 != key: |
|
continue |
|
if k1 == "html": |
|
q.page[f"experiment/display/charts/{k1}_{k2}"] = ui.markup_card( |
|
box="first", title="", content=v2 |
|
) |
|
q.client.delete_cards.add(f"experiment/display/charts/{k1}_{k2}") |
|
|
|
continue |
|
|
|
elif k1 == "image": |
|
q.page[f"experiment/display/charts/{k1}_{k2}"] = ui.image_card( |
|
box="first", title="", type="png", image=v2 |
|
) |
|
q.client.delete_cards.add(f"experiment/display/charts/{k1}_{k2}") |
|
continue |
|
|
|
elif k1 == "df": |
|
df = pd.read_parquet(v2) |
|
min_widths = { |
|
col: "350" for col in df.columns if "text" in str(col).lower() |
|
} |
|
|
|
if key == "train_data": |
|
min_widths["Content"] = "800" |
|
q.page[f"experiment/display/charts/{k1}_{k2}"] = ui.form_card( |
|
box="first", |
|
items=[ |
|
ui_table_from_df( |
|
q=q, |
|
df=df, |
|
name=f"experiment/display/charts/{k1}_{k2}", |
|
sortables=[ |
|
col for col in df.columns if col.startswith("Metric") |
|
], |
|
markdown_cells=[ |
|
col |
|
for col in df.columns |
|
if not col.startswith("Metric") |
|
], |
|
searchables=list(df.columns), |
|
downloadable=True, |
|
resettable=True, |
|
min_widths=min_widths, |
|
height="calc(100vh - 245px)", |
|
max_char_length=50_000, |
|
cell_overflow="tooltip", |
|
) |
|
], |
|
) |
|
q.client.delete_cards.add(f"experiment/display/charts/{k1}_{k2}") |
|
continue |
|
|
|
|
|
async def summary_tab(experiment_id, q): |
|
experiment_df = get_experiments(q) |
|
input_dict = experiment_df[experiment_df.id == experiment_id].iloc[0].to_dict() |
|
cfg = load_config_yaml( |
|
os.path.join(q.client["experiment/display/experiment_path"], "cfg.yaml") |
|
) |
|
_ = get_tokenizer(cfg) |
|
|
|
|
|
card_name = "experiment/display/summary/experiment" |
|
q.page[card_name] = ui.form_card( |
|
box=ui.box(zone="first"), |
|
items=[ |
|
ui.separator("Experiment"), |
|
ui.stats( |
|
[ |
|
ui.stat( |
|
value=cfg.experiment_name, |
|
label="Name", |
|
), |
|
], |
|
justify="between", |
|
inset=True, |
|
), |
|
ui.stats( |
|
[ |
|
ui.stat( |
|
value=input_dict["config_file"], |
|
label="Problem Type", |
|
), |
|
], |
|
justify="between", |
|
inset=True, |
|
), |
|
], |
|
) |
|
q.client.delete_cards.add(card_name) |
|
|
|
|
|
card_name = "experiment/display/summary/datasets" |
|
q.page[card_name] = ui.form_card( |
|
box=ui.box(zone="first"), |
|
items=[ |
|
ui.separator("Datasets"), |
|
ui.stats( |
|
[ |
|
ui.stat( |
|
value=Path(cfg.dataset.train_dataframe).stem, |
|
label="Training Dataset", |
|
), |
|
], |
|
justify="between", |
|
inset=True, |
|
), |
|
ui.stats( |
|
[ |
|
ui.stat( |
|
value=( |
|
"-" |
|
if cfg.dataset.validation_dataframe in ["", "None", None] |
|
else Path(cfg.dataset.validation_dataframe).stem |
|
), |
|
label="Validation Dataset", |
|
), |
|
], |
|
justify="between", |
|
inset=True, |
|
), |
|
], |
|
) |
|
q.client.delete_cards.add(card_name) |
|
|
|
|
|
card_name = "experiment/display/summary/score" |
|
q.page[card_name] = ui.form_card( |
|
box=ui.box(zone="first"), |
|
items=[ |
|
ui.separator("Score"), |
|
ui.stats( |
|
[ |
|
ui.stat( |
|
value=input_dict["metric"], |
|
label="Metric", |
|
), |
|
], |
|
justify="between", |
|
inset=True, |
|
), |
|
ui.stats( |
|
[ |
|
ui.stat( |
|
value=( |
|
"-" |
|
if input_dict["val metric"] in ["", "None", None] |
|
else str(input_dict["val metric"]) |
|
), |
|
label="Validation Score", |
|
), |
|
], |
|
justify="between", |
|
inset=True, |
|
), |
|
], |
|
) |
|
q.client.delete_cards.add(card_name) |
|
|
|
|
|
card_name = "experiment/display/summary/main_configs" |
|
q.page[card_name] = ui.form_card( |
|
box=ui.box(zone="second"), |
|
items=[ |
|
ui.separator("Main Configurations"), |
|
ui.stats( |
|
[ |
|
ui.stat( |
|
value=cfg.llm_backbone, |
|
label="LLM Backbone", |
|
), |
|
ui.stat( |
|
value=str(cfg.training.lora), |
|
label="Lora", |
|
), |
|
ui.stat( |
|
value=str(cfg.training.epochs), |
|
label="Epochs", |
|
), |
|
ui.stat( |
|
value=str(cfg.training.batch_size), |
|
label="Batch Size", |
|
), |
|
], |
|
justify="between", |
|
inset=True, |
|
), |
|
ui.stats( |
|
[ |
|
ui.stat( |
|
value=str(input_dict["loss"]), |
|
label="Loss Function", |
|
), |
|
ui.stat( |
|
value=cfg.architecture.backbone_dtype, |
|
label="Backbone Dtype", |
|
), |
|
ui.stat( |
|
value=str(cfg.architecture.gradient_checkpointing), |
|
label="Gradient Checkpointing", |
|
), |
|
ui.stat( |
|
value=input_dict["gpu_list"], |
|
label="GPU List", |
|
), |
|
], |
|
justify="between", |
|
inset=True, |
|
), |
|
], |
|
) |
|
q.client.delete_cards.add(card_name) |
|
|
|
|
|
card_name = "experiment/display/summary/code" |
|
content = get_experiment_summary_code_card(cfg=cfg) |
|
q.page[card_name] = ui.markdown_card( |
|
box=ui.box(zone="third"), |
|
title="", |
|
content=content, |
|
) |
|
q.client.delete_cards.add(card_name) |
|
|
|
|
|
async def configs_tab(q): |
|
experiment_cfg = load_config_yaml( |
|
os.path.join(q.client["experiment/display/experiment_path"], "cfg.yaml") |
|
) |
|
items = get_cfg_list_items(experiment_cfg) |
|
q.page["experiment/display/config"] = ui.stat_list_card( |
|
box="first", items=items, title="" |
|
) |
|
q.client.delete_cards.add("experiment/display/config") |
|
|
|
|
|
async def logs_tab(q): |
|
logs_path = f"{q.client['experiment/display/experiment_path']}/logs.log" |
|
text = "" |
|
in_pre = 0 |
|
|
|
if os.path.exists(logs_path): |
|
with open(logs_path, "r") as f: |
|
for line in f.readlines(): |
|
if in_pre == 0: |
|
text += "<div>" |
|
if "INFO: Lock" in line: |
|
continue |
|
|
|
n = 250 |
|
chunks = [line[i : i + n] for i in range(0, len(line), n)] |
|
text += "</div><div>".join(chunks) |
|
|
|
|
|
if "<pre>" in line: |
|
in_pre += 1 |
|
if "</pre>" in line: |
|
in_pre -= 1 |
|
if in_pre == 0: |
|
text += "</div>" |
|
items = [ui.text(text)] |
|
q.page["experiment/display/logs"] = ui.form_card(box="first", items=items, title="") |
|
q.client.delete_cards.add("experiment/display/logs") |
|
|
|
|
|
def subsample(key1, key2, value, max_plot_points=1000): |
|
act_plot_points = len(value["steps"]) |
|
if act_plot_points > max_plot_points: |
|
stride = int(np.ceil(act_plot_points / max_plot_points)) |
|
value["steps"] = value["steps"][::stride] |
|
value["values"] = value["values"][::stride] |
|
logger.info( |
|
f"{key1} {key2} sampled from size {act_plot_points} to size " |
|
f"{len(value['steps'])} using stride {stride}." |
|
) |
|
return value |
|
|
|
|
|
def unite_validation_metric_charts(charts_list): |
|
unique_metrics = [] |
|
for chart in charts_list: |
|
unique_metrics.extend(list(chart.get("validation", {}).keys())) |
|
|
|
unique_metrics = set([key for key in unique_metrics if key != "loss"]) |
|
|
|
if len(unique_metrics) > 1: |
|
for chart in charts_list: |
|
if "validation" in chart: |
|
for key in unique_metrics: |
|
if key in chart["validation"]: |
|
chart["validation"]["metric"] = chart["validation"][key] |
|
del chart["validation"][key] |
|
return charts_list |
|
|
|
|
|
async def charts_tab(q, charts_list, legend_labels): |
|
charts_list = unite_validation_metric_charts(charts_list) |
|
|
|
box = ["first", "first", "second", "second"] |
|
cnt = 0 |
|
for k1 in ["meta", "train", "validation"]: |
|
if all([k1 not in charts for charts in charts_list]): |
|
continue |
|
|
|
all_second_keys: Set = set() |
|
for charts in charts_list: |
|
if k1 in charts: |
|
all_second_keys = all_second_keys.union(set(charts[k1].keys())) |
|
|
|
|
|
if "loss" in all_second_keys: |
|
all_second_keys.remove("loss") |
|
list_all_second_keys = ["loss"] + list(all_second_keys) |
|
else: |
|
list_all_second_keys = list(all_second_keys) |
|
|
|
for k2 in list_all_second_keys: |
|
logger.info(f"{k1} {k2}") |
|
|
|
items = [] |
|
|
|
tooltip = "" |
|
if k1 == "meta" and k2 == "lr": |
|
tooltip = "Current learning rate throughout the training process." |
|
elif k1 == "train" and k2 == "loss": |
|
tooltip = ( |
|
"Current training loss throughout the training process. " |
|
"Loss is calculated as the average of the last ten batches." |
|
) |
|
elif k1 == "validation" and k2 == "loss": |
|
tooltip = ( |
|
"Current validation loss throughout the training process. " |
|
"Loss is calculated as the average of all validation batches. " |
|
) |
|
elif k1 == "validation" and k2 != "loss": |
|
tooltip = ( |
|
"Current validation metric throughout the training process. " |
|
"Metric is calculated on full validation set predictions." |
|
) |
|
else: |
|
continue |
|
|
|
title = f"{k1} {k2}".upper().replace("META LR", "LEARNING RATE") |
|
if k2 == "loss": |
|
title = title.replace("LOSS", "BATCH LOSS") |
|
|
|
items.append(ui.text(title, tooltip=tooltip)) |
|
|
|
rows = [] |
|
|
|
max_samples = q.client["chart_plot_max_points"] |
|
for charts, label in zip(charts_list, legend_labels): |
|
if k1 not in charts or k2 not in charts[k1]: |
|
continue |
|
|
|
v2 = charts[k1][k2] |
|
v2 = subsample(k1, k2, v2, max_samples) |
|
|
|
if k2 == "lr" and "lr_diff" in charts["meta"]: |
|
v3 = charts["meta"]["lr_diff"] |
|
v3 = subsample("meta", "lr_diff", v3, max_samples) |
|
rows.extend( |
|
[ |
|
(v2["steps"][i], f"learning rate{label}", v2["values"][i]) |
|
for i in range(len(v2["values"])) |
|
] |
|
+ [ |
|
( |
|
v3["steps"][i], |
|
f"differential learning rate{label}", |
|
v3["values"][i], |
|
) |
|
for i in range(len(v3["values"])) |
|
] |
|
) |
|
color = "=type" |
|
fields = ["step", "type", "value"] |
|
|
|
elif len(charts_list) > 1: |
|
rows.extend( |
|
[ |
|
(v2["steps"][i], label.strip(), v2["values"][i]) |
|
for i in range(len(v2["values"])) |
|
] |
|
) |
|
color = "=type" |
|
fields = ["step", "type", "value"] |
|
else: |
|
rows.extend( |
|
[ |
|
(v2["steps"][i], v2["values"][i]) |
|
for i in range(len(v2["values"])) |
|
] |
|
) |
|
color = wave_theme.color |
|
fields = ["step", "value"] |
|
|
|
d = data(fields=fields, rows=rows, pack=True) |
|
|
|
viz = ui.visualization( |
|
plot=ui.plot( |
|
[ |
|
ui.mark( |
|
type="line", |
|
x_title="step", |
|
x_scale="linear", |
|
y_scale="linear", |
|
x="=step", |
|
y="=value", |
|
color=color, |
|
y_min=0 if k1 == "meta" and k2 == "lr" else None, |
|
color_range=wave_theme.color_range, |
|
) |
|
] |
|
), |
|
data=d, |
|
interactions=["brush"], |
|
height="calc((100vh - 275px)*0.41)", |
|
width="560px", |
|
) |
|
|
|
items.append(viz) |
|
|
|
if k1 == "validation" and k2 == "loss" and np.sum(v2["values"]) == 0: |
|
items.append( |
|
ui.message_bar( |
|
type="info", |
|
text="Validation batch loss cannot be \ |
|
calculated for this problem type.", |
|
) |
|
) |
|
|
|
q.page[f"experiment/display/charts/{k1}_{k2}"] = ui.form_card( |
|
box=box[cnt], items=items |
|
) |
|
q.client.delete_cards.add(f"experiment/display/charts/{k1}_{k2}") |
|
|
|
cnt += 1 |
|
|
|
|
|
async def experiment_artifact_build_error_dialog(q: Q, error: str): |
|
q.page["meta"].dialog = ui.dialog( |
|
"Failed to build artifact", items=[ui.text(error)], closable=True |
|
) |
|
q.client["keep_meta"] = True |
|
|
|
|
|
async def experiment_download_artifact( |
|
q: Q, |
|
get_artifact_path_fn: Callable[[str, str], str], |
|
save_artifact_fn: Callable[[str, str], str], |
|
additional_log: Optional[str] = "", |
|
min_disk_space: Optional[float] = 0.0, |
|
): |
|
"""Download specific artifact, if it does not exist, create it on demand |
|
|
|
Args: |
|
q: Q |
|
get_artifact_path_fn: function that returns path to the artifact |
|
save_artifact_fn: function that generates the artifact and returns its path |
|
additional_log: additional information to be logged |
|
min_disk_space: minimal disk available needed to generate artifact |
|
""" |
|
|
|
experiment = q.client["experiment/display/experiment"] |
|
experiment_path = q.client["experiment/display/experiment_path"] |
|
|
|
zip_path = get_artifact_path_fn(experiment.name, experiment_path) |
|
|
|
if not os.path.exists(zip_path): |
|
try: |
|
check_available_space(experiment_path, min_disk_space) |
|
except LLMResourceException as e: |
|
error = f"Cannot create {os.path.basename(zip_path)}. {e}" |
|
await experiment_artifact_build_error_dialog(q, error) |
|
return |
|
|
|
logger.info(f"Creating {zip_path} on demand") |
|
zip_path = save_artifact_fn(experiment.name, experiment_path) |
|
|
|
if additional_log: |
|
logger.info(f"{additional_log}: {zip_path}") |
|
|
|
q.page["meta"].script = ui.inline_script( |
|
f'window.open("{get_download_link(q, zip_path)}", "_blank");' |
|
) |
|
await q.page.save() |
|
|
|
|
|
async def experiment_download_predictions(q: Q): |
|
"""Download experiment predictions.""" |
|
await experiment_download_artifact( |
|
q, get_predictions_path, save_prediction_outputs, "Predictions path", None |
|
) |
|
|
|
|
|
async def experiment_download_logs(q: Q): |
|
"""Download experiment logs.""" |
|
|
|
experiment = q.client["experiment/display/experiment"] |
|
experiment_path = q.client["experiment/display/experiment_path"] |
|
zip_path = get_logs_path(experiment.name, experiment_path) |
|
|
|
if not os.path.exists(zip_path): |
|
logs = q.client["experiment/display/charts"] |
|
logger.info(f"Creating {zip_path} on demand") |
|
zip_path = save_logs(experiment.name, experiment_path, logs) |
|
|
|
download_url = get_download_link(q, zip_path) |
|
logger.info(f"Logs URL: {download_url}") |
|
|
|
q.page["meta"].script = ui.inline_script( |
|
f'window.open("{download_url}", "_blank");' |
|
) |
|
await q.page.save() |
|
|
|
|
|
async def config_import_uploaded_file(q: Q): |
|
""" "Importing a config file from drag and drop to the filesystem""" |
|
|
|
file_url = q.args["experiment/upload_yaml"][0] |
|
file_name = file_url.split("/")[-1] |
|
path = f"{get_data_dir(q)}/{file_name}" |
|
|
|
local_path = await q.site.download(file_url, path) |
|
|
|
await q.site.unload(q.args["experiment/upload_yaml"][0]) |
|
|
|
with open(local_path, "r") as f: |
|
yaml_data = yaml.safe_load(f) |
|
|
|
yaml_data = flatten_dict(yaml_data) |
|
|
|
q.client["experiment/yaml_data"] = yaml_data |
|
|
|
|
|
async def show_message(q, msg_key, page, idx, msg_type): |
|
info = q.client[msg_key] |
|
if info: |
|
q.page[page].items[idx].message_bar.text = info |
|
q.page[page].items[idx].message_bar.type = msg_type |
|
q.client[msg_key] = "" |
|
|
|
|
|
def get_experiment_list_message_bar(q): |
|
if q.client["experiment_halt_reason"]: |
|
msg_bar = ui.message_bar(type="error", text=q.client["experiment_halt_reason"]) |
|
del q.client["experiment_halt_reason"] |
|
|
|
elif q.client["force_disable_pipelines"]: |
|
msg_bar = ui.message_bar(type="info", text=q.client["force_disable_pipelines"]) |
|
del q.client["force_disable_pipelines"] |
|
|
|
else: |
|
msg_bar = ui.message_bar(type="info", text="") |
|
|
|
return msg_bar |
|
|
|
|
|
async def experiment_download_model(q: Q): |
|
experiment = q.client["experiment/display/experiment"] |
|
experiment_path = q.client["experiment/display/experiment_path"] |
|
zip_path = get_model_path(experiment.name, experiment_path) |
|
|
|
if not os.path.exists(zip_path): |
|
logger.info(f"Creating {zip_path} on demand") |
|
cfg = load_config_yaml(os.path.join(experiment_path, "cfg.yaml")) |
|
|
|
device = "cuda" |
|
experiments = get_experiments(q) |
|
num_running_queued = len( |
|
experiments[experiments["status"].isin(["queued", "running"])] |
|
) |
|
if num_running_queued > 0 or ( |
|
cfg.training.lora and cfg.architecture.backbone_dtype in ("int4", "int8") |
|
): |
|
logger.info("Preparing model on CPU. This might slow down the progress.") |
|
device = "cpu" |
|
with set_env(HUGGINGFACE_TOKEN=q.client["default_huggingface_api_token"]): |
|
cfg, model, tokenizer = load_cfg_model_tokenizer( |
|
experiment_path, merge=True, device=device |
|
) |
|
|
|
model = unwrap_model(model) |
|
checkpoint_path = cfg.output_directory |
|
|
|
model_save_time = time.time() |
|
model.backbone.save_pretrained(checkpoint_path) |
|
|
|
|
|
|
|
tokenizer_files = list(tokenizer.save_pretrained(checkpoint_path) or []) |
|
|
|
card = get_model_card(cfg, model, repo_id="<path_to_local_folder>") |
|
card.save(os.path.join(experiment_path, "model_card.md")) |
|
|
|
logger.info(f"Creating Zip File at {zip_path}") |
|
zf = zipfile.ZipFile(zip_path, "w") |
|
|
|
FILES_TO_PUSH = [ |
|
"vocab.json", |
|
"sentencepiece.bpe.model", |
|
"bpe_encoder.bin", |
|
"tokenizer_config.json", |
|
"tokenizer.json", |
|
"special_tokens_map.json", |
|
"merges.txt", |
|
"generation_config.json", |
|
"config.json", |
|
"added_tokens.json", |
|
"model_card.md", |
|
"classification_head.pth", |
|
] |
|
FILES_TO_PUSH = set( |
|
FILES_TO_PUSH |
|
+ [os.path.split(tokenizer_file)[-1] for tokenizer_file in tokenizer_files] |
|
) |
|
|
|
|
|
paths_added = [] |
|
for file in FILES_TO_PUSH: |
|
path = os.path.join(experiment_path, file) |
|
if os.path.isfile(path): |
|
paths_added.append(path) |
|
add_file_to_zip(zf=zf, path=path) |
|
|
|
|
|
weight_paths = glob.glob(os.path.join(checkpoint_path, "pytorch_model*.*")) |
|
for path in weight_paths: |
|
paths_added.append(path) |
|
add_file_to_zip(zf=zf, path=path) |
|
|
|
|
|
|
|
|
|
for file in os.listdir(checkpoint_path): |
|
file_path = os.path.join(checkpoint_path, file) |
|
if ( |
|
os.path.getmtime(file_path) > model_save_time |
|
and file_path not in paths_added |
|
and file_path != zip_path |
|
): |
|
add_file_to_zip(zf=zf, path=file_path) |
|
paths_added.append(file_path) |
|
logger.info( |
|
f"Added {file_path} to zip file as it " |
|
"was created when saving the model state." |
|
) |
|
zf.close() |
|
|
|
download_url = get_download_link(q, zip_path) |
|
logger.info(f"Logs URL: {download_url}") |
|
|
|
q.page["meta"].script = ui.inline_script( |
|
f'window.open("{download_url}", "_blank");' |
|
) |
|
await q.page.save() |
|
|
|
|
|
async def experiment_push_to_huggingface_dialog(q: Q, error: str = ""): |
|
if q.args["experiment/display/push_to_huggingface"] or error: |
|
devices = ["cpu", "cpu_shard"] + [ |
|
f"cuda:{idx}" for idx in range(torch.cuda.device_count()) |
|
] |
|
default_device = "cuda:0" |
|
|
|
experiments = get_experiments(q) |
|
num_running_queued = len( |
|
experiments[experiments["status"].isin(["queued", "running"])] |
|
) |
|
experiment_path = q.client["experiment/display/experiment_path"] |
|
cfg = load_config_yaml(os.path.join(experiment_path, "cfg.yaml")) |
|
if num_running_queued > 0 or cfg.environment.use_deepspeed: |
|
default_device = "cpu" |
|
|
|
try: |
|
huggingface_hub.login(q.client["default_huggingface_api_token"]) |
|
user_id = huggingface_hub.whoami()["name"] |
|
except Exception: |
|
user_id = "" |
|
|
|
dialog_items = [ |
|
ui.message_bar("error", error, visible=True if error else False), |
|
ui.textbox( |
|
name="experiment/display/push_to_huggingface/account_name", |
|
label="Account Name", |
|
value=user_id, |
|
width="500px", |
|
required=False, |
|
tooltip=( |
|
"The account name on HF to push the model to. " |
|
"Leaving it empty will push it to the default user account." |
|
), |
|
), |
|
ui.textbox( |
|
name="experiment/display/push_to_huggingface/model_name", |
|
label="Model Name", |
|
value=hf_repo_friendly_name( |
|
q.client["experiment/display/experiment"].name |
|
), |
|
width="500px", |
|
required=True, |
|
tooltip="The name of the model as shown on HF.", |
|
), |
|
ui.dropdown( |
|
name="experiment/display/push_to_huggingface/device", |
|
label="Device for preparing the model", |
|
required=True, |
|
value=default_device, |
|
width="500px", |
|
choices=[ui.choice(str(d), str(d)) for d in devices], |
|
tooltip=( |
|
"The local device to prepare the model before pushing it to HF. " |
|
"CPU will never load the weights to the GPU, which can be useful " |
|
"for large models, but will be significantly slower. " |
|
"Cpu_shard will first load on CPU and then shard on all GPUs " |
|
"before pushing to HF." |
|
), |
|
), |
|
ui.textbox( |
|
name="experiment/display/push_to_huggingface/api_key", |
|
label="Huggingface API Key", |
|
value=q.client["default_huggingface_api_token"], |
|
width="500px", |
|
password=True, |
|
required=True, |
|
tooltip="HF API key, needs write access.", |
|
), |
|
ui.toggle( |
|
name="default_safe_serialization", |
|
label="Use Hugging Face safetensors for safe serialization", |
|
value=q.client["default_safe_serialization"], |
|
), |
|
ui.buttons( |
|
[ |
|
ui.button( |
|
name="experiment/display/push_to_huggingface_submit", |
|
label="Export", |
|
primary=True, |
|
), |
|
ui.button(name="cancel", label="Cancel", primary=False), |
|
] |
|
), |
|
] |
|
elif q.args["experiment/display/push_to_huggingface_submit"]: |
|
await busy_dialog( |
|
q=q, |
|
title="Exporting to HuggingFace", |
|
text="Model size can affect the export time significantly.", |
|
) |
|
|
|
experiment_path = q.client["experiment/display/experiment_path"] |
|
device = q.client["experiment/display/push_to_huggingface/device"] |
|
api_key = q.client["experiment/display/push_to_huggingface/api_key"] |
|
user_id = q.client["experiment/display/push_to_huggingface/account_name"] |
|
safe_serialization = q.client["default_safe_serialization"] |
|
model_name = q.client[ |
|
"experiment/display/push_to_huggingface/model_name" |
|
].replace(".", "-") |
|
|
|
publish_model_to_hugging_face( |
|
path_to_experiment=experiment_path, |
|
device=device, |
|
api_key=api_key, |
|
user_id=user_id, |
|
model_name=model_name, |
|
safe_serialization=safe_serialization, |
|
) |
|
|
|
dialog_items = [ |
|
ui.message_bar("success", "Success"), |
|
ui.buttons( |
|
[ |
|
ui.button(name="ok", label="OK", primary=True), |
|
] |
|
), |
|
] |
|
|
|
dialog = ui.dialog( |
|
title="Push to HuggingFace Hub", |
|
items=dialog_items, |
|
closable=True, |
|
name="push_to_huggingface_dialog", |
|
) |
|
|
|
q.page["meta"].dialog = dialog |
|
q.client["keep_meta"] = True |
|
|
|
|
|
def get_experiment_summary_code_card(cfg) -> str: |
|
repo_id: Optional[str] = None |
|
hf_yaml_path = f"{cfg.output_directory}/hf.yaml" |
|
|
|
with open( |
|
os.path.join("model_cards", cfg.environment._summary_card_template), "r" |
|
) as f: |
|
text = f.read() |
|
|
|
if os.path.exists(hf_yaml_path): |
|
with open(hf_yaml_path, "r") as fp: |
|
repo_id = yaml.load(fp, Loader=yaml.FullLoader)["repo_id"] |
|
|
|
if repo_id is None: |
|
repo_id = "account/model" |
|
|
|
|
|
text = text.replace("{{repo_id}}", repo_id) |
|
|
|
|
|
text = text.replace("{{transformers_version}}", transformers.__version__) |
|
text = text.replace("{{einops_version}}", einops.__version__) |
|
text = text.replace("{{accelerate_version}}", accelerate.__version__) |
|
text = text.replace("{{torch_version}}", torch.__version__) |
|
|
|
|
|
text = text.replace("{{text_prompt_start}}", str(cfg.dataset.text_prompt_start)) |
|
text = text.replace( |
|
"{{text_answer_separator}}", str(cfg.dataset.text_answer_separator) |
|
) |
|
text = text.replace( |
|
"{{end_of_sentence}}", |
|
str(cfg._tokenizer_eos_token) if cfg.dataset.add_eos_token_to_prompt else "", |
|
) |
|
|
|
text = text.replace("{{trust_remote_code}}", str(cfg.environment.trust_remote_code)) |
|
|
|
if cfg.problem_type not in NON_GENERATION_PROBLEM_TYPES: |
|
text = text.replace( |
|
"{{min_new_tokens}}", str(cfg.prediction.min_length_inference) |
|
) |
|
text = text.replace( |
|
"{{max_new_tokens}}", str(cfg.prediction.max_length_inference) |
|
) |
|
text = text.replace("{{use_fast}}", str(cfg.tokenizer.use_fast)) |
|
text = text.replace("{{do_sample}}", str(cfg.prediction.do_sample)) |
|
text = text.replace("{{num_beams}}", str(cfg.prediction.num_beams)) |
|
text = text.replace("{{temperature}}", str(cfg.prediction.temperature)) |
|
text = text.replace( |
|
"{{repetition_penalty}}", str(cfg.prediction.repetition_penalty) |
|
) |
|
|
|
return text |
|
|