Spaces:
Runtime error
Runtime error
| """ | |
| Copyright (c) 2022, salesforce.com, inc. | |
| All rights reserved. | |
| SPDX-License-Identifier: BSD-3-Clause | |
| For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
| """ | |
| import logging | |
| import json | |
| from typing import Dict | |
| from omegaconf import OmegaConf | |
| from bliva.common.registry import registry | |
| class Config: | |
| def __init__(self, args): | |
| self.config = {} | |
| self.args = args | |
| # Register the config and configuration for setup | |
| registry.register("configuration", self) | |
| user_config = self._build_opt_list(self.args.options) | |
| config = OmegaConf.load(self.args.cfg_path) | |
| runner_config = self.build_runner_config(config) | |
| model_config = self.build_model_config(config, **user_config) | |
| dataset_config = self.build_dataset_config(config) | |
| # Validate the user-provided runner configuration | |
| # model and dataset configuration are supposed to be validated by the respective classes | |
| # [TODO] validate the model/dataset configuration | |
| # self._validate_runner_config(runner_config) | |
| # Override the default configuration with user options. | |
| self.config = OmegaConf.merge( | |
| runner_config, model_config, dataset_config, user_config | |
| ) | |
| def _validate_runner_config(self, runner_config): | |
| """ | |
| This method validates the configuration, such that | |
| 1) all the user specified options are valid; | |
| 2) no type mismatches between the user specified options and the config. | |
| """ | |
| runner_config_validator = create_runner_config_validator() | |
| runner_config_validator.validate(runner_config) | |
| def _build_opt_list(self, opts): | |
| opts_dot_list = self._convert_to_dot_list(opts) | |
| return OmegaConf.from_dotlist(opts_dot_list) | |
| def build_model_config(config, **kwargs): | |
| model = config.get("model", None) | |
| assert model is not None, "Missing model configuration file." | |
| model_cls = registry.get_model_class(model.arch) | |
| assert model_cls is not None, f"Model '{model.arch}' has not been registered." | |
| model_type = kwargs.get("model.model_type", None) | |
| if not model_type: | |
| model_type = model.get("model_type", None) | |
| # else use the model type selected by user. | |
| assert model_type is not None, "Missing model_type." | |
| model_config_path = model_cls.default_config_path(model_type=model_type) | |
| model_config = OmegaConf.create() | |
| # hiararchy override, customized config > default config | |
| model_config = OmegaConf.merge( | |
| model_config, | |
| OmegaConf.load(model_config_path), | |
| {"model": config["model"]}, | |
| ) | |
| return model_config | |
| def build_runner_config(config): | |
| return {"run": config.run} | |
| def build_dataset_config(config): | |
| datasets = config.get("datasets", None) | |
| if datasets is None: | |
| raise KeyError( | |
| "Expecting 'datasets' as the root key for dataset configuration." | |
| ) | |
| dataset_config = OmegaConf.create() | |
| for dataset_name in datasets: | |
| print(f"Building dataset config for {dataset_name}") | |
| builder_cls = registry.get_builder_class(dataset_name) | |
| dataset_config_type = datasets[dataset_name].get("type", "default") | |
| dataset_config_path = builder_cls.default_config_path( | |
| type=dataset_config_type | |
| ) | |
| # hiararchy override, customized config > default config | |
| dataset_config = OmegaConf.merge( | |
| dataset_config, | |
| OmegaConf.load(dataset_config_path), | |
| {"datasets": {dataset_name: config["datasets"][dataset_name]}}, | |
| ) | |
| return dataset_config | |
| def _convert_to_dot_list(self, opts): | |
| if opts is None: | |
| opts = [] | |
| if len(opts) == 0: | |
| return opts | |
| has_equal = opts[0].find("=") != -1 | |
| if has_equal: | |
| return opts | |
| return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])] | |
| def get_config(self): | |
| return self.config | |
| def run_cfg(self): | |
| return self.config.run | |
| def datasets_cfg(self): | |
| return self.config.datasets | |
| def model_cfg(self): | |
| return self.config.model | |
| def pretty_print(self): | |
| logging.info("\n===== Running Parameters =====") | |
| logging.info(self._convert_node_to_json(self.config.run)) | |
| logging.info("\n====== Dataset Attributes ======") | |
| datasets = self.config.datasets | |
| for dataset in datasets: | |
| if dataset in self.config.datasets: | |
| logging.info(f"\n======== {dataset} =======") | |
| dataset_config = self.config.datasets[dataset] | |
| logging.info(self._convert_node_to_json(dataset_config)) | |
| else: | |
| logging.warning(f"No dataset named '{dataset}' in config. Skipping") | |
| logging.info(f"\n====== Model Attributes ======") | |
| logging.info(self._convert_node_to_json(self.config.model)) | |
| def _convert_node_to_json(self, node): | |
| container = OmegaConf.to_container(node, resolve=True) | |
| return json.dumps(container, indent=4, sort_keys=True) | |
| def to_dict(self): | |
| return OmegaConf.to_container(self.config) | |
| def node_to_dict(node): | |
| return OmegaConf.to_container(node) | |
| class ConfigValidator: | |
| """ | |
| This is a preliminary implementation to centralize and validate the configuration. | |
| May be altered in the future. | |
| A helper class to validate configurations from yaml file. | |
| This serves the following purposes: | |
| 1. Ensure all the options in the yaml are defined, raise error if not. | |
| 2. when type mismatches are found, the validator will raise an error. | |
| 3. a central place to store and display helpful messages for supported configurations. | |
| """ | |
| class _Argument: | |
| def __init__(self, name, choices=None, type=None, help=None): | |
| self.name = name | |
| self.val = None | |
| self.choices = choices | |
| self.type = type | |
| self.help = help | |
| def __str__(self): | |
| s = f"{self.name}={self.val}" | |
| if self.type is not None: | |
| s += f", ({self.type})" | |
| if self.choices is not None: | |
| s += f", choices: {self.choices}" | |
| if self.help is not None: | |
| s += f", ({self.help})" | |
| return s | |
| def __init__(self, description): | |
| self.description = description | |
| self.arguments = dict() | |
| self.parsed_args = None | |
| def __getitem__(self, key): | |
| assert self.parsed_args is not None, "No arguments parsed yet." | |
| return self.parsed_args[key] | |
| def __str__(self) -> str: | |
| return self.format_help() | |
| def add_argument(self, *args, **kwargs): | |
| """ | |
| Assume the first argument is the name of the argument. | |
| """ | |
| self.arguments[args[0]] = self._Argument(*args, **kwargs) | |
| def validate(self, config=None): | |
| """ | |
| Convert yaml config (dict-like) to list, required by argparse. | |
| """ | |
| for k, v in config.items(): | |
| assert ( | |
| k in self.arguments | |
| ), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}.""" | |
| if self.arguments[k].type is not None: | |
| try: | |
| self.arguments[k].val = self.arguments[k].type(v) | |
| except ValueError: | |
| raise ValueError(f"{k} is not a valid {self.arguments[k].type}.") | |
| if self.arguments[k].choices is not None: | |
| assert ( | |
| v in self.arguments[k].choices | |
| ), f"""{k} must be one of {self.arguments[k].choices}.""" | |
| return config | |
| def format_arguments(self): | |
| return str([f"{k}" for k in sorted(self.arguments.keys())]) | |
| def format_help(self): | |
| # description + key-value pair string for each argument | |
| help_msg = str(self.description) | |
| return help_msg + ", available arguments: " + self.format_arguments() | |
| def print_help(self): | |
| # display help message | |
| print(self.format_help()) | |
| def create_runner_config_validator(): | |
| validator = ConfigValidator(description="Runner configurations") | |
| validator.add_argument( | |
| "runner", | |
| type=str, | |
| choices=["runner_base", "runner_iter"], | |
| help="""Runner to use. The "runner_base" uses epoch-based training while iter-based | |
| runner runs based on iters. Default: runner_base""", | |
| ) | |
| # add argumetns for training dataset ratios | |
| validator.add_argument( | |
| "train_dataset_ratios", | |
| type=Dict[str, float], | |
| help="""Ratios of training dataset. This is used in iteration-based runner. | |
| Do not support for epoch-based runner because how to define an epoch becomes tricky. | |
| Default: None""", | |
| ) | |
| validator.add_argument( | |
| "max_iters", | |
| type=float, | |
| help="Maximum number of iterations to run.", | |
| ) | |
| validator.add_argument( | |
| "max_epoch", | |
| type=int, | |
| help="Maximum number of epochs to run.", | |
| ) | |
| # add arguments for iters_per_inner_epoch | |
| validator.add_argument( | |
| "iters_per_inner_epoch", | |
| type=float, | |
| help="Number of iterations per inner epoch. This is required when runner is runner_iter.", | |
| ) | |
| lr_scheds_choices = registry.list_lr_schedulers() | |
| validator.add_argument( | |
| "lr_sched", | |
| type=str, | |
| choices=lr_scheds_choices, | |
| help="Learning rate scheduler to use, from {}".format(lr_scheds_choices), | |
| ) | |
| task_choices = registry.list_tasks() | |
| validator.add_argument( | |
| "task", | |
| type=str, | |
| choices=task_choices, | |
| help="Task to use, from {}".format(task_choices), | |
| ) | |
| # add arguments for init_lr | |
| validator.add_argument( | |
| "init_lr", | |
| type=float, | |
| help="Initial learning rate. This will be the learning rate after warmup and before decay.", | |
| ) | |
| # add arguments for min_lr | |
| validator.add_argument( | |
| "min_lr", | |
| type=float, | |
| help="Minimum learning rate (after decay).", | |
| ) | |
| # add arguments for warmup_lr | |
| validator.add_argument( | |
| "warmup_lr", | |
| type=float, | |
| help="Starting learning rate for warmup.", | |
| ) | |
| # add arguments for learning rate decay rate | |
| validator.add_argument( | |
| "lr_decay_rate", | |
| type=float, | |
| help="Learning rate decay rate. Required if using a decaying learning rate scheduler.", | |
| ) | |
| # add arguments for weight decay | |
| validator.add_argument( | |
| "weight_decay", | |
| type=float, | |
| help="Weight decay rate.", | |
| ) | |
| # add arguments for training batch size | |
| validator.add_argument( | |
| "batch_size_train", | |
| type=int, | |
| help="Training batch size.", | |
| ) | |
| # add arguments for evaluation batch size | |
| validator.add_argument( | |
| "batch_size_eval", | |
| type=int, | |
| help="Evaluation batch size, including validation and testing.", | |
| ) | |
| # add arguments for number of workers for data loading | |
| validator.add_argument( | |
| "num_workers", | |
| help="Number of workers for data loading.", | |
| ) | |
| # add arguments for warm up steps | |
| validator.add_argument( | |
| "warmup_steps", | |
| type=int, | |
| help="Number of warmup steps. Required if a warmup schedule is used.", | |
| ) | |
| # add arguments for random seed | |
| validator.add_argument( | |
| "seed", | |
| type=int, | |
| help="Random seed.", | |
| ) | |
| # add arguments for output directory | |
| validator.add_argument( | |
| "output_dir", | |
| type=str, | |
| help="Output directory to save checkpoints and logs.", | |
| ) | |
| # add arguments for whether only use evaluation | |
| validator.add_argument( | |
| "evaluate", | |
| help="Whether to only evaluate the model. If true, training will not be performed.", | |
| ) | |
| # add arguments for splits used for training, e.g. ["train", "val"] | |
| validator.add_argument( | |
| "train_splits", | |
| type=list, | |
| help="Splits to use for training.", | |
| ) | |
| # add arguments for splits used for validation, e.g. ["val"] | |
| validator.add_argument( | |
| "valid_splits", | |
| type=list, | |
| help="Splits to use for validation. If not provided, will skip the validation.", | |
| ) | |
| # add arguments for splits used for testing, e.g. ["test"] | |
| validator.add_argument( | |
| "test_splits", | |
| type=list, | |
| help="Splits to use for testing. If not provided, will skip the testing.", | |
| ) | |
| # add arguments for accumulating gradient for iterations | |
| validator.add_argument( | |
| "accum_grad_iters", | |
| type=int, | |
| help="Number of iterations to accumulate gradient for.", | |
| ) | |
| # ====== distributed training ====== | |
| validator.add_argument( | |
| "device", | |
| type=str, | |
| choices=["cpu", "cuda"], | |
| help="Device to use. Support 'cuda' or 'cpu' as for now.", | |
| ) | |
| validator.add_argument( | |
| "world_size", | |
| type=int, | |
| help="Number of processes participating in the job.", | |
| ) | |
| validator.add_argument("dist_url", type=str) | |
| validator.add_argument("distributed", type=bool) | |
| # add arguments to opt using distributed sampler during evaluation or not | |
| validator.add_argument( | |
| "use_dist_eval_sampler", | |
| type=bool, | |
| help="Whether to use distributed sampler during evaluation or not.", | |
| ) | |
| # ====== task specific ====== | |
| # generation task specific arguments | |
| # add arguments for maximal length of text output | |
| validator.add_argument( | |
| "max_len", | |
| type=int, | |
| help="Maximal length of text output.", | |
| ) | |
| # add arguments for minimal length of text output | |
| validator.add_argument( | |
| "min_len", | |
| type=int, | |
| help="Minimal length of text output.", | |
| ) | |
| # add arguments number of beams | |
| validator.add_argument( | |
| "num_beams", | |
| type=int, | |
| help="Number of beams used for beam search.", | |
| ) | |
| # vqa task specific arguments | |
| # add arguments for number of answer candidates | |
| validator.add_argument( | |
| "num_ans_candidates", | |
| type=int, | |
| help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""", | |
| ) | |
| # add arguments for inference method | |
| validator.add_argument( | |
| "inference_method", | |
| type=str, | |
| choices=["genearte", "rank"], | |
| help="""Inference method to use for question answering. If rank, requires a answer list.""", | |
| ) | |
| # ====== model specific ====== | |
| validator.add_argument( | |
| "k_test", | |
| type=int, | |
| help="Number of top k most similar samples from ITC/VTC selection to be tested.", | |
| ) | |
| return validator | |