|
import dataclasses |
|
import pprint |
|
from typing import Any, Dict, Tuple |
|
|
|
import yaml |
|
|
|
__all__ = ["get_config"] |
|
|
|
|
|
@dataclasses.dataclass |
|
class Config: |
|
model: str = "ActionSegmentRefinementNetwork" |
|
n_layers: int = 10 |
|
n_refine_layers: int = 10 |
|
n_stages: int = 4 |
|
n_features: int = 64 |
|
n_stages_asb: int = 4 |
|
n_stages_brb: int = 4 |
|
SFI_layer: Tuple[int, ...] = (1, 2 ,3, 4, 5, 6, 7, 8, 9) |
|
|
|
|
|
ce: bool = True |
|
ce_weight: float = 1.0 |
|
|
|
focal: bool = False |
|
focal_weight: float = 1.0 |
|
|
|
tmse: bool = False |
|
tmse_weight: float = 0.15 |
|
|
|
gstmse: bool = True |
|
gstmse_weight: float = 1.0 |
|
gstmse_index: str = "feature" |
|
|
|
|
|
class_weight: bool = True |
|
|
|
batch_size: int = 1 |
|
|
|
|
|
in_channel: int = 2048 |
|
|
|
num_workers: int = 0 |
|
max_epoch: int = 50 |
|
|
|
optimizer: str = "Adam" |
|
|
|
learning_rate: float = 0.0005 |
|
momentum: float = 0.9 |
|
dampening: float = 0.0 |
|
weight_decay: float = 0.0001 |
|
nesterov: bool = True |
|
|
|
param_search: bool = False |
|
|
|
|
|
iou_thresholds: Tuple[float, ...] = (0.1, 0.25, 0.5) |
|
|
|
|
|
tolerance: int = 5 |
|
boundary_th: float = 0.5 |
|
lambda_b: float = 0.1 |
|
|
|
dataset: str = "MCFS-22" |
|
dataset_dir: str = "./dataset" |
|
csv_dir: str = "./csv" |
|
split: int = 1 |
|
|
|
result_path : str = "./config" |
|
seed : int = 42 |
|
device : int = 0 |
|
refinement_method: str = "refinement_with_boundary" |
|
|
|
def __post_init__(self) -> None: |
|
self._type_check() |
|
|
|
print("-" * 10, "Experiment Configuration", "-" * 10) |
|
pprint.pprint(dataclasses.asdict(self), width=1) |
|
|
|
def _type_check(self) -> None: |
|
"""Reference: |
|
https://qiita.com/obithree/items/1c2b43ca94e4fbc3aa8d |
|
""" |
|
|
|
_dict = dataclasses.asdict(self) |
|
|
|
for field, field_type in self.__annotations__.items(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if hasattr(field_type, "__origin__"): |
|
|
|
element_type = field_type.__args__[0] |
|
|
|
|
|
field_type = field_type.__origin__ |
|
|
|
self._type_check_element(field, _dict[field], element_type) |
|
|
|
|
|
|
|
if type(_dict[field]) is not field_type: |
|
raise TypeError( |
|
f"The type of '{field}' field is supposed to be {field_type}." |
|
) |
|
|
|
def _type_check_element( |
|
self, field: str, vals: Tuple[Any], element_type: type |
|
) -> None: |
|
for val in vals: |
|
if type(val) is not element_type: |
|
raise TypeError( |
|
f"The element of '{field}' field is supposed to be {element_type}." |
|
) |
|
|
|
|
|
def convert_list2tuple(_dict: Dict[str, Any]) -> Dict[str, Any]: |
|
|
|
for key, val in _dict.items(): |
|
if isinstance(val, list): |
|
_dict[key] = tuple(val) |
|
|
|
return _dict |
|
|
|
|
|
def get_config(config_path: str) -> Config: |
|
with open(config_path, "r") as f: |
|
config_dict = yaml.safe_load(f) |
|
|
|
config_dict = convert_list2tuple(config_dict) |
|
config = Config(**config_dict) |
|
return config |
|
|