RNN_Playground / app.py
Krzysiek111's picture
removed unnecessary code
89c4568
raw
history blame
No virus
8.75 kB
import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from predict import predict_series
#TODO: Refactor this module
st.set_page_config(page_title='RNN Playground')
pages = {'Intro': 0, 'Implementation details': 1, 'The model': 2}
choice = pages[st.sidebar.radio("Select the chapter: ", tuple(pages.keys()))]
if choice == 0:
st.title("Recurrent Neural Networks playground")
st.subheader("The purpose")
st.write("""\n
The goal of this app is to allow the user to experiment easily with Recurrent Neural Networks. Thanks to that, the app helps to understand: \n
- When to use recurrent neural networks
- Which scenarios are more straightforward for the RNN to predict and which are more difficult
- How the noise interrupts the predictions
- Difference between LSTM and GRU nodes
- Understanding that the increasing number of nodes doesn't always lead to better performance
""")
st.subheader("Typical use case")
st.write("""
1. Create a synthetic dataset, with a wide range of many choices of parameters
2. Create a Recurrent Neural Networks model by selecting a number of nodes in particular layers of the model
3. Automatically train the RNN model and make predictions
4. Compare the predicted values with the actual values
\n""")
st.subheader('The architecture of the model')
st.image('info.jpg', use_column_width=True, caption='Hover the cursor over image to see the enlarge button')
st.write(""" \n Use the radio buttons on the left to navigate between chapters \n \n \n \n""")
elif choice == 1:
st.title(""" \n Implementation details""")
st.subheader("Front-end")
st.write("""\n
The front-end part was made with the use of [the Streamlit library](https://www.streamlit.io/).
The parameters from the sidebar are used to create a dataset. The dataset is visualised using the Seaborn
library and finally sent (with the parameters specifying the number of neurons in particular layers) to
the back-end part through REST API. \n
The front-end is served on Azure as a Web App. """)
st.subheader("Back-end")
st.subheader("[Since the playground is stored on HuggingFace now - the backend is a module of the frontend]")
st.write("""\n
The backed-part is responsible for:
- Retrieving the dataset from the front-end through REST API
- Creating an RNN model using parameters passed from the user
- Training the model
- Predicting the values and returning them to the front-end \n
The most crucial requirements are:
- The neural networks setup has to be able to accurately predict the further shape of a curve for
the widest range of parameters selected by the user.
- Time execution of the back-end part must be short.
- Which means balancing over tradeoff between the time needed for the response and the accuracy of the results
- Cost efficiency
- Since the app is desired to be on-line all the time, the serverless approach has been taken.
That's why the back-end is served on Azure as a serverless Function App.
""")
else:
gran = 0.25
test_len = 8
st.sidebar.header('User Input Parameters')
def user_input_features():
predefined_sets = {'length': [30, ], 'period': [1.34, ], 'amplitude': [0.64, ], 'growth': [0.04, ],
'amplitude_growth': [0.03, ], 'r1_nodes': [20, ], 'r2_nodes': [20, ], 'fc1_nodes': [34, ]}
data, nn = {}, {}
st.sidebar.header('Dataset:')
data['length'] = st.sidebar.slider('Training data length', 20, 50, 28)
data['period'] = st.sidebar.slider('Period of the wave', 0.75, 2.0, 1.0)
data['growth'] = st.sidebar.slider('Values growth', -0.25, 0.25, 0.0)
data['amplitude'] = st.sidebar.slider('Amplitude', 0.25, 1.75, 1.0)
data['amplitude_growth'] = st.sidebar.slider('Amplitude growth', -0.01, 0.1, 0.0)
data['noise'] = st.sidebar.slider('Noise', 0.0, 1.0, 0.0)
st.sidebar.header('Model setup')
nn['use_lstm'] = st.sidebar.radio('Select the type of Recurrent Neuron to use', ['LSTM', 'GRU']) == 'LSTM'
nn['r1_nodes'] = st.sidebar.slider('Number of nodes in the first RNN layer', 1, 30, 13)
nn['r2_nodes'] = st.sidebar.slider('Number of nodes in the second RNN layer', 0, 30, 0)
nn['fc1_nodes'] = st.sidebar.slider('Number of nodes in the fully connected RNN layer', 0, 40, 10)
nn['steps'] = len(np.arange(0, test_len, gran))
#if st.sidebar.button('Load one of the pretested configurations'):
#i = st.sidebar.selectbox('Select:', [-1, 0])
#i = int(np.random.rand(len(predefined_sets['length']))) # Selecting one pretested configuration
#data.update({k: predefined_sets[k][i] for k in set(data) & set(predefined_sets)})
#nn.update({k: predefined_sets[k][i] for k in set(nn) & set(predefined_sets)})"""
return data, nn
params, setup = user_input_features()
st.header("Work in progress - please be back in a few days")
st.subheader("Instructions:")
st.write("""
1. Modify the dataset by using the sliders in the Dataset group on the left on the screen.
2. Select the number of nodes in the model by using the sliders in the RNN setup group.
3. Press the "Train and Predict" button to Train and Predict the model - note: many operations performing under the hood - please be patient.
4. The predicted values will be shown at the bottom of the page.
5. If you are not satisfied with the results - modify the model and try again!
6. Have fun!
\n""")
st.subheader("Generated data:")
X = np.arange(0, params['length'], gran)
X_pred = np.arange(params['length'], params['length'] + test_len, gran)
def generate_wave(x_set):
return np.sin(x_set / params['period']) * (1 + params['amplitude_growth'] * x_set) * params[
'amplitude'] + x_set * params['growth'] + params['noise']*np.random.randn(len(x_set))
Y = generate_wave(X)
Y_pred = generate_wave(X_pred)
X_pred, Y_pred = np.append(X[-1], X_pred), np.append(Y[-1], Y_pred)
c1, c2, c3 = '#1e4a76', '#7dc0f7', '#ff7c0a' # colors
# sns.scatterplot(x=X, y=Y, color=c1)
# st.pyplot()
fig, ax = plt.subplots()
sns.lineplot(x=X, y=Y, color=c1)
sns.lineplot(x=X_pred, y=Y_pred, color=c2, linestyle=':')
plt.ylim(min(-2, min(Y), min(Y_pred)), max(2, max(Y), max(Y_pred)))
plt.legend(['Train data', 'Test data'], loc=3)
plt.xlabel('Sample number')
plt.ylabel('Sample value')
st.pyplot(fig)
st.write("The plot presents generated train and test data. Use the sliders on the left to modify the curve.")
def local_css(file_name):
with open(file_name) as f:
st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
local_css("button_style.css")
st.subheader('Predicted data:')
reminder = st.text('Press the train and predict button on the sidebar once you are ready with the selections.')
if st.sidebar.button('Train and Predict'):
setup['values'] = list(Y)
reminder.empty()
waiters = list()
waiters.append(st.text('Please wait till the train and predict process is finished.'))
waiters.append(st.image('wait.gif'))
waiters.append(st.text("""The process should take around 20-60 seconds."""))
result = predict_series(**setup)
fig, ax = plt.subplots()
sns.lineplot(x=X_pred, y=Y_pred, color=c2, linestyle=':')
sns.lineplot(x=X, y=Y, color=c1)
sns.lineplot(np.append(X[-1], np.arange(0, test_len, gran) + max(X) + gran), np.append(Y[-1], result['result']), color=c3)
plt.legend(['Train data', 'Test data', 'Predicted data'], loc=3)
plt.xlabel('Sample number')
plt.ylabel('Sample value')
st.pyplot(fig)
st.write("The prediction isn't good enough? Try to change settings in the model setup or increase the dataset length.")
st.write('Training took {} epochs, Mean Squared Error: {:.2e}'.format(result['epochs'], result['loss']))
#st.write('Training took {} epochs, Mean Squared Error {}, last loss {}'.format(result['epochs'], result['loss'], result['loss_last']))