st3 / app.py
uisikdag's picture
Update app.py
33a3d83 verified
import streamlit as st
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from chronos import ChronosPipeline
from sklearn.metrics import r2_score, mean_squared_error
import io
# Function to make predictions and visualize the forecast
def make_forecast(data, time_series_field, prediction_length, ground_truth=None):
# Initialize Chronos pipeline
pipeline = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-small",
device_map="auto",
torch_dtype=torch.bfloat16,
)
# Convert data to a compatible type
time_series_data = data[time_series_field].astype(float)
# Prepare context from selected time series field
context = torch.tensor(time_series_data.values)
# Make predictions
forecast = pipeline.predict(context, prediction_length)
# Calculate quantiles for visualization
low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)
# Plot historical data and forecast
st.set_option('deprecation.showPyplotGlobalUse', False)
plt.figure(figsize=(8, 4))
plt.plot(time_series_data, color="royalblue", label="historical data")
plt.plot(range(len(time_series_data), len(time_series_data) + prediction_length), median, color="tomato", label="median forecast")
plt.fill_between(range(len(time_series_data), len(time_series_data) + prediction_length), low, high, color="tomato", alpha=0.3, label="80% prediction interval")
plt.legend()
plt.grid()
st.pyplot()
# Calculate evaluation metrics if ground truth is provided
if ground_truth is not None:
forecast_values = median[-len(ground_truth):]
r2 = r2_score(ground_truth, forecast_values)
rmse = np.sqrt(mean_squared_error(ground_truth, forecast_values))
st.write(f"R-squared (R2): {r2:.4f}")
st.write(f"Root Mean Squared Error (RMSE): {rmse:.4f}")
# Return forecasted values as a list
return median[-prediction_length:].tolist()
# Streamlit app
def main():
st.title("Time Series Forecasting App")
# Upload CSV file
uploaded_file = st.file_uploader("Upload CSV file", type=["csv"])
if uploaded_file is not None:
data = pd.read_csv(uploaded_file)
# Display uploaded data
st.subheader("Uploaded Data")
st.write(data)
# Select time series field
time_series_field = st.selectbox("Select the time series field", data.columns)
# Enter prediction length
prediction_length = st.number_input("Enter the prediction length", min_value=1, value=12)
# Input ground truth values
ground_truth_str = st.text_input("Enter ground truth values (comma-separated)", "")
ground_truth = None
if ground_truth_str:
ground_truth = [float(x.strip()) for x in ground_truth_str.split(",")]
# Button to make predictions
if st.button("Make Forecast"):
st.subheader("Forecast Visualization")
forecast_values = make_forecast(data, time_series_field, prediction_length, ground_truth)
st.write("Forecasted Values:", forecast_values)
if __name__ == "__main__":
main()