# Copyright 2024 the LlamaFactory team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from collections import OrderedDict, defaultdict from enum import Enum from typing import Dict, Optional from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME CHECKPOINT_NAMES = { SAFE_ADAPTER_WEIGHTS_NAME, ADAPTER_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME, } CHOICES = ["A", "B", "C", "D"] DATA_CONFIG = "dataset_info.json" DEFAULT_TEMPLATE = defaultdict(str) FILEEXT2TYPE = { "arrow": "arrow", "csv": "csv", "json": "json", "jsonl": "json", "parquet": "parquet", "txt": "text", } IGNORE_INDEX = -100 IMAGE_PLACEHOLDER = os.environ.get("IMAGE_PLACEHOLDER", "") LAYERNORM_NAMES = {"norm", "ln"} LLAMABOARD_CONFIG = "llamaboard_config.yaml" METHODS = ["full", "freeze", "lora"] MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"} PEFT_METHODS = {"lora"} RUNNING_LOG = "running_log.txt" SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"] SUPPORTED_MODELS = OrderedDict() TRAINER_LOG = "trainer_log.jsonl" TRAINING_ARGS = "training_args.yaml" TRAINING_STAGES = { "Supervised Fine-Tuning": "sft", "Reward Modeling": "rm", "PPO": "ppo", "DPO": "dpo", "KTO": "kto", "Pre-Training": "pt", } STAGES_USE_PAIR_DATA = {"rm", "dpo"} SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = { "cohere", "falcon", "gemma", "gemma2", "llama", "mistral", "phi", "phi3", "qwen2", "starcoder2", } SUPPORTED_CLASS_FOR_S2ATTN = {"llama"} VIDEO_PLACEHOLDER = os.environ.get("VIDEO_PLACEHOLDER", "