Commit
•
97ab62b
0
Parent(s):
Duplicate from derek-thomas/probabilistic-forecast
Browse files- .gitattributes +34 -0
- .gitignore +2 -0
- AirPassengers.csv +1 -0
- README.md +14 -0
- app.py +74 -0
- make_plot.py +114 -0
- packages.txt +0 -0
- requirements.txt +3 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
.idea
|
2 |
+
lightning_logs
|
AirPassengers.csv
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Month,#Passengers
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Probablistic Forecasting
|
3 |
+
emoji: 🐨
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: gray
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.27.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
+
duplicated_from: derek-thomas/probabilistic-forecast
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import pandas as pd
|
3 |
+
from gluonts.dataset.pandas import PandasDataset
|
4 |
+
from gluonts.dataset.split import split
|
5 |
+
from gluonts.torch.model.deepar import DeepAREstimator
|
6 |
+
|
7 |
+
from make_plot import plot_forecast, plot_train_test
|
8 |
+
|
9 |
+
|
10 |
+
def offset_calculation(prediction_length, rolling_windows, length):
|
11 |
+
row_offset = -1 * prediction_length * rolling_windows
|
12 |
+
if abs(row_offset) > 0.95 * length:
|
13 |
+
raise gr.Error("Reduce prediction_length * rolling_windows")
|
14 |
+
return row_offset
|
15 |
+
|
16 |
+
|
17 |
+
def preprocess(input_data, prediction_length, rolling_windows, progress=gr.Progress(track_tqdm=True)):
|
18 |
+
df = pd.read_csv(input_data.name, index_col=0, parse_dates=True)
|
19 |
+
row_offset = offset_calculation(prediction_length, rolling_windows, len(df))
|
20 |
+
return plot_train_test(df.iloc[:row_offset], df.iloc[row_offset:])
|
21 |
+
|
22 |
+
|
23 |
+
def train_and_forecast(input_data, prediction_length, rolling_windows, epochs, progress=gr.Progress(track_tqdm=True)):
|
24 |
+
if not input_data:
|
25 |
+
raise gr.Error("Upload a file with the Upload button")
|
26 |
+
try:
|
27 |
+
df = pd.read_csv(input_data.name, index_col=0, parse_dates=True)
|
28 |
+
except AttributeError:
|
29 |
+
raise gr.Error("Upload a file with the Upload button")
|
30 |
+
|
31 |
+
row_offset = offset_calculation(prediction_length, rolling_windows, len(df))
|
32 |
+
|
33 |
+
gluon_df = PandasDataset(df, target=df.columns[0])
|
34 |
+
|
35 |
+
training_data, test_gen = split(gluon_df, offset=row_offset)
|
36 |
+
|
37 |
+
model = DeepAREstimator(
|
38 |
+
prediction_length=prediction_length,
|
39 |
+
freq=gluon_df.freq,
|
40 |
+
trainer_kwargs=dict(max_epochs=epochs),
|
41 |
+
).train(
|
42 |
+
training_data=training_data,
|
43 |
+
)
|
44 |
+
|
45 |
+
test_data = test_gen.generate_instances(prediction_length=prediction_length, windows=rolling_windows)
|
46 |
+
forecasts = list(model.predict(test_data.input))
|
47 |
+
return plot_forecast(df, forecasts)
|
48 |
+
|
49 |
+
|
50 |
+
with gr.Blocks() as demo:
|
51 |
+
gr.Markdown("""
|
52 |
+
# How to use
|
53 |
+
Upload a univariate csv with the first column showing your dates and the second column having your data
|
54 |
+
|
55 |
+
# How it works
|
56 |
+
1. Click **Upload** to upload your data
|
57 |
+
2. Click **Run**
|
58 |
+
- This app will visualize your data and then train an estimator and show its predictions
|
59 |
+
""")
|
60 |
+
with gr.Accordion(label='Hyperparameters'):
|
61 |
+
with gr.Row():
|
62 |
+
prediction_length = gr.Number(value=12, label='Prediction Length', precision=0)
|
63 |
+
windows = gr.Number(value=3, label='Number of Windows', precision=0)
|
64 |
+
epochs = gr.Number(value=10, label='Number of Epochs', precision=0)
|
65 |
+
with gr.Row():
|
66 |
+
upload_btn = gr.UploadButton(label="Upload")
|
67 |
+
train_btn = gr.Button(label="Train and Forecast")
|
68 |
+
plot = gr.Plot()
|
69 |
+
|
70 |
+
upload_btn.upload(fn=preprocess, inputs=[upload_btn, prediction_length, windows], outputs=plot)
|
71 |
+
train_btn.click(fn=train_and_forecast, inputs=[upload_btn, prediction_length, epochs, windows], outputs=plot)
|
72 |
+
|
73 |
+
if __name__ == "__main__":
|
74 |
+
demo.queue().launch()
|
make_plot.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
import plotly.graph_objects as go
|
6 |
+
|
7 |
+
|
8 |
+
def plot_train_test(df1: pd.DataFrame, df2: pd.DataFrame) -> go.Figure:
|
9 |
+
"""
|
10 |
+
Plot the training and test datasets using Plotly.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
df1 (pd.DataFrame): Train dataset
|
14 |
+
df2 (pd.DataFrame): Test dataset
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
None
|
18 |
+
"""
|
19 |
+
|
20 |
+
# Create a Plotly figure
|
21 |
+
fig = go.Figure()
|
22 |
+
|
23 |
+
# Add the first scatter plot with steelblue color
|
24 |
+
fig.add_trace(go.Scatter(
|
25 |
+
x=df1.index,
|
26 |
+
y=df1.iloc[:, 0],
|
27 |
+
mode='lines',
|
28 |
+
name='Training Data',
|
29 |
+
line=dict(color='steelblue'),
|
30 |
+
marker=dict(color='steelblue')
|
31 |
+
))
|
32 |
+
|
33 |
+
# Add the second scatter plot with yellow color
|
34 |
+
fig.add_trace(go.Scatter(
|
35 |
+
x=df2.index,
|
36 |
+
y=df2.iloc[:, 0],
|
37 |
+
mode='lines',
|
38 |
+
name='Test Data',
|
39 |
+
line=dict(color='gold'),
|
40 |
+
marker=dict(color='gold')
|
41 |
+
))
|
42 |
+
|
43 |
+
# Customize the layout
|
44 |
+
fig.update_layout(
|
45 |
+
title='Univariate Time Series',
|
46 |
+
xaxis=dict(title='Date'),
|
47 |
+
yaxis=dict(title='Value'),
|
48 |
+
showlegend=True,
|
49 |
+
template='plotly_white'
|
50 |
+
)
|
51 |
+
return fig
|
52 |
+
|
53 |
+
|
54 |
+
def plot_forecast(df: pd.DataFrame, forecasts: List[pd.DataFrame]) -> go.Figure:
|
55 |
+
"""
|
56 |
+
Plot the true values and forecasts using Plotly.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
df (pd.DataFrame): DataFrame with the true values. Assumed to have an index and columns.
|
60 |
+
forecasts (List[pd.DataFrame]): List of DataFrames containing the forecasts.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
go.Figure: Plotly figure object.
|
64 |
+
"""
|
65 |
+
|
66 |
+
# Create a Plotly figure
|
67 |
+
fig = go.Figure()
|
68 |
+
|
69 |
+
# Add the true values trace
|
70 |
+
fig.add_trace(go.Scatter(
|
71 |
+
x=pd.to_datetime(df.index),
|
72 |
+
y=df.iloc[:, 0],
|
73 |
+
mode='lines',
|
74 |
+
name='True values',
|
75 |
+
line=dict(color='black')
|
76 |
+
))
|
77 |
+
|
78 |
+
# Add the forecast traces
|
79 |
+
colors = ["green", "blue", "purple"]
|
80 |
+
for i, forecast in enumerate(forecasts):
|
81 |
+
color = colors[i]
|
82 |
+
for sample in forecast.samples:
|
83 |
+
fig.add_trace(go.Scatter(
|
84 |
+
x=forecast.index.to_timestamp(),
|
85 |
+
y=sample,
|
86 |
+
mode='lines',
|
87 |
+
opacity=0.15, # Adjust opacity to control visibility of individual samples
|
88 |
+
name=f'Forecast {i + 1}',
|
89 |
+
showlegend=False, # Hide the individual forecast series from the legend
|
90 |
+
hoverinfo='none', # Disable hover information for the forecast series
|
91 |
+
line=dict(color=color)
|
92 |
+
))
|
93 |
+
# Add the average
|
94 |
+
mean_forecast = np.mean(forecast.samples, axis=0)
|
95 |
+
fig.add_trace(go.Scatter(
|
96 |
+
x=forecast.index.to_timestamp(),
|
97 |
+
y=mean_forecast,
|
98 |
+
mode='lines',
|
99 |
+
name=f'Mean Forecast',
|
100 |
+
line=dict(color='red', dash='dash')
|
101 |
+
))
|
102 |
+
|
103 |
+
# Customize the layout
|
104 |
+
fig.update_layout(
|
105 |
+
title='Passenger Forecast',
|
106 |
+
xaxis=dict(title='Index'),
|
107 |
+
yaxis=dict(title='Passenger Count'),
|
108 |
+
showlegend=True,
|
109 |
+
legend=dict(x=0, y=1, font=dict(size=16)),
|
110 |
+
hovermode='x' # Enable x-axis hover for better interactivity
|
111 |
+
)
|
112 |
+
|
113 |
+
# Return the figure
|
114 |
+
return fig
|
packages.txt
ADDED
File without changes
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
gluonts[torch,pro]
|
2 |
+
pandas
|
3 |
+
plotly
|