Spaces:
Paused
Paused
# coding=utf-8 | |
# Implements parameter-efficient ppo training of fine-tuned ChatGLM. | |
# This code is inspired by: | |
# https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py | |
import math | |
from torch.optim import AdamW | |
from transformers.optimization import get_scheduler | |
from trl import PPOConfig | |
from utils import ( | |
prepare_args, | |
prepare_data, | |
load_pretrained, | |
preprocess_data, | |
PPODataCollatorForChatGLM, | |
PPOTrainerForChatGLM, | |
plot_loss | |
) | |
def main(): | |
# prepare pretrained model and dataset | |
model_args, data_args, training_args, finetuning_args = prepare_args() | |
dataset = prepare_data(model_args, data_args) | |
model, tokenizer = load_pretrained(model_args, training_args, finetuning_args, training_args.do_train, stage="ppo") | |
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="ppo") | |
data_collator = PPODataCollatorForChatGLM( | |
tokenizer=tokenizer, | |
min_input_length=data_args.max_source_length, # avoid truncating input sequences | |
max_input_length=data_args.max_source_length, | |
inference_mode=(not training_args.do_train) | |
) | |
ppo_config = PPOConfig( | |
model_name=model_args.model_name_or_path, | |
learning_rate=training_args.learning_rate, | |
mini_batch_size=training_args.per_device_train_batch_size, | |
batch_size=training_args.per_device_train_batch_size, | |
gradient_accumulation_steps=training_args.gradient_accumulation_steps, | |
ppo_epochs=1, | |
max_grad_norm=training_args.max_grad_norm | |
) | |
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=ppo_config.learning_rate) | |
total_train_batch_size = \ | |
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size | |
lr_scheduler = get_scheduler( | |
training_args.lr_scheduler_type, | |
optimizer=optimizer, | |
num_warmup_steps=training_args.warmup_steps, | |
num_training_steps=(training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)) | |
) | |
# Initialize our Trainer | |
ppo_trainer = PPOTrainerForChatGLM( | |
training_args=training_args, | |
finetuning_args=finetuning_args, | |
config=ppo_config, | |
model=model, | |
ref_model=None, | |
tokenizer=tokenizer, | |
dataset=dataset, | |
data_collator=data_collator, | |
optimizer=optimizer, | |
lr_scheduler=lr_scheduler | |
) | |
ppo_trainer.ppo_train(max_target_length=data_args.max_target_length) | |
ppo_trainer.save_state() | |
ppo_trainer.save_model() | |
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss: | |
plot_loss(training_args, keys=["loss", "reward"]) | |
def _mp_fn(index): | |
# For xla_spawn (TPUs) | |
main() | |
if __name__ == "__main__": | |
main() | |