File size: 1,560 Bytes
b21e4a2
 
 
 
 
 
 
 
00568c1
b21e4a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48434be
 
b21e4a2
 
 
 
 
 
e50ab07
 
 
 
 
 
 
 
 
 
 
 
b21e4a2
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""
shared module for cli specific things
"""

import logging
from dataclasses import dataclass, field
from typing import Optional

import axolotl.monkeypatch.data.batch_dataset_fetcher  # pylint: disable=unused-import  # noqa: F401
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer

configure_logging()
LOG = logging.getLogger("axolotl.common.cli")


@dataclass
class TrainerCliArgs:
    """
    dataclass representing the various non-training arguments
    """

    debug: bool = field(default=False)
    debug_text_only: bool = field(default=False)
    debug_num_examples: int = field(default=5)
    inference: bool = field(default=False)
    merge_lora: bool = field(default=False)
    prompter: Optional[str] = field(default=None)
    shard: bool = field(default=False)


@dataclass
class PreprocessCliArgs:
    """
    dataclass representing arguments for preprocessing only
    """

    debug: bool = field(default=False)
    debug_text_only: bool = field(default=False)
    debug_num_examples: int = field(default=1)
    prompter: Optional[str] = field(default=None)


def load_model_and_tokenizer(
    *,
    cfg: DictDefault,
    cli_args: TrainerCliArgs,
):
    LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
    tokenizer = load_tokenizer(cfg)
    LOG.info("loading model and (optionally) peft_config...")
    model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)

    return model, tokenizer