File size: 2,364 Bytes
88aba71 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import os
import commentjson
import sys
from .log import logger
from .tools import dict_to_argv
def load_config(arg_type: str):
config_path = os.environ.get("WECLONE_CONFIG_PATH", "./settings.jsonc")
logger.info(f"Loading configuration from: {config_path}") # Add logging to see which file is loaded
try:
with open(config_path, "r", encoding="utf-8") as f:
s_config: dict = commentjson.load(f)
except FileNotFoundError:
logger.error(f"Configuration file not found: {config_path}")
sys.exit(1) # Exit if config file is not found
except Exception as e:
logger.error(f"Error loading configuration file {config_path}: {e}")
sys.exit(1)
if arg_type == "cli_args":
config = s_config["cli_args"]
elif arg_type == "web_demo" or arg_type == "api_service":
# infer_args和common_args求并集
config = {**s_config["infer_args"], **s_config["common_args"]}
elif arg_type == "train_pt":
config = {**s_config["train_pt_args"], **s_config["common_args"]}
elif arg_type == "train_sft":
config = {**s_config["train_sft_args"], **s_config["common_args"]}
if s_config["make_dataset_args"]["prompt_with_history"]:
dataset_info_path = os.path.join(config["dataset_dir"], "dataset_info.json")
dataset_info = commentjson.load(open(dataset_info_path, "r", encoding="utf-8"))[config["dataset"]]
if dataset_info["columns"].get("history") is None:
logger.warning(f"{config['dataset']}数据集不包history字段,尝试使用wechat-sft-with-history数据集")
config["dataset"] = "wechat-sft-with-history"
elif arg_type == "make_dataset":
config = {**s_config["make_dataset_args"], **s_config["common_args"]}
config["dataset"] = s_config["train_sft_args"]["dataset"]
config["dataset_dir"] = s_config["train_sft_args"]["dataset_dir"]
config["cutoff_len"] = s_config["train_sft_args"]["cutoff_len"]
else:
raise ValueError("暂不支持的参数类型")
if "train" in arg_type:
config["output_dir"] = config["adapter_name_or_path"]
config.pop("adapter_name_or_path")
config["do_train"] = True
sys.argv += dict_to_argv(config)
return config
|