Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pickle | |
import numpy as np | |
from sklearn.preprocessing import StandardScaler | |
# loading the saved model | |
filename = "" | |
scaler = StandardScaler() | |
loaded_model = "" | |
# price prediction function | |
def stock_price_prediction(input_data, loaded_model): | |
input_arr = np.asarray(input_data, dtype=float).reshape(1, -1) | |
print(input_arr) | |
scaled_input = scaler.fit_transform(input_arr) | |
print(scaled_input) | |
# Predict the stock price | |
prediction = loaded_model.predict(scaled_input) | |
return f'The predicted stock price is: {prediction[0]}' | |
def main(): | |
st.title("Stock Price Prediction Models") | |
# allowing user to select a model | |
st.sidebar.header("**Select a model to use for the prediction**") | |
with st.sidebar: | |
selected_model = st.radio( | |
"", | |
["**Random Forest**", "**Decision Tree**", "**CatBoost**", "**Gradient Boost**", "**XGBoost**"], | |
captions = ["Accuracy: 98.5%", "Accuracy: 98%", "Accuracy: 97.5%", "Accuracy: 97.4%", "Accuracy: 89%"]) | |
if selected_model == "**Random Forest**": | |
filename = "Random_Forest_trained_model.pkl" | |
loaded_model = pickle.load(open(filename, 'rb')) | |
elif selected_model == "**Decision Tree**": | |
filename = "Decision_Tree_Regression_trained_model.pkl" | |
loaded_model = pickle.load(open(filename, 'rb')) | |
elif selected_model == "**CatBoost**": | |
filename = "CatBoostRegressor_trained_model.pkl" | |
loaded_model = pickle.load(open(filename, 'rb')) | |
elif selected_model == "**Gradient Boost**": | |
filename = "Gradient_Boosting_Regression_trained_model.pkl" | |
loaded_model = pickle.load(open(filename, 'rb')) | |
elif selected_model == "**XGBoost**": | |
filename = "XGBoost_trained_model.pkl" | |
loaded_model = pickle.load(open(filename, 'rb')) | |
# getting the input data from the user | |
st.subheader("**Input required values**",) | |
daily_ret_pct = st.text_input('Daily Return as %') | |
daily_var = st.text_input('Daily Variation') | |
macd = st.text_input('MACD') | |
rsi = st.text_input('RSI') | |
ema = st.text_input('EMA') | |
# code for Prediction | |
prediction = '' | |
# creating a button for Prediction | |
if st.button('Predict Stock Price'): | |
input_data = [daily_ret_pct, daily_var, macd, rsi, ema] | |
if all(input_data): | |
prediction = stock_price_prediction(input_data, loaded_model) | |
else: | |
prediction = 'Please fill in all fields' | |
st.success(prediction) | |
if __name__ == '__main__': | |
main() |