winglian commited on
Commit
6b3b271
1 Parent(s): 3a5a2d2

fix for protected model_ namespace w pydantic (#1345)

Browse files
README.md CHANGED
@@ -546,7 +546,7 @@ base_model_ignore_patterns:
546
  # You can set that here, or leave this empty to default to base_model
547
  base_model_config: ./llama-7b-hf
548
  # You can specify to choose a specific model revision from huggingface hub
549
- model_revision:
550
  # Optional tokenizer configuration path in case you want to use a different tokenizer
551
  # than the one defined in the base model
552
  tokenizer_config:
@@ -573,7 +573,7 @@ is_qwen_derived_model:
573
  is_mistral_derived_model:
574
 
575
  # optional overrides to the base model configuration
576
- model_config_overrides:
577
  # RoPE Scaling https://github.com/huggingface/transformers/pull/24653
578
  rope_scaling:
579
  type: # linear | dynamic
 
546
  # You can set that here, or leave this empty to default to base_model
547
  base_model_config: ./llama-7b-hf
548
  # You can specify to choose a specific model revision from huggingface hub
549
+ revision_of_model:
550
  # Optional tokenizer configuration path in case you want to use a different tokenizer
551
  # than the one defined in the base model
552
  tokenizer_config:
 
573
  is_mistral_derived_model:
574
 
575
  # optional overrides to the base model configuration
576
+ overrides_of_model_config:
577
  # RoPE Scaling https://github.com/huggingface/transformers/pull/24653
578
  rope_scaling:
579
  type: # linear | dynamic
src/axolotl/utils/config/__init__.py CHANGED
@@ -124,7 +124,7 @@ def normalize_config(cfg):
124
  (hasattr(model_config, "model_type") and model_config.model_type == "llama")
125
  or cfg.is_llama_derived_model
126
  or "llama" in cfg.base_model.lower()
127
- or (cfg.model_type and "llama" in cfg.model_type.lower())
128
  )
129
 
130
  # figure out if the model is falcon
@@ -140,7 +140,7 @@ def normalize_config(cfg):
140
  )
141
  or cfg.is_falcon_derived_model
142
  or "falcon" in cfg.base_model.lower()
143
- or (cfg.model_type and "rwforcausallm" in cfg.model_type.lower())
144
  )
145
 
146
  cfg.is_mistral_derived_model = (
@@ -153,7 +153,7 @@ def normalize_config(cfg):
153
  )
154
  or cfg.is_mistral_derived_model
155
  or "mistral" in cfg.base_model.lower().split("/")[-1]
156
- or (cfg.model_type and "mistral" in cfg.model_type.lower())
157
  )
158
 
159
  cfg.is_qwen_derived_model = (
@@ -379,11 +379,11 @@ def legacy_validate_config(cfg):
379
  "hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
380
  )
381
 
382
- if cfg.gptq and cfg.model_revision:
383
  raise ValueError(
384
- "model_revision is not supported for GPTQ models. "
385
  + "Please download the model from HuggingFace Hub manually for correct branch, "
386
- + "point to its path, and remove model_revision from the config."
387
  )
388
 
389
  # if cfg.sample_packing and cfg.sdp_attention:
 
124
  (hasattr(model_config, "model_type") and model_config.model_type == "llama")
125
  or cfg.is_llama_derived_model
126
  or "llama" in cfg.base_model.lower()
127
+ or (cfg.type_of_model and "llama" in cfg.type_of_model.lower())
128
  )
129
 
130
  # figure out if the model is falcon
 
140
  )
141
  or cfg.is_falcon_derived_model
142
  or "falcon" in cfg.base_model.lower()
143
+ or (cfg.type_of_model and "rwforcausallm" in cfg.type_of_model.lower())
144
  )
145
 
146
  cfg.is_mistral_derived_model = (
 
153
  )
154
  or cfg.is_mistral_derived_model
155
  or "mistral" in cfg.base_model.lower().split("/")[-1]
156
+ or (cfg.type_of_model and "mistral" in cfg.type_of_model.lower())
157
  )
158
 
159
  cfg.is_qwen_derived_model = (
 
379
  "hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
380
  )
381
 
382
+ if cfg.gptq and cfg.revision_of_model:
383
  raise ValueError(
384
+ "revision_of_model is not supported for GPTQ models. "
385
  + "Please download the model from HuggingFace Hub manually for correct branch, "
386
+ + "point to its path, and remove revision_of_model from the config."
387
  )
388
 
389
  # if cfg.sample_packing and cfg.sdp_attention:
