Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -19,7 +19,7 @@ def get_stock_data(ticker, start_date, end_date):
|
|
| 19 |
|
| 20 |
# Function to preprocess data
|
| 21 |
def preprocess_data(data):
|
| 22 |
-
data = data.
|
| 23 |
scaler = MinMaxScaler()
|
| 24 |
scaled_data = scaler.fit_transform(data)
|
| 25 |
scaled_data = pd.DataFrame(scaled_data, columns=data.columns)
|
|
@@ -41,7 +41,7 @@ def create_features(data):
|
|
| 41 |
return rsi
|
| 42 |
|
| 43 |
data['RSI'] = calculate_rsi(data, 14)
|
| 44 |
-
data = data.
|
| 45 |
data = data.fillna(0) # Fill any remaining NaN values
|
| 46 |
return data
|
| 47 |
|
|
@@ -65,8 +65,8 @@ class LSTMModel(nn.Module):
|
|
| 65 |
self.fc = nn.Linear(hidden_size, output_size)
|
| 66 |
|
| 67 |
def forward(self, x):
|
| 68 |
-
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
|
| 69 |
-
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
|
| 70 |
out, _ = self.lstm(x, (h0, c0))
|
| 71 |
out = self.fc(out[:, -1, :])
|
| 72 |
return out
|
|
@@ -128,27 +128,48 @@ def predict_stock(ticker, start_date, end_date):
|
|
| 128 |
# Inverse transform only the 'Close' column
|
| 129 |
predicted_stock_price = scaler.inverse_transform(predicted_df)[:, 3]
|
| 130 |
|
| 131 |
-
#
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
plt.plot(range(seq_length, len(predicted_stock_price) + seq_length), predicted_stock_price, label='Predicted Stock Price')
|
| 135 |
-
plt.legend()
|
| 136 |
-
plt.show()
|
| 137 |
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
except ValueError as e:
|
| 141 |
-
return str(e)
|
| 142 |
except Exception as e:
|
| 143 |
-
return f"An error occurred: {str(e)}"
|
| 144 |
|
| 145 |
# Gradio Interface
|
|
|
|
| 146 |
inputs = [
|
| 147 |
-
gr.
|
| 148 |
-
gr.Textbox(label="Start Date (YYYY-MM-DD)"),
|
| 149 |
-
gr.Textbox(label="End Date (YYYY-MM-DD)")
|
| 150 |
]
|
| 151 |
|
| 152 |
-
outputs =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
-
gr.Interface(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
# Function to preprocess data
|
| 21 |
def preprocess_data(data):
|
| 22 |
+
data = data.ffill()
|
| 23 |
scaler = MinMaxScaler()
|
| 24 |
scaled_data = scaler.fit_transform(data)
|
| 25 |
scaled_data = pd.DataFrame(scaled_data, columns=data.columns)
|
|
|
|
| 41 |
return rsi
|
| 42 |
|
| 43 |
data['RSI'] = calculate_rsi(data, 14)
|
| 44 |
+
data = data.ffill()
|
| 45 |
data = data.fillna(0) # Fill any remaining NaN values
|
| 46 |
return data
|
| 47 |
|
|
|
|
| 65 |
self.fc = nn.Linear(hidden_size, output_size)
|
| 66 |
|
| 67 |
def forward(self, x):
|
| 68 |
+
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
|
| 69 |
+
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
|
| 70 |
out, _ = self.lstm(x, (h0, c0))
|
| 71 |
out = self.fc(out[:, -1, :])
|
| 72 |
return out
|
|
|
|
| 128 |
# Inverse transform only the 'Close' column
|
| 129 |
predicted_stock_price = scaler.inverse_transform(predicted_df)[:, 3]
|
| 130 |
|
| 131 |
+
# Calculate the latest actual price and predicted next price
|
| 132 |
+
latest_actual_price = stock_data['Close'].values[-1]
|
| 133 |
+
predicted_next_price = predicted_stock_price[-1]
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
+
# Determine whether to buy or sell the stock
|
| 136 |
+
recommendation = "Buy" if predicted_next_price > latest_actual_price else "Sell"
|
| 137 |
+
|
| 138 |
+
# Calculate performance summary
|
| 139 |
+
three_years_ago_price = stock_data['Close'].values[0]
|
| 140 |
+
price_change = latest_actual_price - three_years_ago_price
|
| 141 |
+
percentage_change = (price_change / three_years_ago_price) * 100
|
| 142 |
+
|
| 143 |
+
return latest_actual_price, predicted_next_price, recommendation, price_change, percentage_change
|
| 144 |
|
| 145 |
except ValueError as e:
|
| 146 |
+
return str(e), '', '', '', ''
|
| 147 |
except Exception as e:
|
| 148 |
+
return f"An error occurred: {str(e)}", '', '', '', ''
|
| 149 |
|
| 150 |
# Gradio Interface
|
| 151 |
+
tickers = ['AAPL', 'GOOGL', 'MSFT', 'AMZN', 'TSLA', 'FB', 'NFLX', 'NVDA', 'BRK.B', 'JPM']
|
| 152 |
inputs = [
|
| 153 |
+
gr.Dropdown(choices=tickers, label="Stock Ticker"),
|
| 154 |
+
gr.Textbox(label="Start Date (YYYY-MM-DD)", type="text"),
|
| 155 |
+
gr.Textbox(label="End Date (YYYY-MM-DD)", type="text")
|
| 156 |
]
|
| 157 |
|
| 158 |
+
outputs = [
|
| 159 |
+
gr.Textbox(label="Latest Actual Price"),
|
| 160 |
+
gr.Textbox(label="Predicted Next Closing Price"),
|
| 161 |
+
gr.Textbox(label="Recommendation"),
|
| 162 |
+
gr.Textbox(label="Price Change"),
|
| 163 |
+
gr.Textbox(label="Percentage Change")
|
| 164 |
+
]
|
| 165 |
|
| 166 |
+
gr.Interface(
|
| 167 |
+
fn=predict_stock,
|
| 168 |
+
inputs=inputs,
|
| 169 |
+
outputs=outputs,
|
| 170 |
+
title="Stock Price Prediction",
|
| 171 |
+
examples=[
|
| 172 |
+
["AAPL", "2020-01-01", "2023-01-01"],
|
| 173 |
+
["GOOGL", "2020-01-01", "2023-01-01"]
|
| 174 |
+
]
|
| 175 |
+
).launch()
|