SuriRaja commited on
Commit
bfe9e47
·
verified ·
1 Parent(s): ae0bb66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -18
app.py CHANGED
@@ -19,7 +19,7 @@ def get_stock_data(ticker, start_date, end_date):
19
 
20
  # Function to preprocess data
21
  def preprocess_data(data):
22
- data = data.fillna(method='ffill') # Replace this with `data.ffill()` if needed
23
  scaler = MinMaxScaler()
24
  scaled_data = scaler.fit_transform(data)
25
  scaled_data = pd.DataFrame(scaled_data, columns=data.columns)
@@ -41,7 +41,7 @@ def create_features(data):
41
  return rsi
42
 
43
  data['RSI'] = calculate_rsi(data, 14)
44
- data = data.fillna(method='ffill') # Replace this with `data.ffill()` if needed
45
  data = data.fillna(0) # Fill any remaining NaN values
46
  return data
47
 
@@ -65,8 +65,8 @@ class LSTMModel(nn.Module):
65
  self.fc = nn.Linear(hidden_size, output_size)
66
 
67
  def forward(self, x):
68
- h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
69
- c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
70
  out, _ = self.lstm(x, (h0, c0))
71
  out = self.fc(out[:, -1, :])
72
  return out
@@ -128,27 +128,48 @@ def predict_stock(ticker, start_date, end_date):
128
  # Inverse transform only the 'Close' column
129
  predicted_stock_price = scaler.inverse_transform(predicted_df)[:, 3]
130
 
131
- # Plot actual vs. predicted stock prices
132
- plt.figure(figsize=(14, 5))
133
- plt.plot(stock_data['Close'].values, label='Actual Stock Price')
134
- plt.plot(range(seq_length, len(predicted_stock_price) + seq_length), predicted_stock_price, label='Predicted Stock Price')
135
- plt.legend()
136
- plt.show()
137
 
138
- return predicted_stock_price[-1]
 
 
 
 
 
 
 
 
139
 
140
  except ValueError as e:
141
- return str(e)
142
  except Exception as e:
143
- return f"An error occurred: {str(e)}"
144
 
145
  # Gradio Interface
 
146
  inputs = [
147
- gr.Textbox(label="Stock Ticker"),
148
- gr.Textbox(label="Start Date (YYYY-MM-DD)"),
149
- gr.Textbox(label="End Date (YYYY-MM-DD)")
150
  ]
151
 
152
- outputs = gr.Textbox(label="Predicted Next Closing Price")
 
 
 
 
 
 
153
 
154
- gr.Interface(fn=predict_stock, inputs=inputs, outputs=outputs, title="Stock Price Prediction").launch()
 
 
 
 
 
 
 
 
 
 
19
 
20
  # Function to preprocess data
21
  def preprocess_data(data):
22
+ data = data.ffill()
23
  scaler = MinMaxScaler()
24
  scaled_data = scaler.fit_transform(data)
25
  scaled_data = pd.DataFrame(scaled_data, columns=data.columns)
 
41
  return rsi
42
 
43
  data['RSI'] = calculate_rsi(data, 14)
44
+ data = data.ffill()
45
  data = data.fillna(0) # Fill any remaining NaN values
46
  return data
47
 
 
65
  self.fc = nn.Linear(hidden_size, output_size)
66
 
67
  def forward(self, x):
68
+ h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
69
+ c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
70
  out, _ = self.lstm(x, (h0, c0))
71
  out = self.fc(out[:, -1, :])
72
  return out
 
128
  # Inverse transform only the 'Close' column
129
  predicted_stock_price = scaler.inverse_transform(predicted_df)[:, 3]
130
 
131
+ # Calculate the latest actual price and predicted next price
132
+ latest_actual_price = stock_data['Close'].values[-1]
133
+ predicted_next_price = predicted_stock_price[-1]
 
 
 
134
 
135
+ # Determine whether to buy or sell the stock
136
+ recommendation = "Buy" if predicted_next_price > latest_actual_price else "Sell"
137
+
138
+ # Calculate performance summary
139
+ three_years_ago_price = stock_data['Close'].values[0]
140
+ price_change = latest_actual_price - three_years_ago_price
141
+ percentage_change = (price_change / three_years_ago_price) * 100
142
+
143
+ return latest_actual_price, predicted_next_price, recommendation, price_change, percentage_change
144
 
145
  except ValueError as e:
146
+ return str(e), '', '', '', ''
147
  except Exception as e:
148
+ return f"An error occurred: {str(e)}", '', '', '', ''
149
 
150
  # Gradio Interface
151
+ tickers = ['AAPL', 'GOOGL', 'MSFT', 'AMZN', 'TSLA', 'FB', 'NFLX', 'NVDA', 'BRK.B', 'JPM']
152
  inputs = [
153
+ gr.Dropdown(choices=tickers, label="Stock Ticker"),
154
+ gr.Textbox(label="Start Date (YYYY-MM-DD)", type="text"),
155
+ gr.Textbox(label="End Date (YYYY-MM-DD)", type="text")
156
  ]
157
 
158
+ outputs = [
159
+ gr.Textbox(label="Latest Actual Price"),
160
+ gr.Textbox(label="Predicted Next Closing Price"),
161
+ gr.Textbox(label="Recommendation"),
162
+ gr.Textbox(label="Price Change"),
163
+ gr.Textbox(label="Percentage Change")
164
+ ]
165
 
166
+ gr.Interface(
167
+ fn=predict_stock,
168
+ inputs=inputs,
169
+ outputs=outputs,
170
+ title="Stock Price Prediction",
171
+ examples=[
172
+ ["AAPL", "2020-01-01", "2023-01-01"],
173
+ ["GOOGL", "2020-01-01", "2023-01-01"]
174
+ ]
175
+ ).launch()