src/axolotl/utils/config/models/input/v0_4_1/__init__.py CHANGED
@@ -47,6 +47,16 @@ class DeprecatedParameters(BaseModel):
47
  return noisy_embedding_alpha
48
 
49
 
 
 
 
 
 
 
 
 
 
 
50
  class PretrainingDataset(BaseModel):
51
  """pretraining dataset configuration subset"""
52
 
@@ -234,12 +244,8 @@ class ModelInputConfig(BaseModel):
234
  tokenizer_type: Optional[str] = Field(
235
  default=None, metadata={"help": "transformers tokenizer class"}
236
  )
237
- model_type: Optional[str] = Field(default=None)
238
- model_revision: Optional[str] = None
239
  trust_remote_code: Optional[bool] = None
240
 
241
- model_config_overrides: Optional[Dict[str, Any]] = None
242
-
243
  @field_validator("trust_remote_code")
244
  @classmethod
245
  def hint_trust_remote_code(cls, trust_remote_code):
@@ -362,11 +368,17 @@ class AxolotlInputConfig(
362
  HyperparametersConfig,
363
  WandbConfig,
364
  MLFlowConfig,
 
365
  DeprecatedParameters,
366
  BaseModel,
367
  ):
368
  """wrapper of all config options"""
369
 
 
 
 
 
 
370
  strict: Optional[bool] = Field(default=False)
371
  resume_from_checkpoint: Optional[str] = None
372
  auto_resume_from_checkpoints: Optional[bool] = None
@@ -550,11 +562,11 @@ class AxolotlInputConfig(
550
  @model_validator(mode="before")
551
  @classmethod
552
  def check_gptq_w_revision(cls, data):
553
- if data.get("gptq") and data.get("model_revision"):
554
  raise ValueError(
555
- "model_revision is not supported for GPTQ models. "
556
  + "Please download the model from HuggingFace Hub manually for correct branch, "
557
- + "point to its path, and remove model_revision from the config."
558
  )
559
  return data
560
 
 
47
  return noisy_embedding_alpha
48
 
49
 
50
+ class RemappedParameters(BaseModel):
51
+ """parameters that have been remapped to other names"""
52
+
53
+ overrides_of_model_config: Optional[Dict[str, Any]] = Field(
54
+ default=None, alias="model_config"
55
+ )
56
+ type_of_model: Optional[str] = Field(default=None, alias="model_type")
57
+ revision_of_model: Optional[str] = Field(default=None, alias="model_revision")
58
+
59
+
60
  class PretrainingDataset(BaseModel):
61
  """pretraining dataset configuration subset"""
62
 
 
244
  tokenizer_type: Optional[str] = Field(
245
  default=None, metadata={"help": "transformers tokenizer class"}
246
  )
 
 
247
  trust_remote_code: Optional[bool] = None
248
 
 
 
249
  @field_validator("trust_remote_code")
250
  @classmethod
251
  def hint_trust_remote_code(cls, trust_remote_code):
 
368
  HyperparametersConfig,
369
  WandbConfig,
370
  MLFlowConfig,
371
+ RemappedParameters,
372
  DeprecatedParameters,
373
  BaseModel,
374
  ):
375
  """wrapper of all config options"""
376
 
377
+ class Config:
378
+ """Config for alias"""
379
+
380
+ populate_by_name = True
381
+
382
  strict: Optional[bool] = Field(default=False)
383
  resume_from_checkpoint: Optional[str] = None
384
  auto_resume_from_checkpoints: Optional[bool] = None
 
562
  @model_validator(mode="before")
563
  @classmethod
564
  def check_gptq_w_revision(cls, data):
565
+ if data.get("gptq") and data.get("revision_of_model"):
566
  raise ValueError(
567
+ "revision_of_model is not supported for GPTQ models. "
568
  + "Please download the model from HuggingFace Hub manually for correct branch, "
569
+ + "point to its path, and remove revision_of_model from the config."
570
  )
571
  return data
572
 
src/axolotl/utils/models.py CHANGED
@@ -86,8 +86,8 @@ def load_model_config(cfg):
86
  model_config_name = cfg.tokenizer_config
87
  trust_remote_code = cfg.trust_remote_code is True
88
  config_kwargs = {}
89
- if cfg.model_revision:
90
- config_kwargs["revision"] = cfg.model_revision
91
 
92
  try:
93
  model_config = AutoConfig.from_pretrained(
@@ -104,8 +104,8 @@ def load_model_config(cfg):
104
  )
105
  raise err
106
 
107
- if cfg.model_config_overrides:
108
- for key, val in cfg.model_config_overrides.items():
109
  setattr(model_config, key, val)
