File size: 8,168 Bytes
07423df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import dataclasses
import logging
from dataclasses import dataclass, fields
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple

from llm_studio.src import possible_values
from llm_studio.src.nesting import Dependency, Nesting
from llm_studio.src.order import Order
from llm_studio.src.tooltips import tooltips

logger = logging.getLogger(__name__)


def _get_bases_below_parent(cls: type, parent: type, bases=None) -> Set[type]:
    if bases is None:
        bases = set()

    if parent not in cls.__bases__:
        for base in cls.__bases__:
            bases.update(_get_bases_below_parent(base, parent, bases))
    else:
        # don't support multiple inheritance when
        # inherting directly from the parent
        assert len(cls.__bases__) == 1

        bases.add(cls)

    return bases


@dataclass
class DefaultConfig:
    """
    Template for any configuration file
    """

    def __post_init__(self):
        self._possible_values: Dict[str, Any] = {k: None for k in self.__dict__}
        self._visibility = {k: 0 for k in self.__dict__}

        # go up the class hierarchy until we are one below the `DefaultConfig`
        bases = _get_bases_below_parent(self.__class__, DefaultConfig)

        # there must be exactly one unique class up the class hierarchy
        # which inherits directly from the `DefaultConfig`
        assert len(bases) == 1
        base = next(iter(bases))

        # initialize the order to the fields this class has
        self._order = Order([field.name for field in fields(base)])

        # initialize nesting dependencies
        self._nesting = Nesting()

    def _get_possible_values(
        self, field: str, value: Any, type_annotation: type, mode: str, dataset_fn=None
    ) -> Optional[Tuple[Optional[possible_values.Value], Any]]:
        """
        Returns a set of possible values for the field provided, and the current value.

        Args:
            field: the field
            value: the preliminary value of the field.
            type_annotation: Type Annotation of the field.
            mode: current mode, one of {"train", "test", "predict"}.
            dataset_fn: A function returning a tuple (dataset, value). Will be called
                if the possible values depend on the dataset.

        Returns:
            Possible values for the field, the current value.
        """

        poss_values = self._possible_values.get(field, None)

        if isinstance(poss_values, possible_values.DatasetValue):
            if dataset_fn is None:
                raise ValueError(
                    f"{poss_values} needs a dataset to compute possible values!\n"
                    "`dataset_fn` must be provided."
                )

            dataset, value = dataset_fn(field, value)
            poss_values, value = poss_values.get_value(
                dataset=dataset, value=value, type_annotation=type_annotation, mode=mode
            )
        elif isinstance(poss_values, Sequence):
            if all(isinstance(x, (float, int)) for x in poss_values):
                poss_values = possible_values.Number(
                    min=poss_values[0], max=poss_values[1], step=poss_values[2]
                )
            elif all(isinstance(x, str) for x in poss_values):
                poss_values = possible_values.String(tuple(poss_values))
            else:
                raise ValueError(
                    f"Could not interpret {poss_values} as any possible value class."
                )

        return poss_values, value

    def _get_tooltips(self, field: str, predict: bool = False) -> Optional[str]:
        """
        Returns a tooltip for the field provided
        """
        return tooltips.get(f"experiments_{field}", None)

    def _get_visibility(self, field: str) -> Optional[int]:
        """Returns a visibility level for the field provided.
         0 -- visible in the Wave app
        -1 -- not visible in the Wave App
        -2 -- visible in Dataset Import, but not visible in Create Experiment
        """

        return self._visibility.get(field, None)

    def _get_nesting_triggers(self) -> Set[str]:
        """Returns a Set of keys other elements are depending on"""

        return self._nesting.triggers

    def _get_nesting_dependencies(self, key: str) -> List[Dependency] | None:
        """Returns a all dependencies for a given key"""

        if key in self._nesting.dependencies:
            dependencies = self._nesting.dependencies[key]
        else:
            dependencies = None
        return dependencies

    def _get_order(self, warn_if_unset=True) -> List[str]:
        """
        Returns the order in which to show the keys in the config.

        Args:
            warn_if_unset: Whether to log a warning if order is unset for multiple keys.

        Returns:
            A list of the same length and with same elements as `self.__dict__.keys()`.
        """

        keys = self.__dict__.keys()

        ordered_keys = [key for key in self._order if key in keys]
        unordered_keys = list(set(keys) - set(ordered_keys))

        unordered_ui_keys = [
            key
            for key in unordered_keys
            if not (key.startswith("_") or self._get_visibility(key) == -1)
        ]

        # warn if there is more than one key without order.
        # one is not problematic since it will just always be last
        if warn_if_unset and len(unordered_ui_keys) > 1:
            logger.warning(f"No order set for keys: {unordered_ui_keys}.")

        return ordered_keys + unordered_keys

    @classmethod
    def get_annotations(cls):
        """Returns type annotations through all the Parent config classes"""

        d: Dict[str, Any] = {}
        for c in cls.mro()[::-1]:
            try:
                d.update(**c.__annotations__)
            except AttributeError:
                # object, at least, has no __annotations__ attribute.
                pass
        return d

    @classmethod
    def from_dict(cls, d: dict):
        """Creates a config object from a dictionary"""
        d_filtered = {k: v for k, v in d.items() if k in cls.get_annotations()}
        if len(d) != len(d_filtered):
            logger.warning(
                f"Keys {set(d.keys()) - set(d_filtered.keys())} are not in the config."
            )
        return cls(**d_filtered)  # mypy: ignore


@dataclass
class DefaultConfigProblemBase(DefaultConfig):
    """
    Base class for all problem configs.
    Defines the interface for all problem configs.
    """

    experiment_name: str
    output_directory: str
    llm_backbone: str

    dataset: Any
    tokenizer: Any
    architecture: Any
    training: Any
    augmentation: Any
    prediction: Any
    environment: Any
    logging: Any

    @property
    def problem_type(self) -> str:
        """
        Parse problem_type from config filename,
        for example: text_causal_language_modeling_config.py -> causal_language_modeling
        """
        return type(self).__dict__["__module__"].split(".")[-1].replace("_config", "")

    @classmethod
    def from_dict(cls, cfg_dict: dict):
        class_fields = {f.name: f for f in dataclasses.fields(cls)}

        # Prepare arguments for creating a new dataclass instance
        init_args = {}
        for field_name, field_obj in class_fields.items():
            if hasattr(field_obj.type, "from_dict"):
                attr_value = cfg_dict.get(field_name, {})
                init_args[field_name] = field_obj.type.from_dict(attr_value)
            else:
                # Use the value from cfg_dict,
                # or the field's default value if not available in cfg_dict
                init_args[field_name] = cfg_dict.get(field_name, field_obj.default)

        return cls(**init_args)

    def check(self) -> Dict[str, List]:
        """
        Checks for errors (incompatible settings) for the specific problem type.
        Returns:
        A dictionary with two keys:
        - "title": A list of error titles.
        - "message": A list of error messages.
        """
        errors: Dict[str, List] = {"title": [], "message": []}
        return errors