stock_price / app.py
anushkamantri's picture
Upload app.py
fbb056f verified
raw
history blame contribute delete
No virus
2.67 kB
import streamlit as st
import pickle
import numpy as np
from sklearn.preprocessing import StandardScaler
# loading the saved model
filename = ""
scaler = StandardScaler()
loaded_model = ""
# price prediction function
def stock_price_prediction(input_data, loaded_model):
input_arr = np.asarray(input_data, dtype=float).reshape(1, -1)
print(input_arr)
scaled_input = scaler.fit_transform(input_arr)
print(scaled_input)
# Predict the stock price
prediction = loaded_model.predict(scaled_input)
return f'The predicted stock price is: {prediction[0]}'
def main():
st.title("Stock Price Prediction Models")
# allowing user to select a model
st.sidebar.header("**Select a model to use for the prediction**")
with st.sidebar:
selected_model = st.radio(
"",
["**Random Forest**", "**Decision Tree**", "**CatBoost**", "**Gradient Boost**", "**XGBoost**"],
captions = ["Accuracy: 98.5%", "Accuracy: 98%", "Accuracy: 97.5%", "Accuracy: 97.4%", "Accuracy: 89%"])
if selected_model == "**Random Forest**":
filename = "Random_Forest_trained_model.pkl"
loaded_model = pickle.load(open(filename, 'rb'))
elif selected_model == "**Decision Tree**":
filename = "Decision_Tree_Regression_trained_model.pkl"
loaded_model = pickle.load(open(filename, 'rb'))
elif selected_model == "**CatBoost**":
filename = "CatBoostRegressor_trained_model.pkl"
loaded_model = pickle.load(open(filename, 'rb'))
elif selected_model == "**Gradient Boost**":
filename = "Gradient_Boosting_Regression_trained_model.pkl"
loaded_model = pickle.load(open(filename, 'rb'))
elif selected_model == "**XGBoost**":
filename = "XGBoost_trained_model.pkl"
loaded_model = pickle.load(open(filename, 'rb'))
# getting the input data from the user
st.subheader("**Input required values**",)
daily_ret_pct = st.text_input('Daily Return as %')
daily_var = st.text_input('Daily Variation')
macd = st.text_input('MACD')
rsi = st.text_input('RSI')
ema = st.text_input('EMA')
# code for Prediction
prediction = ''
# creating a button for Prediction
if st.button('Predict Stock Price'):
input_data = [daily_ret_pct, daily_var, macd, rsi, ema]
if all(input_data):
prediction = stock_price_prediction(input_data, loaded_model)
else:
prediction = 'Please fill in all fields'
st.success(prediction)
if __name__ == '__main__':
main()