dibend commited on
Commit
ed74d2e
1 Parent(s): ddd2d15

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -27
app.py CHANGED
@@ -1,62 +1,81 @@
1
  import gradio as gr
2
  import yfinance as yf
3
  from sklearn.linear_model import LinearRegression
4
- import plotly.graph_objs as go
5
- import numpy as np
6
 
7
  def train_predict_wrapper(ticker, start_date, end_date, prediction_days):
8
  """
9
  Downloads stock data, trains a linear regression model, and predicts future prices.
10
 
11
  Args:
12
- ticker: The ticker symbol of the stock.
13
- start_date: The start date for the data (YYYY-MM-DD format).
14
- end_date: The end date for the data (YYYY-MM-DD format).
15
- prediction_days: The number of days to predict.
16
 
17
  Returns:
18
- A plot of predicted closing prices for the next `prediction_days`.
19
  """
20
  # Download stock data
21
  data = yf.download(ticker, start=start_date, end=end_date)
22
- # Extract closing price
23
  data = data["Close"]
24
 
25
- # Prepare data for model
26
- data = data.reset_index()
27
- data['Date'] = data['Date'].map(mdates.date2num)
28
- X = np.array(data.index).reshape(-1, 1)
29
- y = data['Close'].values
30
-
31
  # Train linear regression model
 
 
32
  model = LinearRegression()
33
- model.fit(X[:-prediction_days], y[:-prediction_days])
34
 
35
  # Predict future prices
36
- future_indices = np.array(range(len(X), len(X) + prediction_days)).reshape(-1, 1)
37
- predicted_prices = model.predict(future_indices)
 
38
 
39
- # Plot
40
- dates = [mdates.num2date(date).strftime('%Y-%m-%d') for date in data['Date']]
41
- future_dates = [mdates.num2date(date).strftime('%Y-%m-%d') for date in future_indices.flatten()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  fig = go.Figure()
43
- fig.add_trace(go.Scatter(x=dates, y=y, mode='lines', name='Historical Prices'))
44
- fig.add_trace(go.Scatter(x=future_dates, y=predicted_prices, mode='lines', name='Predicted Prices'))
45
- fig.update_layout(title='Stock Price Prediction', xaxis_title='Date', yaxis_title='Price')
 
 
 
 
 
46
 
47
  return fig
48
 
 
49
  # Define Gradio interface
50
- iface = gr.Interface(
51
  fn=train_predict_wrapper,
52
  inputs=[
53
  gr.Textbox(label="Ticker Symbol"),
54
  gr.Textbox(label="Start Date (YYYY-MM-DD)"),
55
  gr.Textbox(label="End Date (YYYY-MM-DD)"),
56
- gr.Slider(minimum=1, maximum=30, step=1, label="Prediction Days")
57
  ],
58
- outputs="plot"
59
  )
60
 
61
  # Launch the app
62
- iface.launch()
 
1
  import gradio as gr
2
  import yfinance as yf
3
  from sklearn.linear_model import LinearRegression
4
+ import plotly.graph_objects as go
5
+
6
 
7
  def train_predict_wrapper(ticker, start_date, end_date, prediction_days):
8
  """
9
  Downloads stock data, trains a linear regression model, and predicts future prices.
10
 
11
  Args:
12
+ ticker: The ticker symbol of the stock.
13
+ start_date: The start date for the data (YYYY-MM-DD format).
14
+ end_date: The end date for the data (YYYY-MM-DD format).
15
+ prediction_days: The number of days to predict.
16
 
17
  Returns:
18
+ A plotly figure with the historical and predicted prices.
19
  """
20
  # Download stock data
21
  data = yf.download(ticker, start=start_date, end=end_date)
 
22
  data = data["Close"]
23
 
 
 
 
 
 
 
24
  # Train linear regression model
25
+ X = data.index.values[:-prediction_days].reshape(-1, 1)
26
+ y = data.values[:-prediction_days]
27
  model = LinearRegression()
28
+ model.fit(X, y)
29
 
30
  # Predict future prices
31
+ future_dates = data.index.values[-prediction_days:]
32
+ X_future = future_dates.reshape(-1, 1)
33
+ predicted_prices = model.predict(X_future)
34
 
35
+ # Prepare data for plotting
36
+ historical_prices = go.Scatter(
37
+ x=data.index,
38
+ y=data.values,
39
+ mode="lines",
40
+ line_color=lambda p: "green" if p > 0 else "red",
41
+ name="Historical Prices",
42
+ )
43
+ predicted_prices_trace = go.Scatter(
44
+ x=future_dates,
45
+ y=predicted_prices,
46
+ mode="lines",
47
+ line_color="gold",
48
+ line_width=3,
49
+ marker_line_width=3,
50
+ marker_color="black",
51
+ name="Predicted Prices",
52
+ )
53
+
54
+ # Plot data
55
  fig = go.Figure()
56
+ fig.add_trace(historical_prices)
57
+ fig.add_trace(predicted_prices_trace)
58
+ fig.update_layout(
59
+ title="Stock Price Prediction",
60
+ xaxis_title="Date",
61
+ yaxis_title="Price",
62
+ legend_title_text="Data",
63
+ )
64
 
65
  return fig
66
 
67
+
68
  # Define Gradio interface
69
+ interface = gr.Interface(
70
  fn=train_predict_wrapper,
71
  inputs=[
72
  gr.Textbox(label="Ticker Symbol"),
73
  gr.Textbox(label="Start Date (YYYY-MM-DD)"),
74
  gr.Textbox(label="End Date (YYYY-MM-DD)"),
75
+ gr.Slider(minimum=1, maximum=30, step=1, label="Prediction Days"),
76
  ],
77
+ outputs="plot",
78
  )
79
 
80
  # Launch the app
81
+ interface.launch()