Spaces:
Paused
Paused
# coding=utf-8 | |
# Implements parameter-efficient training of a reward model based on ChatGLM. | |
# This code is inspired by: | |
# https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py | |
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py | |
from utils import ( | |
prepare_args, | |
prepare_data, | |
load_pretrained, | |
preprocess_data, | |
PairwiseDataCollatorForChatGLM, | |
PairwiseTrainerForChatGLM, | |
plot_loss | |
) | |
def main(): | |
# prepare pretrained model and dataset | |
model_args, data_args, training_args, finetuning_args = prepare_args() | |
dataset = prepare_data(model_args, data_args) | |
model, tokenizer = load_pretrained(model_args, training_args, finetuning_args, training_args.do_train, stage="rwd") | |
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="rwd") | |
data_collator = PairwiseDataCollatorForChatGLM( | |
tokenizer=tokenizer, | |
inference_mode=(not training_args.do_train) | |
) | |
training_args.remove_unused_columns = False # Important for pairwise dataset | |
# Initialize our Trainer | |
trainer = PairwiseTrainerForChatGLM( | |
finetuning_args=finetuning_args, | |
model=model, | |
args=training_args, | |
train_dataset=dataset if training_args.do_train else None, | |
eval_dataset=dataset if training_args.do_eval else None, | |
tokenizer=tokenizer, | |
data_collator=data_collator | |
) | |
# Training | |
if training_args.do_train: | |
train_result = trainer.train() | |
trainer.log_metrics("train", train_result.metrics) | |
trainer.save_metrics("train", train_result.metrics) | |
trainer.save_state() | |
trainer.save_model() | |
if trainer.is_world_process_zero() and finetuning_args.plot_loss: | |
plot_loss(training_args) | |
def _mp_fn(index): | |
# For xla_spawn (TPUs) | |
main() | |
if __name__ == "__main__": | |
main() | |