winglian commited on
Commit
125cccb
1 Parent(s): fd55bc8

Refactor train cfg cli (#499)

Browse files

* wip to cleanup cfg cli options

* fix launcher

* fix cli args

Files changed (2) hide show
  1. scripts/finetune.py +80 -46
  2. src/axolotl/utils/models.py +21 -19
scripts/finetune.py CHANGED
@@ -6,11 +6,13 @@ import os
6
  import random
7
  import signal
8
  import sys
 
9
  from pathlib import Path
10
  from typing import Any, Dict, List, Optional, Union
11
 
12
  import fire
13
  import torch
 
14
  import yaml
15
 
16
  # add src to the pythonpath so we don't need to pip install this
@@ -22,7 +24,7 @@ from axolotl.utils.config import normalize_config, validate_config
22
  from axolotl.utils.data import prepare_dataset
23
  from axolotl.utils.dict import DictDefault
24
  from axolotl.utils.distributed import is_main_process
25
- from axolotl.utils.models import load_model, load_tokenizer
26
  from axolotl.utils.tokenization import check_dataset_labels
27
  from axolotl.utils.trainer import setup_trainer
28
  from axolotl.utils.wandb import setup_wandb_env_vars
@@ -37,6 +39,20 @@ LOG = logging.getLogger("axolotl.scripts")
37
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
38
 
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def print_axolotl_text_art():
41
  ascii_art = """
42
  dP dP dP
@@ -61,6 +77,8 @@ def get_multi_line_input() -> Optional[str]:
61
 
62
 
63
  def do_inference(cfg, model, tokenizer, prompter: Optional[str]):
 
 
64
  default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
65
 
66
  for token, symbol in default_tokens.items():
@@ -158,45 +176,20 @@ def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> b
158
 
159
 
160
  def train(
161
- config: Path = Path("configs/"),
162
- prepare_ds_only: bool = False,
163
- **kwargs,
164
  ):
165
- print_axolotl_text_art()
166
- if Path(config).is_dir():
167
- config = choose_config(config)
168
-
169
- # load the config from the yaml file
170
- with open(config, encoding="utf-8") as file:
171
- cfg: DictDefault = DictDefault(yaml.safe_load(file))
172
- # if there are any options passed in the cli, if it is something that seems valid from the yaml,
173
- # then overwrite the value
174
- cfg_keys = cfg.keys()
175
- for k, _ in kwargs.items():
176
- # if not strict, allow writing to cfg even if it's not in the yml already
177
- if k in cfg_keys or not cfg.strict:
178
- # handle booleans
179
- if isinstance(cfg[k], bool):
180
- cfg[k] = bool(kwargs[k])
181
- else:
182
- cfg[k] = kwargs[k]
183
-
184
- validate_config(cfg)
185
-
186
- normalize_config(cfg)
187
-
188
- setup_wandb_env_vars(cfg)
189
-
190
  # load the tokenizer first
191
  LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
192
  tokenizer = load_tokenizer(cfg)
193
 
194
- if (
195
- check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
196
  ): # don't need to load dataset for these
197
  train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
198
 
199
- if cfg.debug or "debug" in kwargs:
200
  LOG.info("check_dataset_labels...")
201
  check_dataset_labels(
202
  train_dataset.select(
@@ -205,17 +198,17 @@ def train(
205
  tokenizer,
206
  )
207
 
208
- if prepare_ds_only:
209
  LOG.info("Finished preparing dataset. Exiting...")
210
  return
211
 
212
  # Load the model and tokenizer
213
  LOG.info("loading model and (optionally) peft_config...")
214
- model, peft_config = load_model(cfg, tokenizer)
215
 
216
  safe_serialization = cfg.save_safetensors is True
217
 
218
- if "merge_lora" in kwargs and cfg.adapter is not None:
219
  LOG.info("running merge of LoRA with base model")
220
  model = model.merge_and_unload()
221
  model.to(dtype=torch.float16)
@@ -229,18 +222,13 @@ def train(
229
  tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
230
  return
231
 
232
- if cfg.inference:
233
- LOG.info("calling do_inference function")
234
- prompter: Optional[str] = "AlpacaPrompter"
235
- if "prompter" in kwargs:
236
- if kwargs["prompter"] == "None":
237
- prompter = None
238
- else:
239
- prompter = kwargs["prompter"]
240
- do_inference(cfg, model, tokenizer, prompter=prompter)
241
  return
242
 
243
- if "shard" in kwargs:
 
244
  model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
245
  return
246
 
@@ -322,5 +310,51 @@ def train(
322
  model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
323
 
324
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  if __name__ == "__main__":
326
- fire.Fire(train)
 
6
  import random
7
  import signal
8
  import sys
9
+ from dataclasses import dataclass, field
10
  from pathlib import Path
11
  from typing import Any, Dict, List, Optional, Union
12
 
13
  import fire
14
  import torch
15
+ import transformers
16
  import yaml
17
 
18
  # add src to the pythonpath so we don't need to pip install this
 
24
  from axolotl.utils.data import prepare_dataset
25
  from axolotl.utils.dict import DictDefault
26
  from axolotl.utils.distributed import is_main_process
27
+ from axolotl.utils.models import load_model, load_model_config, load_tokenizer
28
  from axolotl.utils.tokenization import check_dataset_labels
29
  from axolotl.utils.trainer import setup_trainer
30
  from axolotl.utils.wandb import setup_wandb_env_vars
 
39
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
40
 
41
 
42
+ @dataclass
43
+ class TrainerCliArgs:
44
+ """
45
+ dataclass representing the various non-training arguments
46
+ """
47
+
48
+ debug: bool = field(default=False)
49
+ inference: bool = field(default=False)
50
+ merge_lora: bool = field(default=False)
51
+ prepare_ds_only: bool = field(default=False)
52
+ prompter: Optional[str] = field(default=None)
53
+ shard: bool = field(default=False)
54
+
55
+
56
  def print_axolotl_text_art():
57
  ascii_art = """
58
  dP dP dP
 
77
 
78
 
79
  def do_inference(cfg, model, tokenizer, prompter: Optional[str]):
80
+ if prompter == "None":
81
+ prompter = None
82
  default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
83
 
84
  for token, symbol in default_tokens.items():
 
176
 
177
 
178
  def train(
179
+ *,
180
+ cfg: DictDefault,
181
+ cli_args: TrainerCliArgs,
182
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  # load the tokenizer first
184
  LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
185
  tokenizer = load_tokenizer(cfg)
186
 
187
+ if not (
188
+ cli_args.shard or cli_args.merge_lora or cli_args.inference
189
  ): # don't need to load dataset for these
190
  train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
191
 
192
+ if cli_args.debug or cfg.debug:
193
  LOG.info("check_dataset_labels...")
194
  check_dataset_labels(
195
  train_dataset.select(
 
198
  tokenizer,
199
  )
200
 
201
+ if cli_args.prepare_ds_only:
202
  LOG.info("Finished preparing dataset. Exiting...")
203
  return
204
 
205
  # Load the model and tokenizer
206
  LOG.info("loading model and (optionally) peft_config...")
207
+ model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
208
 
209
  safe_serialization = cfg.save_safetensors is True
210
 
211
+ if cli_args.merge_lora and cfg.adapter is not None:
212
  LOG.info("running merge of LoRA with base model")
213
  model = model.merge_and_unload()
214
  model.to(dtype=torch.float16)
 
222
  tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
223
  return
224
 
225
+ if cli_args.inference:
226
+ LOG.debug("Running inference on model")
227
+ do_inference(cfg, model, tokenizer, prompter=cli_args.prompter)
 
 
 
 
 
 
228
  return
229
 
230
+ if cli_args.shard:
231
+ LOG.debug("Re-saving model w/ sharding")
232
  model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
233
  return
234
 
 
310
  model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
311
 
312
 
313
+ def load_cfg(config: Path = Path("examples/"), **kwargs):
314
+ if Path(config).is_dir():
315
+ config = choose_config(config)
316
+
317
+ # load the config from the yaml file
318
+ with open(config, encoding="utf-8") as file:
319
+ cfg: DictDefault = DictDefault(yaml.safe_load(file))
320
+ # if there are any options passed in the cli, if it is something that seems valid from the yaml,
321
+ # then overwrite the value
322
+ cfg_keys = cfg.keys()
323
+ for k, _ in kwargs.items():
324
+ # if not strict, allow writing to cfg even if it's not in the yml already
325
+ if k in cfg_keys or not cfg.strict:
326
+ # handle booleans
327
+ if isinstance(cfg[k], bool):
328
+ cfg[k] = bool(kwargs[k])
329
+ else:
330
+ cfg[k] = kwargs[k]
331
+
332
+ model_config = load_model_config(cfg)
333
+
334
+ # figure out if the model is llama
335
+ cfg.is_llama_derived_model = (
336
+ (hasattr(model_config, "model_type") and model_config.model_type == "llama")
337
+ or cfg.is_llama_derived_model
338
+ or "llama" in cfg.base_model
339
+ or (cfg.model_type and "llama" in cfg.model_type.lower())
340
+ )
341
+ validate_config(cfg)
342
+
343
+ normalize_config(cfg)
344
+
345
+ setup_wandb_env_vars(cfg)
346
+ return cfg
347
+
348
+
349
+ def do_train(config: Path = Path("examples/"), **kwargs):
350
+ print_axolotl_text_art()
351
+ parsed_cfg = load_cfg(config, **kwargs)
352
+ parser = transformers.HfArgumentParser((TrainerCliArgs))
353
+ parsed_cli_args, _ = parser.parse_args_into_dataclasses(
354
+ return_remaining_strings=True
355
+ )
356
+ train(cfg=parsed_cfg, cli_args=parsed_cli_args)
357
+
358
+
359
  if __name__ == "__main__":
360
+ fire.Fire(do_train)
src/axolotl/utils/models.py CHANGED
@@ -5,12 +5,13 @@ import logging
5
  import math
6
  import os
7
  from pathlib import Path
8
- from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
9
 
10
  import bitsandbytes as bnb
11
  import torch
12
  import transformers
13
  from optimum.bettertransformer import BetterTransformer
 
14
  from transformers import ( # noqa: F401
15
  AutoConfig,
16
  AutoModelForCausalLM,
@@ -23,13 +24,17 @@ from transformers import ( # noqa: F401
23
 
24
  from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
25
  from axolotl.utils.bench import log_gpu_memory_usage
 
26
 
27
  LOG = logging.getLogger("axolotl")
28
 
29
- if TYPE_CHECKING:
30
- from peft import PeftConfig # noqa: F401
31
 
32
- from axolotl.utils.dict import DictDefault # noqa: F401
 
 
 
 
 
33
 
34
 
35
  def load_tokenizer(cfg):
@@ -86,8 +91,10 @@ def load_tokenizer(cfg):
86
 
87
 
88
  def load_model(
89
- cfg, tokenizer
90
- ): # type: (DictDefault, PreTrainedTokenizerBase) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
 
 
91
  """
92
  Load a model for a given configuration and tokenizer.
93
  """
@@ -97,14 +104,9 @@ def load_model(
97
 
98
  # TODO refactor as a kwarg
99
  load_in_8bit = cfg.load_in_8bit
100
- cfg.is_llama_derived_model = (
101
- "llama" in base_model
102
- or (cfg.model_type and "llama" in cfg.model_type.lower())
103
- or cfg.is_llama_derived_model
104
- )
105
 
106
  if cfg.is_llama_derived_model and cfg.flash_attention:
107
- if cfg.device not in ["mps", "cpu"] and not cfg.inference:
108
  from axolotl.monkeypatch.llama_attn_hijack_flash import (
109
  replace_llama_attn_with_flash_attn,
110
  )
@@ -146,7 +148,7 @@ def load_model(
146
  if (
147
  cfg.is_llama_derived_model
148
  and (cfg.max_packed_sequence_len or cfg.sample_packing)
149
- and not cfg.inference
150
  ):
151
  from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
152
 
@@ -424,15 +426,15 @@ def load_model(
424
  return model, lora_config
425
 
426
 
427
- def load_adapter(model, cfg, adapter):
428
- # type: (PreTrainedModel, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
429
 
430
  if adapter is None:
431
  return model, None
432
  if hasattr(model, "enable_input_require_grads"):
433
  model.enable_input_require_grads()
434
  if adapter in ["lora", "qlora"]:
435
- return load_lora(model, cfg)
436
  if adapter == "llama-adapter":
437
  return load_llama_adapter(model, cfg)
438
 
@@ -478,8 +480,8 @@ def find_all_linear_names(model):
478
  return list(lora_module_names)
479
 
480
 
481
- def load_lora(model, cfg):
482
- # type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
483
 
484
  from peft import LoraConfig, PeftModel, get_peft_model
485
 
@@ -506,7 +508,7 @@ def load_lora(model, cfg):
506
  model = PeftModel.from_pretrained(
507
  model,
508
  cfg.lora_model_dir,
509
- is_trainable=not cfg.inference,
510
  )
511
  else:
512
  model = get_peft_model(model, lora_config)
 
5
  import math
6
  import os
7
  from pathlib import Path
8
+ from typing import Optional, Tuple # noqa: F401
9
 
10
  import bitsandbytes as bnb
11
  import torch
12
  import transformers
13
  from optimum.bettertransformer import BetterTransformer
14
+ from peft import PeftConfig
15
  from transformers import ( # noqa: F401
16
  AutoConfig,
17
  AutoModelForCausalLM,
 
24
 
25
  from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
26
  from axolotl.utils.bench import log_gpu_memory_usage
27
+ from axolotl.utils.dict import DictDefault
28
 
29
  LOG = logging.getLogger("axolotl")
30
 
 
 
31
 
32
+ def load_model_config(cfg):
33
+ model_config_name = cfg.base_model_config or cfg.base_model
34
+ trust_remote_code: bool = False or cfg.trust_remote_code
35
+ return AutoConfig.from_pretrained(
36
+ model_config_name, trust_remote_code=trust_remote_code
37
+ )
38
 
39
 
40
  def load_tokenizer(cfg):
 
91
 
92
 
93
  def load_model(
94
+ cfg: DictDefault,
95
+ tokenizer: PreTrainedTokenizerBase,
96
+ inference: bool = False,
97
+ ) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
98
  """
99
  Load a model for a given configuration and tokenizer.
100
  """
 
104
 
105
  # TODO refactor as a kwarg
106
  load_in_8bit = cfg.load_in_8bit
 
 
 
 
 
107
 
108
  if cfg.is_llama_derived_model and cfg.flash_attention:
109
+ if cfg.device not in ["mps", "cpu"] and not inference:
110
  from axolotl.monkeypatch.llama_attn_hijack_flash import (
111
  replace_llama_attn_with_flash_attn,
112
  )
 
148
  if (
149
  cfg.is_llama_derived_model
150
  and (cfg.max_packed_sequence_len or cfg.sample_packing)
151
+ and not inference
152
  ):
153
  from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
154
 
 
426
  return model, lora_config
427
 
428
 
429
+ def load_adapter(model, cfg, adapter, inference=False):
430
+ # type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
431
 
432
  if adapter is None:
433
  return model, None
434
  if hasattr(model, "enable_input_require_grads"):
435
  model.enable_input_require_grads()
436
  if adapter in ["lora", "qlora"]:
437
+ return load_lora(model, cfg, inference=inference)
438
  if adapter == "llama-adapter":
439
  return load_llama_adapter(model, cfg)
440
 
 
480
  return list(lora_module_names)
481
 
482
 
483
+ def load_lora(model, cfg, inference=False):
484
+ # type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
485
 
486
  from peft import LoraConfig, PeftModel, get_peft_model
487
 
 
508
  model = PeftModel.from_pretrained(
509
  model,
510
  cfg.lora_model_dir,
511
+ is_trainable=(not inference),
512
  )
513
  else:
514
  model = get_peft_model(model, lora_config)