TD3 Model for AAPL Stock Prediction

Model Description

This model is a TD3 (Twin Delayed Deep Deterministic Policy Gradient) algorithm applied for stock price prediction, specifically trained on AAPL (Apple Inc.) stock data. The TD3 model is a reinforcement learning agent that interacts with a stock market environment and is trained to predict and maximize the return from trading AAPL stock.

Key Features

  • Asset: AAPL (Apple Inc.) Stock
  • Model Type: TD3 (Twin Delayed DDPG)
  • Action Space: Continuous (Buy, Sell, Hold decisions)
  • Reward: Modeled on cumulative returns
  • Training Data: Historical stock prices and related financial indicators for AAPL stock
  • Environment: Custom stock trading environment simulating price movement and portfolio management
  • Framework: PyTorch

Quick Run

To use this model for stock prediction and trading, install the required dependencies and load the model via Hugging Face. Here is an example code snippet:

import torch
from huggingface_hub import hf_hub_download
import torch.nn as nn
import numpy as np

# Download the model
model_path = hf_hub_download(repo_id="siddheshtv/td3-stock-aapl", filename="td3_stock_prediction_model_AAPL_full.pth")

# Load the model
checkpoint = torch.load(model_path)

# Recreate the Actor class
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action):
        super(Actor, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 400),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(400, 300),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(300, action_dim),
            nn.Tanh()
        )
        self.max_action = max_action

    def forward(self, state):
        return self.max_action * self.net(state)

# Instantiate the model
model = Actor(checkpoint['state_dim'], checkpoint['action_dim'], checkpoint['max_action'])
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()  # Set the model to evaluation mode

# Function to select action
def select_action(state):
    with torch.no_grad():
        state = torch.FloatTensor(state.reshape(1, -1))
        return model(state).cpu().data.numpy().flatten()

# Example usage
state = np.random.rand(checkpoint['state_dim'])  # Replace with actual state data
action = select_action(state)
print(f"Predicted action: {action}")

Citation

@misc{siddheshtv-td3,
  title={TD3 Model for AAPL Stock Prediction},
  author={Siddhesh Kulthe},
  year={2024},
  howpublished={\url{https://huggingface.co/siddheshtv/td3-stock-aapl}},
  note={TD3 model for predicting stock price movements of AAPL using reinforcement learning},
}
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.