File size: 743 Bytes
e8ac98a ea238c4 e8ac98a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from warnings import filterwarnings
filterwarnings('ignore')
""" Trainer """
def train(dict configuration, X_train, y_train, X_test, y_test):
cdef object early_stopping = EarlyStopping(
monitor = 'val_loss',
patience = 5,
mode = 'min'
)
cdef object model_checkpoint = ModelCheckpoint(
filepath = configuration['model_file'],
save_best_only = True,
monitor = 'val_loss',
mode = 'min'
)
cdef object history = configuration['model'].fit(
X_train, y_train,
epochs = configuration['epochs'],
batch_size = configuration['batch_size'],
validation_data = (X_test, y_test),
callbacks = [ early_stopping, model_checkpoint ]
)
return history
|