Spaces:
Runtime error
Runtime error
refactoring part1 - minor perf updates, removed single letter names, moved functions to separate files
517420b
import matplotlib.pyplot as plt | |
import numpy as np | |
import seaborn as sns | |
import streamlit as st | |
from app_helper_functions import generate_wave, local_css, user_input_features | |
from predict import predict_series | |
points_granularity = 0.25 | |
test_len = 8 | |
#TODO: Refactor this module | |
st.set_page_config(page_title='RNN Playground') | |
pages = {'Intro': 0, 'Implementation details': 1, 'The model': 2} | |
choice = pages[st.sidebar.radio("Select the chapter: ", tuple(pages.keys()))] | |
if choice == 0: | |
st.title("Recurrent Neural Networks playground") | |
st.subheader("The purpose") | |
st.write("""\n | |
The goal of this app is to allow the user to experiment easily with Recurrent Neural Networks. Thanks to that, the app helps to understand: \n | |
- When to use recurrent neural networks | |
- Which scenarios are more straightforward for the RNN to predict and which are more difficult | |
- How the noise interrupts the predictions | |
- Difference between LSTM and GRU nodes | |
- Understanding that the increasing number of nodes doesn't always lead to better performance | |
""") | |
st.subheader("Typical use case") | |
st.write(""" | |
1. Create a synthetic dataset, with a wide range of many choices of parameters | |
2. Create a Recurrent Neural Networks model by selecting a number of nodes in particular layers of the model | |
3. Automatically train the RNN model and make predictions | |
4. Compare the predicted values with the actual values | |
\n""") | |
st.subheader('The architecture of the model') | |
st.image('info.jpg', use_column_width=True, caption='Hover the cursor over image to see the enlarge button') | |
st.write(""" \n Use the radio buttons on the left to navigate between chapters \n \n \n \n""") | |
elif choice == 1: | |
st.title(""" \n Implementation details""") | |
st.subheader("Front-end") | |
st.write("""\n | |
The front-end part was made with the use of [the Streamlit library](https://www.streamlit.io/). | |
The parameters from the sidebar are used to create a dataset. The dataset is visualised using the Seaborn | |
library and finally sent (with the parameters specifying the number of neurons in particular layers) to | |
the back-end part through REST API. \n | |
The front-end is served on Azure as a Web App. """) | |
st.subheader("Back-end") | |
st.subheader("[Since the playground is stored on HuggingFace now - the backend is a module of the frontend]") | |
st.write("""\n | |
The backed-part is responsible for: | |
- Retrieving the dataset from the front-end through REST API | |
- Creating an RNN model using parameters passed from the user | |
- Training the model | |
- Predicting the values and returning them to the front-end \n | |
The most crucial requirements are: | |
- The neural networks setup has to be able to accurately predict the further shape of a curve for | |
the widest range of parameters selected by the user. | |
- Time execution of the back-end part must be short. | |
- Which means balancing over tradeoff between the time needed for the response and the accuracy of the results | |
- Cost efficiency | |
- Since the app is desired to be on-line all the time, the serverless approach has been taken. | |
That's why the back-end is served on Azure as a serverless Function App. | |
""") | |
else: | |
# Print instructions | |
st.sidebar.header('User Input Parameters') | |
params, setup = user_input_features(test_len, points_granularity) | |
st.header("""Refactoring & performance updates in progress!""") | |
st.subheader("Instructions:") | |
st.write(""" | |
1. Modify the dataset by using the sliders in the Dataset group on the left on the screen. | |
2. Select the number of nodes in the model by using the sliders in the RNN setup group. | |
3. Press the "Train and Predict" button to Train and Predict the model - note: many operations performing under the hood - please be patient. | |
4. The predicted values will be shown at the bottom of the page. | |
5. If you are not satisfied with the results - modify the model and try again! | |
6. Have fun! | |
\n""") | |
# Generate and present generated data | |
st.subheader("Generated data:") | |
X = np.arange(0, params['length'], points_granularity) | |
X_pred = np.arange(params['length'], params['length'] + test_len, points_granularity) | |
Y = generate_wave(X, params) | |
Y_pred = generate_wave(X_pred, params) | |
X_pred, Y_pred = np.append(X[-1], X_pred), np.append(Y[-1], Y_pred) | |
# TODO: move plotting to separate funtion | |
c1, c2, c3 = '#1e4a76', '#7dc0f7', '#ff7c0a' # colors | |
fig, ax = plt.subplots() | |
sns.lineplot(x=X, y=Y, color=c1) | |
sns.lineplot(x=X_pred, y=Y_pred, color=c2, linestyle=':') | |
plt.ylim(min(-2, min(Y), min(Y_pred)), max(2, max(Y), max(Y_pred))) | |
plt.legend(['Train data', 'Test data'], loc=3) | |
plt.xlabel('Sample number') | |
plt.ylabel('Sample value') | |
st.pyplot(fig) | |
st.write("The plot presents generated train and test data. Use the sliders on the left to modify the curve.") | |
local_css("button_style.css") | |
st.subheader('Predicted data:') | |
reminder = st.text('Press the train and predict button on the sidebar once you are ready with the selections.') | |
# Calc and post-calc flow | |
if st.sidebar.button('Train and Predict'): | |
setup['values'] = list(Y) | |
reminder.empty() | |
# Waiters - contains what should be shown pior to receiving results - it's removed afterwards | |
waiters = list() | |
waiters.append(st.text('Please wait till the train and predict process is finished.')) | |
waiters.append(st.image('wait.gif')) | |
waiters.append(st.text("""The process should take around 20-60 seconds.""")) | |
result = predict_series(**setup) | |
_ = [waiter.empty() for waiter in waiters] | |
# Plot results | |
fig, ax = plt.subplots() | |
sns.lineplot(x=X_pred, y=Y_pred, color=c2, linestyle=':') | |
sns.lineplot(x=X, y=Y, color=c1) | |
sns.lineplot(np.append(X[-1], np.arange(0, test_len, points_granularity) + max(X) + points_granularity), np.append(Y[-1], result['result']), color=c3) | |
plt.legend(['Train data', 'Test data', 'Predicted data'], loc=3) | |
plt.xlabel('Sample number') | |
plt.ylabel('Sample value') | |
st.pyplot(fig) | |
# Print statistics | |
st.write("The prediction isn't good enough? Try to change settings in the model setup or increase the dataset length.") | |
st.write('Training took {} epochs, Mean Squared Error: {:.2e}'.format(result['epochs'], result['loss'])) | |
#st.write('Training took {} epochs, Mean Squared Error {}, last loss {}'.format(result['epochs'], result['loss'], result['loss_last'])) |