digitalWDF / src /utils /.ipynb_checkpoints /common-checkpoint.py
bigPear's picture
Upload 76 files
7975f51
import os
import sys
import torch
import hashlib
from typing import Literal, Optional, Tuple
import transformers
from transformers import (
AutoConfig,
AutoModel,
AutoTokenizer,
HfArgumentParser,
Seq2SeqTrainingArguments
)
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
import datasets
from datasets import Dataset, concatenate_datasets, load_dataset
from peft import (
PeftModel,
TaskType,
LoraConfig,
get_peft_model
)
from trl import AutoModelForCausalLMWithValueHead
from .config import (
ModelArguments,
DataTrainingArguments,
FinetuningArguments
)
from .other import (
get_logger,
load_trainable_params,
load_valuehead_params,
print_trainable_params,
prepare_model_for_training,
IGNORE_INDEX,
FINETUNING_ARGS_NAME
)
logger = get_logger(__name__)
check_min_version("4.27.4")
require_version("datasets>=2.10.0", "To fix: pip install datasets>=2.10.0")
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1")
def init_adapter(
model: PreTrainedModel,
model_args: ModelArguments,
finetuning_args: FinetuningArguments,
is_trainable: bool
) -> PreTrainedModel:
r"""
Initializes the adapters.
Note that the trainable parameters must be cast to float32.
"""
if finetuning_args.finetuning_type == "none" and is_trainable:
raise ValueError("You cannot use finetuning_type=none while training.")
if finetuning_args.finetuning_type == "full":
logger.info("Fine-tuning method: Full")
model = model.float()
if model_args.checkpoint_dir is not None:
load_trainable_params(model, model_args.checkpoint_dir[0])
if finetuning_args.finetuning_type == "freeze":
logger.info("Fine-tuning method: Freeze")
for name, param in model.named_parameters():
if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers):
param.requires_grad_(False)
else:
param.data = param.data.to(torch.float32)
if model_args.checkpoint_dir is not None:
load_trainable_params(model, model_args.checkpoint_dir[0])
if finetuning_args.finetuning_type == "p_tuning":
logger.info("Fine-tuning method: P-Tuning v2") # nothing to do
if model_args.checkpoint_dir is not None:
load_trainable_params(model, model_args.checkpoint_dir[0])
if finetuning_args.finetuning_type == "lora":
logger.info("Fine-tuning method: LoRA")
lastest_checkpoint = None
if model_args.checkpoint_dir is not None:
if is_trainable and finetuning_args.resume_lora_training: # continually training on the lora weights
checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
else:
checkpoints_to_merge = model_args.checkpoint_dir
for checkpoint in checkpoints_to_merge:
model = PeftModel.from_pretrained(model, checkpoint)
model = model.merge_and_unload()
logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
if lastest_checkpoint is not None: # resume lora training
model = PeftModel.from_pretrained(model, lastest_checkpoint, is_trainable=True)
if lastest_checkpoint is None: # create new lora weights
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=finetuning_args.lora_rank,
lora_alpha=finetuning_args.lora_alpha,
lora_dropout=finetuning_args.lora_dropout,
target_modules=finetuning_args.lora_target
)
model = get_peft_model(model, lora_config)
return model
def load_pretrained(
model_args: ModelArguments,
training_args: Optional[Seq2SeqTrainingArguments] = None,
finetuning_args: Optional[FinetuningArguments] = None,
is_trainable: Optional[bool] = False,
stage: Optional[Literal["sft", "rwd", "ppo"]] = "sft"
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
r"""
Load pretrained model and tokenizer.
"""
if (not is_trainable) and (model_args.checkpoint_dir is None):
logger.warning("Checkpoint is not found at evaluation, load the original model.")
finetuning_args = FinetuningArguments(finetuning_type="none")
if model_args.checkpoint_dir is not None: # load fine-tuned model from checkpoint
for checkpoint_dir in model_args.checkpoint_dir:
if not os.path.isfile(os.path.join(checkpoint_dir, FINETUNING_ARGS_NAME)):
raise ValueError("The fine-tuning arguments are not found in the provided dictionary.")
logger.info("Load fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
finetuning_args = torch.load(os.path.join(model_args.checkpoint_dir[0], FINETUNING_ARGS_NAME))
if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) > 1:
logger.warning("Only LoRA tuning accepts multiple checkpoints.")
assert stage == "sft" or finetuning_args.finetuning_type == "lora", "RM and PPO training can only be performed with LoRA method."
quantization = None
if model_args.quantization_bit is not None:
if is_trainable:
if finetuning_args.finetuning_type == "full":
raise ValueError("Full parameter fine-tuning does not support quantization.")
elif finetuning_args.finetuning_type == "p_tuning":
quantization = "cpm" # use cpm's quantization
else:
quantization = "bnb" # use bnb's quantization
else:
quantization = "cpm"
config_kwargs = {
"trust_remote_code": True,
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
}
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
padding_side="left",
**config_kwargs
)
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
**config_kwargs
)
# P-Tuning v2 configurations.
# We use the built-in p-tuning method of ChatGLM, we cannot use PEFT since the attention masks of ChatGLM are unusual. >_<
if finetuning_args.finetuning_type == "p_tuning":
config.pre_seq_len = finetuning_args.pre_seq_len # enable this will fix other parameters automatically
config.prefix_projection = finetuning_args.prefix_projection
# Quantization configurations for Full, Freeze and LoRA in training (using bitsandbytes library).
if quantization == "bnb":
assert model_args.quantization_bit == 8, "Freeze and LoRA fine-tuning only accept 8-bit quantization."
require_version("bitsandbytes>=0.37.0", "bitsandbytes library is required to use this feature.")
from bitsandbytes.cuda_setup.main import get_compute_capability, get_cuda_lib_handle, is_cublasLt_compatible
cuda = get_cuda_lib_handle()
cc = get_compute_capability(cuda)
assert is_cublasLt_compatible(cc), "The current GPU(s) is incompatible with quantization."
config_kwargs["load_in_8bit"] = True
config_kwargs["device_map"] = "auto" # it should not be specified outside of load_in_8bit
# Load and prepare pretrained models (without valuehead).
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, **config_kwargs)
model = prepare_model_for_training(model) if is_trainable else model
model = init_adapter(model, model_args, finetuning_args, is_trainable)
if not is_trainable:
model.requires_grad_(False) # fix all params
model = model.half() # cast all params to float16
# Quantization with the built-in method for P-Tuning v2 training or evaluation.
# Model parameters should be cast to float16 in quantized P-Tuning setting.
if quantization == "cpm":
assert model_args.quantization_bit in [4, 8], "P-Tuning v2 and inference mode only accept 4-bit or 8-bit quantization."
assert not (is_trainable and training_args.fp16), "FP16 training conflicts with cpm quantization."
model.quantize(model_args.quantization_bit) # in-place method
for name, param in model.named_parameters():
if "prefix_encoder" not in name:
param.data = param.data.to(torch.float16) # convert all params in half precision except prefix_encoder
if quantization is not None:
logger.info("Quantized model to {} bit.".format(model_args.quantization_bit))
if stage == "rwd" or stage == "ppo": # add value head
assert is_trainable, "Reward and PPO stages cannot be performed at evaluation."
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
if stage == "ppo": # load reward model
assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
load_valuehead_params(model, model_args.reward_model)
# Set the parameter _is_int8_training_enabled for the AutoModelForCausalLMWithValueHead model
# To meet the compliance requirements of the transformers library
if quantization == "bnb":
model._is_int8_training_enabled = True
print_trainable_params(model)
return model, tokenizer
def prepare_args() -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]:
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
model_args, data_args, training_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses()
# Setup logging
if training_args.should_log:
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
transformers.utils.logging.set_verbosity_info()
log_level = training_args.get_process_log_level()
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
if int(training_args.do_train) + int(training_args.do_eval) + int(training_args.do_predict) != 1:
raise ValueError("We must perform a single operation among do_train, do_eval and do_predict.")
if model_args.quantization_bit is not None and training_args.do_train == False:
logger.warning("We do not recommend to evaluaute model in 4/8-bit mode.")
if training_args.do_train and (not training_args.fp16):
logger.warning("We recommend enable fp16 mixed precision training for ChatGLM-6B.")
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model.
transformers.set_seed(training_args.seed)
return model_args, data_args, training_args, finetuning_args
def prepare_data(
model_args: ModelArguments,
data_args: DataTrainingArguments
) -> Dataset:
def checksum(file_path, hash):
with open(file_path, "rb") as datafile:
binary_data = datafile.read()
sha1 = hashlib.sha1(binary_data).hexdigest()
if sha1 != hash:
logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))
max_samples = data_args.max_samples
all_datasets = [] # support multiple datasets
for dataset_info in data_args.dataset_list:
logger.info("Loading dataset {}...".format(dataset_info))
if dataset_info.load_from == "hf_hub":
raw_datasets = load_dataset(dataset_info.dataset_name, cache_dir=model_args.cache_dir)
elif dataset_info.load_from == "script":
raw_datasets = load_dataset(
os.path.join(data_args.dataset_dir, dataset_info.dataset_name),
cache_dir=model_args.cache_dir
)
elif dataset_info.load_from == "file":
data_file = os.path.join(data_args.dataset_dir, dataset_info.file_name) # support json, jsonl and csv
extension = dataset_info.file_name.split(".")[-1]
if dataset_info.file_sha1 is not None:
checksum(data_file, dataset_info.file_sha1)
else:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.")
raw_datasets = load_dataset(
extension,
data_files=data_file,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None
)
else:
raise NotImplementedError
dataset = raw_datasets[data_args.split]
if max_samples is not None:
max_samples_temp = min(len(dataset), max_samples)
dataset = dataset.select(range(max_samples_temp))
dummy_data = [None] * len(dataset)
for column, column_name in [
("prompt_column", "prompt"),
("query_column", "query"),
("response_column", "response"),
("history_column", "history")
]: # every dataset will have 4 columns same as each other
if getattr(dataset_info, column) != column_name:
if getattr(dataset_info, column):
dataset = dataset.rename_column(getattr(dataset_info, column), column_name)
else: # None or empty string
dataset = dataset.add_column(column_name, dummy_data)
all_datasets.append(dataset)
if len(data_args.dataset_list) == 1:
all_datasets = all_datasets[0]
else:
all_datasets = concatenate_datasets(all_datasets)
return all_datasets
def preprocess_data(
dataset: Dataset,
tokenizer: PreTrainedTokenizer,
data_args: DataTrainingArguments,
training_args: Seq2SeqTrainingArguments,
stage: Optional[Literal["sft", "rwd", "ppo"]] = "sft"
) -> Dataset:
column_names = list(dataset.column_names)
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
def format_example(examples): # support question with a single answer or multiple answers
for i in range(len(examples["prompt"])):
if examples["prompt"][i] and examples["response"][i]:
query, answer = examples["prompt"][i], examples["response"][i]
if examples["query"][i]:
query += examples["query"][i]
if examples["history"][i]:
prompt = ""
history = examples["history"][i]
for j, (old_query, response) in enumerate(history):
prompt += "[Round {}]\n问:{}\n答:{}\n".format(j, old_query, response)
prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
else:
prompt = query
prompt = prefix + prompt
yield prompt, answer
def preprocess_function_train(examples):
# build inputs with format `X [gMASK] [BOS] Y [EOS]` and labels with format `[IGNORE] ... [IGNORE] [BOS] Y [EOS]`
model_inputs = {"input_ids": [], "labels": []}
for prompt, answer in format_example(examples):
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
target_ids = tokenizer.encode(text=answer, add_special_tokens=False)
if len(source_ids) > data_args.max_source_length - 2: # gmask and bos tokens
source_ids = source_ids[:data_args.max_source_length - 2]
if len(target_ids) > data_args.max_target_length - 1: # eos token
target_ids = target_ids[:data_args.max_target_length - 1]
input_ids = tokenizer.build_inputs_with_special_tokens(source_ids, target_ids)
context_length = input_ids.index(tokenizer.bos_token_id)
labels = [IGNORE_INDEX] * context_length + input_ids[context_length:]
model_inputs["input_ids"].append(input_ids)
model_inputs["labels"].append(labels)
return model_inputs
def preprocess_function_eval(examples):
# build inputs with format `[PAD] ... [PAD] X [gMASK] [BOS]` and labels with format `Y [gMASK] [BOS]`
# left-padding is needed for prediction, use the built-in function of the tokenizer
inputs, targets = [], []
for prompt, answer in format_example(examples):
inputs.append(prompt)
targets.append(answer)
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True)
labels = tokenizer(text_target=targets, max_length=data_args.max_target_length, truncation=True) # no padding
if data_args.ignore_pad_token_for_loss:
labels["input_ids"] = [
[(l_id if l_id != tokenizer.pad_token_id else IGNORE_INDEX) for l_id in label] for label in labels["input_ids"]
]
model_inputs["labels"] = labels["input_ids"]
return model_inputs
def preprocess_function_train_pair(examples):
# build input pairs with format `X [gMASK] [BOS] Y1 [EOS]` and `X [gMASK] [BOS] Y2 [EOS]`
model_inputs = {"accept_ids": [], "reject_ids": []}
for prompt, answer in format_example(examples):
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False)
reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)
if len(source_ids) > data_args.max_source_length - 2: # gmask and bos tokens
source_ids = source_ids[:data_args.max_source_length - 2]
if len(accept_ids) > data_args.max_target_length - 1: # eos token
accept_ids = accept_ids[:data_args.max_target_length - 1]
if len(reject_ids) > data_args.max_target_length - 1: # eos token
reject_ids = reject_ids[:data_args.max_target_length - 1]
accept_ids = tokenizer.build_inputs_with_special_tokens(source_ids[:], accept_ids) # avoid copying error
reject_ids = tokenizer.build_inputs_with_special_tokens(source_ids[:], reject_ids)
model_inputs["accept_ids"].append(accept_ids)
model_inputs["reject_ids"].append(reject_ids)
return model_inputs
def preprocess_function_train_ppo(examples):
# build inputs with format `X [gMASK] [BOS]`
model_inputs = {"input_ids": []}
for prompt, _ in format_example(examples):
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
if len(source_ids) > data_args.max_source_length - 2: # gmask and bos tokens
source_ids = source_ids[:data_args.max_source_length - 2]
input_ids = tokenizer.build_inputs_with_special_tokens(source_ids)
model_inputs["input_ids"].append(input_ids)
return model_inputs
def print_sft_dataset_example(example):
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"])))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(tokenizer.decode(example["labels"])))
def print_pairwise_dataset_example(example):
print("accept_ids:\n{}".format(example["accept_ids"]))
print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"])))
print("reject_ids:\n{}".format(example["reject_ids"]))
print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"])))
def print_ppo_dataset_example(example):
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"])))
if stage == "sft":
preprocess_function = preprocess_function_train if training_args.do_train else preprocess_function_eval
elif stage == "rwd":
preprocess_function = preprocess_function_train_pair
elif stage == "ppo":
preprocess_function = preprocess_function_train_ppo
with training_args.main_process_first(desc="dataset map pre-processing"):
dataset = dataset.map(
preprocess_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on dataset"
)
if stage == "sft":
print_sft_dataset_example(dataset[0])
elif stage == "rwd":
print_pairwise_dataset_example(dataset[0])
elif stage == "ppo":
print_ppo_dataset_example(dataset[0])
return dataset