JustinLin610's picture
first commit
ee21b96
|
raw
history blame
No virus
4.38 kB
# Adaptive Span
Adaptive Span is a novel self-attention mechanism that can learn its optimal
attention span. This allows us to extend significantly the maximum context size
used in Transformer, while maintaining control over their memory footprint
and computational time. It uses the Truncated BPTT technique for training,
as in [transformerXL](https://github.com/pytorch/fairseq/blob/main/examples/truncated_bptt/README.md).
Adaptive Span was introduced by paper:
[Adaptive Attention Span in Transformers](https://arxiv.org/abs/1905.07799),
which achieved state-of-the-art language modeling results at the time of publication.
We manage to reproduce their result in fairseq and keep most of the
[original implementation](https://github.com/facebookresearch/adaptive-span) untouched.
You can refer to the their sweep file as well if any combination of hyperparameter is not clear.
##### 0. Setup
First you need to process the Enwik8 dataset, we use the pre-tokenized dataset
from [adaptive span paper](https://github.com/facebookresearch/adaptive-span/blob/master/get_data.sh).
You can download the dataset, and then run:
```bash
fairseq-preprocess --only-source --trainpref ~/data/enwik8/train.txt \
--validpref ~/data/enwik8/valid.txt --testpref ~/data/enwik8/test.txt \
--destdir ~/data/enwik8/data-bin/ --joined-dictionary --workers 20
```
##### 1. Train a Adaptive Span model on Enwik8
We will train a 12-layer Adaptive Span model following the [hyperparameters
used in the original
paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh).
The following command assumes 4 GPUs, so that the total batch size is 64
sequences (4 x 16). Training should take 2-3 days on 4 V100 GPUs:
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
--user-dir examples/adaptive_span \
--data ~/data/enwik8/data-bin/ \
--fp16 --fp16-no-flatten-grads --max-update 600000 \
--task truncated_bptt_lm --tokens-per-sample 512 --arch adaptive_span \
--n-layer 12 --d-model 512 --n-head 8 --d-inner 2048 --dropout 0.3 \
--attn-span 8192 --optimizer adagrad_with_grad_clip --adagrad-clip 0.03 \
--validate-interval-updates 1000 \
--lr-scheduler fixed --warmup-updates 32000 --batch-size-valid 32 \
--lr 0.07 --criterion adaptive_span_loss --batch-size 16 --update-freq 1 \
--seed 2 --log-format json --log-interval 25 --aux-loss-scaler 5e-07
```
This should land around 1.05 on validation, 1.03 on test. You can lower the
--aux-loss-scaler for better performance (longer span). It gives ~0.03 bpc
improvement to the transformerXL baseline here.
If training on a single GPU, set `--update-freq=4` to accumulate 4x gradients
and simulate training on 4 GPUs.
You can also reproduce the transformerXL result on enwik8 using this code base.
It should land around 1.06 on test,matching the [original paper](https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/run_enwik8_base.sh).
You can try by
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
--user-dir examples/truncated_bptt \
~/data/enwik8/data-bin/ \
--task truncated_bptt_lm --fp16 --max-update 400000 \
--tokens-per-sample 512 --arch transformer_xl --n-layer 12 \
--d-model 512 --n-head 8 --d-head 64 --d-inner 2048 --dropout 0.1 \
--dropatt 0.0 --mem-len 512 --optimizer adam --clip-norm 0.25 \
--lr-scheduler cosine --warmup-updates 0 \
--lr 0.0 --lr 0.00025 --batch-size 15 \
--update-freq 1 --seed 2 --log-format json --log-interval 25 \
--fp16
```
##### 2. Evaluate
For Adaptive Span:
```bash
fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \
--user-dir examples/adaptive_span \
--task truncated_bptt_lm --batch-size 8 --tokens-per-sample 512 --gen-subset test
```
For Transformer-XL evaluation:
```bash
fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \
--user-dir examples/truncated_bptt/ --task truncated_bptt_lm --batch-size 8 \
--tokens-per-sample 80 \
--model-overrides '{"mem_len":2100,"clamp_len":820,"same_length":True}' \
--gen-subset valid
```
*Note:* During training the model saw 512 tokens of context
(``--tokens-per-sample=512``), with batch size 8. These settings match the evaluation
settings from [the original
paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh).