import os import shutil import pandas as pd import streamlit as st # from streamlit_tensorboard import st_tensorboard from huggingface_hub import list_models from huggingface_hub import HfApi # ============================================================== st.set_page_config(layout="wide") # ============================================================== logdir="/tmp/tensorboard_logs" os.makedirs(logdir, exist_ok=True) def clean_logdir(logdir): try: shutil.rmtree(logdir) except Exception as e: print(e) @st.cache_resource def get_models(): _author = "hahunavth" _filter = "emofs2" return list( list_models(author=_author, filter=_filter, sort="last_modified", direction=-1) ) TB_FILE_PREFIX = "events.out.tfevents" def download_repoo_tb(repo_id, api, log_dir, df): repo_name = repo_id.split("/")[-1] if api.repo_exists(repo_id): files = api.list_repo_files(repo_id) else: raise ValueError(f"Repo {repo_id} does not exist") tb_files = [f for f in files if f.split('/')[-1].startswith(TB_FILE_PREFIX)] tb_files_info = list(api.list_files_info(repo_id, tb_files)) tb_files_info = [f for f in tb_files_info if f.size > 0] for repo_file in tb_files_info: path = repo_file.path size = repo_file.size stage = path.split('/')[-2] fname = path.split('/')[-1] sub_folder = path.replace(f"/{fname}", '') if ((df["repo"]==repo_name) & (df["path"]==path) & (df["size"]==size)).any() and os.path.exists(os.path.join(log_dir, repo_name, path)): print(f"Skipping {repo_name}/{path}") continue else: print(f"Downloading {repo_name}/{path}") api.hf_hub_download(repo_id=repo_id, filename=fname, subfolder=sub_folder, local_dir=os.path.join(log_dir, repo_name)) new_df = pd.DataFrame([{ "repo": repo_name, "path": path, "size": size, }]) df = pd.concat([df, new_df], ignore_index=True) return df @st.cache_resource def create_cache_dataframe(): return pd.DataFrame(columns=["repo", "path", "size"]) # ============================================================== api = HfApi() df = create_cache_dataframe() models = get_models() model_ids = [model.id for model in models] # select many with st.expander("Download tf", expanded=True): with st.form("my_form"): selected_models = st.multiselect("Select models", model_ids, default=None) submit = st.form_submit_button("Download logs") if submit: # download tensorboard logs with st.spinner("Downloading logs..."): for model_id in selected_models: st.write(f"Downloading logs for {model_id}") df = download_repoo_tb(model_id, api, logdir, df) st.write("Done") clean_btn = st.button("Clean all") if clean_btn: clean_logdir(logdir) create_cache_dataframe.clear() get_models.clear() # with st.expander("...", expanded=True): # st_tensorboard(logdir=logdir, port=6006, width=1760, scrolling=False) # st.text(st)