Spaces:
Runtime error
Runtime error
import streamlit as st | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from predict import predict_series | |
#TODO: Refactor this module | |
st.set_page_config(page_title='RNN Playground') | |
pages = {'Intro': 0, 'Implementation details': 1, 'The model': 2} | |
choice = pages[st.sidebar.radio("Select the chapter: ", tuple(pages.keys()))] | |
if choice == 0: | |
st.title("Recurrent Neural Networks playground") | |
st.subheader("The purpose") | |
st.write("""\n | |
The goal of this app is to allow the user to experiment easily with Recurrent Neural Networks. Thanks to that, the app helps to understand: \n | |
- When to use recurrent neural networks | |
- Which scenarios are more straightforward for the RNN to predict and which are more difficult | |
- How the noise interrupts the predictions | |
- Difference between LSTM and GRU nodes | |
- Understanding that the increasing number of nodes doesn't always lead to better performance | |
""") | |
st.subheader("Typical use case") | |
st.write(""" | |
1. Create a synthetic dataset, with a wide range of many choices of parameters | |
2. Create a Recurrent Neural Networks model by selecting a number of nodes in particular layers of the model | |
3. Automatically train the RNN model and make predictions | |
4. Compare the predicted values with the actual values | |
\n""") | |
st.subheader('The architecture of the model') | |
st.image('info.jpg', use_column_width=True, caption='Hover the cursor over image to see the enlarge button') | |
st.write(""" \n Use the radio buttons on the left to navigate between chapters \n \n \n \n""") | |
elif choice == 1: | |
st.title(""" \n Implementation details""") | |
st.subheader("Front-end") | |
st.write("""\n | |
The front-end part was made with the use of [the Streamlit library](https://www.streamlit.io/). | |
The parameters from the sidebar are used to create a dataset. The dataset is visualised using the Seaborn | |
library and finally sent (with the parameters specifying the number of neurons in particular layers) to | |
the back-end part through REST API. \n | |
The front-end is served on Azure as a Web App. """) | |
st.subheader("Back-end") | |
st.subheader("[Since the playground is stored on HuggingFace now - the backend is a module of the frontend]") | |
st.write("""\n | |
The backed-part is responsible for: | |
- Retrieving the dataset from the front-end through REST API | |
- Creating an RNN model using parameters passed from the user | |
- Training the model | |
- Predicting the values and returning them to the front-end \n | |
The most crucial requirements are: | |
- The neural networks setup has to be able to accurately predict the further shape of a curve for | |
the widest range of parameters selected by the user. | |
- Time execution of the back-end part must be short. | |
- Which means balancing over tradeoff between the time needed for the response and the accuracy of the results | |
- Cost efficiency | |
- Since the app is desired to be on-line all the time, the serverless approach has been taken. | |
That's why the back-end is served on Azure as a serverless Function App. | |
""") | |
else: | |
gran = 0.25 | |
test_len = 8 | |
st.sidebar.header('User Input Parameters') | |
def user_input_features(): | |
predefined_sets = {'length': [30, ], 'period': [1.34, ], 'amplitude': [0.64, ], 'growth': [0.04, ], | |
'amplitude_growth': [0.03, ], 'r1_nodes': [20, ], 'r2_nodes': [20, ], 'fc1_nodes': [34, ]} | |
data, nn = {}, {} | |
st.sidebar.header('Dataset:') | |
data['length'] = st.sidebar.slider('Training data length', 20, 50, 28) | |
data['period'] = st.sidebar.slider('Period of the wave', 0.75, 2.0, 1.0) | |
data['growth'] = st.sidebar.slider('Values growth', -0.25, 0.25, 0.0) | |
data['amplitude'] = st.sidebar.slider('Amplitude', 0.25, 1.75, 1.0) | |
data['amplitude_growth'] = st.sidebar.slider('Amplitude growth', -0.01, 0.1, 0.0) | |
data['noise'] = st.sidebar.slider('Noise', 0.0, 1.0, 0.0) | |
st.sidebar.header('Model setup') | |
nn['use_lstm'] = st.sidebar.radio('Select the type of Recurrent Neuron to use', ['LSTM', 'GRU']) == 'LSTM' | |
nn['r1_nodes'] = st.sidebar.slider('Number of nodes in the first RNN layer', 1, 30, 13) | |
nn['r2_nodes'] = st.sidebar.slider('Number of nodes in the second RNN layer', 0, 30, 0) | |
nn['fc1_nodes'] = st.sidebar.slider('Number of nodes in the fully connected RNN layer', 0, 40, 10) | |
nn['steps'] = len(np.arange(0, test_len, gran)) | |
#if st.sidebar.button('Load one of the pretested configurations'): | |
#i = st.sidebar.selectbox('Select:', [-1, 0]) | |
#i = int(np.random.rand(len(predefined_sets['length']))) # Selecting one pretested configuration | |
#data.update({k: predefined_sets[k][i] for k in set(data) & set(predefined_sets)}) | |
#nn.update({k: predefined_sets[k][i] for k in set(nn) & set(predefined_sets)})""" | |
return data, nn | |
params, setup = user_input_features() | |
st.header("""Refactoring & performance updates in progress! | |
please be back in a few days""") | |
st.subheader("Instructions:") | |
st.write(""" | |
1. Modify the dataset by using the sliders in the Dataset group on the left on the screen. | |
2. Select the number of nodes in the model by using the sliders in the RNN setup group. | |
3. Press the "Train and Predict" button to Train and Predict the model - note: many operations performing under the hood - please be patient. | |
4. The predicted values will be shown at the bottom of the page. | |
5. If you are not satisfied with the results - modify the model and try again! | |
6. Have fun! | |
\n""") | |
st.subheader("Generated data:") | |
X = np.arange(0, params['length'], gran) | |
X_pred = np.arange(params['length'], params['length'] + test_len, gran) | |
def generate_wave(x_set): | |
return np.sin(x_set / params['period']) * (1 + params['amplitude_growth'] * x_set) * params[ | |
'amplitude'] + x_set * params['growth'] + params['noise']*np.random.randn(len(x_set)) | |
Y = generate_wave(X) | |
Y_pred = generate_wave(X_pred) | |
X_pred, Y_pred = np.append(X[-1], X_pred), np.append(Y[-1], Y_pred) | |
c1, c2, c3 = '#1e4a76', '#7dc0f7', '#ff7c0a' # colors | |
# sns.scatterplot(x=X, y=Y, color=c1) | |
# st.pyplot() | |
fig, ax = plt.subplots() | |
sns.lineplot(x=X, y=Y, color=c1) | |
sns.lineplot(x=X_pred, y=Y_pred, color=c2, linestyle=':') | |
plt.ylim(min(-2, min(Y), min(Y_pred)), max(2, max(Y), max(Y_pred))) | |
plt.legend(['Train data', 'Test data'], loc=3) | |
plt.xlabel('Sample number') | |
plt.ylabel('Sample value') | |
st.pyplot(fig) | |
st.write("The plot presents generated train and test data. Use the sliders on the left to modify the curve.") | |
def local_css(file_name): | |
with open(file_name) as f: | |
st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True) | |
local_css("button_style.css") | |
st.subheader('Predicted data:') | |
reminder = st.text('Press the train and predict button on the sidebar once you are ready with the selections.') | |
if st.sidebar.button('Train and Predict'): | |
setup['values'] = list(Y) | |
reminder.empty() | |
waiters = list() | |
waiters.append(st.text('Please wait till the train and predict process is finished.')) | |
waiters.append(st.image('wait.gif')) | |
waiters.append(st.text("""The process should take around 20-60 seconds.""")) | |
result = predict_series(**setup) | |
_ = [waiter.empty() for waiter in waiters] | |
fig, ax = plt.subplots() | |
sns.lineplot(x=X_pred, y=Y_pred, color=c2, linestyle=':') | |
sns.lineplot(x=X, y=Y, color=c1) | |
sns.lineplot(np.append(X[-1], np.arange(0, test_len, gran) + max(X) + gran), np.append(Y[-1], result['result']), color=c3) | |
plt.legend(['Train data', 'Test data', 'Predicted data'], loc=3) | |
plt.xlabel('Sample number') | |
plt.ylabel('Sample value') | |
st.pyplot(fig) | |
st.write("The prediction isn't good enough? Try to change settings in the model setup or increase the dataset length.") | |
st.write('Training took {} epochs, Mean Squared Error: {:.2e}'.format(result['epochs'], result['loss'])) | |
#st.write('Training took {} epochs, Mean Squared Error {}, last loss {}'.format(result['epochs'], result['loss'], result['loss_last'])) |