license: mit
SuperHOT Prototype 2 w/ 4-8K Context
This is a second prototype of SuperHOT, this time with 4K context and no RLHF. In my testing, it can go all the way to 6K without breaking down and I made the change with intention to reach 8K, so I'll assume it will go to 8K although I only trained on 4K sequences.
In order to use the 8K context, you will need to apply the monkeypatch I have added in this repo -- without it, it will not work. The patch is very simple, and you can make the changes yourself:
- Increase the
max_position_embeddings
to 8192 to stretch the sinusoidal - Stretch the frequency steps by a scale of
0.25
The intuition is to calibrate the model to within the learned positions of the pre-trained model as the model may be overfit on the token-position relationship (not my idea, Ofir Press'). By interpolating the encodings, we remain within the bounds of the pre-trained model (work with the overfitting rather than against it). The monkeypatch will work for the pre-trained model without fine-tuning, but you will need to fine-tune as the results will not be that good without it.
It can probably be even better than this with a few other modifications which I am testing (swap softmax for ReLU, increase head dimension)
In my testing, I tried random positional encoding, but I was not able to replicate the results of Jianlin Su, so maybe I did it incorrectly. I also tried shifted positions, log n scaling, log-sigmoid, and increase the head dimension, though this dilated RoPE (DoPE :) ) is the only one which worked for me consistently -- Note these are all based on finetuning, since the goal is to extend the context of the pre-trained model. Pre-training will paint a different picture.
I trained the LoRA with the following configuration:
- 1200 samples (~400 samples over 2048 sequence length)
- learning rate of 3e-4
- 3 epochs
- The exported modules are:
- q_proj
- k_proj
- v_proj
- o_proj
- all bias
- Rank = 2
- Alpha = 8
- no dropout
- weight decay of 0.1
- AdamW beta1 of 0.9 and beta2 0.99, epsilon of 1e-5
- Trained on 4-bit base model