Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
# Base CLI to parse Arguments | |
# | |
# @ Fabian Hörst, fabian.hoerst@uk-essen.de | |
# Institute for Artifical Intelligence in Medicine, | |
# University Medicine Essen | |
import argparse | |
import logging | |
from abc import ABC, abstractmethod | |
from typing import Tuple, Union | |
import yaml | |
from pydantic import BaseModel | |
class ABCParser(ABC): | |
"""Blueprint for Argument Parser""" | |
def __init__(self) -> None: | |
pass | |
def get_config(self) -> Tuple[Union[BaseModel, dict], logging.Logger]: | |
"""Load configuration and create a logger | |
Returns: | |
Tuple[PreProcessingConfig, logging.Logger]: Configuration and Logger | |
""" | |
pass | |
def store_config(self) -> None: | |
"""Store the config file in the logging directory to keep track of the configuration.""" | |
pass | |
class ExperimentBaseParser: | |
"""Configuration Parser for Machine Learning Experiments""" | |
def __init__(self) -> None: | |
parser = argparse.ArgumentParser( | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
description="Start an experiment with given configuration file.", | |
) | |
requiredNamed = parser.add_argument_group("required named arguments") | |
requiredNamed.add_argument( | |
"--config", type=str, help="Path to a config file", required=True | |
) | |
parser.add_argument("--gpu", type=int, help="Cuda-GPU ID") | |
group = parser.add_mutually_exclusive_group(required=False) | |
group.add_argument( | |
"--sweep", | |
action="store_true", | |
help="Starting a sweep. For this the configuration file must be structured according to WandB sweeping. " | |
"Compare https://docs.wandb.ai/guides/sweeps and https://community.wandb.ai/t/nested-sweep-configuration/3369/3 " | |
"for further information. This parameter cannot be set in the config file!", | |
) | |
group.add_argument( | |
"--agent", | |
type=str, | |
help="Add a new agent to the sweep. " | |
"Please pass the sweep ID as argument in the way entity/project/sweep_id, e.g., user1/test_project/v4hwbijh. " | |
"The agent configuration can be found in the WandB dashboard for the running sweep in the sweep overview tab " | |
"under launch agent. Just paste the entity/project/sweep_id given there. The provided config file must be a sweep config file." | |
"This parameter cannot be set in the config file!", | |
) | |
group.add_argument( | |
"--checkpoint", | |
type=str, | |
help="Path to a PyTorch checkpoint file. " | |
"The file is loaded and continued to train with the provided settings. " | |
"If this is passed, no sweeps are possible. " | |
"This parameter cannot be set in the config file!", | |
) | |
self.parser = parser | |
def parse_arguments(self) -> Tuple[Union[BaseModel, dict]]: | |
"""Parse the arguments from CLI and load yaml config | |
Returns: | |
Tuple[Union[BaseModel, dict]]: Parsed arguments | |
""" | |
# parse the arguments | |
opt = self.parser.parse_args() #定义了一个opt变量,用来存储参数 | |
with open(opt.config, "r") as config_file: | |
yaml_config = yaml.safe_load(config_file) | |
yaml_config_dict = dict(yaml_config) #将yaml文件转换为字典 | |
opt_dict = vars(opt) #将opt转换为字典 | |
# check for gpu to overwrite with cli argument | |
if "gpu" in opt_dict: #如果gpu在opt_dict中 | |
if opt_dict["gpu"] is not None: | |
yaml_config_dict["gpu"] = opt_dict["gpu"] #将opt_dict中的gpu值赋给yaml_config_dict中的gpu | |
# check if either training, sweep, checkpoint or start agent should be called | |
# first step: remove such keys from the config file | |
if "run_sweep" in yaml_config_dict: #如果yaml_config_dict中有run_sweep | |
yaml_config_dict.pop("run_sweep") #删除yaml_config_dict中的run_sweep | |
if "agent" in yaml_config_dict: | |
yaml_config_dict.pop("agent") | |
if "checkpoint" in yaml_config_dict: | |
yaml_config_dict.pop("checkpoint") | |
# select one of the options | |
if "sweep" in opt_dict and opt_dict["sweep"] is True: | |
yaml_config_dict["run_sweep"] = True | |
else: | |
yaml_config_dict["run_sweep"] = False | |
if "agent" in opt_dict: | |
yaml_config_dict["agent"] = opt_dict["agent"] | |
if "checkpoint" in opt_dict: | |
if opt_dict["checkpoint"] is not None: | |
yaml_config_dict["checkpoint"] = opt_dict["checkpoint"] | |
self.config = yaml_config_dict #将yaml_config_dict赋给self.config | |
return self.config | |