|
import os
|
|
import commentjson
|
|
import sys
|
|
|
|
from .log import logger
|
|
from .tools import dict_to_argv
|
|
|
|
|
|
def load_config(arg_type: str):
|
|
config_path = os.environ.get("WECLONE_CONFIG_PATH", "./settings.jsonc")
|
|
logger.info(f"Loading configuration from: {config_path}")
|
|
try:
|
|
with open(config_path, "r", encoding="utf-8") as f:
|
|
s_config: dict = commentjson.load(f)
|
|
except FileNotFoundError:
|
|
logger.error(f"Configuration file not found: {config_path}")
|
|
sys.exit(1)
|
|
except Exception as e:
|
|
logger.error(f"Error loading configuration file {config_path}: {e}")
|
|
sys.exit(1)
|
|
|
|
if arg_type == "cli_args":
|
|
config = s_config["cli_args"]
|
|
elif arg_type == "web_demo" or arg_type == "api_service":
|
|
|
|
config = {**s_config["infer_args"], **s_config["common_args"]}
|
|
elif arg_type == "train_pt":
|
|
config = {**s_config["train_pt_args"], **s_config["common_args"]}
|
|
elif arg_type == "train_sft":
|
|
config = {**s_config["train_sft_args"], **s_config["common_args"]}
|
|
if s_config["make_dataset_args"]["prompt_with_history"]:
|
|
dataset_info_path = os.path.join(config["dataset_dir"], "dataset_info.json")
|
|
dataset_info = commentjson.load(open(dataset_info_path, "r", encoding="utf-8"))[config["dataset"]]
|
|
if dataset_info["columns"].get("history") is None:
|
|
logger.warning(f"{config['dataset']}数据集不包history字段,尝试使用wechat-sft-with-history数据集")
|
|
config["dataset"] = "wechat-sft-with-history"
|
|
|
|
elif arg_type == "make_dataset":
|
|
config = {**s_config["make_dataset_args"], **s_config["common_args"]}
|
|
config["dataset"] = s_config["train_sft_args"]["dataset"]
|
|
config["dataset_dir"] = s_config["train_sft_args"]["dataset_dir"]
|
|
config["cutoff_len"] = s_config["train_sft_args"]["cutoff_len"]
|
|
else:
|
|
raise ValueError("暂不支持的参数类型")
|
|
|
|
if "train" in arg_type:
|
|
config["output_dir"] = config["adapter_name_or_path"]
|
|
config.pop("adapter_name_or_path")
|
|
config["do_train"] = True
|
|
|
|
sys.argv += dict_to_argv(config)
|
|
|
|
return config
|
|
|