Merge pull request #356 from tmm1/load_model-args
Browse files- scripts/finetune.py +1 -8
- src/axolotl/utils/models.py +7 -4
scripts/finetune.py
CHANGED
|
@@ -255,14 +255,7 @@ def train(
|
|
| 255 |
|
| 256 |
# Load the model and tokenizer
|
| 257 |
LOG.info("loading model and peft_config...")
|
| 258 |
-
model, peft_config = load_model(
|
| 259 |
-
cfg.base_model,
|
| 260 |
-
cfg.base_model_config,
|
| 261 |
-
cfg.model_type,
|
| 262 |
-
tokenizer,
|
| 263 |
-
cfg,
|
| 264 |
-
adapter=cfg.adapter,
|
| 265 |
-
)
|
| 266 |
|
| 267 |
if "merge_lora" in kwargs and cfg.adapter is not None:
|
| 268 |
LOG.info("running merge of LoRA with base model")
|
|
|
|
| 255 |
|
| 256 |
# Load the model and tokenizer
|
| 257 |
LOG.info("loading model and peft_config...")
|
| 258 |
+
model, peft_config = load_model(cfg, tokenizer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
|
| 260 |
if "merge_lora" in kwargs and cfg.adapter is not None:
|
| 261 |
LOG.info("running merge of LoRA with base model")
|
src/axolotl/utils/models.py
CHANGED
|
@@ -78,12 +78,15 @@ def load_tokenizer(
|
|
| 78 |
|
| 79 |
|
| 80 |
def load_model(
|
| 81 |
-
|
| 82 |
-
):
|
| 83 |
-
# type: (str, str, str, PreTrainedTokenizerBase, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
| 84 |
"""
|
| 85 |
-
Load a model
|
| 86 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
# TODO refactor as a kwarg
|
| 89 |
load_in_8bit = cfg.load_in_8bit
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
def load_model(
|
| 81 |
+
cfg, tokenizer
|
| 82 |
+
): # type: (DictDefault, PreTrainedTokenizerBase) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
|
|
|
| 83 |
"""
|
| 84 |
+
Load a model for a given configuration and tokenizer.
|
| 85 |
"""
|
| 86 |
+
base_model = cfg.base_model
|
| 87 |
+
base_model_config = cfg.base_model_config
|
| 88 |
+
model_type = cfg.model_type
|
| 89 |
+
adapter = cfg.adapter
|
| 90 |
|
| 91 |
# TODO refactor as a kwarg
|
| 92 |
load_in_8bit = cfg.load_in_8bit
|