kashif's picture
kashif HF staff
use Agg
7cd99ed
raw
history blame contribute delete
No virus
1.24 kB
import gradio as gr
import pandas as pd
from gluonts.dataset.pandas import PandasDataset
from gluonts.dataset.split import split
from gluonts.torch.model.deepar import DeepAREstimator
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def fn(upload_data):
df = pd.read_csv(upload_data.name, index_col=0, parse_dates=True)
dataset = PandasDataset(df, target=df.columns[0])
training_data, test_gen = split(dataset, offset=-36)
model = DeepAREstimator(
prediction_length=12,
freq=dataset.freq,
trainer_kwargs=dict(max_epochs=10),
).train(
training_data=training_data,
)
test_data = test_gen.generate_instances(prediction_length=12, windows=3)
forecasts = list(model.predict(test_data.input))
fig = plt.figure()
df["#Passengers"].plot(color="black")
for forecast, color in zip(forecasts, ["green", "blue", "purple"]):
forecast.plot(color=f"tab:{color}")
plt.legend(["True values"], loc="upper left", fontsize="xx-large")
return fig
with gr.Blocks() as demo:
plot = gr.Plot()
upload_btn = gr.UploadButton()
upload_btn.upload(fn, inputs=upload_btn, outputs=plot)
if __name__ == "__main__":
demo.launch()