kashif's picture
kashif HF staff
use Agg
7cd99ed
import gradio as gr
import pandas as pd
from gluonts.dataset.pandas import PandasDataset
from gluonts.dataset.split import split
from gluonts.torch.model.deepar import DeepAREstimator
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def fn(upload_data):
df = pd.read_csv(upload_data.name, index_col=0, parse_dates=True)
dataset = PandasDataset(df, target=df.columns[0])
training_data, test_gen = split(dataset, offset=-36)
model = DeepAREstimator(
prediction_length=12,
freq=dataset.freq,
trainer_kwargs=dict(max_epochs=10),
).train(
training_data=training_data,
)
test_data = test_gen.generate_instances(prediction_length=12, windows=3)
forecasts = list(model.predict(test_data.input))
fig = plt.figure()
df["#Passengers"].plot(color="black")
for forecast, color in zip(forecasts, ["green", "blue", "purple"]):
forecast.plot(color=f"tab:{color}")
plt.legend(["True values"], loc="upper left", fontsize="xx-large")
return fig
with gr.Blocks() as demo:
plot = gr.Plot()
upload_btn = gr.UploadButton()
upload_btn.upload(fn, inputs=upload_btn, outputs=plot)
if __name__ == "__main__":
demo.launch()