thesven commited on
Commit
040e502
1 Parent(s): d265dde

update to interactive plot

Browse files
Files changed (2) hide show
  1. app.py +24 -16
  2. requirements.txt +2 -0
app.py CHANGED
@@ -9,6 +9,8 @@ import gradio as gr
9
  import spaces
10
  import torch
11
  import pandas as pd
 
 
12
 
13
  # External imports
14
 
@@ -210,8 +212,6 @@ def create_test_dataloader(
210
  )
211
 
212
  def plot(ts_index, test_dataset, forecasts, prediction_length):
213
- fig, ax = plt.subplots(figsize=(12, 8), facecolor='white')
214
-
215
  # Length of the target data
216
  target_length = len(test_dataset[ts_index]['target'])
217
 
@@ -223,26 +223,34 @@ def plot(ts_index, test_dataset, forecasts, prediction_length):
223
  ).to_timestamp()
224
 
225
  # Plotting actual data
226
- ax.plot(
227
- index[:target_length],
228
- test_dataset[ts_index]['target'],
229
- label="Actual"
 
230
  )
231
 
232
  # Plotting the forecast data
233
- # Forecast starts right after the last actual data point
234
- forecast_start_index = target_length
235
- ax.plot(
236
- index[forecast_start_index:],
237
- forecasts[ts_index][0][:prediction_length], # Use forecasts[ts_index][0][:prediction_length] to slice the forecast values
238
- label="Prediction"
239
  )
240
 
241
- ax.set_ylim(0, 140000)
242
- ax.xaxis.set_major_locator(mdates.MonthLocator(bymonth=(1, 7)))
243
- ax.xaxis.set_minor_locator(mdates.MonthLocator())
 
 
 
 
 
 
 
 
 
244
 
245
- plt.legend()
246
  return fig
247
 
248
  def do_prediction(days_to_predict: int):
 
9
  import spaces
10
  import torch
11
  import pandas as pd
12
+ import plotly.graph_objects as go
13
+ from plotly.subplots import make_subplots
14
 
15
  # External imports
16
 
 
212
  )
213
 
214
  def plot(ts_index, test_dataset, forecasts, prediction_length):
 
 
215
  # Length of the target data
216
  target_length = len(test_dataset[ts_index]['target'])
217
 
 
223
  ).to_timestamp()
224
 
225
  # Plotting actual data
226
+ actual_data = go.Scatter(
227
+ x=index[:target_length],
228
+ y=test_dataset[ts_index]['target'],
229
+ name="Actual",
230
+ mode='lines',
231
  )
232
 
233
  # Plotting the forecast data
234
+ forecast_data = go.Scatter(
235
+ x=index[target_length:],
236
+ y=forecasts[ts_index][0][:prediction_length],
237
+ name="Prediction",
238
+ mode='lines',
 
239
  )
240
 
241
+ # Create the figure
242
+ fig = make_subplots(rows=1, cols=1)
243
+ fig.add_trace(actual_data, row=1, col=1)
244
+ fig.add_trace(forecast_data, row=1, col=1)
245
+
246
+ # Set layout and title
247
+ fig.update_layout(
248
+ xaxis_title="Date",
249
+ yaxis_title="Value",
250
+ title="Actual vs. Predicted Values",
251
+ xaxis_rangeslider_visible=True,
252
+ )
253
 
 
254
  return fig
255
 
256
  def do_prediction(days_to_predict: int):
requirements.txt CHANGED
@@ -118,6 +118,7 @@ parso==0.8.4
118
  pexpect==4.9.0
119
  pillow==10.3.0
120
  platformdirs==4.2.1
 
121
  pretty-errors==1.2.25
122
  prometheus_client==0.20.0
123
  prompt-toolkit==3.0.43
@@ -165,6 +166,7 @@ spaces==0.28.0
165
  stack-data==0.6.3
166
  starlette==0.37.2
167
  sympy==1.12
 
168
  tensorboard==2.16.2
169
  tensorboard-data-server==0.7.2
170
  terminado==0.18.1
 
118
  pexpect==4.9.0
119
  pillow==10.3.0
120
  platformdirs==4.2.1
121
+ plotly==5.22.0
122
  pretty-errors==1.2.25
123
  prometheus_client==0.20.0
124
  prompt-toolkit==3.0.43
 
166
  stack-data==0.6.3
167
  starlette==0.37.2
168
  sympy==1.12
169
+ tenacity==8.3.0
170
  tensorboard==2.16.2
171
  tensorboard-data-server==0.7.2
172
  terminado==0.18.1