dibend commited on
Commit
87d437c
1 Parent(s): 38f860b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -38
app.py CHANGED
@@ -4,65 +4,42 @@ import pandas as pd
4
  from sklearn.linear_model import LinearRegression
5
  import plotly.graph_objects as go
6
 
7
-
8
  def train_predict_wrapper(ticker, start_date, end_date, prediction_days):
9
- """
10
- Downloads stock data, trains a linear regression model, and predicts future prices.
11
-
12
- Args:
13
- ticker: The ticker symbol of the stock.
14
- start_date: The start date for the data (YYYY-MM-DD format).
15
- end_date: The end date for the data (YYYY-MM-DD format).
16
- prediction_days: The number of days to predict.
17
-
18
- Returns:
19
- A plotly figure with the historical and predicted prices.
20
- """
21
  # Download stock data
22
  data = yf.download(ticker, start=start_date, end=end_date)
23
  data = data["Close"]
24
 
25
- # Convert dates to a numerical format (days since start date)
26
- start_date = pd.to_datetime(start_date)
27
- days_since_start = (data.index - start_date).days
28
 
29
  # Train linear regression model
30
- X = days_since_start.values[:-prediction_days].reshape(-1, 1)
31
  y = data.values[:-prediction_days]
32
  model = LinearRegression()
33
  model.fit(X, y)
34
 
35
  # Prepare data for prediction
36
- last_date = data.index[-1]
37
- future_dates = pd.date_range(start=last_date, periods=prediction_days + 1, closed='right')
38
- future_days_since_start = (future_dates - start_date).days
39
- X_future = future_days_since_start.values.reshape(-1, 1)
 
40
 
41
  # Predict future prices
42
  predicted_prices = model.predict(X_future)
43
 
44
- # Predict future prices
45
- future_dates = data.index.values[-prediction_days:]
46
- X_future = future_dates.reshape(-1, 1)
47
- predicted_prices = model.predict(X_future)
48
-
49
  # Prepare data for plotting
50
  historical_prices = go.Scatter(
51
- x=data.index,
52
  y=data.values,
53
  mode="lines",
54
- line_color=lambda p: "green" if p > 0 else "red",
55
- name="Historical Prices",
56
  )
57
  predicted_prices_trace = go.Scatter(
58
- x=future_dates,
59
  y=predicted_prices,
60
  mode="lines",
61
- line_color="gold",
62
- line_width=3,
63
- marker_line_width=3,
64
- marker_color="black",
65
- name="Predicted Prices",
66
  )
67
 
68
  # Plot data
@@ -73,12 +50,11 @@ def train_predict_wrapper(ticker, start_date, end_date, prediction_days):
73
  title="Stock Price Prediction",
74
  xaxis_title="Date",
75
  yaxis_title="Price",
76
- legend_title_text="Data",
77
  )
78
 
79
  return fig
80
 
81
-
82
  # Define Gradio interface
83
  interface = gr.Interface(
84
  fn=train_predict_wrapper,
@@ -88,7 +64,7 @@ interface = gr.Interface(
88
  gr.Textbox(label="End Date (YYYY-MM-DD)"),
89
  gr.Slider(minimum=1, maximum=30, step=1, label="Prediction Days"),
90
  ],
91
- outputs="plot",
92
  )
93
 
94
  # Launch the app
 
4
  from sklearn.linear_model import LinearRegression
5
  import plotly.graph_objects as go
6
 
 
7
  def train_predict_wrapper(ticker, start_date, end_date, prediction_days):
 
 
 
 
 
 
 
 
 
 
 
 
8
  # Download stock data
9
  data = yf.download(ticker, start=start_date, end=end_date)
10
  data = data["Close"]
11
 
12
+ # Convert index to Unix timestamp (seconds)
13
+ data.index = (data.index - pd.Timestamp("1970-01-01")) // pd.Timedelta('1s')
 
14
 
15
  # Train linear regression model
16
+ X = data.index.values[:-prediction_days].reshape(-1, 1)
17
  y = data.values[:-prediction_days]
18
  model = LinearRegression()
19
  model.fit(X, y)
20
 
21
  # Prepare data for prediction
22
+ last_timestamp = data.index[-1]
23
+ future_timestamps = pd.date_range(start=pd.to_datetime(last_timestamp, unit='s'),
24
+ periods=prediction_days + 1, closed='right')
25
+ future_timestamps = (future_timestamps - pd.Timestamp("1970-01-01")) // pd.Timedelta('1s')
26
+ X_future = future_timestamps.values.reshape(-1, 1)
27
 
28
  # Predict future prices
29
  predicted_prices = model.predict(X_future)
30
 
 
 
 
 
 
31
  # Prepare data for plotting
32
  historical_prices = go.Scatter(
33
+ x=pd.to_datetime(data.index, unit='s'),
34
  y=data.values,
35
  mode="lines",
36
+ name="Historical Prices"
 
37
  )
38
  predicted_prices_trace = go.Scatter(
39
+ x=pd.to_datetime(future_timestamps, unit='s'),
40
  y=predicted_prices,
41
  mode="lines",
42
+ name="Predicted Prices"
 
 
 
 
43
  )
44
 
45
  # Plot data
 
50
  title="Stock Price Prediction",
51
  xaxis_title="Date",
52
  yaxis_title="Price",
53
+ legend_title_text="Data"
54
  )
55
 
56
  return fig
57
 
 
58
  # Define Gradio interface
59
  interface = gr.Interface(
60
  fn=train_predict_wrapper,
 
64
  gr.Textbox(label="End Date (YYYY-MM-DD)"),
65
  gr.Slider(minimum=1, maximum=30, step=1, label="Prediction Days"),
66
  ],
67
+ outputs="plot"
68
  )
69
 
70
  # Launch the app