kuldeep0204's picture
Update app.py
4d919ad verified
import gradio as gr
import numpy as np
import torch
import pandas as pd
from chronos import ChronosPipeline
from io import StringIO
# --- Model Loading ---
# This part is outside the function so it only runs once when the app starts
try:
model_name = "amazon/chronos-t5-small"
pipeline = ChronosPipeline.from_pretrained(
model_name,
device_map="cpu", # Force CPU usage for free tier
torch_dtype=torch.float32,
)
print(f"Loaded model: {model_name}")
except Exception as e:
# A fallback in case the model fails to load
print(f"Error loading model: {e}")
pipeline = None
# --- Prediction Function ---
def forecast_time_series(csv_file, prediction_length):
"""
Takes a CSV file, extracts the last column (time series), and forecasts.
"""
if pipeline is None:
return "Model failed to load. Please check logs/dependencies."
try:
# Read the CSV file content from the Gradio InputFile
content = csv_file.read().decode('utf-8')
df = pd.read_csv(StringIO(content))
# Assume the time series data is in the last column
# and has no missing values
historical_data = df.iloc[:, -1].values
if len(historical_data) < 50:
return "Please upload a time series with at least 50 historical points for a good forecast."
# Convert historical data to the required format
historical_series = torch.tensor(historical_data, dtype=torch.float32)
# Generate the forecast
forecast_samples = pipeline.predict(
historical_series,
prediction_length=int(prediction_length),
num_samples=20, # Number of probabilistic paths to generate
)
# Calculate the median for the central prediction line
median_forecast = np.quantile(forecast_samples.numpy(), 0.5, axis=0)
# Prepare the output data for plotting
historical_index = np.arange(len(historical_data))
forecast_index = np.arange(len(historical_data), len(historical_data) + int(prediction_length))
# Create a single plot with both historical and forecast data
plot_data = {
"Historical": list(historical_data),
"Forecast": list(median_forecast),
}
return {
"Historical": (historical_index, historical_data),
"Forecast": (forecast_index, median_forecast)
}
except Exception as e:
return f"An error occurred: {e}"
# --- Gradio Interface Setup ---
# Define the example input file structure (for user convenience)
example_data = [
[
'date,value\n2025-01-01,10.0\n2025-01-02,11.5\n...\n2025-03-20,15.2',
7
] # A sample input isn't a file, so it can't be added directly here.
# Users will need to upload a CSV file manually.
]
gr_plot = gr.Plot(label="Time Series Forecast (Historical + Predicted Median)")
gr.Interface(
fn=forecast_time_series,
inputs=[
gr.File(label="Upload a CSV file (Time series must be in the last column)"),
gr.Slider(minimum=7, maximum=30, step=1, value=14, label="Number of Future Steps (Days) to Predict"),
],
outputs=gr_plot,
title="Chronos Time Series Forecasting Demo on Hugging Face",
description="Upload a CSV file containing a single historical time series. This demo uses the Chronos-T5-Small Foundation Model to generate a 14-day (default) forecast.",
examples=None,
live=False,
).launch()