| | """
|
| | RNN Model Architecture for CIFAR-10 Classification
|
| | """
|
| | import torch
|
| | import torch.nn as nn
|
| | import config
|
| |
|
| |
|
| | class CIFAR10RNN(nn.Module):
|
| | """
|
| | Recurrent Neural Network (LSTM) for CIFAR-10 classification
|
| |
|
| | Architecture:
|
| | - Input sequence: 32 rows of 32x3 pixels (= 96 features per step)
|
| | - Bidirectional LSTM layers
|
| | - Fully connected layer for classification
|
| | """
|
| |
|
| | def __init__(self, input_size=96, hidden_size=256, num_layers=2, num_classes=10):
|
| | super(CIFAR10RNN, self).__init__()
|
| |
|
| | self.hidden_size = hidden_size
|
| | self.num_layers = num_layers
|
| |
|
| |
|
| |
|
| | self.lstm = nn.LSTM(
|
| | input_size,
|
| | hidden_size,
|
| | num_layers,
|
| | batch_first=True,
|
| | bidirectional=True,
|
| | dropout=config.RNN_DROPOUT if num_layers > 1 else 0
|
| | )
|
| |
|
| |
|
| |
|
| | self.fc = nn.Sequential(
|
| | nn.Linear(hidden_size * 2, 512),
|
| | nn.ReLU(),
|
| | nn.Dropout(0.3),
|
| | nn.Linear(512, num_classes)
|
| | )
|
| |
|
| | def forward(self, x):
|
| |
|
| |
|
| | batch_size = x.size(0)
|
| |
|
| |
|
| |
|
| | x = x.permute(0, 2, 1, 3).contiguous()
|
| | x = x.view(batch_size, 32, -1)
|
| |
|
| |
|
| |
|
| | out, _ = self.lstm(x)
|
| |
|
| |
|
| | out = out[:, -1, :]
|
| |
|
| |
|
| | out = self.fc(out)
|
| |
|
| | return out
|
| |
|
| |
|
| | def get_model(num_classes=10, device='cpu'):
|
| | """
|
| | Create and return the RNN model
|
| |
|
| | Args:
|
| | num_classes (int): Number of output classes
|
| | device (str or torch.device): Device to load the model on
|
| |
|
| | Returns:
|
| | CIFAR10RNN: The RNN model
|
| | """
|
| | model = CIFAR10RNN(
|
| | input_size=32*3,
|
| | hidden_size=config.HIDDEN_SIZE,
|
| | num_layers=config.NUM_LAYERS,
|
| | num_classes=num_classes
|
| | )
|
| | model = model.to(device)
|
| | return model
|
| |
|
| |
|
| | def count_parameters(model):
|
| | """
|
| | Count the number of trainable parameters in the model
|
| | """
|
| | return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| |
|