Edit model card

Deep Q-Network for Floorplan Navigation

Model Description

This model is a Deep Q-Network (DQN) designed to find the most efficient path through a floorplan without hitting obstacles. The model combines traditional pathfinding algorithms with reinforcement learning for optimal performance.

Model Architecture

The model is a fully connected neural network with the following architecture:

  • Input Layer: Flattened grid representation of the floorplan
  • Hidden Layers: Two hidden layers with 64 units each and ReLU activation
  • Output Layer: Four units representing the possible actions (up, down, left, right)

Training

The model was trained using a hybrid approach:

  1. A(*) Algorithm: Initially, the A* algorithm was used to find the shortest path in a static environment.
  2. Reinforcement Learning: The DQN was trained with guidance from the A* path to improve efficiency and adaptability.

Hyperparameters

  • Learning Rate: 0.001
  • Batch Size: 64
  • Gamma (Discount Factor): 0.99
  • Target Update Frequency: Every 100 episodes
  • Number of Episodes: 50

Checkpoints

Checkpoints are saved during training for convenience:

  • checkpoint_11.pth.tar: After 11 episodes
  • checkpoint_21.pth.tar: After 21 episodes
  • checkpoint_31.pth.tar: After 31 episodes
  • checkpoint_41.pth.tar: After 41 episodes

Usage

To use this model, load the saved state dictionary and initialize the DQN with the same architecture. The model can then be used to navigate a floorplan and find the most efficient path to the target.

Example Code

import torch

# Define the DQN class (same as in the training script)
class DQN(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size):
        super(DQN, self).__init__()
        self.input_size = input_size
        self.hidden_sizes = hidden_sizes
        self.output_size = output_size

        self.fc_layers = nn.ModuleList()
        prev_size = input_size
        for size in hidden_sizes:
            self.fc_layers.append(nn.Linear(prev_size, size))
            prev_size = size
        self.output_layer = nn.Linear(prev_size, output_size)

    def forward(self, x):
        if len(x.shape) > 2:
            x = x.view(x.size(0), -1)
        for layer in self.fc_layers:
            x = F.relu(layer(x))
        x = self.output_layer(x)
        return x

# Load the model
input_size = 100  # 10x10 grid flattened
hidden_sizes = [64, 64]
output_size = 4
model = DQN(input_size, hidden_sizes, output_size)
model.load_state_dict(torch.load('dqn_model.pth'))
model.eval()

# Use the model for inference (example state)
state = ...  # Define your state here
with torch.no_grad():
    action = model(torch.tensor(state, dtype=torch.float32).unsqueeze(0)).argmax().item()

Training Script

The training script train.py is included in the repository for those who wish to reproduce the training process or continue training from a specific checkpoint.

Training Instructions

  • Clone the repository.
  • Ensure you have the necessary dependencies installed.
  • Run the training script:
bash
Copy code
python train.py

To continue training from a checkpoint, modify the script to load the checkpoint before training.

Evaluation

The model was evaluated based on:

  • Average Reward: The mean reward over several episodes
  • Success Rate: The proportion of episodes where the agent successfully reached the target

Initial Evaluation Results

  • Average Reward: 8.84
  • Success Rate: 1.0

Limitations

  • The model's performance can be influenced by the complexity of the floorplan and the density of obstacles.
  • It requires a grid-based representation of the environment for accurate navigation.

Acknowledgements

This project leverages the power of reinforcement learning combined with traditional pathfinding algorithms to navigate complex environments efficiently.

License

This model is licensed under the Apache 2.0 License.

Citation

If you use this model in your research, please cite it as follows:

@misc{jones2024dqnfloorplan,
author = {Christopher Jones},
title = {Deep Q-Network for Floorplan Navigation},
year = {2024},
howpublished = {\url{https://huggingface.co/cajcodes/dqn-floorplan-navigator}},
note = {Accessed: YYYY-MM-DD}
}
Downloads last month
2
Video Preview
loading