Update app.py
Browse files
app.py
CHANGED
@@ -1,62 +1,81 @@
|
|
1 |
import gradio as gr
|
2 |
import yfinance as yf
|
3 |
from sklearn.linear_model import LinearRegression
|
4 |
-
import plotly.
|
5 |
-
|
6 |
|
7 |
def train_predict_wrapper(ticker, start_date, end_date, prediction_days):
|
8 |
"""
|
9 |
Downloads stock data, trains a linear regression model, and predicts future prices.
|
10 |
|
11 |
Args:
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
|
17 |
Returns:
|
18 |
-
|
19 |
"""
|
20 |
# Download stock data
|
21 |
data = yf.download(ticker, start=start_date, end=end_date)
|
22 |
-
# Extract closing price
|
23 |
data = data["Close"]
|
24 |
|
25 |
-
# Prepare data for model
|
26 |
-
data = data.reset_index()
|
27 |
-
data['Date'] = data['Date'].map(mdates.date2num)
|
28 |
-
X = np.array(data.index).reshape(-1, 1)
|
29 |
-
y = data['Close'].values
|
30 |
-
|
31 |
# Train linear regression model
|
|
|
|
|
32 |
model = LinearRegression()
|
33 |
-
model.fit(X
|
34 |
|
35 |
# Predict future prices
|
36 |
-
|
37 |
-
|
|
|
38 |
|
39 |
-
#
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
fig = go.Figure()
|
43 |
-
fig.add_trace(
|
44 |
-
fig.add_trace(
|
45 |
-
fig.update_layout(
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
return fig
|
48 |
|
|
|
49 |
# Define Gradio interface
|
50 |
-
|
51 |
fn=train_predict_wrapper,
|
52 |
inputs=[
|
53 |
gr.Textbox(label="Ticker Symbol"),
|
54 |
gr.Textbox(label="Start Date (YYYY-MM-DD)"),
|
55 |
gr.Textbox(label="End Date (YYYY-MM-DD)"),
|
56 |
-
gr.Slider(minimum=1, maximum=30, step=1, label="Prediction Days")
|
57 |
],
|
58 |
-
outputs="plot"
|
59 |
)
|
60 |
|
61 |
# Launch the app
|
62 |
-
|
|
|
1 |
import gradio as gr
|
2 |
import yfinance as yf
|
3 |
from sklearn.linear_model import LinearRegression
|
4 |
+
import plotly.graph_objects as go
|
5 |
+
|
6 |
|
7 |
def train_predict_wrapper(ticker, start_date, end_date, prediction_days):
|
8 |
"""
|
9 |
Downloads stock data, trains a linear regression model, and predicts future prices.
|
10 |
|
11 |
Args:
|
12 |
+
ticker: The ticker symbol of the stock.
|
13 |
+
start_date: The start date for the data (YYYY-MM-DD format).
|
14 |
+
end_date: The end date for the data (YYYY-MM-DD format).
|
15 |
+
prediction_days: The number of days to predict.
|
16 |
|
17 |
Returns:
|
18 |
+
A plotly figure with the historical and predicted prices.
|
19 |
"""
|
20 |
# Download stock data
|
21 |
data = yf.download(ticker, start=start_date, end=end_date)
|
|
|
22 |
data = data["Close"]
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
# Train linear regression model
|
25 |
+
X = data.index.values[:-prediction_days].reshape(-1, 1)
|
26 |
+
y = data.values[:-prediction_days]
|
27 |
model = LinearRegression()
|
28 |
+
model.fit(X, y)
|
29 |
|
30 |
# Predict future prices
|
31 |
+
future_dates = data.index.values[-prediction_days:]
|
32 |
+
X_future = future_dates.reshape(-1, 1)
|
33 |
+
predicted_prices = model.predict(X_future)
|
34 |
|
35 |
+
# Prepare data for plotting
|
36 |
+
historical_prices = go.Scatter(
|
37 |
+
x=data.index,
|
38 |
+
y=data.values,
|
39 |
+
mode="lines",
|
40 |
+
line_color=lambda p: "green" if p > 0 else "red",
|
41 |
+
name="Historical Prices",
|
42 |
+
)
|
43 |
+
predicted_prices_trace = go.Scatter(
|
44 |
+
x=future_dates,
|
45 |
+
y=predicted_prices,
|
46 |
+
mode="lines",
|
47 |
+
line_color="gold",
|
48 |
+
line_width=3,
|
49 |
+
marker_line_width=3,
|
50 |
+
marker_color="black",
|
51 |
+
name="Predicted Prices",
|
52 |
+
)
|
53 |
+
|
54 |
+
# Plot data
|
55 |
fig = go.Figure()
|
56 |
+
fig.add_trace(historical_prices)
|
57 |
+
fig.add_trace(predicted_prices_trace)
|
58 |
+
fig.update_layout(
|
59 |
+
title="Stock Price Prediction",
|
60 |
+
xaxis_title="Date",
|
61 |
+
yaxis_title="Price",
|
62 |
+
legend_title_text="Data",
|
63 |
+
)
|
64 |
|
65 |
return fig
|
66 |
|
67 |
+
|
68 |
# Define Gradio interface
|
69 |
+
interface = gr.Interface(
|
70 |
fn=train_predict_wrapper,
|
71 |
inputs=[
|
72 |
gr.Textbox(label="Ticker Symbol"),
|
73 |
gr.Textbox(label="Start Date (YYYY-MM-DD)"),
|
74 |
gr.Textbox(label="End Date (YYYY-MM-DD)"),
|
75 |
+
gr.Slider(minimum=1, maximum=30, step=1, label="Prediction Days"),
|
76 |
],
|
77 |
+
outputs="plot",
|
78 |
)
|
79 |
|
80 |
# Launch the app
|
81 |
+
interface.launch()
|