File size: 4,021 Bytes
41e3185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import dataclasses
import pprint
from typing import Any, Dict, Tuple

import yaml

__all__ = ["get_config"]


@dataclasses.dataclass
class Config:
    model: str = "ActionSegmentRefinementNetwork"
    n_layers: int = 10
    n_refine_layers: int = 10
    n_stages: int = 4 
    n_features: int = 64
    n_stages_asb: int = 4
    n_stages_brb: int = 4
    SFI_layer: Tuple[int, ...] = (1, 2 ,3, 4, 5, 6, 7, 8, 9)

    # loss function
    ce: bool = True  # cross entropy
    ce_weight: float = 1.0

    focal: bool = False
    focal_weight: float = 1.0

    tmse: bool = False  # temporal mse
    tmse_weight: float = 0.15

    gstmse: bool = True  # gaussian similarity loss
    gstmse_weight: float = 1.0
    gstmse_index: str = "feature"  # similarity index

    # if you use class weight to calculate cross entropy or not
    class_weight: bool = True

    batch_size: int = 1

    # the number of input feature channels
    in_channel: int = 2048

    num_workers: int = 0
    max_epoch: int = 50

    optimizer: str = "Adam"

    learning_rate: float = 0.0005
    momentum: float = 0.9  # momentum of SGD
    dampening: float = 0.0  # dampening for momentum of SGD
    weight_decay: float = 0.0001  # weight decay
    nesterov: bool = True  # enables Nesterov momentum

    param_search: bool = False

    # thresholds for calcualting F1 Score
    iou_thresholds: Tuple[float, ...] = (0.1, 0.25, 0.5)

    # boundary regression
    tolerance: int = 5
    boundary_th: float = 0.5
    lambda_b: float = 0.1

    dataset: str = "MCFS-22"
    dataset_dir: str = "./dataset"
    csv_dir: str = "./csv"
    split: int = 1
    
    result_path : str = "./config"
    seed : int = 42
    device : int = 0
    refinement_method: str = "refinement_with_boundary"

    def __post_init__(self) -> None:
        self._type_check()

        print("-" * 10, "Experiment Configuration", "-" * 10)
        pprint.pprint(dataclasses.asdict(self), width=1)

    def _type_check(self) -> None:
        """Reference:
        https://qiita.com/obithree/items/1c2b43ca94e4fbc3aa8d
        """

        _dict = dataclasses.asdict(self)

        for field, field_type in self.__annotations__.items():
            # if you use type annotation class provided by `typing`,
            # you should convert it to the type class used in python.
            # e.g.) Tuple[int] -> tuple
            # https://stackoverflow.com/questions/51171908/extracting-data-from-typing-types

            # check the instance is Tuple or not.
            # https://github.com/zalando/connexion/issues/739
            if hasattr(field_type, "__origin__"):
                # e.g.) Tuple[int].__args__[0] -> `int`
                element_type = field_type.__args__[0]

                # e.g.) Tuple[int].__origin__ -> `tuple`
                field_type = field_type.__origin__

                self._type_check_element(field, _dict[field], element_type)

            # bool is the subclass of int,
            # so need to use `type() is` instead of `isinstance`
            if type(_dict[field]) is not field_type:
                raise TypeError(
                    f"The type of '{field}' field is supposed to be {field_type}."
                )

    def _type_check_element(
        self, field: str, vals: Tuple[Any], element_type: type
    ) -> None:
        for val in vals:
            if type(val) is not element_type:
                raise TypeError(
                    f"The element of '{field}' field is supposed to be {element_type}."
                )


def convert_list2tuple(_dict: Dict[str, Any]) -> Dict[str, Any]:
    # cannot use list in dataclass because mutable defaults are not allowed.
    for key, val in _dict.items():
        if isinstance(val, list):
            _dict[key] = tuple(val)

    return _dict


def get_config(config_path: str) -> Config:
    with open(config_path, "r") as f:
        config_dict = yaml.safe_load(f)

    config_dict = convert_list2tuple(config_dict)
    config = Config(**config_dict)
    return config