# -*- coding: utf-8 -*- import joblib import pandas as pd import gradio as gr import logging from datetime import datetime, timedelta # logging to console for server-side errors (not shown in frontend) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # --------------------------- # Load ARIMA Model # --------------------------- def load_model(pkl_file): """Load a pickled ARIMA results object using joblib.""" try: model = joblib.load(pkl_file) return model except Exception as e: return str(e) # --------------------------- # Forecast function # --------------------------- def forecast_arima(days_ahead): """ Forecast the next `days_ahead` steps using a saved ARIMA model. Returns a tuple (DataFrame, message). On success, message is an empty string. On failure, returns an empty DataFrame and a descriptive error message. """ # Basic validation try: days = int(days_ahead) except Exception: return "Error: days must be an integer." if days <= 0: return "Error: days must be >= 1." model = load_model("arima_model.pkl") if isinstance(model, str): logger.error("Error loading model: %s", model) return pd.DataFrame(columns=['ds', 'yhat', 'yhat_lower', 'yhat_upper']) # Try the common statsmodels API paths try: # preferred: get_forecast on ARIMAResults res = model.get_forecast(steps=days) try: df = res.summary_frame() # Expect columns like 'mean', 'mean_ci_lower', 'mean_ci_upper' mean = df['mean'] if 'mean' in df.columns else df.iloc[:, 0] lower = df['mean_ci_lower'] if 'mean_ci_lower' in df.columns else (df.iloc[:, -2] if df.shape[1] >= 3 else pd.Series([pd.NA]*len(df))) upper = df['mean_ci_upper'] if 'mean_ci_upper' in df.columns else (df.iloc[:, -1] if df.shape[1] >= 2 else pd.Series([pd.NA]*len(df))) except Exception: # If summary_frame isn't available, try to use the object directly mean = pd.Series(res.predicted_mean if hasattr(res, 'predicted_mean') else res) lower = pd.Series([pd.NA]*len(mean)) upper = pd.Series([pd.NA]*len(mean)) except Exception: # fallback: some ARIMA models expose forecast(...) returning an array try: arr = model.forecast(steps=days) mean = pd.Series(arr) lower = pd.Series([pd.NA]*len(mean)) upper = pd.Series([pd.NA]*len(mean)) except Exception as e: logger.exception("Error during forecasting: %s", e) return pd.DataFrame(columns=['ds', 'yhat', 'yhat_lower', 'yhat_upper']) # Build ds (dates) if possible, otherwise use numeric index try: idx = mean.index if isinstance(idx, pd.DatetimeIndex): ds = idx.date else: # Try to infer last date from model endog/index last_date = None try: # many statsmodels models store the training index in model.data.row_labels labels = getattr(model, 'data', None) if labels is not None and hasattr(labels, 'row_labels'): rl = labels.row_labels if isinstance(rl, pd.DatetimeIndex) and len(rl) > 0: last_date = rl[-1] except Exception: last_date = None if last_date is not None: start = pd.to_datetime(last_date) + pd.Timedelta(days=1) ds = pd.date_range(start=start, periods=len(mean), freq='D').date else: # fallback to relative days from today start = pd.Timestamp.today().normalize() + pd.Timedelta(days=1) ds = pd.date_range(start=start, periods=len(mean), freq='D').date except Exception: ds = list(range(1, len(mean) + 1)) result = pd.DataFrame({ 'ds': ds, 'yhat': mean.values, 'yhat_lower': lower.values if len(lower) == len(mean) else [pd.NA]*len(mean), 'yhat_upper': upper.values if len(upper) == len(mean) else [pd.NA]*len(mean), }) return result # --------------------------- # Gradio Interface # --------------------------- def create_app(): with gr.Blocks() as app: gr.Markdown("## 📈 ARIMA Stock Price Forecast App") gr.Markdown("Provide the number of days to forecast using the saved ARIMA model.") with gr.Row(): days_input = gr.Number(label="Days to Forecast", value=30, precision=0) forecast_btn = gr.Button("Forecast") output_table = gr.Dataframe(headers=["Date", "Predicted", "Lower Bound", "Upper Bound"]) forecast_btn.click( fn=forecast_arima, inputs=[days_input], outputs=output_table ) return app if __name__ == '__main__': app = create_app() app.launch(share=True)