Spaces:
Paused
Paused
# coding=utf-8 | |
# Implements several parameter-efficient supervised fine-tuning method for ChatGLM. | |
# This code is inspired by https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py | |
from utils import ( | |
load_pretrained, | |
prepare_args, | |
prepare_data, | |
preprocess_data, | |
plot_loss, | |
Seq2SeqDataCollatorForChatGLM, | |
ComputeMetrics, | |
Seq2SeqTrainerForChatGLM | |
) | |
def main(): | |
# Prepare pretrained model and dataset | |
model_args, data_args, training_args, finetuning_args = prepare_args() | |
dataset = prepare_data(model_args, data_args) | |
model, tokenizer = load_pretrained(model_args, training_args, finetuning_args, training_args.do_train, stage="sft") | |
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft") | |
data_collator = Seq2SeqDataCollatorForChatGLM( | |
tokenizer=tokenizer, | |
model=model, | |
ignore_pad_token_for_loss=data_args.ignore_pad_token_for_loss, | |
inference_mode=(not training_args.do_train) | |
) | |
# Override the decoding parameters of Seq2SeqTrainer | |
training_args.generation_max_length = training_args.generation_max_length if \ | |
training_args.generation_max_length is not None else data_args.max_target_length | |
training_args.generation_num_beams = data_args.num_beams if \ | |
data_args.num_beams is not None else training_args.generation_num_beams | |
# Initialize our Trainer | |
trainer = Seq2SeqTrainerForChatGLM( | |
finetuning_args=finetuning_args, | |
model=model, | |
args=training_args, | |
train_dataset=dataset if training_args.do_train else None, | |
eval_dataset=dataset if training_args.do_eval else None, | |
tokenizer=tokenizer, | |
data_collator=data_collator, | |
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None | |
) | |
# Keyword arguments for `model.generate` | |
gen_kwargs = { | |
"do_sample": True, | |
"top_p": 0.7, | |
"max_length": 768, | |
"temperature": 0.95 | |
} | |
# Training | |
if training_args.do_train: | |
train_result = trainer.train() | |
trainer.log_metrics("train", train_result.metrics) | |
trainer.save_metrics("train", train_result.metrics) | |
trainer.save_state() | |
trainer.save_model() | |
if trainer.is_world_process_zero() and finetuning_args.plot_loss: | |
plot_loss(training_args) | |
# Evaluation | |
if training_args.do_eval: | |
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs) | |
trainer.log_metrics("eval", metrics) | |
trainer.save_metrics("eval", metrics) | |
# Predict | |
if training_args.do_predict: | |
predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs) | |
trainer.log_metrics("predict", predict_results.metrics) | |
trainer.save_metrics("predict", predict_results.metrics) | |
trainer.save_predictions(predict_results, tokenizer) | |
def _mp_fn(index): | |
# For xla_spawn (TPUs) | |
main() | |
if __name__ == "__main__": | |
main() | |