gpt1 / pre_training.py
Alexandru Gherghescu
Save pre training losses at the end
3824b4f unverified
raw
history blame contribute delete
No virus
2.2 kB
import torch
from torch.optim import Adam
from transformers import (
AutoTokenizer,
Trainer,
TrainingArguments,
DataCollatorForLanguageModeling,
get_scheduler,
)
from datasets import load_from_disk
from configuration_gpt1 import GPT1Config
from modeling_gpt1 import GPT1Model, GPT1ForCausalLM
GPT1Config.register_for_auto_class()
GPT1Model.register_for_auto_class('AutoModel')
GPT1ForCausalLM.register_for_auto_class('AutoModelForCausalLM')
# load the already tokenized dataset (see preprocessing.py)
tokenized_datasets = load_from_disk('data')
# shuffle for good measure
tokenized_datasets = tokenized_datasets.shuffle(seed=42)
print(tokenized_datasets)
tokenizer = AutoTokenizer.from_pretrained('.')
config = GPT1Config()
model = GPT1ForCausalLM(config)
print(model)
_total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {_total_params}")
batch_size = 16
epochs = 100
tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
optimizer = Adam(model.parameters(), lr=2.5e-4, weight_decay=0.01)
scheduler = get_scheduler('cosine',
optimizer=optimizer,
num_warmup_steps=2000,
num_training_steps=epochs * len(tokenized_datasets['train']))
args = TrainingArguments(
output_dir='checkpoints',
per_device_train_batch_size=batch_size, # divide by number of GPU's
per_device_eval_batch_size=batch_size, # divide by number of GPU's
evaluation_strategy='epoch',
gradient_accumulation_steps=4,
num_train_epochs=epochs,
save_total_limit=10,
max_grad_norm=1.0,
logging_strategy='steps',
logging_steps=100,
logging_first_step=True,
logging_nan_inf_filter=False,
fp16=False,
)
trainer = Trainer(
model=model,
args=args,
data_collator=data_collator,
train_dataset=tokenized_datasets['train'],
eval_dataset=tokenized_datasets['test'],
tokenizer=tokenizer,
optimizers=(optimizer, scheduler),
)
print("Starting training...")
trainer.train()
torch.save(trainer.state.log_history, 'trainer_history.pt')
trainer.save_model('trained')