|
import transformers |
|
from transformers import Trainer |
|
|
|
from llm_finetune.arguments import ( |
|
ModelArguments, |
|
DataArguments, |
|
TrainingArguments, |
|
) |
|
from llm_finetune.dataset import make_supervised_data_module |
|
|
|
|
|
def train(): |
|
parser = transformers.HfArgumentParser( |
|
(ModelArguments, DataArguments, TrainingArguments) |
|
) |
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
|
|
|
model = transformers.AutoModelForCausalLM.from_pretrained( |
|
model_args.model_name_or_path, |
|
cache_dir=training_args.cache_dir, |
|
) |
|
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained( |
|
model_args.model_name_or_path, |
|
cache_dir=training_args.cache_dir, |
|
model_max_length=training_args.model_max_length, |
|
padding_side="right", |
|
use_fast=False, |
|
) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
data_module = make_supervised_data_module( |
|
tokenizer=tokenizer, |
|
data_args=data_args, |
|
) |
|
trainer = Trainer( |
|
model=model, tokenizer=tokenizer, args=training_args, **data_module |
|
) |
|
trainer.train(training_args.checkpoint) |
|
trainer.save_state() |
|
trainer.save_model(output_dir=training_args.output_dir) |
|
|
|
|
|
if __name__ == "__main__": |
|
train() |
|
|