Benchmark01 / app.py
CausaLInference's picture
Update app.py
fe68c82 verified
import pandas as pd
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from statsmodels.tsa.api import Holt, ExponentialSmoothing
from statsmodels.tsa.arima.model import ARIMA
def forecast_methods(train, h, methods):
forecast = []
titles = []
if methods['naive']:
naive_forecast = np.tile(train.iloc[-1], h)
forecast.append(naive_forecast)
titles.append("Naive")
if methods['mean']:
mean_forecast = np.tile(train.mean(), h)
forecast.append(mean_forecast)
titles.append("Mean")
if methods['drift']:
drift_forecast = train.iloc[-1] + (np.arange(1, h + 1) *
((train.iloc[-1] - train.iloc[0]) / (len(train) - 1)))
forecast.append(drift_forecast)
titles.append("Drift")
if methods['holt']:
holt_forecast = Holt(train).fit().forecast(h)
forecast.append(holt_forecast)
titles.append("Holt")
if methods['hw']:
hw_forecast = ExponentialSmoothing(train, seasonal='additive',
seasonal_periods=12).fit().forecast(h)
forecast.append(hw_forecast)
titles.append("HW Additive")
if methods['arima']:
arima_model = ARIMA(train, order=(1, 1, 1)).fit()
arima_forecast = arima_model.forecast(steps=h)
forecast.append(arima_forecast)
titles.append("ARIMA")
return forecast, titles
def forecast_benchmark(file, forecast_horizon, naive, mean, drift, holt, hw, arima):
data = pd.read_csv(file.name, header=None)
train = data.iloc[:, 0]
methods = {
'naive': naive,
'mean': mean,
'drift': drift,
'holt': holt,
'hw': hw,
'arima': arima,
}
forecasts, titles = forecast_methods(train, forecast_horizon, methods)
# Criar o gráfico
plt.figure(figsize=(10, 6))
plt.plot(train, label="Dados Atuais")
for forecast, title in zip(forecasts, titles):
plt.plot(np.arange(len(train), len(train) + len(forecast)), forecast, label=title)
plt.legend()
plt.title("Benchmark de Séries Temporais")
plt.grid(True)
plt.savefig("forecast_plot.png")
plt.close()
return "forecast_plot.png"
# Interface Gradio
iface = gr.Interface(
fn=forecast_benchmark,
inputs=[
gr.File(label="Upload CSV file"),
gr.Slider(minimum=1, maximum=60, step=1, value=24, label="Forecast Horizon"),
gr.Checkbox(label="Naive"),
gr.Checkbox(label="Mean"),
gr.Checkbox(label="Drift"),
gr.Checkbox(label="Holt"),
gr.Checkbox(label="Holt-Winters (Additive)"),
gr.Checkbox(label="ARIMA"),
],
outputs=gr.Image(label="Forecast Plot"),
title="Time Series Forecasting Benchmark",
description="Upload a CSV com dados de série temporal (1 coluna). Escolha os métodos de previsão para comparar."
)
iface.launch()