dibend commited on
Commit
a2d3623
1 Parent(s): 1cd8044

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -36
app.py CHANGED
@@ -1,40 +1,62 @@
1
- from gradio_app_builder import app
2
  import yfinance as yf
3
  from sklearn.linear_model import LinearRegression
4
- import json
 
5
 
6
- @app.route("/")
7
  def train_predict_wrapper(ticker, start_date, end_date, prediction_days):
8
- """
9
- Downloads stock data, trains a linear regression model, and predicts future prices.
10
-
11
- Args:
12
- ticker: The ticker symbol of the stock.
13
- start_date: The start date for the data (YYYY-MM-DD format).
14
- end_date: The end date for the data (YYYY-MM-DD format).
15
- prediction_days: The number of days to predict.
16
-
17
- Returns:
18
- A list of predicted closing prices for the next `prediction_days`.
19
- """
20
- # Download stock data
21
- data = yf.download(ticker, start=start_date, end=end_date)
22
- # Extract closing price and set as index
23
- data = data["Close"].to_frame().set_index(data.index.values)
24
-
25
- # Train linear regression model
26
- X = data.index.values[:-prediction_days].reshape(-1, 1)
27
- y = data.values[:-prediction_days]
28
- model = LinearRegression()
29
- model.fit(X, y)
30
-
31
- # Predict future prices
32
- future_dates = data.index.values[-prediction_days:]
33
- X_future = future_dates.reshape(-1, 1)
34
- predicted_prices = model.predict(X_future).tolist()
35
-
36
- # Return predicted prices as JSON
37
- return json.dumps(predicted_prices)
38
-
39
- # Launch the Gradio application
40
- app.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
  import yfinance as yf
3
  from sklearn.linear_model import LinearRegression
4
+ import plotly.graph_objs as go
5
+ import numpy as np
6
 
 
7
  def train_predict_wrapper(ticker, start_date, end_date, prediction_days):
8
+ """
9
+ Downloads stock data, trains a linear regression model, and predicts future prices.
10
+
11
+ Args:
12
+ ticker: The ticker symbol of the stock.
13
+ start_date: The start date for the data (YYYY-MM-DD format).
14
+ end_date: The end date for the data (YYYY-MM-DD format).
15
+ prediction_days: The number of days to predict.
16
+
17
+ Returns:
18
+ A plot of predicted closing prices for the next `prediction_days`.
19
+ """
20
+ # Download stock data
21
+ data = yf.download(ticker, start=start_date, end=end_date)
22
+ # Extract closing price
23
+ data = data["Close"]
24
+
25
+ # Prepare data for model
26
+ data = data.reset_index()
27
+ data['Date'] = data['Date'].map(mdates.date2num)
28
+ X = np.array(data.index).reshape(-1, 1)
29
+ y = data['Close'].values
30
+
31
+ # Train linear regression model
32
+ model = LinearRegression()
33
+ model.fit(X[:-prediction_days], y[:-prediction_days])
34
+
35
+ # Predict future prices
36
+ future_indices = np.array(range(len(X), len(X) + prediction_days)).reshape(-1, 1)
37
+ predicted_prices = model.predict(future_indices)
38
+
39
+ # Plot
40
+ dates = [mdates.num2date(date).strftime('%Y-%m-%d') for date in data['Date']]
41
+ future_dates = [mdates.num2date(date).strftime('%Y-%m-%d') for date in future_indices.flatten()]
42
+ fig = go.Figure()
43
+ fig.add_trace(go.Scatter(x=dates, y=y, mode='lines', name='Historical Prices'))
44
+ fig.add_trace(go.Scatter(x=future_dates, y=predicted_prices, mode='lines', name='Predicted Prices'))
45
+ fig.update_layout(title='Stock Price Prediction', xaxis_title='Date', yaxis_title='Price')
46
+
47
+ return fig
48
+
49
+ # Define Gradio interface
50
+ iface = gr.Interface(
51
+ fn=train_predict_wrapper,
52
+ inputs=[
53
+ gr.inputs.Textbox(label="Ticker Symbol"),
54
+ gr.inputs.Textbox(label="Start Date (YYYY-MM-DD)"),
55
+ gr.inputs.Textbox(label="End Date (YYYY-MM-DD)"),
56
+ gr.inputs.Slider(minimum=1, maximum=30, step=1, default=5, label="Prediction Days")
57
+ ],
58
+ outputs="plot"
59
+ )
60
+
61
+ # Launch the app
62
+ iface.launch()