kashif HF staff commited on
Commit
f1bfd9d
1 Parent(s): 59dac5b

add distribution type

Browse files
Files changed (1) hide show
  1. app.py +20 -1
app.py CHANGED
@@ -1,8 +1,14 @@
1
  import gradio as gr
2
  import pandas as pd
 
3
  from gluonts.dataset.pandas import PandasDataset
4
  from gluonts.dataset.split import split
5
  from gluonts.torch.model.deepar import DeepAREstimator
 
 
 
 
 
6
  from gluonts.evaluation import Evaluator, make_evaluation_predictions
7
 
8
  from make_plot import plot_forecast, plot_train_test
@@ -32,6 +38,7 @@ def train_and_forecast(
32
  prediction_length,
33
  rolling_windows,
34
  epochs,
 
35
  progress=gr.Progress(track_tqdm=True),
36
  ):
37
  if not input_data:
@@ -54,7 +61,14 @@ def train_and_forecast(
54
 
55
  training_data, test_gen = split(gluon_df, offset=row_offset)
56
 
 
 
 
 
 
 
57
  estimator = DeepAREstimator(
 
58
  prediction_length=prediction_length,
59
  freq=gluon_df.freq,
60
  trainer_kwargs=dict(max_epochs=epochs),
@@ -108,6 +122,11 @@ with gr.Blocks() as demo:
108
  )
109
  windows = gr.Number(value=3, label="Number of Windows", precision=0)
110
  epochs = gr.Number(value=10, label="Number of Epochs", precision=0)
 
 
 
 
 
111
 
112
  with gr.Row(label="Dataset"):
113
  upload_btn = gr.UploadButton(label="Upload")
@@ -122,7 +141,7 @@ with gr.Blocks() as demo:
122
  )
123
  train_btn.click(
124
  fn=train_and_forecast,
125
- inputs=[upload_btn, prediction_length, windows, epochs],
126
  outputs=[plot, json],
127
  )
128
 
 
1
  import gradio as gr
2
  import pandas as pd
3
+
4
  from gluonts.dataset.pandas import PandasDataset
5
  from gluonts.dataset.split import split
6
  from gluonts.torch.model.deepar import DeepAREstimator
7
+ from gluonts.torch.distributions import (
8
+ NegativeBinomialOutput,
9
+ StudentTOutput,
10
+ NormalOutput,
11
+ )
12
  from gluonts.evaluation import Evaluator, make_evaluation_predictions
13
 
14
  from make_plot import plot_forecast, plot_train_test
 
38
  prediction_length,
39
  rolling_windows,
40
  epochs,
41
+ distribution,
42
  progress=gr.Progress(track_tqdm=True),
43
  ):
44
  if not input_data:
 
61
 
62
  training_data, test_gen = split(gluon_df, offset=row_offset)
63
 
64
+ if distribution == "StudentT":
65
+ distr_output = StudentTOutput()
66
+ elif distribution == "Normal":
67
+ distr_output = NormalOutput()
68
+ else:
69
+ distr_output = NegativeBinomialOutput()
70
  estimator = DeepAREstimator(
71
+ distr_output=distr_output,
72
  prediction_length=prediction_length,
73
  freq=gluon_df.freq,
74
  trainer_kwargs=dict(max_epochs=epochs),
 
122
  )
123
  windows = gr.Number(value=3, label="Number of Windows", precision=0)
124
  epochs = gr.Number(value=10, label="Number of Epochs", precision=0)
125
+ distribution = gr.Radio(
126
+ choices=["StudentT", "Negative Binomial", "Normal"],
127
+ value="StudentT",
128
+ label="Distribution",
129
+ )
130
 
131
  with gr.Row(label="Dataset"):
132
  upload_btn = gr.UploadButton(label="Upload")
 
141
  )
142
  train_btn.click(
143
  fn=train_and_forecast,
144
+ inputs=[upload_btn, prediction_length, windows, epochs, distribution],
145
  outputs=[plot, json],
146
  )
147