| import argparse |
| import sys |
| import os |
| import importlib |
| from easydict import EasyDict |
|
|
| class ParamManager: |
| |
| def __init__(self, args): |
|
|
| output_path_param = self.add_output_path_param(args) |
|
|
| method_param = self.get_method_param(args) |
|
|
| self.args = EasyDict( |
| dict( |
| vars(args), |
| **output_path_param, |
| **method_param |
| ) |
| ) |
|
|
| def get_method_param(self, args): |
| |
| if args.config_file_name.endswith('.py'): |
| module_name = '.' + args.config_file_name[:-3] |
| else: |
| module_name = '.' + args.config_file_name |
|
|
| config = importlib.import_module(module_name, 'configs') |
|
|
| method_param = config.Param |
| method_args = method_param(args) |
|
|
| return method_args.hyper_param |
|
|
| def add_output_path_param(self, args): |
| |
| |
| task_output_dir = os.path.join(args.output_dir, args.type) |
| if not os.path.exists(task_output_dir): |
| os.makedirs(task_output_dir) |
| |
| concat_names = [args.method, args.dataset, args.known_cls_ratio, args.labeled_ratio, args.backbone, args.seed] |
| method_output_name = "_".join([str(x) for x in concat_names]) |
|
|
| method_output_dir = os.path.join(task_output_dir, method_output_name) |
| if not os.path.exists(method_output_dir): |
| os.makedirs(method_output_dir) |
|
|
| model_output_dir = os.path.join(method_output_dir, args.model_dir) |
| if not os.path.exists(model_output_dir): |
| os.makedirs(model_output_dir) |
|
|
| output_path_param = { |
| 'method_output_dir': method_output_dir, |
| 'model_output_dir': model_output_dir |
| } |
| |
| return output_path_param |