110
 
111
  check_model_config(cfg, model_config)
@@ -272,7 +272,7 @@ def load_model(
272
  Load a model for a given configuration and tokenizer.
273
  """
274
  base_model = cfg.base_model
275
- model_type = cfg.model_type
276
  model_config = load_model_config(cfg)
277
 
278
  # TODO refactor as a kwarg
@@ -426,8 +426,8 @@ def load_model(
426
  if is_deepspeed_zero3_enabled():
427
  del model_kwargs["device_map"]
428
 
429
- if cfg.model_revision:
430
- model_kwargs["revision"] = cfg.model_revision
431
  if cfg.gptq:
432
  if not hasattr(model_config, "quantization_config"):
433
  LOG.warning("model config does not contain quantization_config information")
 
86
  model_config_name = cfg.tokenizer_config
87
  trust_remote_code = cfg.trust_remote_code is True
88
  config_kwargs = {}
89
+ if cfg.revision_of_model:
90
+ config_kwargs["revision"] = cfg.revision_of_model
91
 
92
  try:
93
  model_config = AutoConfig.from_pretrained(
 
104
  )
105
  raise err
106
 
107
+ if cfg.overrides_of_model_config:
108
+ for key, val in cfg.overrides_of_model_config.items():
109
  setattr(model_config, key, val)
110
 
111
  check_model_config(cfg, model_config)
 
272
  Load a model for a given configuration and tokenizer.
273
  """
274
  base_model = cfg.base_model
275
+ model_type = cfg.type_of_model
276
  model_config = load_model_config(cfg)
277
 
278
  # TODO refactor as a kwarg
 
426
  if is_deepspeed_zero3_enabled():
427
  del model_kwargs["device_map"]
428
 
429
+ if cfg.revision_of_model:
430
+ model_kwargs["revision"] = cfg.revision_of_model
431
  if cfg.gptq:
432
  if not hasattr(model_config, "quantization_config"):
433
  LOG.warning("model config does not contain quantization_config information")
tests/test_validation.py CHANGED
@@ -3,6 +3,7 @@
3
 
4
  import logging
5
  import os
 
6
  from typing import Optional
7
 
8
  import pytest
@@ -14,6 +15,8 @@ from axolotl.utils.dict import DictDefault
14
  from axolotl.utils.models import check_model_config
15
  from axolotl.utils.wandb_ import setup_wandb_env_vars
16
 
 
 
17
 
18
  @pytest.fixture(name="minimal_cfg")
19
  def fixture_cfg():
@@ -190,6 +193,45 @@ class TestValidation(BaseValidation):
190
 
191
  assert new_cfg.learning_rate == 0.00005
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  def test_qlora(self, minimal_cfg):
194
  base_cfg = (
195
  DictDefault(
 
3
 
4
  import logging
5
  import os
6
+ import warnings
7
  from typing import Optional
8
 
9
  import pytest
 
15
  from axolotl.utils.models import check_model_config
16
  from axolotl.utils.wandb_ import setup_wandb_env_vars
17
 
18
+ warnings.filterwarnings("error")
19
+
20
 
21
  @pytest.fixture(name="minimal_cfg")
22
  def fixture_cfg():
 
193
 
194
  assert new_cfg.learning_rate == 0.00005
195
 
196
+ def test_model_config_remap(self, minimal_cfg):
197
+ cfg = (
198
+ DictDefault(
199
+ {
200
+ "model_config": {"model_type": "mistral"},
201
+ }
202
+ )
203
+ | minimal_cfg
204
+ )
205
+
206
+ new_cfg = validate_config(cfg)
207
+ assert new_cfg.overrides_of_model_config["model_type"] == "mistral"
208
+
209
+ def test_model_type_remap(self, minimal_cfg):
210
+ cfg = (
211
+ DictDefault(
212
+ {
213
+ "model_type": "AutoModelForCausalLM",
214
+ }
215
+ )
216
+ | minimal_cfg
217
+ )
218
+
219
+ new_cfg = validate_config(cfg)
220
+ assert new_cfg.type_of_model == "AutoModelForCausalLM"
221
+
222
+ def test_model_revision_remap(self, minimal_cfg):
223
+ cfg = (
224
+ DictDefault(
225
+ {
226
+ "model_revision": "main",
227
+ }
228
+ )
229
+ | minimal_cfg
230
+ )
231
+
232
+ new_cfg = validate_config(cfg)
233
+ assert new_cfg.revision_of_model == "main"
234
+
235
  def test_qlora(self, minimal_cfg):
236
  base_cfg = (
237
  DictDefault(