File size: 2,364 Bytes
88aba71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
import commentjson
import sys

from .log import logger
from .tools import dict_to_argv


def load_config(arg_type: str):
    config_path = os.environ.get("WECLONE_CONFIG_PATH", "./settings.jsonc")
    logger.info(f"Loading configuration from: {config_path}")  # Add logging to see which file is loaded
    try:
        with open(config_path, "r", encoding="utf-8") as f:
            s_config: dict = commentjson.load(f)
    except FileNotFoundError:
        logger.error(f"Configuration file not found: {config_path}")
        sys.exit(1)  # Exit if config file is not found
    except Exception as e:
        logger.error(f"Error loading configuration file {config_path}: {e}")
        sys.exit(1)

    if arg_type == "cli_args":
        config = s_config["cli_args"]
    elif arg_type == "web_demo" or arg_type == "api_service":
        # infer_args和common_args求并集
        config = {**s_config["infer_args"], **s_config["common_args"]}
    elif arg_type == "train_pt":
        config = {**s_config["train_pt_args"], **s_config["common_args"]}
    elif arg_type == "train_sft":
        config = {**s_config["train_sft_args"], **s_config["common_args"]}
        if s_config["make_dataset_args"]["prompt_with_history"]:
            dataset_info_path = os.path.join(config["dataset_dir"], "dataset_info.json")
            dataset_info = commentjson.load(open(dataset_info_path, "r", encoding="utf-8"))[config["dataset"]]
            if dataset_info["columns"].get("history") is None:
                logger.warning(f"{config['dataset']}数据集不包history字段,尝试使用wechat-sft-with-history数据集")
                config["dataset"] = "wechat-sft-with-history"

    elif arg_type == "make_dataset":
        config = {**s_config["make_dataset_args"], **s_config["common_args"]}
        config["dataset"] = s_config["train_sft_args"]["dataset"]
        config["dataset_dir"] = s_config["train_sft_args"]["dataset_dir"]
        config["cutoff_len"] = s_config["train_sft_args"]["cutoff_len"]
    else:
        raise ValueError("暂不支持的参数类型")

    if "train" in arg_type:
        config["output_dir"] = config["adapter_name_or_path"]
        config.pop("adapter_name_or_path")
        config["do_train"] = True

    sys.argv += dict_to_argv(config)

    return config