Spaces:
Runtime error
Runtime error
| import os | |
| import gradio as gr | |
| import tensorflow as tf | |
| import joblib | |
| import numpy as np | |
| import pandas as pd | |
| import yfinance as yf | |
| from huggingface_hub import hf_hub_download | |
| # --- 0. Force CPU-only mode for TensorFlow --- | |
| # This prevents TensorFlow from trying to allocate GPU memory on a CPU-only instance. | |
| os.environ['CUDA_VISIBLE_DEVICES'] = '-1' | |
| # --- 1. Define Constants and Download Model/Scalers from Hugging Face Hub --- | |
| MODEL_REPO = "munem420/stock-forecaster-lstm" | |
| MODEL_FILENAME = "model_lstm.keras" | |
| SCALER_FILENAME = "scalers.joblib" | |
| print("--- Downloading model and scalers from Hugging Face Hub ---") | |
| try: | |
| model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME) | |
| scalers_path = hf_hub_download(repo_id=MODEL_REPO, filename=SCALER_FILENAME) | |
| print("β Files downloaded successfully.") | |
| except Exception as e: | |
| print(f"β Critical Error: Could not download files from the Hub. {e}") | |
| # Set paths to None so the app knows that loading failed. | |
| model_path, scalers_path = None, None | |
| # --- 2. Load the Model and Scalers into Memory --- | |
| loaded_model_lstm = None | |
| loaded_scalers = None | |
| if model_path and os.path.exists(model_path): | |
| try: | |
| loaded_model_lstm = tf.keras.models.load_model(model_path) | |
| print("β TensorFlow model loaded successfully.") | |
| except Exception as e: | |
| print(f"β Critical Error: Could not load the TensorFlow model. {e}") | |
| if scalers_path and os.path.exists(scalers_path): | |
| try: | |
| loaded_scalers = joblib.load(scalers_path) | |
| print("β Scalers loaded successfully.") | |
| except Exception as e: | |
| print(f"β Critical Error: Could not load the scalers file. {e}") | |
| # --- 3. The Core Forecasting Function --- | |
| def forecast_stock(input_name: str, input_width: int = 60) -> str: | |
| """ | |
| Fetches live stock data, preprocesses it, and returns a prediction string. | |
| """ | |
| # Fail fast if the model/scalers didn't load during startup | |
| if not loaded_model_lstm or not loaded_scalers: | |
| return "Error: Model or scalers are not loaded. The backend may have failed to start correctly. Check the Space logs." | |
| ticker = input_name.strip().upper() | |
| if not ticker: | |
| return "Error: Please enter a stock ticker." | |
| print(f"\n--- Generating forecast for {ticker} ---") | |
| # Fetch recent data using yfinance | |
| try: | |
| # Fetch more than needed to ensure we have enough valid trading days | |
| data_df = yf.download(ticker, period="200d", progress=False) | |
| if data_df.empty: | |
| return f"Error: No data found for ticker '{ticker}'. It may be delisted or an invalid symbol." | |
| except Exception as e: | |
| return f"Error fetching data for '{ticker}': {e}" | |
| if len(data_df) < input_width: | |
| return f"Error: Not enough historical data for {ticker}. Need {input_width} days, but only found {len(data_df)}." | |
| # Prepare the data for the model | |
| recent_data = data_df.tail(input_width) | |
| close_prices = recent_data['Close'].values.reshape(-1, 1) | |
| # Find the correct scaler. The original model was trained on specific stocks. | |
| # We try to find a matching scaler, otherwise, we use a default as a fallback. | |
| scaler = loaded_scalers.get(ticker) | |
| if not scaler: | |
| print(f"Warning: No specific scaler found for {ticker}. Using ZURVY's scaler as a fallback.") | |
| scaler = loaded_scalers.get('ZURVY') | |
| if not scaler: | |
| return "Error: Critical failure. The default 'ZURVY' scaler could not be found." | |
| # Scale the data and make a prediction | |
| try: | |
| scaled_data = scaler.transform(close_prices) | |
| X_pred = scaled_data.reshape(1, input_width, 1) # Reshape for LSTM: [batch, timesteps, features] | |
| prediction_scaled = loaded_model_lstm.predict(X_pred, verbose=0)[0][0] | |
| prediction_actual = scaler.inverse_transform(np.array([[prediction_scaled]]))[0][0] | |
| except Exception as e: | |
| return f"An error occurred during model prediction: {e}" | |
| # Format the final result | |
| last_close = recent_data['Close'].iloc[-1] | |
| result_str = ( | |
| f"Forecast for: {ticker}\n" | |
| f"Last Close Price: ${last_close:.2f}\n" | |
| f"Predicted Next Day's Close: ${prediction_actual:.2f}" | |
| ) | |
| print(result_str) | |
| return result_str | |
| # --- 4. Create the Gradio Interface and API Endpoint --- | |
| def predict_api(ticker_symbol: str) -> str: | |
| """A simple wrapper for the main forecast function to be exposed as an API.""" | |
| return forecast_stock(ticker_symbol) | |
| with gr.Blocks(title="Stock Forecaster Backend") as app: | |
| gr.Markdown("## Stock Forecaster Backend\nThis Gradio app serves the API for the React frontend.") | |
| # These components are not visible but are required to create the API endpoint | |
| ticker_input = gr.Textbox(label="Stock Ticker", visible=False) | |
| output_text = gr.Textbox(label="Forecast", visible=False) | |
| # This creates the API endpoint at /run/predict | |
| ticker_input.submit( | |
| fn=predict_api, | |
| inputs=[ticker_input], | |
| outputs=[output_text], | |
| api_name="predict" | |
| ) | |
| # --- 5. Mount the static React build directory to be served --- | |
| # This requires a recent version of Gradio (e.g., 4.x), specified in README.md | |
| app = gr.mount_static_directory(app, "build") | |
| # --- 6. Launch the Gradio App --- | |
| if __name__ == "__main__": | |
| app.launch() |