NGen 3
While using with transformers you can only use the 15M variant for now.
NGen 3 is an advanced Transformer model training pipeline that supports multiple model variants. It ranges from a nano variant (approximately 120M parameters) to a foundational variant (approximately 1B parameters). The pipeline incorporates modern architectural improvements such as rotary positional embeddings, RMSNorm, and GEGLU activations to boost performance and training efficiency.
Note: Although NGen 3 is designed to train a 1B-parameter model, its advanced architecture pushes its performance closer to that of much larger models.
Model Variants
NGen 3 supports the following variants via the --variant
flag:
- nano: ~120M parameters
- small: ~300M parameters
- medium: ~500M parameters
- large: ~700M parameters
- foundational: ~1B parameters
Each variant adjusts key hyperparameters such as the number of layers, model dimension (d_model
), number of attention heads (n_heads
), and the feed-forward dimension (d_ff
).
Requirements
- Python 3.8+
- PyTorch
- Transformers
- Datasets
- DeepSpeed (optional, for efficient training)
- Azure ML SDK (for distributed training on Azure)
Install dependencies using pip (adjust as needed):
pip install torch transformers datasets deepspeed azureml-core
Usage
1. Data Preparation
First, download and preprocess the OpenWebText dataset:
python prepare.py --output_dir ./_data_ --max_length 4096
This script downloads, tokenizes, and saves the dataset in Arrow format to the ./data directory.
2. Local Training
The main training script is train.py. It loads the processed dataset (by default from ./data), instantiates the desired model variant, and starts training.
Example CLI Commands
- Train the nano (120M) variant:
python train.py --dataset_dir ./_data_ --output_dir ./checkpoints_nano --batch_size 4 --epochs 3 --variant nano
- Train the small (300M) variant:
python train.py --dataset_dir ./_data_ --output_dir ./checkpoints_small --batch_size 4 --epochs 3 --variant small
- Train the medium (500M) variant:
python train.py --dataset_dir ./_data_ --output_dir ./checkpoints_medium --batch_size 4 --epochs 3 --variant medium
- Train the large (700M) variant:
python train.py --dataset_dir ./_data_ --output_dir ./checkpoints_large --batch_size 4 --epochs 3 --variant large
- Train the foundational (1B) variant with rotary embeddings enabled:
python train.py --dataset_dir ./_data_ --output_dir ./checkpoints_foundational --batch_size 4 --epochs 3 --variant foundational --use_rotary
3. Training on Azure ML
- Step 1: Set Up Azure ML Resources
Use azure_setup.py
to create or connect to your Azure ML workspace and set up a compute cluster:
python azure_setup.py \
--workspace_name MyWorkspace \
--resource_group MyResourceGroup \
--subscription_id YOUR_SUBSCRIPTION_ID \
--location eastus \
--compute_name gpu-cluster \
--vm_size Standard_NC6 \
--max_nodes 4 \
--min_nodes 0
- Step 2: Submit a Training Job to Azure ML
Use
submit_train.py
to submit your training script to Azure ML:
python submit_train.py \
--experiment_name ngen3-experiment \
--compute_target gpu-cluster \
--script train.py \
--dataset_dir ./_data_ \
--output_dir ./checkpoints_foundational \
--batch_size 4 \
--epochs 3 \
--variant foundational \
--use_rotary
4. DeepSpeed Integration
The deepspeed.json file configures mixed-precision training and ZeRO optimizations. To leverage DeepSpeed, ensure it is installed and adjust your training script or submission command to enable DeepSpeed support.
License
License The NGen 3 project is developed and maintained by TNSA AI. The licensing model is dual:
- The nano and small variants are open source and released under the MIT License.
- The medium, large, and foundational variants are proprietary and are not open source. Use of these proprietary components is subject to TNSA AI's proprietary licensing terms.
Copyright
© 2023 TNSA AI. All rights reserved. for Use read LICENCE
in the LICENSE file
- Downloads last month
- 5
Model tree for TNSA-AI/ngen3
Unable to build the model tree, the base model loops to the model itself. Learn more.