import streamlit as st
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from chronos import ChronosPipeline
from sklearn.metrics import r2_score, mean_squared_error
import io
# Function to make predictions and visualize the forecast
def make_forecast(data, time_series_field, prediction_length, ground_truth=None):
# Initialize Chronos pipeline
pipeline = ChronosPipeline.from_pretrained(
# Convert data to a compatible type
time_series_data = data[time_series_field].astype(float)
# Prepare context from selected time series field
context = torch.tensor(time_series_data.values)
# Make predictions
forecast = pipeline.predict(context, prediction_length)
# Calculate quantiles for visualization
low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)
# Plot historical data and forecast
st.set_option('deprecation.showPyplotGlobalUse', False)
plt.figure(figsize=(8, 4))
plt.plot(time_series_data, color="royalblue", label="historical data")
plt.plot(range(len(time_series_data), len(time_series_data) + prediction_length), median, color="tomato", label="median forecast")
plt.fill_between(range(len(time_series_data), len(time_series_data) + prediction_length), low, high, color="tomato", alpha=0.3, label="80% prediction interval")
# Calculate evaluation metrics if ground truth is provided
if ground_truth is not None:
forecast_values = median[-len(ground_truth):]
r2 = r2_score(ground_truth, forecast_values)
rmse = np.sqrt(mean_squared_error(ground_truth, forecast_values))
st.write(f"R-squared (R2): {r2:.4f}")
st.write(f"Root Mean Squared Error (RMSE): {rmse:.4f}")
# Return forecasted values as a list
return median[-prediction_length:].tolist()
# Streamlit app
def main():
st.title("Time Series Forecasting App")
# Upload CSV file
uploaded_file = st.file_uploader("Upload CSV file", type=["csv"])
if uploaded_file is not None:
data = pd.read_csv(uploaded_file)
# Display uploaded data
st.subheader("Uploaded Data")
# Select time series field
time_series_field = st.selectbox("Select the time series field", data.columns)
# Enter prediction length
prediction_length = st.number_input("Enter the prediction length", min_value=1, value=12)
# Input ground truth values
ground_truth_str = st.text_input("Enter ground truth values (comma-separated)", "")
ground_truth = None
if ground_truth_str:
ground_truth = [float(x.strip()) for x in ground_truth_str.split(",")]
# Button to make predictions
if st.button("Make Forecast"):
st.subheader("Forecast Visualization")
forecast_values = make_forecast(data, time_series_field, prediction_length, ground_truth)
st.write("Forecasted Values:", forecast_values)
if __name__ == "__main__":