import streamlit as st import matplotlib.pyplot as plt import numpy as np import pandas as pd import torch from chronos import ChronosPipeline from sklearn.metrics import r2_score, mean_squared_error import io # Function to make predictions and visualize the forecast def make_forecast(data, time_series_field, prediction_length, ground_truth=None): # Initialize Chronos pipeline pipeline = ChronosPipeline.from_pretrained( "amazon/chronos-t5-small", device_map="auto", torch_dtype=torch.bfloat16, ) # Convert data to a compatible type time_series_data = data[time_series_field].astype(float) # Prepare context from selected time series field context = torch.tensor(time_series_data.values) # Make predictions forecast = pipeline.predict(context, prediction_length) # Calculate quantiles for visualization low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0) # Plot historical data and forecast st.set_option('deprecation.showPyplotGlobalUse', False) plt.figure(figsize=(8, 4)) plt.plot(time_series_data, color="royalblue", label="historical data") plt.plot(range(len(time_series_data), len(time_series_data) + prediction_length), median, color="tomato", label="median forecast") plt.fill_between(range(len(time_series_data), len(time_series_data) + prediction_length), low, high, color="tomato", alpha=0.3, label="80% prediction interval") plt.legend() plt.grid() st.pyplot() # Calculate evaluation metrics if ground truth is provided if ground_truth is not None: forecast_values = median[-len(ground_truth):] r2 = r2_score(ground_truth, forecast_values) rmse = np.sqrt(mean_squared_error(ground_truth, forecast_values)) st.write(f"R-squared (R2): {r2:.4f}") st.write(f"Root Mean Squared Error (RMSE): {rmse:.4f}") # Return forecasted values as a list return median[-prediction_length:].tolist() # Streamlit app def main(): st.title("Time Series Forecasting App") # Upload CSV file uploaded_file = st.file_uploader("Upload CSV file", type=["csv"]) if uploaded_file is not None: data = pd.read_csv(uploaded_file) # Display uploaded data st.subheader("Uploaded Data") st.write(data) # Select time series field time_series_field = st.selectbox("Select the time series field", data.columns) # Enter prediction length prediction_length = st.number_input("Enter the prediction length", min_value=1, value=12) # Input ground truth values ground_truth_str = st.text_input("Enter ground truth values (comma-separated)", "") ground_truth = None if ground_truth_str: ground_truth = [float(x.strip()) for x in ground_truth_str.split(",")] # Button to make predictions if st.button("Make Forecast"): st.subheader("Forecast Visualization") forecast_values = make_forecast(data, time_series_field, prediction_length, ground_truth) st.write("Forecasted Values:", forecast_values) if __name__ == "__main__": main()