SimpleNN / app.py
ricardo-lsantos's picture
Added message
fcd7754
import streamlit as st
from nn import NeuralNetwork
import json
from utils import sigmoid, sigmoid_prime
INPUTS = [[0,0],[0,1],[1,0],[1,1]]
OUTPUTS = [[0],[1],[1],[0]]
def resetSession():
st.session_state.nn = None
st.session_state.train_count = 0
## Controller Function
def runNN():
nn = st.session_state.nn
df = {
"input": [],
"expected": [],
"predicted": [],
"rounded": [],
"correct": []
}
for i in range(4):
result = nn.predict(INPUTS[i][0],INPUTS[i][1], activation=sigmoid)
df["input"].append(f"{INPUTS[i][0]} xor {INPUTS[i][1]}")
df["expected"].append(OUTPUTS[i][0])
df["predicted"].append(result)
df["rounded"].append(round(result))
df["correct"].append('correct' if round(result)==OUTPUTS[i][0] else 'incorrect')
st.dataframe(df)
# st.write(f"for input `{INPUTS[i][0]} xor {INPUTS[i][1]}` expected `{OUTPUTS[i][0]}` predicted `{result}` which rounds to `{round(result)}` and is `{ 'correct' if round(result)==OUTPUTS[i][0] else 'incorrect' }`")
def sidebar():
# Neural network controls
st.sidebar.header('Neural Network Controls')
st.sidebar.text('Number of epochs')
epochs = st.sidebar.slider('Epochs', 1, 10000, 500)
st.sidebar.text('Learning rate')
alphas = st.sidebar.slider('Alphas', 1, 100, 20)
col1, col2 = st.sidebar.columns(2)
if col1.button('New Model'):
btnNewModel()
if col2.button('Reset Model'):
resetSession()
if "nn" in st.session_state and st.session_state.nn is not None:
if st.sidebar.button('Train Model'):
btnTrainModel(epochs, alphas)
if st.sidebar.button('Run Neural Network'):
btnRunModel()
st.sidebar.download_button(label="Save Model", data=json.dumps(st.session_state.nn.getModelJson()), file_name="model.json", mime="application/json")
def btnNewModel():
resetSession()
st.session_state.nn = NeuralNetwork()
st.sidebar.text("New model created")
def btnTrainModel(epochs, alphas):
st.session_state.nn.train(inputs=INPUTS, outputs=OUTPUTS, epochs=epochs, alpha=alphas)
st.session_state.train_count += 1
st.sidebar.text(f"Model trained {st.session_state.train_count} times")
def btnRunModel():
runNN()
def btnResetModel():
resetSession()
st.sidebar.text("Model reset")
def app():
# initSession()
st.title('Simple Neural Network App')
st.write('I followed a tutorial in the reference and changed to apply good programming practices.')
st.write('This is the Neural Network image we are trying to implement!')
st.image('nn.png', width=550)
sidebar()
st.markdown('''
### References
* https://www.codingame.com/playgrounds/59631/neural-network-xor-example-from-scratch-no-libs
''')
if __name__ == '__main__':
app()