RNN_Playground / app_helper_functions.py
Krzysiek111's picture
refactoring part1 - minor perf updates, removed single letter names, moved functions to separate files
517420b
raw
history blame
No virus
1.48 kB
import numpy as np
import streamlit as st
def user_input_features(test_len, points_granularity):
data, nn = {}, {}
st.sidebar.header('Dataset:')
data['length'] = st.sidebar.slider('Training data length', 20, 50, 28)
data['period'] = st.sidebar.slider('Period of the wave', 0.75, 2.0, 1.0)
data['growth'] = st.sidebar.slider('Values growth', -0.25, 0.25, 0.0)
data['amplitude'] = st.sidebar.slider('Amplitude', 0.25, 1.75, 1.0)
data['amplitude_growth'] = st.sidebar.slider('Amplitude growth', -0.01, 0.1, 0.0)
data['noise'] = st.sidebar.slider('Noise', 0.0, 1.0, 0.0)
st.sidebar.header('Model setup')
nn['use_lstm'] = st.sidebar.radio('Select the type of Recurrent Neuron to use', ['LSTM', 'GRU']) == 'LSTM'
nn['r1_nodes'] = st.sidebar.slider('Number of nodes in the first RNN layer', 1, 30, 13)
nn['r2_nodes'] = st.sidebar.slider('Number of nodes in the second RNN layer', 0, 30, 0)
nn['fc1_nodes'] = st.sidebar.slider('Number of nodes in the fully connected RNN layer', 0, 40, 10)
nn['steps'] = len(np.arange(0, test_len, points_granularity))
return data, nn
def generate_wave(x_set, params):
return np.sin(x_set / params['period']) * (1 + params['amplitude_growth'] * x_set) * params[
'amplitude'] + x_set * params['growth'] + params['noise']*np.random.randn(len(x_set))
def local_css(file_name):
with open(file_name) as f:
st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)