RNN_Playground / predict.py
Krzysiek111's picture
refactoring part1 - minor perf updates, removed single letter names, moved functions to separate files
517420b
raw
history blame
No virus
2.67 kB
import numpy as np
from sklearn.preprocessing import StandardScaler
from tensorflow import keras
verbose = 0
# TODO: Refactor this module
def predict_series(values, r1_nodes=10, r2_nodes=0, fc1_nodes=0, steps=20, use_lstm=True, seq_length = 15, *args, **kwargs):
# TODO: simplify and optimize creating windows
train = np.array(values)
train_last_value = train[-1]
train = train[1:] - train[:-1]
sc = StandardScaler()
train = sc.fit_transform(train.reshape(-1, 1))
X, Y = [], []
for t in range(len(train) - seq_length):
x = train[t:t + seq_length]
X.append(x)
Y.append(train[t + seq_length])
X = np.array(X).reshape(-1, seq_length, 1)
Y = np.array(Y)
# TODO: Add SimpleRNN
if use_lstm:
rnn_layer = keras.layers.LSTM
else:
rnn_layer = keras.layers.GRU
model = keras.Sequential()
model.add(rnn_layer(r1_nodes, return_sequences=bool(r2_nodes)))
if r2_nodes:
model.add(rnn_layer(r2_nodes))
if fc1_nodes:
model.add(keras.layers.Dense(fc1_nodes, activation='relu'))
model.add(keras.layers.Dense(1))
# TODO: optimize execution time
model.compile(
loss='mse',
optimizer=keras.optimizers.Adamax(lr=0.2))
callbacks = [keras.callbacks.EarlyStopping(patience=150, monitor='loss', restore_best_weights=True)]
r = model.fit(
X, Y,
epochs=500,
callbacks=callbacks,
verbose=verbose,
validation_split=0.0)
predictions = np.array([])
last_x = X[-1]
for _ in range(steps):
p = model.predict(last_x.reshape(1, -1, 1))[0, 0]
predictions = np.append(predictions, p)
last_x = np.roll(last_x, -1)
last_x[-1] = p
predictions = sc.inverse_transform(predictions.reshape(-1, 1))
predictions.reshape(-1)
predictions[0] = train_last_value + predictions[0]
for i in range(1, len(predictions)):
predictions[i] += predictions[i-1]
result = {'result': list(predictions.reshape(-1)), 'epochs': r.epoch[-1] + 1, 'loss': min(r.history['loss']), 'loss_last': r.history['loss'][-1]}
return result
if __name__ == "__main__":
# Code for debugging/testing
from time import time
t1 = time()
# verbose = 2
data = np.sin(np.arange(0.0, 28.0, 0.35)*2)
result = predict_series(data, steps=66, r1_nodes=14, r2_nodes=14, fc1_nodes=20)
print('exec time: {:8.3f}'.format(time()-t1))
print(print(result['epochs'], result['loss']))
import seaborn as sns
sns.lineplot(x=range(30), y=data[-30:], color='r')
sns.lineplot(x=range(30, 30+len(result['result'])), y=result['result'], color='b')