Cannot plot a loss curve for the text generation gpt2 model

#39
by workpiece - opened

#The below is my code:

import matplotlib.pyplot as plt
import numpy as np
from transformers import Trainer, TrainingArguments
from transformers import EarlyStoppingCallback

training_args = TrainingArguments(
output_dir="gpt_model",
overwrite_output_dir=True,
learning_rate=7e-5,
weight_decay=0.01,
num_train_epochs=20,
logging_steps=50,
save_total_limit=2,
per_device_train_batch_size=3,
save_steps=10_000,
evaluation_strategy='no'
)

trainer = Trainer(
model=mpm_model,
args=training_args,
data_collator=data_collator,
train_dataset=mpm_dataset['train'],
)

train your model

trainer.train()

Get the training loss

losses = [x['loss'] for x in trainer.state.log_history]

plot the training loss

steps = np.linspace(0, len(losses), len(losses))
plt.plot(steps, losses)
plt.title("Training loss")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.show()


Error generated for the above code

KeyError: 'loss'
KeyError.jpg

Sign up or log in to comment