Spaces:
Paused
Paused
import os | |
import sys | |
import torch | |
import hashlib | |
from typing import Literal, Optional, Tuple | |
import transformers | |
from transformers import ( | |
AutoConfig, | |
AutoModel, | |
AutoTokenizer, | |
HfArgumentParser, | |
Seq2SeqTrainingArguments | |
) | |
from transformers.utils import check_min_version | |
from transformers.utils.versions import require_version | |
from transformers.modeling_utils import PreTrainedModel | |
from transformers.tokenization_utils import PreTrainedTokenizer | |
import datasets | |
from datasets import Dataset, concatenate_datasets, load_dataset | |
from peft import ( | |
PeftModel, | |
TaskType, | |
LoraConfig, | |
get_peft_model | |
) | |
from trl import AutoModelForCausalLMWithValueHead | |
from .config import ( | |
ModelArguments, | |
DataTrainingArguments, | |
FinetuningArguments | |
) | |
from .other import ( | |
get_logger, | |
load_trainable_params, | |
load_valuehead_params, | |
print_trainable_params, | |
prepare_model_for_training, | |
IGNORE_INDEX, | |
FINETUNING_ARGS_NAME | |
) | |
logger = get_logger(__name__) | |
check_min_version("4.27.4") | |
require_version("datasets>=2.10.0", "To fix: pip install datasets>=2.10.0") | |
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0") | |
require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1") | |
def init_adapter( | |
model: PreTrainedModel, | |
model_args: ModelArguments, | |
finetuning_args: FinetuningArguments, | |
is_trainable: bool | |
) -> PreTrainedModel: | |
r""" | |
Initializes the adapters. | |
Note that the trainable parameters must be cast to float32. | |
""" | |
if finetuning_args.finetuning_type == "none" and is_trainable: | |
raise ValueError("You cannot use finetuning_type=none while training.") | |
if finetuning_args.finetuning_type == "full": | |
logger.info("Fine-tuning method: Full") | |
model = model.float() | |
if model_args.checkpoint_dir is not None: | |
load_trainable_params(model, model_args.checkpoint_dir[0]) | |
if finetuning_args.finetuning_type == "freeze": | |
logger.info("Fine-tuning method: Freeze") | |
for name, param in model.named_parameters(): | |
if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers): | |
param.requires_grad_(False) | |
else: | |
param.data = param.data.to(torch.float32) | |
if model_args.checkpoint_dir is not None: | |
load_trainable_params(model, model_args.checkpoint_dir[0]) | |
if finetuning_args.finetuning_type == "p_tuning": | |
logger.info("Fine-tuning method: P-Tuning v2") # nothing to do | |
if model_args.checkpoint_dir is not None: | |
load_trainable_params(model, model_args.checkpoint_dir[0]) | |
if finetuning_args.finetuning_type == "lora": | |
logger.info("Fine-tuning method: LoRA") | |
lastest_checkpoint = None | |
if model_args.checkpoint_dir is not None: | |
if is_trainable and finetuning_args.resume_lora_training: # continually training on the lora weights | |
checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] | |
else: | |
checkpoints_to_merge = model_args.checkpoint_dir | |
for checkpoint in checkpoints_to_merge: | |
model = PeftModel.from_pretrained(model, checkpoint) | |
model = model.merge_and_unload() | |
logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge))) | |
if lastest_checkpoint is not None: # resume lora training | |
model = PeftModel.from_pretrained(model, lastest_checkpoint, is_trainable=True) | |
if lastest_checkpoint is None: # create new lora weights | |
lora_config = LoraConfig( | |
task_type=TaskType.CAUSAL_LM, | |
inference_mode=False, | |
r=finetuning_args.lora_rank, | |
lora_alpha=finetuning_args.lora_alpha, | |
lora_dropout=finetuning_args.lora_dropout, | |
target_modules=finetuning_args.lora_target | |
) | |
model = get_peft_model(model, lora_config) | |
return model | |
def load_pretrained( | |
model_args: ModelArguments, | |
training_args: Optional[Seq2SeqTrainingArguments] = None, | |
finetuning_args: Optional[FinetuningArguments] = None, | |
is_trainable: Optional[bool] = False, | |
stage: Optional[Literal["sft", "rwd", "ppo"]] = "sft" | |
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: | |
r""" | |
Load pretrained model and tokenizer. | |
""" | |
if (not is_trainable) and (model_args.checkpoint_dir is None): | |
logger.warning("Checkpoint is not found at evaluation, load the original model.") | |
finetuning_args = FinetuningArguments(finetuning_type="none") | |
if model_args.checkpoint_dir is not None: # load fine-tuned model from checkpoint | |
for checkpoint_dir in model_args.checkpoint_dir: | |
if not os.path.isfile(os.path.join(checkpoint_dir, FINETUNING_ARGS_NAME)): | |
raise ValueError("The fine-tuning arguments are not found in the provided dictionary.") | |
logger.info("Load fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir))) | |
finetuning_args = torch.load(os.path.join(model_args.checkpoint_dir[0], FINETUNING_ARGS_NAME)) | |
if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) > 1: | |
logger.warning("Only LoRA tuning accepts multiple checkpoints.") | |
assert stage == "sft" or finetuning_args.finetuning_type == "lora", "RM and PPO training can only be performed with LoRA method." | |
quantization = None | |
if model_args.quantization_bit is not None: | |
if is_trainable: | |
if finetuning_args.finetuning_type == "full": | |
raise ValueError("Full parameter fine-tuning does not support quantization.") | |
elif finetuning_args.finetuning_type == "p_tuning": | |
quantization = "cpm" # use cpm's quantization | |
else: | |
quantization = "bnb" # use bnb's quantization | |
else: | |
quantization = "cpm" | |
config_kwargs = { | |
"trust_remote_code": True, | |
"cache_dir": model_args.cache_dir, | |
"revision": model_args.model_revision, | |
"use_auth_token": True if model_args.use_auth_token else None, | |
} | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, | |
use_fast=model_args.use_fast_tokenizer, | |
padding_side="left", | |
**config_kwargs | |
) | |
config = AutoConfig.from_pretrained( | |
model_args.config_name if model_args.config_name else model_args.model_name_or_path, | |
**config_kwargs | |
) | |
# P-Tuning v2 configurations. | |
# We use the built-in p-tuning method of ChatGLM, we cannot use PEFT since the attention masks of ChatGLM are unusual. >_< | |
if finetuning_args.finetuning_type == "p_tuning": | |
config.pre_seq_len = finetuning_args.pre_seq_len # enable this will fix other parameters automatically | |
config.prefix_projection = finetuning_args.prefix_projection | |
# Quantization configurations for Full, Freeze and LoRA in training (using bitsandbytes library). | |
if quantization == "bnb": | |
assert model_args.quantization_bit == 8, "Freeze and LoRA fine-tuning only accept 8-bit quantization." | |
require_version("bitsandbytes>=0.37.0", "bitsandbytes library is required to use this feature.") | |
from bitsandbytes.cuda_setup.main import get_compute_capability, get_cuda_lib_handle, is_cublasLt_compatible | |
cuda = get_cuda_lib_handle() | |
cc = get_compute_capability(cuda) | |
assert is_cublasLt_compatible(cc), "The current GPU(s) is incompatible with quantization." | |
config_kwargs["load_in_8bit"] = True | |
config_kwargs["device_map"] = "auto" # it should not be specified outside of load_in_8bit | |
# Load and prepare pretrained models (without valuehead). | |
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, **config_kwargs) | |
model = prepare_model_for_training(model) if is_trainable else model | |
model = init_adapter(model, model_args, finetuning_args, is_trainable) | |
if not is_trainable: | |
model.requires_grad_(False) # fix all params | |
model = model.half() # cast all params to float16 | |
# Quantization with the built-in method for P-Tuning v2 training or evaluation. | |
# Model parameters should be cast to float16 in quantized P-Tuning setting. | |
if quantization == "cpm": | |
assert model_args.quantization_bit in [4, 8], "P-Tuning v2 and inference mode only accept 4-bit or 8-bit quantization." | |
assert not (is_trainable and training_args.fp16), "FP16 training conflicts with cpm quantization." | |
model.quantize(model_args.quantization_bit) # in-place method | |
for name, param in model.named_parameters(): | |
if "prefix_encoder" not in name: | |
param.data = param.data.to(torch.float16) # convert all params in half precision except prefix_encoder | |
if quantization is not None: | |
logger.info("Quantized model to {} bit.".format(model_args.quantization_bit)) | |
if stage == "rwd" or stage == "ppo": # add value head | |
assert is_trainable, "Reward and PPO stages cannot be performed at evaluation." | |
model = AutoModelForCausalLMWithValueHead.from_pretrained(model) | |
if stage == "ppo": # load reward model | |
assert model_args.reward_model is not None, "Reward model is necessary for PPO training." | |
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False) | |
load_valuehead_params(model, model_args.reward_model) | |
# Set the parameter _is_int8_training_enabled for the AutoModelForCausalLMWithValueHead model | |
# To meet the compliance requirements of the transformers library | |
if quantization == "bnb": | |
model._is_int8_training_enabled = True | |
print_trainable_params(model) | |
return model, tokenizer | |
def prepare_args() -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]: | |
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments)) | |
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file. | |
model_args, data_args, training_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) | |
else: | |
model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses() | |
# Setup logging | |
if training_args.should_log: | |
# The default of training_args.log_level is passive, so we set log level at info here to have that default. | |
transformers.utils.logging.set_verbosity_info() | |
log_level = training_args.get_process_log_level() | |
datasets.utils.logging.set_verbosity(log_level) | |
transformers.utils.logging.set_verbosity(log_level) | |
transformers.utils.logging.enable_default_handler() | |
transformers.utils.logging.enable_explicit_format() | |
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints) | |
if int(training_args.do_train) + int(training_args.do_eval) + int(training_args.do_predict) != 1: | |
raise ValueError("We must perform a single operation among do_train, do_eval and do_predict.") | |
if model_args.quantization_bit is not None and training_args.do_train == False: | |
logger.warning("We do not recommend to evaluaute model in 4/8-bit mode.") | |
if training_args.do_train and (not training_args.fp16): | |
logger.warning("We recommend enable fp16 mixed precision training for ChatGLM-6B.") | |
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning | |
# Log on each process the small summary: | |
logger.warning( | |
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n" | |
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" | |
) | |
logger.info(f"Training/evaluation parameters {training_args}") | |
# Set seed before initializing model. | |
transformers.set_seed(training_args.seed) | |
return model_args, data_args, training_args, finetuning_args | |
def prepare_data( | |
model_args: ModelArguments, | |
data_args: DataTrainingArguments | |
) -> Dataset: | |
def checksum(file_path, hash): | |
with open(file_path, "rb") as datafile: | |
binary_data = datafile.read() | |
sha1 = hashlib.sha1(binary_data).hexdigest() | |
if sha1 != hash: | |
logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path)) | |
max_samples = data_args.max_samples | |
all_datasets = [] # support multiple datasets | |
for dataset_info in data_args.dataset_list: | |
logger.info("Loading dataset {}...".format(dataset_info)) | |
if dataset_info.load_from == "hf_hub": | |
raw_datasets = load_dataset(dataset_info.dataset_name, cache_dir=model_args.cache_dir) | |
elif dataset_info.load_from == "script": | |
raw_datasets = load_dataset( | |
os.path.join(data_args.dataset_dir, dataset_info.dataset_name), | |
cache_dir=model_args.cache_dir | |
) | |
elif dataset_info.load_from == "file": | |
data_file = os.path.join(data_args.dataset_dir, dataset_info.file_name) # support json, jsonl and csv | |
extension = dataset_info.file_name.split(".")[-1] | |
if dataset_info.file_sha1 is not None: | |
checksum(data_file, dataset_info.file_sha1) | |
else: | |
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.") | |
raw_datasets = load_dataset( | |
extension, | |
data_files=data_file, | |
cache_dir=model_args.cache_dir, | |
use_auth_token=True if model_args.use_auth_token else None | |
) | |
else: | |
raise NotImplementedError | |
dataset = raw_datasets[data_args.split] | |
if max_samples is not None: | |
max_samples_temp = min(len(dataset), max_samples) | |
dataset = dataset.select(range(max_samples_temp)) | |
dummy_data = [None] * len(dataset) | |
for column, column_name in [ | |
("prompt_column", "prompt"), | |
("query_column", "query"), | |
("response_column", "response"), | |
("history_column", "history") | |
]: # every dataset will have 4 columns same as each other | |
if getattr(dataset_info, column) != column_name: | |
if getattr(dataset_info, column): | |
dataset = dataset.rename_column(getattr(dataset_info, column), column_name) | |
else: # None or empty string | |
dataset = dataset.add_column(column_name, dummy_data) | |
all_datasets.append(dataset) | |
if len(data_args.dataset_list) == 1: | |
all_datasets = all_datasets[0] | |
else: | |
all_datasets = concatenate_datasets(all_datasets) | |
return all_datasets | |
def preprocess_data( | |
dataset: Dataset, | |
tokenizer: PreTrainedTokenizer, | |
data_args: DataTrainingArguments, | |
training_args: Seq2SeqTrainingArguments, | |
stage: Optional[Literal["sft", "rwd", "ppo"]] = "sft" | |
) -> Dataset: | |
column_names = list(dataset.column_names) | |
prefix = data_args.source_prefix if data_args.source_prefix is not None else "" | |
def format_example(examples): # support question with a single answer or multiple answers | |
for i in range(len(examples["prompt"])): | |
if examples["prompt"][i] and examples["response"][i]: | |
query, answer = examples["prompt"][i], examples["response"][i] | |
if examples["query"][i]: | |
query += examples["query"][i] | |
if examples["history"][i]: | |
prompt = "" | |
history = examples["history"][i] | |
for j, (old_query, response) in enumerate(history): | |
prompt += "[Round {}]\n问:{}\n答:{}\n".format(j, old_query, response) | |
prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) | |
else: | |
prompt = query | |
prompt = prefix + prompt | |
yield prompt, answer | |
def preprocess_function_train(examples): | |
# build inputs with format `X [gMASK] [BOS] Y [EOS]` and labels with format `[IGNORE] ... [IGNORE] [BOS] Y [EOS]` | |
model_inputs = {"input_ids": [], "labels": []} | |
for prompt, answer in format_example(examples): | |
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False) | |
target_ids = tokenizer.encode(text=answer, add_special_tokens=False) | |
if len(source_ids) > data_args.max_source_length - 2: # gmask and bos tokens | |
source_ids = source_ids[:data_args.max_source_length - 2] | |
if len(target_ids) > data_args.max_target_length - 1: # eos token | |
target_ids = target_ids[:data_args.max_target_length - 1] | |
input_ids = tokenizer.build_inputs_with_special_tokens(source_ids, target_ids) | |
context_length = input_ids.index(tokenizer.bos_token_id) | |
labels = [IGNORE_INDEX] * context_length + input_ids[context_length:] | |
model_inputs["input_ids"].append(input_ids) | |
model_inputs["labels"].append(labels) | |
return model_inputs | |
def preprocess_function_eval(examples): | |
# build inputs with format `[PAD] ... [PAD] X [gMASK] [BOS]` and labels with format `Y [gMASK] [BOS]` | |
# left-padding is needed for prediction, use the built-in function of the tokenizer | |
inputs, targets = [], [] | |
for prompt, answer in format_example(examples): | |
inputs.append(prompt) | |
targets.append(answer) | |
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True) | |
labels = tokenizer(text_target=targets, max_length=data_args.max_target_length, truncation=True) # no padding | |
if data_args.ignore_pad_token_for_loss: | |
labels["input_ids"] = [ | |
[(l_id if l_id != tokenizer.pad_token_id else IGNORE_INDEX) for l_id in label] for label in labels["input_ids"] | |
] | |
model_inputs["labels"] = labels["input_ids"] | |
return model_inputs | |
def preprocess_function_train_pair(examples): | |
# build input pairs with format `X [gMASK] [BOS] Y1 [EOS]` and `X [gMASK] [BOS] Y2 [EOS]` | |
model_inputs = {"accept_ids": [], "reject_ids": []} | |
for prompt, answer in format_example(examples): | |
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False) | |
accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False) | |
reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False) | |
if len(source_ids) > data_args.max_source_length - 2: # gmask and bos tokens | |
source_ids = source_ids[:data_args.max_source_length - 2] | |
if len(accept_ids) > data_args.max_target_length - 1: # eos token | |
accept_ids = accept_ids[:data_args.max_target_length - 1] | |
if len(reject_ids) > data_args.max_target_length - 1: # eos token | |
reject_ids = reject_ids[:data_args.max_target_length - 1] | |
accept_ids = tokenizer.build_inputs_with_special_tokens(source_ids[:], accept_ids) # avoid copying error | |
reject_ids = tokenizer.build_inputs_with_special_tokens(source_ids[:], reject_ids) | |
model_inputs["accept_ids"].append(accept_ids) | |
model_inputs["reject_ids"].append(reject_ids) | |
return model_inputs | |
def preprocess_function_train_ppo(examples): | |
# build inputs with format `X [gMASK] [BOS]` | |
model_inputs = {"input_ids": []} | |
for prompt, _ in format_example(examples): | |
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False) | |
if len(source_ids) > data_args.max_source_length - 2: # gmask and bos tokens | |
source_ids = source_ids[:data_args.max_source_length - 2] | |
input_ids = tokenizer.build_inputs_with_special_tokens(source_ids) | |
model_inputs["input_ids"].append(input_ids) | |
return model_inputs | |
def print_sft_dataset_example(example): | |
print("input_ids:\n{}".format(example["input_ids"])) | |
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"]))) | |
print("label_ids:\n{}".format(example["labels"])) | |
print("labels:\n{}".format(tokenizer.decode(example["labels"]))) | |
def print_pairwise_dataset_example(example): | |
print("accept_ids:\n{}".format(example["accept_ids"])) | |
print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"]))) | |
print("reject_ids:\n{}".format(example["reject_ids"])) | |
print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"]))) | |
def print_ppo_dataset_example(example): | |
print("input_ids:\n{}".format(example["input_ids"])) | |
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"]))) | |
if stage == "sft": | |
preprocess_function = preprocess_function_train if training_args.do_train else preprocess_function_eval | |
elif stage == "rwd": | |
preprocess_function = preprocess_function_train_pair | |
elif stage == "ppo": | |
preprocess_function = preprocess_function_train_ppo | |
with training_args.main_process_first(desc="dataset map pre-processing"): | |
dataset = dataset.map( | |
preprocess_function, | |
batched=True, | |
num_proc=data_args.preprocessing_num_workers, | |
remove_columns=column_names, | |
load_from_cache_file=not data_args.overwrite_cache, | |
desc="Running tokenizer on dataset" | |
) | |
if stage == "sft": | |
print_sft_dataset_example(dataset[0]) | |
elif stage == "rwd": | |
print_pairwise_dataset_example(dataset[0]) | |
elif stage == "ppo": | |
print_ppo_dataset_example(dataset[0]) | |
return dataset | |