amitpress commited on
Commit
0c0d46a
·
1 Parent(s): e709e21
Files changed (2) hide show
  1. app.py +77 -0
  2. best.keras +3 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import yfinance as yf
3
+ import numpy as np
4
+ import pandas as pd
5
+ import tensorflow as tf
6
+ import matplotlib.pyplot as plt
7
+ from sklearn.preprocessing import MinMaxScaler
8
+ # Load your pre-trained Keras model
9
+ model = tf.keras.models.load_model("./best.keras")
10
+
11
+ # scale the data
12
+ def create_scaler(df):
13
+ scaler = MinMaxScaler(feature_range=(0,1))
14
+ scaled_df = scaler.fit_transform(df['Close'].values.reshape(-1, 1))
15
+ return scaler, scaled_df
16
+ # create input output sequence
17
+ def create_sequence(scaled_df):
18
+ X, y = [], []
19
+ window = 60
20
+ n_future = 1
21
+ for i in range(len(scaled_df) - window - n_future - 1):
22
+ X.append(scaled_df[i:i+window])
23
+ y.append(scaled_df[i+window+n_future])
24
+ X = np.array(X)
25
+ y = np.array(y)
26
+ return X, y
27
+
28
+ def fetch_and_predict(ticker, period):
29
+ # Fetch historical stock data using yfinance
30
+ try:
31
+ df = yf.download(ticker, period=period)
32
+ if isinstance(df.columns, pd.MultiIndex):
33
+ df.columns = df.columns.get_level_values(0)
34
+ except Exception as e:
35
+ print("check 2")
36
+ return f"Error downloading data: {e}"
37
+
38
+ # Check if we have enough data for predictions
39
+
40
+ if df.shape[0] < 60:
41
+ return "Not enough data for predictions. Please select a longer period."
42
+
43
+ # prepare data
44
+ scaler, df = create_scaler(df)
45
+ X, y = create_sequence(df)
46
+ # Predicting stock prices
47
+ try:
48
+ print("fine")
49
+ yhat = model.predict(X)
50
+ except Exception as e:
51
+ return f"Error during prediction: {e}"
52
+ # Plot the predicted prices
53
+ plt.figure(figsize=(14, 7))
54
+ plt.plot(y, label='Actual Prices')
55
+ plt.plot(yhat, label='Predicted Prices')
56
+ plt.title(f'Stock Price Prediction (LSTM) - [{str(ticker)}]')
57
+ plt.xlabel('Time')
58
+ plt.ylabel('Stock Price')
59
+ plt.legend()
60
+ plt.xticks(rotation=45)
61
+ return plt.gcf()
62
+
63
+ interface = gr.Interface(
64
+ fn=fetch_and_predict,
65
+ inputs=[
66
+ gr.Textbox(label="Stock Ticker", placeholder="Enter stock ticker (e.g., DAL, AAPL)"),
67
+ gr.Textbox(label="Period", placeholder="Enter period (e.g., '1y')")
68
+ ],
69
+ outputs=gr.Plot(),
70
+ live=False,
71
+ allow_flagging="never",
72
+ title="Stock Price Prediction",
73
+ description="Enter the stock ticker and period, then click the button to fetch data and predict prices.",
74
+ theme="huggingface",
75
+ )
76
+
77
+ interface.launch()
best.keras ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9f5adbd0e6c4bc1bfe8b553596f050976ab95fcd19a4cd4b4f53914441650c3
3
+ size 430225