YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
RNN-based Language Model for Next-Token Prediction:
This repository contains the implementation of an RNN-based Language Model designed for next-token prediction using text in both English and Hausa. This model is part of a midterm project focused on neural network architectures, trained to predict the next token in a sequence based on the given input.
Project Overview: The goal of this project was to build a next-token prediction model using a Recurrent Neural Network (RNN). The model was trained on a combined dataset consisting of English and Hausa text, making it a multilingual language model.
The key tasks of the project included:
##Training a neural network-based model for next-token prediction. Monitoring the model's performance by tracking both training and validation loss. Calculating the perplexity score to evaluate the model’s performance. Uploading the model and its checkpoints for further use and evaluation.
Model Architecture:
The model is based on a simple RNN architecture, which is suitable for sequential data like text. The key components of the model include:
Embedding Layer: Converts token indices into dense vectors of fixed size (n_embd). RNN Layer: Processes the input sequences and captures the dependencies between tokens. Dropout Layer: Helps prevent overfitting by randomly setting a fraction of input units to zero during training. Fully Connected Layer: Maps the output of the RNN to the vocabulary size to predict the next token. Key Hyperparameters: Embedding size: 32 Hidden size: 64 Dropout rate: 0.2 Batch size: 35 Block size: 9 Learning rate: 0.01 Max iterations: 3000
Perplexity Score:
Perplexity is a key metric for evaluating the performance of language models. It measures how well a model predicts the next token in a sequence. For this project, the final perplexity score was 8.17, indicating that the model is reasonably confident in its predictions given the complexity of the text data.
The model's perplexity score is moderate, and improvements could be made by using more advanced architectures such as LSTMs or GRUs or by fine-tuning the hyperparameters.
Model Checkpoints:
The model's state and optimizer state were saved periodically during training to allow checkpoint-based training. These checkpoints can be used to resume training or for evaluation at specific steps.
##Checkpoints are available in the repository:
checkpoint_step_800.pth: Model state at 800 steps. checkpoint_step_1600.pth: Model state at 1600 steps. final_model.pth: The final model state after 3000 training iterations.
Usage
To use the model for inference or further training, you can load it using PyTorch: #import torch from my_model import SimpleRNNModel # Make sure to include your model class
#Initialize the model model = SimpleRNNModel(vocab_size, n_embd, n_hidden) model.load_state_dict(torch.load('final_model.pth')) model.eval() # Set the model to evaluation mode
#Example input (sequence of token indices) input_sequence = torch.randint(0, vocab_size, (1, block_size))
#Predict the next token with torch.no_grad(): logits, _ = model(input_sequence) predicted_token = torch.argmax(logits[:, -1, :], dim=-1) print(f"Predicted token index: {predicted_token.item()}")
Training Details
The model was trained using PyTorch. The following optimization and training strategies were applied:
Optimizer: Adam optimizer was used with a learning rate of 0.01. Training Loss: Cross-entropy loss was calculated to measure the difference between the predicted and actual token distributions. Validation: The validation set was evaluated periodically to ensure that the model wasn’t overfitting.
Future Improvements
While this model performed well for a basic RNN, potential improvements include:
Switching to LSTM/GRU: LSTM or GRU architectures would capture long-term dependencies better than a simple RNN. Fine-tuning: Further fine-tuning of hyperparameters like batch size, learning rate, and sequence length might result in lower perplexity scores. Data Augmentation: Using a larger or more diverse dataset could also improve the model’s ability to generalize.
Conclusion
This project demonstrates the implementation and training of a simple RNN-based model for next-token prediction. The model shows reasonable performance for the task at hand, with a moderate perplexity score and smooth training curves. With more complex architectures and additional fine-tuning, the model could be improved further to handle more sophisticated natural language tasks.