import gradio as gr import yfinance as yf from datetime import datetime, timedelta from sklearn.neighbors import KNeighborsRegressor from sklearn.ensemble import ExtraTreesRegressor from xgboost import XGBRegressor from sklearn.preprocessing import MinMaxScaler def get_stock_data(date, stock_symbol, model_type): end_date = datetime.strptime(date, "%Y-%m-%d") start_date = end_date - timedelta(days=365*2) start_date_str = start_date.strftime("%Y-%m-%d") end_date_str = end_date.strftime("%Y-%m-%d") stock = yf.download(stock_symbol, start=start_date_str, end=end_date_str) data = stock X = data.iloc[1:].values Y = data.iloc[1:]['Open'].shift(-1).dropna().values X = X[1:-1] Y = Y[1:] # Normalize X and Y scaler = MinMaxScaler() X_scaled = scaler.fit_transform(X) Y_scaled = scaler.fit_transform(Y.reshape(-1, 1)) X_test = data.tail(1).values match model_type: case "KNN": # Fit KNN model knn = KNeighborsRegressor() knn.fit(X_scaled, Y_scaled) prediction_scaled = knn.predict(X_test) prediction = scaler.inverse_transform(prediction_scaled)[0][0] case "XGBoost": xgb = XGBRegressor() xgb.fit(X_scaled, Y_scaled) prediction_scaled = xgb.predict(X_test) prediction_scaled = prediction_scaled.reshape(1,-1) prediction = scaler.inverse_transform(prediction_scaled)[0][0] # Add other models # Extra Trees chosen for no particular reason case "Extra Trees Regressor": etr = ExtraTreesRegressor() etr.fit(X_scaled, Y_scaled) prediction_scaled = etr.predict(X_test) prediction_scaled = prediction_scaled.reshape(1,-1) prediction = scaler.inverse_transform(prediction_scaled)[0][0] return str(prediction) def stock_data_interface(date, stock_symbol,model_type): prediction = get_stock_data(date, stock_symbol,model_type) return prediction iface = gr.Interface(fn=stock_data_interface, inputs=[gr.inputs.Textbox(label="Date (YYYY-MM-DD)"), gr.inputs.Textbox(label="Stock Symbol"), gr.inputs.Dropdown(choices=["KNN","XGBoost","Extra Trees Regressor"], label="Type of model")], outputs=[gr.outputs.Textbox(label="Prediction")], title="Stock Data Interface", description="Enter a date and a stock symbol to retrieve the stock data for the past two years and predict on the latest data.") iface.launch()