RNN_Playground / app_helper_functions.py
Krzysiek111's picture
refactoring part1 - minor perf updates, removed single letter names, moved functions to separate files
517420b
raw
history blame contribute delete
No virus
1.48 kB
import numpy as np
import streamlit as st
def user_input_features(test_len, points_granularity):
data, nn = {}, {}
st.sidebar.header('Dataset:')
data['length'] = st.sidebar.slider('Training data length', 20, 50, 28)
data['period'] = st.sidebar.slider('Period of the wave', 0.75, 2.0, 1.0)
data['growth'] = st.sidebar.slider('Values growth', -0.25, 0.25, 0.0)
data['amplitude'] = st.sidebar.slider('Amplitude', 0.25, 1.75, 1.0)
data['amplitude_growth'] = st.sidebar.slider('Amplitude growth', -0.01, 0.1, 0.0)
data['noise'] = st.sidebar.slider('Noise', 0.0, 1.0, 0.0)
st.sidebar.header('Model setup')
nn['use_lstm'] = st.sidebar.radio('Select the type of Recurrent Neuron to use', ['LSTM', 'GRU']) == 'LSTM'
nn['r1_nodes'] = st.sidebar.slider('Number of nodes in the first RNN layer', 1, 30, 13)
nn['r2_nodes'] = st.sidebar.slider('Number of nodes in the second RNN layer', 0, 30, 0)
nn['fc1_nodes'] = st.sidebar.slider('Number of nodes in the fully connected RNN layer', 0, 40, 10)
nn['steps'] = len(np.arange(0, test_len, points_granularity))
return data, nn
def generate_wave(x_set, params):
return np.sin(x_set / params['period']) * (1 + params['amplitude_growth'] * x_set) * params[
'amplitude'] + x_set * params['growth'] + params['noise']*np.random.randn(len(x_set))
def local_css(file_name):
with open(file_name) as f:
st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)