|
import pandas as pd |
|
import wandb |
|
|
|
|
|
def get_wandb_data(entity: str, project: str, api_key: str, job_type: str) -> pd.DataFrame: |
|
api = wandb.Api(api_key=api_key) |
|
|
|
|
|
filter_dict = {"jobType": job_type} |
|
runs = api.runs(f"{entity}/{project}", filters=filter_dict) |
|
|
|
summary_list, config_list, name_list = [], [], [] |
|
for run in runs: |
|
|
|
|
|
summary_list.append(run.summary._json_dict) |
|
|
|
|
|
|
|
config_list.append({k: v for k, v in run.config.items()}) |
|
|
|
|
|
name_list.append(run.name) |
|
|
|
summary_df = pd.json_normalize(summary_list, max_level=1) |
|
config_df = pd.json_normalize(config_list, max_level=2) |
|
runs_df = pd.concat([summary_df, config_df], axis=1) |
|
runs_df.index = name_list |
|
return runs_df |
|
|