File size: 8,168 Bytes
07423df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 |
import dataclasses
import logging
from dataclasses import dataclass, fields
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple
from llm_studio.src import possible_values
from llm_studio.src.nesting import Dependency, Nesting
from llm_studio.src.order import Order
from llm_studio.src.tooltips import tooltips
logger = logging.getLogger(__name__)
def _get_bases_below_parent(cls: type, parent: type, bases=None) -> Set[type]:
if bases is None:
bases = set()
if parent not in cls.__bases__:
for base in cls.__bases__:
bases.update(_get_bases_below_parent(base, parent, bases))
else:
# don't support multiple inheritance when
# inherting directly from the parent
assert len(cls.__bases__) == 1
bases.add(cls)
return bases
@dataclass
class DefaultConfig:
"""
Template for any configuration file
"""
def __post_init__(self):
self._possible_values: Dict[str, Any] = {k: None for k in self.__dict__}
self._visibility = {k: 0 for k in self.__dict__}
# go up the class hierarchy until we are one below the `DefaultConfig`
bases = _get_bases_below_parent(self.__class__, DefaultConfig)
# there must be exactly one unique class up the class hierarchy
# which inherits directly from the `DefaultConfig`
assert len(bases) == 1
base = next(iter(bases))
# initialize the order to the fields this class has
self._order = Order([field.name for field in fields(base)])
# initialize nesting dependencies
self._nesting = Nesting()
def _get_possible_values(
self, field: str, value: Any, type_annotation: type, mode: str, dataset_fn=None
) -> Optional[Tuple[Optional[possible_values.Value], Any]]:
"""
Returns a set of possible values for the field provided, and the current value.
Args:
field: the field
value: the preliminary value of the field.
type_annotation: Type Annotation of the field.
mode: current mode, one of {"train", "test", "predict"}.
dataset_fn: A function returning a tuple (dataset, value). Will be called
if the possible values depend on the dataset.
Returns:
Possible values for the field, the current value.
"""
poss_values = self._possible_values.get(field, None)
if isinstance(poss_values, possible_values.DatasetValue):
if dataset_fn is None:
raise ValueError(
f"{poss_values} needs a dataset to compute possible values!\n"
"`dataset_fn` must be provided."
)
dataset, value = dataset_fn(field, value)
poss_values, value = poss_values.get_value(
dataset=dataset, value=value, type_annotation=type_annotation, mode=mode
)
elif isinstance(poss_values, Sequence):
if all(isinstance(x, (float, int)) for x in poss_values):
poss_values = possible_values.Number(
min=poss_values[0], max=poss_values[1], step=poss_values[2]
)
elif all(isinstance(x, str) for x in poss_values):
poss_values = possible_values.String(tuple(poss_values))
else:
raise ValueError(
f"Could not interpret {poss_values} as any possible value class."
)
return poss_values, value
def _get_tooltips(self, field: str, predict: bool = False) -> Optional[str]:
"""
Returns a tooltip for the field provided
"""
return tooltips.get(f"experiments_{field}", None)
def _get_visibility(self, field: str) -> Optional[int]:
"""Returns a visibility level for the field provided.
0 -- visible in the Wave app
-1 -- not visible in the Wave App
-2 -- visible in Dataset Import, but not visible in Create Experiment
"""
return self._visibility.get(field, None)
def _get_nesting_triggers(self) -> Set[str]:
"""Returns a Set of keys other elements are depending on"""
return self._nesting.triggers
def _get_nesting_dependencies(self, key: str) -> List[Dependency] | None:
"""Returns a all dependencies for a given key"""
if key in self._nesting.dependencies:
dependencies = self._nesting.dependencies[key]
else:
dependencies = None
return dependencies
def _get_order(self, warn_if_unset=True) -> List[str]:
"""
Returns the order in which to show the keys in the config.
Args:
warn_if_unset: Whether to log a warning if order is unset for multiple keys.
Returns:
A list of the same length and with same elements as `self.__dict__.keys()`.
"""
keys = self.__dict__.keys()
ordered_keys = [key for key in self._order if key in keys]
unordered_keys = list(set(keys) - set(ordered_keys))
unordered_ui_keys = [
key
for key in unordered_keys
if not (key.startswith("_") or self._get_visibility(key) == -1)
]
# warn if there is more than one key without order.
# one is not problematic since it will just always be last
if warn_if_unset and len(unordered_ui_keys) > 1:
logger.warning(f"No order set for keys: {unordered_ui_keys}.")
return ordered_keys + unordered_keys
@classmethod
def get_annotations(cls):
"""Returns type annotations through all the Parent config classes"""
d: Dict[str, Any] = {}
for c in cls.mro()[::-1]:
try:
d.update(**c.__annotations__)
except AttributeError:
# object, at least, has no __annotations__ attribute.
pass
return d
@classmethod
def from_dict(cls, d: dict):
"""Creates a config object from a dictionary"""
d_filtered = {k: v for k, v in d.items() if k in cls.get_annotations()}
if len(d) != len(d_filtered):
logger.warning(
f"Keys {set(d.keys()) - set(d_filtered.keys())} are not in the config."
)
return cls(**d_filtered) # mypy: ignore
@dataclass
class DefaultConfigProblemBase(DefaultConfig):
"""
Base class for all problem configs.
Defines the interface for all problem configs.
"""
experiment_name: str
output_directory: str
llm_backbone: str
dataset: Any
tokenizer: Any
architecture: Any
training: Any
augmentation: Any
prediction: Any
environment: Any
logging: Any
@property
def problem_type(self) -> str:
"""
Parse problem_type from config filename,
for example: text_causal_language_modeling_config.py -> causal_language_modeling
"""
return type(self).__dict__["__module__"].split(".")[-1].replace("_config", "")
@classmethod
def from_dict(cls, cfg_dict: dict):
class_fields = {f.name: f for f in dataclasses.fields(cls)}
# Prepare arguments for creating a new dataclass instance
init_args = {}
for field_name, field_obj in class_fields.items():
if hasattr(field_obj.type, "from_dict"):
attr_value = cfg_dict.get(field_name, {})
init_args[field_name] = field_obj.type.from_dict(attr_value)
else:
# Use the value from cfg_dict,
# or the field's default value if not available in cfg_dict
init_args[field_name] = cfg_dict.get(field_name, field_obj.default)
return cls(**init_args)
def check(self) -> Dict[str, List]:
"""
Checks for errors (incompatible settings) for the specific problem type.
Returns:
A dictionary with two keys:
- "title": A list of error titles.
- "message": A list of error messages.
"""
errors: Dict[str, List] = {"title": [], "message": []}
return errors
|