Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import yfinance as yf | |
| import pandas as pd | |
| import numpy as np | |
| from sklearn.preprocessing import MinMaxScaler | |
| import matplotlib.pyplot as plt | |
| import pickle | |
| from datetime import datetime, timedelta | |
| # Define the LSTM model architecture | |
| class LSTMModel(nn.Module): | |
| def __init__(self, input_size=1, hidden_size=50, num_layers=1, output_size=1, dropout=0.2): | |
| super(LSTMModel, self).__init__() | |
| self.hidden_size = hidden_size | |
| self.num_layers = num_layers | |
| self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, | |
| num_layers=num_layers, batch_first=True, | |
| dropout=dropout if num_layers > 1 else 0) | |
| self.dropout = nn.Dropout(dropout) | |
| self.linear = nn.Linear(hidden_size, output_size) | |
| def forward(self, x): | |
| h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size) | |
| c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size) | |
| out, _ = self.lstm(x, (h0, c0)) | |
| out = out[:, -1, :] | |
| out = self.dropout(out) | |
| out = self.linear(out) | |
| return out | |
| def safe_load_model(): | |
| """Safely load model and scaler""" | |
| try: | |
| # Try weights_only=True first (secure) | |
| checkpoint = torch.load('lstm_stock_model.pth', map_location='cpu', weights_only=True) | |
| scaler = None | |
| except: | |
| try: | |
| # Fallback to weights_only=False | |
| checkpoint = torch.load('lstm_stock_model.pth', map_location='cpu', weights_only=False) | |
| scaler = checkpoint.get('scaler', None) | |
| except Exception as e: | |
| raise Exception(f"Failed to load model: {e}") | |
| # Load model architecture | |
| model = LSTMModel() | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.eval() | |
| # Load scaler from separate file if not in checkpoint | |
| if scaler is None: | |
| try: | |
| with open('scaler.pkl', 'rb') as f: | |
| scaler = pickle.load(f) | |
| except: | |
| raise Exception("Scaler not found. Please ensure scaler.pkl exists.") | |
| sequence_length = checkpoint.get('sequence_length', 60) | |
| return model, scaler, sequence_length | |
| def predict_stock(ticker="AAPL", days=30): | |
| try: | |
| # Load model safely | |
| model, scaler, sequence_length = safe_load_model() | |
| # Fetch recent stock data | |
| print(f"Fetching data for {ticker}...") | |
| stock_data = yf.download(ticker, period="2y", interval="1d") | |
| if stock_data.empty: | |
| return create_error_plot(f"No data found for {ticker}") | |
| # Use closing prices | |
| closing_prices = stock_data['Close'].values.reshape(-1, 1) | |
| # Scale the data | |
| scaled_data = scaler.transform(closing_prices) | |
| # Create sequence for prediction | |
| if len(scaled_data) >= sequence_length: | |
| last_sequence = scaled_data[-sequence_length:].reshape(1, sequence_length, 1) | |
| else: | |
| padding = np.full((sequence_length - len(scaled_data), 1), scaled_data[0, 0]) | |
| last_sequence = np.vstack([padding, scaled_data]).reshape(1, sequence_length, 1) | |
| # Generate predictions | |
| predictions = [] | |
| current_sequence = torch.FloatTensor(last_sequence) | |
| with torch.no_grad(): | |
| for _ in range(days): | |
| next_pred = model(current_sequence) | |
| predictions.append(next_pred.item()) | |
| # Update sequence | |
| new_sequence = torch.cat([ | |
| current_sequence[:, 1:, :], | |
| next_pred.reshape(1, 1, 1) | |
| ], dim=1) | |
| current_sequence = new_sequence | |
| # Convert back to original scale | |
| predictions_array = np.array(predictions).reshape(-1, 1) | |
| predictions_original = scaler.inverse_transform(predictions_array).flatten() | |
| # Create plot | |
| return create_forecast_plot(stock_data, predictions_original, ticker, days) | |
| except Exception as e: | |
| print(f"Error in prediction: {e}") | |
| return create_error_plot(str(e)) | |
| def create_forecast_plot(stock_data, predictions, ticker, days): | |
| """Create forecast plot""" | |
| from datetime import timedelta | |
| last_date = stock_data.index[-1] | |
| forecast_dates = [last_date + timedelta(days=i+1) for i in range(days)] | |
| plt.figure(figsize=(12, 6)) | |
| # Historical data | |
| historical_days = min(100, len(stock_data)) | |
| plt.plot(stock_data.index[-historical_days:], | |
| stock_data['Close'][-historical_days:], | |
| label='Historical Prices', color='blue', linewidth=2) | |
| # Forecast | |
| plt.plot(forecast_dates, predictions, | |
| label='Forecast', color='red', linewidth=2, linestyle='--', marker='o') | |
| plt.title(f'{ticker} Stock Price Forecast - Next {days} Days') | |
| plt.xlabel('Date') | |
| plt.ylabel('Price (USD)') | |
| plt.legend() | |
| plt.grid(True, alpha=0.3) | |
| plt.xticks(rotation=45) | |
| plt.tight_layout() | |
| return plt | |
| def create_error_plot(error_message): | |
| """Create error plot""" | |
| plt.figure(figsize=(10, 6)) | |
| plt.text(0.5, 0.5, f'Error: {error_message}', | |
| ha='center', va='center', transform=plt.gca().transAxes, | |
| fontsize=12, bbox=dict(boxstyle="round,pad=0.3", facecolor="lightcoral")) | |
| plt.title('Prediction Error') | |
| plt.axis('off') | |
| return plt | |
| # Create Gradio interface | |
| iface = gr.Interface( | |
| fn=predict_stock, | |
| inputs=[ | |
| gr.Textbox(value="AAPL", label="Stock Ticker"), | |
| gr.Number(value=30, label="Days to Forecast", minimum=1, maximum=365) | |
| ], | |
| outputs=gr.Plot(label="Forecast Results"), | |
| title="Stock Price Forecaster (PyTorch LSTM)", | |
| description="Predict future stock prices using LSTM neural networks.", | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() |