Naphula commited on
Commit
ec5ad0d
·
verified ·
1 Parent(s): c38f701

Upload config.py

Browse files
Files changed (1) hide show
  1. config.py +264 -0
config.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 Arcee AI
2
+ # SPDX-License-Identifier: LGPL-3.0-only
3
+
4
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
5
+
6
+ import yaml
7
+ from pydantic import BaseModel, model_validator
8
+ from typing_extensions import Literal, TypeAlias
9
+
10
+ from mergekit.common import ModelReference
11
+ from mergekit.tokenizer.config import TokenizerConfig
12
+
13
+ ScalarOrGradient: TypeAlias = Union[float, List[float], str, bool] # ScalarOrGradient: TypeAlias = Union[float, List[float]]
14
+
15
+
16
+ class ConditionalParameter(BaseModel):
17
+ value: ScalarOrGradient
18
+ filter: Optional[str] = None
19
+
20
+
21
+ ParameterSetting: TypeAlias = Union[
22
+ ConditionalParameter, List[ConditionalParameter], ScalarOrGradient, str
23
+ ]
24
+
25
+
26
+ def evaluate_setting(
27
+ tensor_name: str, setting: ParameterSetting, t: float = 0
28
+ ) -> Optional[float]:
29
+ if isinstance(setting, (float, int, bool, str)):
30
+ return setting
31
+ elif isinstance(setting, list):
32
+ if all(isinstance(e, (int, float)) for e in setting):
33
+ scaled = t * (len(setting) - 1)
34
+ i0 = int(scaled)
35
+ i1 = min(len(setting) - 1, i0 + 1)
36
+ frac = scaled - i0
37
+
38
+ return (1 - frac) * setting[i0] + frac * setting[i1]
39
+ elif all(isinstance(e, (float, int, bool, str)) for e in setting):
40
+ return setting[int(t * (len(setting) - 1))]
41
+ else:
42
+ for cond in setting:
43
+ if (
44
+ (cond.filter is None)
45
+ or (cond.filter == "*")
46
+ or (tensor_name and cond.filter in tensor_name)
47
+ ):
48
+ res = evaluate_setting(tensor_name, cond.value, t)
49
+ return res
50
+ else:
51
+ raise RuntimeError(f"Unexpected setting value: {setting}")
52
+ return None
53
+
54
+
55
+ class InputSliceDefinition(BaseModel):
56
+ model: ModelReference
57
+ layer_range: Tuple[int, int]
58
+ parameters: Optional[Dict[str, ParameterSetting]] = None
59
+
60
+
61
+ class InputModelDefinition(BaseModel):
62
+ model: ModelReference
63
+ parameters: Optional[Dict[str, ParameterSetting]] = None
64
+
65
+
66
+ class OutputSliceDefinition(BaseModel):
67
+ sources: List[InputSliceDefinition]
68
+ base_model: Optional[ModelReference] = None
69
+ residual_weight: Optional[float] = None
70
+ parameters: Optional[Dict[str, ParameterSetting]] = None
71
+
72
+
73
+ class OutputModuleDefinition(BaseModel):
74
+ slices: Optional[List[OutputSliceDefinition]] = None
75
+ models: Optional[List[InputModelDefinition]] = None
76
+ parameters: Optional[Dict[str, ParameterSetting]] = None
77
+
78
+ @model_validator(mode="after")
79
+ def validate_inputs(self):
80
+ if ((not self.slices) and (not self.models)) or (self.slices and self.models):
81
+ raise RuntimeError("Must specify either output slices or models to merge")
82
+ return self
83
+
84
+
85
+ class MergeConfiguration(BaseModel):
86
+ modules: Optional[Dict[str, OutputModuleDefinition]] = None
87
+ slices: Optional[List[OutputSliceDefinition]] = None
88
+ models: Optional[List[InputModelDefinition]] = None
89
+
90
+ merge_method: str
91
+ base_model: Optional[ModelReference] = None
92
+ dtype: Optional[str] = None
93
+ tokenizer_source: Union[Literal["union"], Literal["base"], ModelReference, None] = (
94
+ None
95
+ )
96
+ tokenizer: Optional[TokenizerConfig] = None
97
+ chat_template: Optional[str] = None
98
+ out_dtype: Optional[str] = None
99
+ parameters: Optional[Dict[str, ParameterSetting]] = None
100
+
101
+ def referenced_models(self) -> List[ModelReference]:
102
+ models = set()
103
+ if self.base_model:
104
+ models.add(self.base_model)
105
+ if self.models:
106
+ for model_in in self.models:
107
+ models.add(model_in.model)
108
+ if self.slices:
109
+ for s in self.slices:
110
+ for src in s.sources:
111
+ models.add(src.model)
112
+ if self.modules:
113
+ for m in self.modules.values():
114
+ if m.models:
115
+ for model_in in m.models:
116
+ models.add(model_in.model)
117
+ if m.slices:
118
+ for s in m.slices:
119
+ for src in s.sources:
120
+ models.add(src.model)
121
+ return list(models)
122
+
123
+ @model_validator(mode="after")
124
+ def validate_inputs(self):
125
+ set_ct = 0
126
+ if self.modules:
127
+ set_ct += 1
128
+ if self.slices:
129
+ set_ct += 1
130
+ if self.models:
131
+ set_ct += 1
132
+
133
+ if set_ct != 1:
134
+ raise RuntimeError(
135
+ "Exactly one of 'models', 'slices', or 'modules' must be present"
136
+ )
137
+ return self
138
+
139
+ @model_validator(mode="after")
140
+ def validate_tokenizer(self):
141
+ if self.tokenizer_source and self.tokenizer:
142
+ raise RuntimeError("Cannot specify both tokenizer_source and tokenizer")
143
+ return self
144
+
145
+ def to_yaml(self) -> str:
146
+ return yaml.dump(
147
+ self.model_dump(exclude_defaults=True, mode="json"),
148
+ Dumper=ConfigYamlDumper,
149
+ ).rstrip()
150
+
151
+
152
+ class ConfigReader(BaseModel):
153
+ config: MergeConfiguration
154
+ t: float
155
+ tensor_name: Optional[str] = None
156
+ slice_out: Optional[OutputSliceDefinition] = None
157
+ module: Optional[OutputModuleDefinition] = None
158
+
159
+ @property
160
+ def base_model(self) -> Optional[ModelReference]:
161
+ if self.slice_out and self.slice_out.base_model:
162
+ res = self.slice_out.base_model
163
+ else:
164
+ res = self.config.base_model
165
+
166
+ return res
167
+
168
+ def for_out_slice(self, slice: OutputSliceDefinition) -> "ConfigReader":
169
+ return ConfigReader(
170
+ config=self.config,
171
+ t=self.t,
172
+ tensor_name=self.tensor_name,
173
+ slice_out=slice,
174
+ module=self.module,
175
+ )
176
+
177
+ def for_tensor(self, tensor_name: str) -> "ConfigReader":
178
+ return ConfigReader(
179
+ config=self.config,
180
+ t=self.t,
181
+ tensor_name=tensor_name,
182
+ slice_out=self.slice_out,
183
+ module=self.module,
184
+ )
185
+
186
+ def with_t(self, t: float) -> "ConfigReader":
187
+ return ConfigReader(
188
+ config=self.config,
189
+ t=t,
190
+ tensor_name=self.tensor_name,
191
+ slice_out=self.slice_out,
192
+ module=self.module,
193
+ )
194
+
195
+ def for_module(self, module: OutputModuleDefinition) -> "ConfigReader":
196
+ return ConfigReader(
197
+ config=self.config,
198
+ t=self.t,
199
+ tensor_name=self.tensor_name,
200
+ slice_out=self.slice_out,
201
+ module=module,
202
+ )
203
+
204
+ def parameter(
205
+ self,
206
+ name: str,
207
+ model: Optional[ModelReference] = None,
208
+ default: Any = None,
209
+ required: bool = False,
210
+ ) -> Any:
211
+ if self.slice_out:
212
+ if model:
213
+ for s in self.slice_out.sources:
214
+ if s.model == model and s.parameters and name in s.parameters:
215
+ value = evaluate_setting(
216
+ self.tensor_name, s.parameters[name], self.t
217
+ )
218
+ if value is not None:
219
+ return value
220
+
221
+ if self.slice_out.parameters and name in self.slice_out.parameters:
222
+ value = evaluate_setting(
223
+ self.tensor_name, self.slice_out.parameters[name], self.t
224
+ )
225
+ if value is not None:
226
+ return value
227
+
228
+ if self.module and self.module.parameters and name in self.module.parameters:
229
+ value = evaluate_setting(
230
+ self.tensor_name,
231
+ self.module.parameters[name],
232
+ self.t,
233
+ )
234
+ if value is not None:
235
+ return value
236
+
237
+ if self.config.parameters and name in self.config.parameters:
238
+ value = evaluate_setting(
239
+ self.tensor_name,
240
+ self.config.parameters[name],
241
+ self.t,
242
+ )
243
+ if value is not None:
244
+ return value
245
+
246
+ if required:
247
+ path_paths = [str(s) for s in [model, self.tensor_name] if s]
248
+ p = ".".join(path_paths)
249
+ suffix = f" for {p}" if p else ""
250
+ raise RuntimeError(f"Missing required parameter {name}{suffix}")
251
+ return default
252
+
253
+
254
+ class ConfigYamlDumper(yaml.Dumper):
255
+ """Custom YAML dumper to format lists of numbers in flow style."""
256
+
257
+ def represent_list(self, data: Iterable[Any]) -> yaml.SequenceNode:
258
+ flow_style = all(isinstance(e, (int, float)) for e in data)
259
+ return self.represent_sequence(
260
+ "tag:yaml.org,2002:seq", data, flow_style=flow_style
261
+ )
262
+
263
+
264
+ ConfigYamlDumper.add_representer(list, ConfigYamlDumper.represent_list)