Spaces:
Runtime error
Runtime error
update to interactive plot
Browse files- app.py +24 -16
- requirements.txt +2 -0
app.py
CHANGED
@@ -9,6 +9,8 @@ import gradio as gr
|
|
9 |
import spaces
|
10 |
import torch
|
11 |
import pandas as pd
|
|
|
|
|
12 |
|
13 |
# External imports
|
14 |
|
@@ -210,8 +212,6 @@ def create_test_dataloader(
|
|
210 |
)
|
211 |
|
212 |
def plot(ts_index, test_dataset, forecasts, prediction_length):
|
213 |
-
fig, ax = plt.subplots(figsize=(12, 8), facecolor='white')
|
214 |
-
|
215 |
# Length of the target data
|
216 |
target_length = len(test_dataset[ts_index]['target'])
|
217 |
|
@@ -223,26 +223,34 @@ def plot(ts_index, test_dataset, forecasts, prediction_length):
|
|
223 |
).to_timestamp()
|
224 |
|
225 |
# Plotting actual data
|
226 |
-
|
227 |
-
index[:target_length],
|
228 |
-
test_dataset[ts_index]['target'],
|
229 |
-
|
|
|
230 |
)
|
231 |
|
232 |
# Plotting the forecast data
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
label="Prediction"
|
239 |
)
|
240 |
|
241 |
-
|
242 |
-
|
243 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
|
245 |
-
plt.legend()
|
246 |
return fig
|
247 |
|
248 |
def do_prediction(days_to_predict: int):
|
|
|
9 |
import spaces
|
10 |
import torch
|
11 |
import pandas as pd
|
12 |
+
import plotly.graph_objects as go
|
13 |
+
from plotly.subplots import make_subplots
|
14 |
|
15 |
# External imports
|
16 |
|
|
|
212 |
)
|
213 |
|
214 |
def plot(ts_index, test_dataset, forecasts, prediction_length):
|
|
|
|
|
215 |
# Length of the target data
|
216 |
target_length = len(test_dataset[ts_index]['target'])
|
217 |
|
|
|
223 |
).to_timestamp()
|
224 |
|
225 |
# Plotting actual data
|
226 |
+
actual_data = go.Scatter(
|
227 |
+
x=index[:target_length],
|
228 |
+
y=test_dataset[ts_index]['target'],
|
229 |
+
name="Actual",
|
230 |
+
mode='lines',
|
231 |
)
|
232 |
|
233 |
# Plotting the forecast data
|
234 |
+
forecast_data = go.Scatter(
|
235 |
+
x=index[target_length:],
|
236 |
+
y=forecasts[ts_index][0][:prediction_length],
|
237 |
+
name="Prediction",
|
238 |
+
mode='lines',
|
|
|
239 |
)
|
240 |
|
241 |
+
# Create the figure
|
242 |
+
fig = make_subplots(rows=1, cols=1)
|
243 |
+
fig.add_trace(actual_data, row=1, col=1)
|
244 |
+
fig.add_trace(forecast_data, row=1, col=1)
|
245 |
+
|
246 |
+
# Set layout and title
|
247 |
+
fig.update_layout(
|
248 |
+
xaxis_title="Date",
|
249 |
+
yaxis_title="Value",
|
250 |
+
title="Actual vs. Predicted Values",
|
251 |
+
xaxis_rangeslider_visible=True,
|
252 |
+
)
|
253 |
|
|
|
254 |
return fig
|
255 |
|
256 |
def do_prediction(days_to_predict: int):
|
requirements.txt
CHANGED
@@ -118,6 +118,7 @@ parso==0.8.4
|
|
118 |
pexpect==4.9.0
|
119 |
pillow==10.3.0
|
120 |
platformdirs==4.2.1
|
|
|
121 |
pretty-errors==1.2.25
|
122 |
prometheus_client==0.20.0
|
123 |
prompt-toolkit==3.0.43
|
@@ -165,6 +166,7 @@ spaces==0.28.0
|
|
165 |
stack-data==0.6.3
|
166 |
starlette==0.37.2
|
167 |
sympy==1.12
|
|
|
168 |
tensorboard==2.16.2
|
169 |
tensorboard-data-server==0.7.2
|
170 |
terminado==0.18.1
|
|
|
118 |
pexpect==4.9.0
|
119 |
pillow==10.3.0
|
120 |
platformdirs==4.2.1
|
121 |
+
plotly==5.22.0
|
122 |
pretty-errors==1.2.25
|
123 |
prometheus_client==0.20.0
|
124 |
prompt-toolkit==3.0.43
|
|
|
166 |
stack-data==0.6.3
|
167 |
starlette==0.37.2
|
168 |
sympy==1.12
|
169 |
+
tenacity==8.3.0
|
170 |
tensorboard==2.16.2
|
171 |
tensorboard-data-server==0.7.2
|
172 |
terminado==0.18.1
|