|
import streamlit as st |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
from chronos import ChronosPipeline |
|
from sklearn.metrics import r2_score, mean_squared_error |
|
import io |
|
|
|
|
|
def make_forecast(data, time_series_field, prediction_length, ground_truth=None): |
|
|
|
pipeline = ChronosPipeline.from_pretrained( |
|
"amazon/chronos-t5-small", |
|
device_map="auto", |
|
torch_dtype=torch.bfloat16, |
|
) |
|
|
|
|
|
time_series_data = data[time_series_field].astype(float) |
|
|
|
|
|
context = torch.tensor(time_series_data.values) |
|
|
|
|
|
forecast = pipeline.predict(context, prediction_length) |
|
|
|
|
|
low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0) |
|
|
|
|
|
st.set_option('deprecation.showPyplotGlobalUse', False) |
|
plt.figure(figsize=(8, 4)) |
|
plt.plot(time_series_data, color="royalblue", label="historical data") |
|
plt.plot(range(len(time_series_data), len(time_series_data) + prediction_length), median, color="tomato", label="median forecast") |
|
plt.fill_between(range(len(time_series_data), len(time_series_data) + prediction_length), low, high, color="tomato", alpha=0.3, label="80% prediction interval") |
|
plt.legend() |
|
plt.grid() |
|
st.pyplot() |
|
|
|
|
|
if ground_truth is not None: |
|
forecast_values = median[-len(ground_truth):] |
|
r2 = r2_score(ground_truth, forecast_values) |
|
rmse = np.sqrt(mean_squared_error(ground_truth, forecast_values)) |
|
st.write(f"R-squared (R2): {r2:.4f}") |
|
st.write(f"Root Mean Squared Error (RMSE): {rmse:.4f}") |
|
|
|
|
|
return median[-prediction_length:].tolist() |
|
|
|
|
|
|
|
def main(): |
|
st.title("Time Series Forecasting App") |
|
|
|
|
|
uploaded_file = st.file_uploader("Upload CSV file", type=["csv"]) |
|
|
|
if uploaded_file is not None: |
|
data = pd.read_csv(uploaded_file) |
|
|
|
|
|
st.subheader("Uploaded Data") |
|
st.write(data) |
|
|
|
|
|
time_series_field = st.selectbox("Select the time series field", data.columns) |
|
|
|
|
|
prediction_length = st.number_input("Enter the prediction length", min_value=1, value=12) |
|
|
|
|
|
ground_truth_str = st.text_input("Enter ground truth values (comma-separated)", "") |
|
ground_truth = None |
|
if ground_truth_str: |
|
ground_truth = [float(x.strip()) for x in ground_truth_str.split(",")] |
|
|
|
|
|
if st.button("Make Forecast"): |
|
st.subheader("Forecast Visualization") |
|
forecast_values = make_forecast(data, time_series_field, prediction_length, ground_truth) |
|
st.write("Forecasted Values:", forecast_values) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|