zaydzuhri commited on
Commit
a0806ea
1 Parent(s): 979c3a8

Training in progress, step 5000

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +165 -0
  2. config.json +42 -0
  3. configs/deepspeed.yaml +10 -0
  4. configs/ds_config.json +19 -0
  5. configs/gla_16M.json +26 -0
  6. configs/gla_1B.json +26 -0
  7. configs/gla_340M.json +26 -0
  8. configs/gla_7B.json +29 -0
  9. configs/gsa_16M.json +27 -0
  10. configs/scan_16M.json +29 -0
  11. configs/scan_16M_8192.json +29 -0
  12. configs/scan_20M.json +29 -0
  13. configs/scan_340M.json +29 -0
  14. configs/transformer_16M.json +26 -0
  15. configs/transformer_16M_8192.json +26 -0
  16. fla/__init__.py +58 -0
  17. fla/layers/__init__.py +31 -0
  18. fla/layers/abc.py +207 -0
  19. fla/layers/attn.py +182 -0
  20. fla/layers/based.py +105 -0
  21. fla/layers/bitattn.py +183 -0
  22. fla/layers/delta_net.py +267 -0
  23. fla/layers/gla.py +280 -0
  24. fla/layers/gsa.py +233 -0
  25. fla/layers/hgrn.py +153 -0
  26. fla/layers/hgrn2.py +207 -0
  27. fla/layers/linear_attn.py +171 -0
  28. fla/layers/multiscale_retention.py +282 -0
  29. fla/layers/rebased.py +136 -0
  30. fla/layers/rwkv6.py +291 -0
  31. fla/layers/scan.py +237 -0
  32. fla/layers/simple_gla.py +252 -0
  33. fla/models/__init__.py +39 -0
  34. fla/models/abc/__init__.py +13 -0
  35. fla/models/abc/configuration_abc.py +84 -0
  36. fla/models/abc/modeling_abc.py +403 -0
  37. fla/models/bitnet/__init__.py +13 -0
  38. fla/models/bitnet/configuration_bitnet.py +68 -0
  39. fla/models/bitnet/modeling_bitnet.py +428 -0
  40. fla/models/delta_net/__init__.py +13 -0
  41. fla/models/delta_net/configuration_delta_net.py +87 -0
  42. fla/models/delta_net/modeling_delta_net.py +439 -0
  43. fla/models/gla/__init__.py +13 -0
  44. fla/models/gla/configuration_gla.py +90 -0
  45. fla/models/gla/modeling_gla.py +418 -0
  46. fla/models/gsa/__init__.py +13 -0
  47. fla/models/gsa/configuration_gsa.py +94 -0
  48. fla/models/gsa/modeling_gsa.py +442 -0
  49. fla/models/hgrn/__init__.py +13 -0
  50. fla/models/hgrn/configuration_hgrn.py +74 -0
README.md ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # 🔥 Flame
4
+
5
+ </div>
6
+
7
+ A minimal framework for training FLA models, whether from scratch or through finetuning.
8
+
9
+ Built on the robust infrastructure of 🤗, `flame` enables you to train large language models with just a few lines of code:
10
+ we use `datasets` for data processing, `transformers` for model definitions, and `accelerate`[^1] for seamless distributed training.
11
+
12
+ In this README, we will guide you through the process of using `flame` to train GLA models.
13
+
14
+ ## Setup
15
+
16
+ To get started, you'll need to install the required packages.
17
+ Both `fla` and `flame` have minimal dependencies.
18
+ Clone the `fla` repository and install the necessary packages as follows:
19
+
20
+ ```bash
21
+ git clone https://github.com/sustcsonglin/flash-linear-attention.git
22
+ pip install .
23
+ pip install accelerate wandb
24
+ pip3 install deepspeed
25
+ ```
26
+
27
+ > [!CAUTION]
28
+ > The 🤗 `tokenizers` have some [memory leak issues](https://github.com/huggingface/tokenizers/issues/1539) when processing very long documents.
29
+ > To address this, please ensure you install `tokenizers>=0.20.4`.
30
+
31
+ ## Preprocessing
32
+
33
+ Before training, you need to download and pre-tokenize your dataset.
34
+ We provide a straightforward script for this.
35
+ For instance, to tokenize a 10B sample of the `fineweb-edu` dataset, run:
36
+
37
+ ```bash
38
+ python preprocess.py \
39
+ --dataset HuggingFaceFW/fineweb-edu \
40
+ --name sample-10BT \
41
+ --split train \
42
+ --context_length 2048
43
+ ```
44
+ or an even smaller example, just for testing:
45
+ ```bash
46
+ python preprocess.py \
47
+ --dataset alturing/gutenberg-texts \
48
+ --split train \
49
+ --context_length 2048
50
+ ```
51
+
52
+ This will cache the processed dataset at `data/HuggingFaceFW/fineweb-edu/sample-10BT/train`.
53
+
54
+ GLA utilizes a subset of Slimpajama for pretraining [in the paper](https://proceedings.mlr.press/v235/yang24ab.html).
55
+ Given the size of the dataset, the fastest way to download it is using `git lfs` (refer to [this issue](https://huggingface.co/datasets/cerebras/SlimPajama-627B/discussions/2)).
56
+ ```bash
57
+ git lfs install
58
+ git clone https://huggingface.co/datasets/cerebras/SlimPajama-627B
59
+ python preprocess.py \
60
+ --dataset SlimPajama-627B \
61
+ --split train \
62
+ --context_length 2048
63
+ ```
64
+
65
+ ## Training from scratch
66
+
67
+ To train your 340M model from scratch, execute the following command:
68
+
69
+ ```bash
70
+ bash train.sh \
71
+ type=gla \
72
+ lr=3e-4 \
73
+ steps=20480 \
74
+ batch=8 \
75
+ update=1 \
76
+ warmup=1024 \
77
+ context=2048 \
78
+ path=exp/gla-340M-10B \
79
+ project=fla \
80
+ model=configs/gla_340M.json \
81
+ data=HuggingFaceFW/fineweb-edu \
82
+ name=sample-10BT \
83
+ cache=data/HuggingFaceFW/fineweb-edu/sample-10BT/train
84
+ ```
85
+ or for testing SCAN:
86
+ ```bash
87
+ bash train.sh \
88
+ type=scan \
89
+ lr=3e-4 \
90
+ steps=1000 \
91
+ batch=8 \
92
+ update=1 \
93
+ warmup=100 \
94
+ context=2048 \
95
+ path=exp/scan-340M-test \
96
+ project=fla \
97
+ model=configs/scan_340M.json \
98
+ data=alturing/gutenberg-texts \
99
+ name=sample-10BT \
100
+ cache=data/alturing/gutenberg-texts/train
101
+ ```
102
+
103
+ `flame` also supports resuming interrupted training by specifying the checkpoint path.
104
+ Simply use the following command to resume training:
105
+
106
+ ```bash
107
+ bash train.sh \
108
+ type=gla \
109
+ lr=3e-4 \
110
+ steps=20480 \
111
+ batch=8 \
112
+ update=1 \
113
+ warmup=1024 \
114
+ context=2048 \
115
+ path=exp/gla-340M-10B \
116
+ project=fla \
117
+ model=configs/gla_340M.json \
118
+ data=HuggingFaceFW/fineweb-edu \
119
+ name=sample-10BT \
120
+ cache=data/HuggingFaceFW/fineweb-edu/sample-10BT/train \
121
+ checkpoint=exp/gla-340M-10B/checkpoint-8192
122
+ ```
123
+
124
+ You can also use `wandb` to monitor your training process effectively.
125
+
126
+ ![wandb](https://github.com/user-attachments/assets/05ca031c-1cae-41c9-bfcb-5b6b6d0df729)
127
+
128
+ ## Continual Pretraining
129
+
130
+ `flame` supports continual training from a pretrained checkpoint.
131
+ Below, we provide an example of how to finetune Mistral-7B to GLA.
132
+ You can follow similar steps to reproduce the results in the [GSA paper](https://arxiv.org/abs/2409.07146):
133
+
134
+ 1. Initialize a brand-new GLA-7B model from the config and copy the mathced pretrained weights from Mistral-7B:
135
+ ```bash
136
+ cd ../utils
137
+ python convert_from_llama.py \
138
+ --model mistralai/Mistral-7B-v0.1 \
139
+ --config ../training/configs/gla_7B.json \
140
+ --output ../training/converted/gla-7B
141
+ cd -
142
+ ```
143
+
144
+ 2. Directly launch training from the converted checkpoint:
145
+ ```bash
146
+ bash train.sh \
147
+ type=gla \
148
+ lr=3e-5 \
149
+ steps=10240 \
150
+ batch=4 \
151
+ update=8 \
152
+ warmup=512 \
153
+ context=2048 \
154
+ path=exp/gla-7B-20B \
155
+ project=fla \
156
+ model=converted/gla-7B \
157
+ data=SlimPajama-627B \
158
+ cache=data/SlimPajama-627B/train
159
+ ```
160
+
161
+ Please be aware that finetuning on a single node may not be the most efficient approach.
162
+ If available, consider leveraging multi-node GPUs for optimal performance.
163
+ You can find guidance on how to launch a multi-node job in the [accelerate tutorial](https://github.com/huggingface/accelerate/blob/main/examples/slurm/submit_multinode.sh).
164
+
165
+ [^1]: The `accelerate` library supports various distributed frameworks, like `deepspeed` and `megatron` for large-scale training. We use `deepspeed` in our case.
config.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "configs/scan_16M_8192.json",
3
+ "architectures": [
4
+ "SCANForCausalLM"
5
+ ],
6
+ "attn": null,
7
+ "attn_mode": "parallel",
8
+ "bos_token_id": 1,
9
+ "clamp_max": null,
10
+ "clamp_min": null,
11
+ "elementwise_affine": true,
12
+ "eos_token_id": 2,
13
+ "expand_k": 1,
14
+ "expand_v": 1,
15
+ "fuse_cross_entropy": true,
16
+ "fuse_norm": true,
17
+ "gate_act": "softmax",
18
+ "gate_logit_normalizer": 8,
19
+ "hidden_act": "swish",
20
+ "hidden_ratio": 4,
21
+ "hidden_size": 256,
22
+ "initializer_range": 0.02,
23
+ "intermediate_size": null,
24
+ "max_position_embeddings": 8192,
25
+ "model_type": "scan",
26
+ "norm_eps": 1e-06,
27
+ "norm_first": true,
28
+ "num_heads": 4,
29
+ "num_hidden_layers": 10,
30
+ "num_kv_heads": null,
31
+ "state_size": 16,
32
+ "tie_word_embeddings": true,
33
+ "torch_dtype": "bfloat16",
34
+ "transformers_version": "4.47.0",
35
+ "use_cache": true,
36
+ "use_gk": true,
37
+ "use_gv": false,
38
+ "use_norm": true,
39
+ "use_output_gate": false,
40
+ "vocab_size": 32000,
41
+ "window_size": 128
42
+ }
configs/deepspeed.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ distributed_type: DEEPSPEED
3
+ deepspeed_config:
4
+ deepspeed_config_file: configs/ds_config.json
5
+ zero3_init_flag: true
6
+ machine_rank: 0
7
+ main_training_function: main
8
+ num_machines: 1
9
+ num_processes: 1
10
+ use_cpu: false
configs/ds_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train_batch_size": "auto",
3
+ "train_micro_batch_size_per_gpu": "auto",
4
+ "gradient_accumulation_steps": "auto",
5
+ "gradient_clipping": "auto",
6
+ "zero_allow_untested_optimizer": true,
7
+ "bf16": {
8
+ "enabled": true
9
+ },
10
+ "zero_optimization": {
11
+ "stage": 2,
12
+ "allgather_partitions": true,
13
+ "allgather_bucket_size": 5e8,
14
+ "reduce_scatter": true,
15
+ "reduce_bucket_size": 5e8,
16
+ "overlap_comm": false,
17
+ "contiguous_gradients": true
18
+ }
19
+ }
configs/gla_16M.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "clamp_min": null,
5
+ "eos_token_id": 2,
6
+ "expand_k": 0.5,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "fuse_norm": true,
10
+ "hidden_act": "swish",
11
+ "hidden_ratio": 4,
12
+ "hidden_size": 256,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": null,
15
+ "max_position_embeddings": 2048,
16
+ "model_type": "gla",
17
+ "num_heads": 4,
18
+ "num_hidden_layers": 10,
19
+ "norm_eps": 1e-06,
20
+ "tie_word_embeddings": true,
21
+ "transformers_version": "4.38.2",
22
+ "use_cache": true,
23
+ "use_gk": true,
24
+ "use_gv": false,
25
+ "vocab_size": 32000
26
+ }
configs/gla_1B.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "clamp_min": null,
5
+ "eos_token_id": 2,
6
+ "expand_k": 0.5,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "fuse_norm": true,
10
+ "hidden_act": "swish",
11
+ "hidden_ratio": 4,
12
+ "hidden_size": 2048,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": null,
15
+ "max_position_embeddings": 2048,
16
+ "model_type": "gla",
17
+ "num_heads": 4,
18
+ "num_hidden_layers": 24,
19
+ "norm_eps": 1e-06,
20
+ "tie_word_embeddings": false,
21
+ "transformers_version": "4.38.2",
22
+ "use_cache": true,
23
+ "use_gk": true,
24
+ "use_gv": false,
25
+ "vocab_size": 32000
26
+ }
configs/gla_340M.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "clamp_min": null,
5
+ "eos_token_id": 2,
6
+ "expand_k": 0.5,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "fuse_norm": true,
10
+ "hidden_act": "swish",
11
+ "hidden_ratio": 4,
12
+ "hidden_size": 1024,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": null,
15
+ "max_position_embeddings": 2048,
16
+ "model_type": "gla",
17
+ "num_heads": 4,
18
+ "num_hidden_layers": 24,
19
+ "norm_eps": 1e-06,
20
+ "tie_word_embeddings": true,
21
+ "transformers_version": "4.38.2",
22
+ "use_cache": true,
23
+ "use_gk": true,
24
+ "use_gv": false,
25
+ "vocab_size": 32000
26
+ }
configs/gla_7B.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "clamp_min": null,
5
+ "eos_token_id": 2,
6
+ "expand_k": 1,
7
+ "expand_v": 1,
8
+ "feature_map": "relu",
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "hidden_act": "swish",
12
+ "hidden_ratio": 4,
13
+ "hidden_size": 4096,
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": 14336,
16
+ "max_position_embeddings": 32768,
17
+ "model_type": "gla",
18
+ "num_heads": 32,
19
+ "num_kv_heads": 8,
20
+ "num_hidden_layers": 32,
21
+ "norm_eps": 1e-05,
22
+ "tie_word_embeddings": false,
23
+ "transformers_version": "4.40.0",
24
+ "use_cache": true,
25
+ "use_output_gate": false,
26
+ "use_gk": true,
27
+ "use_gv": false,
28
+ "vocab_size": 32000
29
+ }
configs/gsa_16M.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "clamp_min": null,
5
+ "eos_token_id": 2,
6
+ "expand_k": 1,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "fuse_norm": true,
10
+ "hidden_act": "swish",
11
+ "hidden_ratio": 4,
12
+ "hidden_size": 256,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": null,
15
+ "max_position_embeddings": 2048,
16
+ "model_type": "gsa",
17
+ "num_slots": 16,
18
+ "num_heads": 4,
19
+ "num_hidden_layers": 10,
20
+ "norm_eps": 1e-06,
21
+ "tie_word_embeddings": true,
22
+ "transformers_version": "4.38.2",
23
+ "use_cache": true,
24
+ "use_gk": true,
25
+ "use_gv": false,
26
+ "vocab_size": 32000
27
+ }
configs/scan_16M.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "parallel",
3
+ "bos_token_id": 1,
4
+ "clamp_min": null,
5
+ "eos_token_id": 2,
6
+ "expand_k": 1,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "fuse_norm": true,
10
+ "hidden_act": "swish",
11
+ "gate_act": "softmax",
12
+ "hidden_ratio": 4,
13
+ "hidden_size": 256,
14
+ "window_size": 128,
15
+ "state_size": 16,
16
+ "initializer_range": 0.02,
17
+ "intermediate_size": null,
18
+ "max_position_embeddings": 2048,
19
+ "model_type": "scan",
20
+ "num_heads": 4,
21
+ "num_hidden_layers": 10,
22
+ "norm_eps": 1e-06,
23
+ "tie_word_embeddings": true,
24
+ "transformers_version": "4.38.2",
25
+ "use_cache": true,
26
+ "use_gk": true,
27
+ "use_gv": false,
28
+ "vocab_size": 32000
29
+ }
configs/scan_16M_8192.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "parallel",
3
+ "bos_token_id": 1,
4
+ "clamp_min": null,
5
+ "eos_token_id": 2,
6
+ "expand_k": 1,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "fuse_norm": true,
10
+ "hidden_act": "swish",
11
+ "gate_act": "softmax",
12
+ "hidden_ratio": 4,
13
+ "hidden_size": 256,
14
+ "window_size": 128,
15
+ "state_size": 16,
16
+ "initializer_range": 0.02,
17
+ "intermediate_size": null,
18
+ "max_position_embeddings": 8192,
19
+ "model_type": "scan",
20
+ "num_heads": 4,
21
+ "num_hidden_layers": 10,
22
+ "norm_eps": 1e-06,
23
+ "tie_word_embeddings": true,
24
+ "transformers_version": "4.38.2",
25
+ "use_cache": true,
26
+ "use_gk": true,
27
+ "use_gv": false,
28
+ "vocab_size": 32000
29
+ }
configs/scan_20M.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "parallel",
3
+ "bos_token_id": 1,
4
+ "clamp_min": null,
5
+ "eos_token_id": 2,
6
+ "expand_k": 1,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "fuse_norm": true,
10
+ "hidden_act": "swish",
11
+ "gate_act": "softmax",
12
+ "hidden_ratio": 4,
13
+ "hidden_size": 384,
14
+ "window_size": 128,
15
+ "state_size": 16,
16
+ "initializer_range": 0.02,
17
+ "intermediate_size": null,
18
+ "max_position_embeddings": 2048,
19
+ "model_type": "scan",
20
+ "num_heads": 6,
21
+ "num_hidden_layers": 10,
22
+ "norm_eps": 1e-06,
23
+ "tie_word_embeddings": true,
24
+ "transformers_version": "4.38.2",
25
+ "use_cache": true,
26
+ "use_gk": true,
27
+ "use_gv": false,
28
+ "vocab_size": 32000
29
+ }
configs/scan_340M.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "parallel",
3
+ "bos_token_id": 1,
4
+ "clamp_min": null,
5
+ "eos_token_id": 2,
6
+ "expand_k": 1,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "fuse_norm": true,
10
+ "hidden_act": "swish",
11
+ "gate_act": "softmax",
12
+ "hidden_ratio": 4,
13
+ "hidden_size": 1024,
14
+ "window_size": 128,
15
+ "state_size": 32,
16
+ "initializer_range": 0.02,
17
+ "intermediate_size": null,
18
+ "max_position_embeddings": 2048,
19
+ "model_type": "scan",
20
+ "num_heads": 4,
21
+ "num_hidden_layers": 24,
22
+ "norm_eps": 1e-06,
23
+ "tie_word_embeddings": true,
24
+ "transformers_version": "4.38.2",
25
+ "use_cache": true,
26
+ "use_gk": true,
27
+ "use_gv": false,
28
+ "vocab_size": 32000
29
+ }
configs/transformer_16M.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "transformer",
3
+ "attention_bias": false,
4
+ "bos_token_id": 1,
5
+ "clamp_min": null,
6
+ "eos_token_id": 2,
7
+ "fuse_cross_entropy": true,
8
+ "fuse_norm": true,
9
+ "hidden_act": "swish",
10
+ "hidden_ratio": 4,
11
+ "hidden_size": 256,
12
+ "state_size": 16,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": null,
15
+ "max_position_embeddings": 2048,
16
+ "num_heads": 4,
17
+ "num_kv_heads": 4,
18
+ "num_hidden_layers": 10,
19
+ "norm_eps": 1e-06,
20
+ "tie_word_embeddings": true,
21
+ "transformers_version": "4.38.2",
22
+ "use_cache": true,
23
+ "use_gk": true,
24
+ "use_gv": false,
25
+ "vocab_size": 32000
26
+ }
configs/transformer_16M_8192.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "transformer",
3
+ "attention_bias": false,
4
+ "bos_token_id": 1,
5
+ "clamp_min": null,
6
+ "eos_token_id": 2,
7
+ "fuse_cross_entropy": true,
8
+ "fuse_norm": true,
9
+ "hidden_act": "swish",
10
+ "hidden_ratio": 4,
11
+ "hidden_size": 256,
12
+ "state_size": 16,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": null,
15
+ "max_position_embeddings": 8192,
16
+ "num_heads": 4,
17
+ "num_kv_heads": 4,
18
+ "num_hidden_layers": 10,
19
+ "norm_eps": 1e-06,
20
+ "tie_word_embeddings": true,
21
+ "transformers_version": "4.38.2",
22
+ "use_cache": true,
23
+ "use_gk": true,
24
+ "use_gv": false,
25
+ "vocab_size": 32000
26
+ }
fla/__init__.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from fla.layers import (ABCAttention, Attention, BasedLinearAttention,
4
+ BitAttention, DeltaNet, GatedLinearAttention,
5
+ GatedSlotAttention, HGRN2Attention, HGRNAttention,
6
+ LinearAttention, MultiScaleRetention,
7
+ ReBasedLinearAttention)
8
+ from fla.models import (ABCForCausalLM, ABCModel, BitNetForCausalLM,
9
+ BitNetModel, DeltaNetForCausalLM, DeltaNetModel,
10
+ GLAForCausalLM, GLAModel, GSAForCausalLM, GSAModel,
11
+ HGRN2ForCausalLM, HGRN2Model, HGRNForCausalLM,
12
+ LinearAttentionForCausalLM, LinearAttentionModel,
13
+ RetNetForCausalLM, RetNetModel, RWKV6ForCausalLM,
14
+ RWKV6Model, TransformerForCausalLM, TransformerModel)
15
+
16
+ __all__ = [
17
+ 'ABCAttention',
18
+ 'Attention',
19
+ 'BasedLinearAttention',
20
+ 'BitAttention',
21
+ 'DeltaNet',
22
+ 'HGRNAttention',
23
+ 'HGRN2Attention',
24
+ 'GatedLinearAttention',
25
+ 'GatedSlotAttention',
26
+ 'LinearAttention',
27
+ 'MultiScaleRetention',
28
+ 'ReBasedLinearAttention',
29
+ 'ABCForCausalLM',
30
+ 'ABCModel',
31
+ 'BitNetForCausalLM',
32
+ 'BitNetModel',
33
+ 'DeltaNetForCausalLM',
34
+ 'DeltaNetModel',
35
+ 'HGRNForCausalLM',
36
+ 'HGRNModel',
37
+ 'HGRN2ForCausalLM',
38
+ 'HGRN2Model',
39
+ 'GLAForCausalLM',
40
+ 'GLAModel',
41
+ 'GSAForCausalLM',
42
+ 'GSAModel',
43
+ 'LinearAttentionForCausalLM',
44
+ 'LinearAttentionModel',
45
+ 'RetNetForCausalLM',
46
+ 'RetNetModel',
47
+ 'RWKV6ForCausalLM',
48
+ 'RWKV6Model',
49
+ 'TransformerForCausalLM',
50
+ 'TransformerModel',
51
+ 'chunk_gla',
52
+ 'chunk_retention',
53
+ 'fused_chunk_based',
54
+ 'fused_chunk_gla',
55
+ 'fused_chunk_retention'
56
+ ]
57
+
58
+ __version__ = '0.1'
fla/layers/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .abc import ABCAttention
4
+ from .attn import Attention
5
+ from .based import BasedLinearAttention
6
+ from .bitattn import BitAttention
7
+ from .delta_net import DeltaNet
8
+ from .gla import GatedLinearAttention
9
+ from .gsa import GatedSlotAttention
10
+ from .hgrn import HGRNAttention
11
+ from .hgrn2 import HGRN2Attention
12
+ from .linear_attn import LinearAttention
13
+ from .multiscale_retention import MultiScaleRetention
14
+ from .rebased import ReBasedLinearAttention
15
+ from .rwkv6 import RWKV6Attention
16
+
17
+ __all__ = [
18
+ 'ABCAttention',
19
+ 'Attention',
20
+ 'BasedLinearAttention',
21
+ 'BitAttention',
22
+ 'DeltaNet',
23
+ 'GatedLinearAttention',
24
+ 'GatedSlotAttention',
25
+ 'HGRNAttention',
26
+ 'HGRN2Attention',
27
+ 'LinearAttention',
28
+ 'MultiScaleRetention',
29
+ 'ReBasedLinearAttention',
30
+ 'RWKV6Attention',
31
+ ]
fla/layers/abc.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange
12
+
13
+ from fla.modules import (FusedRMSNormSwishGate, RMSNorm, RotaryEmbedding,
14
+ ShortConvolution)
15
+ from fla.modules.activations import swiglu, swish
16
+ from fla.ops.abc.chunk import chunk_abc
17
+
18
+ if TYPE_CHECKING:
19
+ from fla.models.utils import Cache
20
+
21
+
22
+ class ABCAttention(nn.Module):
23
+
24
+ def __init__(
25
+ self,
26
+ hidden_size: int = 1024,
27
+ expand_k: float = 0.5,
28
+ expand_v: float = 1.0,
29
+ num_heads: int = 4,
30
+ use_short_conv: bool = False,
31
+ conv_size: int = 4,
32
+ conv_bias: bool = False,
33
+ num_slots: Optional[int] = None,
34
+ elementwise_affine: Optional[bool] = True,
35
+ norm_eps: float = 1e-5,
36
+ gate_low_rank_dim: int = 16,
37
+ gate_logit_normalizer: int = 16,
38
+ use_input_gate: bool = False,
39
+ use_output_gate: bool = True,
40
+ use_norm: bool = True,
41
+ clamp_min: Optional[float] = -32,
42
+ clamp_max: Optional[float] = 32,
43
+ layer_idx: Optional[int] = None,
44
+ **kwargs
45
+ ) -> ABCAttention:
46
+ super().__init__()
47
+
48
+ self.hidden_size = hidden_size
49
+ self.expand_k = expand_k
50
+ self.expand_v = expand_v
51
+ self.num_heads = num_heads
52
+ self.key_dim = int(self.hidden_size * self.expand_k)
53
+ self.value_dim = int(self.hidden_size * self.expand_v)
54
+ self.head_k_dim = self.key_dim // self.num_heads
55
+ self.head_v_dim = self.value_dim // self.num_heads
56
+
57
+ self.use_short_conv = use_short_conv
58
+ self.conv_size = conv_size
59
+ self.conv_bias = conv_bias
60
+
61
+ self.gate_low_rank_dim = gate_low_rank_dim
62
+ self.gate_logit_normalizer = gate_logit_normalizer
63
+
64
+ self.use_input_gate = use_input_gate
65
+ self.use_output_gate = use_output_gate
66
+ self.use_norm = use_norm
67
+
68
+ if num_slots is None:
69
+ num_slots = self.head_k_dim
70
+ self.num_slots = num_slots
71
+
72
+ self.norm_eps = norm_eps
73
+
74
+ self.clamp_min = clamp_min
75
+ self.clamp_max = clamp_max
76
+ self.layer_idx = layer_idx
77
+
78
+ if layer_idx is None:
79
+ warnings.warn(
80
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
81
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
82
+ "when creating this class."
83
+ )
84
+
85
+ self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
86
+ self.k_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
87
+ self.v_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False)
88
+
89
+ if use_output_gate:
90
+ self.g_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False)
91
+ self.s_proj = nn.Linear(self.hidden_size, self.num_heads * self.num_slots, bias=False)
92
+ self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
93
+
94
+ if use_short_conv:
95
+ self.conv_size = conv_size
96
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
97
+ self.k_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
98
+ self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu')
99
+
100
+ if self.use_norm:
101
+ if self.use_output_gate:
102
+ self.g_norm = FusedRMSNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps)
103
+ else:
104
+ self.g_norm = RMSNorm(hidden_size=self.head_v_dim, elementwise_affine=elementwise_affine, eps=norm_eps)
105
+
106
+ if self.use_rope:
107
+ self.rotary = RotaryEmbedding(self.head_k_dim)
108
+
109
+ self.apply(self._initialize_weights)
110
+
111
+ def _initialize_weights(self, module: nn.Module):
112
+ if getattr(module, "_is_hf_initialized", False):
113
+ return
114
+ if isinstance(module, nn.Linear):
115
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
116
+ if module.bias is not None:
117
+ nn.init.zeros_(module.bias)
118
+ module._is_hf_initialized = True
119
+
120
+ def forward(
121
+ self,
122
+ hidden_states: torch.Tensor,
123
+ attention_mask: Optional[torch.Tensor] = None,
124
+ past_key_values: Optional[Cache] = None,
125
+ use_cache: Optional[bool] = False,
126
+ output_attentions: Optional[bool] = False,
127
+ **kwargs
128
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
129
+ if attention_mask is not None:
130
+ assert len(attention_mask.shape) == 2, (
131
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
132
+ "for padding purposes (0 indicating padding). "
133
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
134
+ )
135
+
136
+ last_state = None
137
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
138
+ last_state = past_key_values[self.layer_idx]
139
+
140
+ if self.use_short_conv:
141
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
142
+ if last_state is not None:
143
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
144
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
145
+ q, conv_state_q = self.q_conv1d(x=self.q_proj(hidden_states),
146
+ mask=conv_mask,
147
+ cache=conv_state_q,
148
+ output_final_state=use_cache)
149
+ k, conv_state_k = self.k_conv1d(x=self.k_proj(hidden_states),
150
+ mask=conv_mask,
151
+ cache=conv_state_k,
152
+ output_final_state=use_cache)
153
+ v, conv_state_v = self.v_conv1d(x=self.v_proj(hidden_states),
154
+ mask=conv_mask,
155
+ cache=conv_state_v,
156
+ output_final_state=use_cache)
157
+ else:
158
+ q = self.q_proj(hidden_states)
159
+ k = self.k_proj(hidden_states)
160
+ v = self.v_proj(hidden_states)
161
+
162
+ if self.use_input_gate:
163
+ q, k, v = map(lambda x: swish(x), (q, k, v))
164
+ # dealing with left-padding
165
+ if attention_mask is not None:
166
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
167
+
168
+ q, k, v = map(lambda x: rearrange(x, '... (h d) -> ... h d', h=self.num_heads), (q, k, v))
169
+ if self.use_rope:
170
+ seqlen_offset = 0
171
+ if past_key_values is not None:
172
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
173
+ q, k = self.rotary(q, k, seqlen_offset)
174
+
175
+ s = rearrange(self.s_proj(hidden_states), '... (h m) -> ... h m', h=self.num_heads)
176
+ s = s.clamp_(self.clamp_min, self.clamp_max)
177
+
178
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
179
+ o, recurrent_state = chunk_abc(
180
+ q=q,
181
+ k=k,
182
+ v=v,
183
+ s=s,
184
+ initial_state=recurrent_state,
185
+ output_final_state=use_cache,
186
+ head_first=False
187
+ )
188
+ if past_key_values is not None:
189
+ past_key_values.update(
190
+ recurrent_state=recurrent_state,
191
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
192
+ layer_idx=self.layer_idx,
193
+ offset=q.shape[2]
194
+ )
195
+
196
+ if self.use_norm and not self.use_output_gate:
197
+ o = self.g_norm(o)
198
+ elif self.use_output_gate:
199
+ g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', h=self.num_heads)
200
+ o = self.g_norm(o, g) if self.use_norm else swiglu(g, o)
201
+ o = rearrange(o, '... h d -> ... (h d)')
202
+ o = self.o_proj(o)
203
+
204
+ return o, None, past_key_values
205
+
206
+ def state_size(self, seq_len: int = 2048):
207
+ return self.num_heads * self.key_dim * self.head_v_dim
fla/layers/attn.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint
13
+ from einops import rearrange
14
+ from transformers.utils import logging
15
+
16
+ from fla.modules import RMSNorm, RotaryEmbedding
17
+
18
+ if TYPE_CHECKING:
19
+ from fla.models.utils import Cache
20
+
21
+ try:
22
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
23
+ from flash_attn.bert_padding import (index_first_axis, pad_input,
24
+ unpad_input)
25
+ except ImportError:
26
+ warnings.warn(
27
+ "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
28
+ category=ImportWarning
29
+ )
30
+ flash_attn_func = None
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ class Attention(nn.Module):
36
+
37
+ def __init__(
38
+ self,
39
+ hidden_size: int = 2048,
40
+ num_heads: int = 32,
41
+ num_kv_heads: Optional[int] = None,
42
+ window_size: Optional[int] = None,
43
+ rope_theta: Optional[float] = 10000.,
44
+ max_position_embeddings: Optional[int] = None,
45
+ norm_first: bool = False,
46
+ norm_eps: float = 1e-5,
47
+ layer_idx: int = None
48
+ ):
49
+ super().__init__()
50
+
51
+ self.num_heads = num_heads
52
+ if num_kv_heads is None:
53
+ self.num_kv_heads = self.num_heads
54
+ else:
55
+ self.num_kv_heads = num_kv_heads
56
+ self.num_kv_groups = num_heads // self.num_kv_heads
57
+ self.hidden_size = hidden_size
58
+ self.head_dim = self.hidden_size // self.num_heads
59
+ self.kv_dim = self.num_kv_heads * self.head_dim
60
+ self.kv_dim = self.num_kv_heads * self.head_dim
61
+ self.window_size = window_size
62
+ self.rope_theta = rope_theta
63
+ self.max_position_embeddings = max_position_embeddings
64
+ self.norm_first = norm_first
65
+ self.layer_idx = layer_idx
66
+
67
+ if norm_first:
68
+ self.norm = RMSNorm(self.hidden_size, eps=norm_eps)
69
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
70
+ self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False)
71
+ self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False)
72
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
73
+
74
+ self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta)
75
+
76
+ def forward(
77
+ self,
78
+ hidden_states: torch.Tensor,
79
+ attention_mask: Optional[torch.LongTensor] = None,
80
+ past_key_values: Optional[Cache] = None,
81
+ output_attentions: bool = False,
82
+ use_cache: bool = False,
83
+ **kwargs,
84
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
85
+ if attention_mask is not None:
86
+ assert len(attention_mask.shape) == 2, (
87
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
88
+ "for padding purposes (0 indicating padding). "
89
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
90
+ )
91
+
92
+ batch_size, q_len, _ = hidden_states.size()
93
+
94
+ if self.norm_first:
95
+ hidden_states = self.norm(hidden_states)
96
+
97
+ q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', h=self.num_heads)
98
+ k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', h=self.num_kv_heads)
99
+ v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', h=self.num_kv_heads)
100
+
101
+ seqlen_offset, max_seqlen = 0, q_len
102
+ if past_key_values is not None:
103
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
104
+ max_seqlen = q.shape[1] + seqlen_offset
105
+
106
+ if attention_mask is not None:
107
+ # to deliminate the offsets of padding tokens
108
+ seqlen_offset = (seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]).clamp(min=0)
109
+ max_seqlen = q.shape[1] + max(seqlen_offset)
110
+
111
+ if self.max_position_embeddings is not None:
112
+ max_seqlen = max(max_seqlen, self.max_position_embeddings)
113
+ q, k = self.rotary(q, k, seqlen_offset, max_seqlen)
114
+
115
+ if past_key_values is not None:
116
+ k, v = past_key_values.update(
117
+ attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)),
118
+ layer_idx=self.layer_idx,
119
+ offset=q_len,
120
+ cache_kwargs=dict(window_size=self.window_size)
121
+ )['attn_state']
122
+ k = rearrange(k, '... (h d) -> ... h d', h=self.num_kv_heads)
123
+ v = rearrange(v, '... (h d) -> ... h d', h=self.num_kv_heads)
124
+
125
+ if flash_attn_func is None:
126
+ raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first")
127
+
128
+ # Contains at least one padding token in the sequence
129
+ if attention_mask is not None:
130
+ q, k, v, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(q, k, v, attention_mask, q_len)
131
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
132
+ max_seqlen_q, max_seqlen_k = max_seq_lens
133
+ o = flash_attn_varlen_func(
134
+ q, k, v,
135
+ cu_seqlens_q=cu_seqlens_q,
136
+ cu_seqlens_k=cu_seqlens_k,
137
+ max_seqlen_q=max_seqlen_q,
138
+ max_seqlen_k=max_seqlen_k,
139
+ causal=True,
140
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
141
+ )
142
+ o = pad_input(o, indices_q, batch_size, q_len)
143
+ else:
144
+ o = flash_attn_func(
145
+ q, k, v,
146
+ causal=True,
147
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
148
+ )
149
+ o = o.reshape(batch_size, q_len, self.hidden_size)
150
+ o = self.o_proj(o)
151
+
152
+ if not output_attentions:
153
+ attentions = None
154
+
155
+ return o, attentions, past_key_values
156
+
157
+ def _upad_input(self, q, k, v, attention_mask, q_len):
158
+ seqlens = attention_mask.sum(-1, dtype=torch.int32)
159
+ indices_k = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
160
+ max_seqlen_k = seqlens.max().item()
161
+ cu_seqlens_k = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0))
162
+ batch_size, seq_len, num_key_value_heads, head_dim = k.shape
163
+
164
+ k = index_first_axis(k.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
165
+ v = index_first_axis(v.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
166
+ if q_len == seq_len:
167
+ q = index_first_axis(q.reshape(batch_size * seq_len, self.num_heads, head_dim), indices_k)
168
+ cu_seqlens_q = cu_seqlens_k
169
+ max_seqlen_q = max_seqlen_k
170
+ indices_q = indices_k
171
+ elif q_len == 1:
172
+ max_seqlen_q = 1
173
+ # There is a memcpy here, that is very bad.
174
+ cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device)
175
+ indices_q = cu_seqlens_q[:-1]
176
+ q = q.squeeze(1)
177
+ else:
178
+ # The -q_len: slice assumes left padding.
179
+ attention_mask = attention_mask[:, -q_len:]
180
+ q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask)
181
+
182
+ return q, k, v, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
fla/layers/based.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ """
5
+ Linear attention in Based.
6
+ https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange
12
+
13
+ from fla.modules.feature_map import TaylorFeatureMap
14
+ from fla.ops.based import parallel_based
15
+ from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn
16
+
17
+
18
+ class BasedLinearAttention(nn.Module):
19
+
20
+ def __init__(
21
+ self,
22
+ hidden_size: int,
23
+ feature_dim: int = 16,
24
+ num_key_value_heads: int = 12,
25
+ num_heads: int = 12,
26
+ feature_name: str = "taylor_exp",
27
+ eps: float = 1e-12,
28
+ causal: bool = True,
29
+ mode: str = "parallel",
30
+ ):
31
+ super().__init__()
32
+
33
+ self.hidden_size = hidden_size
34
+ self.mode = mode
35
+ self.feature_name = feature_name
36
+ self.feature_dim = feature_dim
37
+ self.num_key_value_heads = num_key_value_heads
38
+ self.num_heads = num_heads
39
+ self.head_dim = self.hidden_size // self.num_key_value_heads
40
+ self.causal = causal
41
+
42
+ self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
43
+ self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
44
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
45
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
46
+ self.dropout = nn.Identity()
47
+ self.feature_map = TaylorFeatureMap(feature_dim)
48
+ self.eps = eps
49
+
50
+ self.apply(self._initialize_weights)
51
+
52
+ def _initialize_weights(self, module: nn.Module):
53
+ if getattr(module, "_is_hf_initialized", False):
54
+ return
55
+ if isinstance(module, nn.Linear):
56
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
57
+ if module.bias is not None:
58
+ nn.init.zeros_(module.bias)
59
+ module._is_hf_initialized = True
60
+
61
+ def forward(self, hidden_states: torch.Tensor, **kwargs):
62
+ mode = self.mode
63
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
64
+ q, k, v = map(lambda x: rearrange(x, "... (h d) -> ... h d", h=self.num_heads), [q, k, v])
65
+ if mode == "fused_chunk":
66
+ q, k = self.feature_map(q), self.feature_map(k)
67
+ o = fused_chunk_linear_attn(q, k, v, normalize=True, scale=1, head_first=False)
68
+ elif mode == 'chunk':
69
+ q, k = self.feature_map(q), self.feature_map(k)
70
+ o = chunk_linear_attn(q, k, v, normalize=True, scale=1, head_first=False)
71
+ elif mode == 'parallel':
72
+ assert q.shape[-1] <= 128
73
+ o = parallel_based(q, k, v, True, True, head_first=False)
74
+ o = self.o_proj(o)
75
+ o = self.dropout(o)
76
+ return o
77
+
78
+ # https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119
79
+
80
+ def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor = None, *args, **kwargs):
81
+ """
82
+ x (torch.Tensor): tensor of shape (b, d, t)
83
+ y (torch.Tensor): tensor of shape (b, d, t)
84
+ """
85
+ # hidden_states = hidden_states.transpose(1, 2)
86
+ b, t, _ = hidden_states.size()
87
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
88
+
89
+ q = q.view(b, t, self.num_heads, self.feature_dim).transpose(1, 2)
90
+ k = k.view(b, t, self.num_key_value_heads, self.feature_dim).transpose(1, 2)
91
+ v = v.view(b, t, self.num_key_value_heads, self.head_dim).transpose(1, 2)
92
+
93
+ # Linear attention
94
+ q, k = self.feature_map(q), self.feature_map(k)
95
+ q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1)
96
+
97
+ # Compute attention
98
+ if self.causal:
99
+ y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps))
100
+ else:
101
+ y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps))
102
+ y = rearrange(y, 'b h t d -> b t (h d)')
103
+ y = self.o_proj(y.to(hidden_states.dtype))
104
+ y = self.dropout(y)
105
+ return y.to(hidden_states.dtype)
fla/layers/bitattn.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint
13
+ from einops import rearrange
14
+ from transformers.utils import logging
15
+
16
+ from fla.modules import RMSNorm, RotaryEmbedding
17
+ from fla.modules.fused_bitlinear import FusedBitLinear
18
+
19
+ if TYPE_CHECKING:
20
+ from fla.models.utils import Cache
21
+
22
+ try:
23
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
24
+ from flash_attn.bert_padding import (index_first_axis, pad_input,
25
+ unpad_input)
26
+ except ImportError:
27
+ warnings.warn(
28
+ "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
29
+ category=ImportWarning
30
+ )
31
+ flash_attn_func = None
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ class BitAttention(nn.Module):
37
+
38
+ def __init__(
39
+ self,
40
+ hidden_size: int = 2048,
41
+ num_heads: int = 32,
42
+ num_kv_heads: Optional[int] = None,
43
+ window_size: Optional[int] = None,
44
+ rope_theta: Optional[float] = 10000.,
45
+ max_position_embeddings: Optional[int] = None,
46
+ norm_first: bool = False,
47
+ norm_eps: float = 1e-5,
48
+ layer_idx: int = None
49
+ ):
50
+ super().__init__()
51
+
52
+ self.num_heads = num_heads
53
+ if num_kv_heads is None:
54
+ self.num_kv_heads = self.num_heads
55
+ else:
56
+ self.num_kv_heads = num_kv_heads
57
+ self.num_kv_groups = num_heads // self.num_kv_heads
58
+ self.hidden_size = hidden_size
59
+ self.head_dim = self.hidden_size // self.num_heads
60
+ self.kv_dim = self.num_kv_heads * self.head_dim
61
+ self.kv_dim = self.num_kv_heads * self.head_dim
62
+ self.window_size = window_size
63
+ self.rope_theta = rope_theta
64
+ self.max_position_embeddings = max_position_embeddings
65
+ self.norm_first = norm_first
66
+ self.layer_idx = layer_idx
67
+
68
+ if norm_first:
69
+ self.norm = RMSNorm(self.hidden_size, eps=norm_eps)
70
+ self.q_proj = FusedBitLinear(self.hidden_size, self.hidden_size, bias=False)
71
+ self.k_proj = FusedBitLinear(self.hidden_size, self.kv_dim, bias=False)
72
+ self.v_proj = FusedBitLinear(self.hidden_size, self.kv_dim, bias=False)
73
+ self.o_proj = FusedBitLinear(self.hidden_size, self.hidden_size, bias=False)
74
+
75
+ self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta)
76
+
77
+ def forward(
78
+ self,
79
+ hidden_states: torch.Tensor,
80
+ attention_mask: Optional[torch.LongTensor] = None,
81
+ past_key_values: Optional[Cache] = None,
82
+ output_attentions: bool = False,
83
+ use_cache: bool = False,
84
+ **kwargs,
85
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
86
+ if attention_mask is not None:
87
+ assert len(attention_mask.shape) == 2, (
88
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
89
+ "for padding purposes (0 indicating padding). "
90
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
91
+ )
92
+
93
+ batch_size, q_len, _ = hidden_states.size()
94
+
95
+ if self.norm_first:
96
+ hidden_states = self.norm(hidden_states)
97
+
98
+ q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', h=self.num_heads)
99
+ k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', h=self.num_kv_heads)
100
+ v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', h=self.num_kv_heads)
101
+
102
+ seqlen_offset, max_seqlen = 0, q_len
103
+ if past_key_values is not None:
104
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
105
+ max_seqlen = q.shape[1] + seqlen_offset
106
+
107
+ if attention_mask is not None:
108
+ # to deliminate the offsets of padding tokens
109
+ seqlen_offset = (seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]).clamp(min=0)
110
+ max_seqlen = q.shape[1] + max(seqlen_offset)
111
+
112
+ if self.max_position_embeddings is not None:
113
+ max_seqlen = max(max_seqlen, self.max_position_embeddings)
114
+ q, k = self.rotary(q, k, seqlen_offset, max_seqlen)
115
+
116
+ if past_key_values is not None:
117
+ k, v = past_key_values.update(
118
+ attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)),
119
+ layer_idx=self.layer_idx,
120
+ offset=q_len,
121
+ cache_kwargs=dict(window_size=self.window_size)
122
+ )['attn_state']
123
+ k = rearrange(k, '... (h d) -> ... h d', h=self.num_kv_heads)
124
+ v = rearrange(v, '... (h d) -> ... h d', h=self.num_kv_heads)
125
+
126
+ if flash_attn_func is None:
127
+ raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first")
128
+
129
+ # Contains at least one padding token in the sequence
130
+ if attention_mask is not None:
131
+ q, k, v, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(q, k, v, attention_mask, q_len)
132
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
133
+ max_seqlen_q, max_seqlen_k = max_seq_lens
134
+ o = flash_attn_varlen_func(
135
+ q, k, v,
136
+ cu_seqlens_q=cu_seqlens_q,
137
+ cu_seqlens_k=cu_seqlens_k,
138
+ max_seqlen_q=max_seqlen_q,
139
+ max_seqlen_k=max_seqlen_k,
140
+ causal=True,
141
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
142
+ )
143
+ o = pad_input(o, indices_q, batch_size, q_len)
144
+ else:
145
+ o = flash_attn_func(
146
+ q, k, v,
147
+ causal=True,
148
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
149
+ )
150
+ o = o.reshape(batch_size, q_len, self.hidden_size)
151
+ o = self.o_proj(o)
152
+
153
+ if not output_attentions:
154
+ attentions = None
155
+
156
+ return o, attentions, past_key_values
157
+
158
+ def _upad_input(self, q, k, v, attention_mask, q_len):
159
+ seqlens = attention_mask.sum(-1, dtype=torch.int32)
160
+ indices_k = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
161
+ max_seqlen_k = seqlens.max().item()
162
+ cu_seqlens_k = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0))
163
+ batch_size, seq_len, num_key_value_heads, head_dim = k.shape
164
+
165
+ k = index_first_axis(k.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
166
+ v = index_first_axis(v.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
167
+ if q_len == seq_len:
168
+ q = index_first_axis(q.reshape(batch_size * seq_len, self.num_heads, head_dim), indices_k)
169
+ cu_seqlens_q = cu_seqlens_k
170
+ max_seqlen_q = max_seqlen_k
171
+ indices_q = indices_k
172
+ elif q_len == 1:
173
+ max_seqlen_q = 1
174
+ # There is a memcpy here, that is very bad.
175
+ cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device)
176
+ indices_q = cu_seqlens_q[:-1]
177
+ q = q.squeeze(1)
178
+ else:
179
+ # The -q_len: slice assumes left padding.
180
+ attention_mask = attention_mask[:, -q_len:]
181
+ q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask)
182
+
183
+ return q, k, v, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
fla/layers/delta_net.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ # Sect4.2 of Linear Transformers Are Secretly Fast Weight Programmers https://arxiv.org/abs/2102.11174
5
+ from __future__ import annotations
6
+
7
+ from typing import TYPE_CHECKING, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange
12
+ from torch.nn import functional as F
13
+
14
+ from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
15
+ from fla.modules.l2norm import l2_norm
16
+ from fla.ops.delta_rule import (chunk_delta_rule, fused_chunk_delta_rule,
17
+ fused_recurrent_delta_rule)
18
+
19
+ if TYPE_CHECKING:
20
+ from fla.models.utils import Cache
21
+
22
+
23
+ def elu_p1(x):
24
+ return (F.elu(x, 1., False) + 1.).to(x)
25
+
26
+
27
+ def sum_norm(x):
28
+ return (x / x.sum(-1, keepdim=True)).to(x)
29
+
30
+ # https://github.com/IDSIA/recurrent-fwp/blob/master/algorithmic/layers.py#L86C1-L146C1
31
+
32
+
33
+ class DeltaNet(nn.Module):
34
+ def __init__(
35
+ self,
36
+ d_model: int = None,
37
+ hidden_size: int = 1024,
38
+ expand_k: float = 1.0,
39
+ expand_v: float = 1.0,
40
+ num_heads: int = 4,
41
+ mode: str = 'chunk',
42
+ use_beta: bool = True,
43
+ use_gate: bool = False,
44
+ use_output_norm: bool = True,
45
+ use_elu: bool = False,
46
+ use_short_conv: bool = True,
47
+ conv_size: int = 4,
48
+ conv_bias: bool = False,
49
+ layer_idx: int = None,
50
+ qk_activation: str = 'silu',
51
+ qk_norm: str = 'l2',
52
+ norm_first: bool = False,
53
+ norm_eps: float = 1e-5,
54
+ **kwargs
55
+ ) -> DeltaNet:
56
+ super().__init__()
57
+
58
+ self.mode = mode
59
+ self.qk_activation = qk_activation
60
+ self.qk_norm = qk_norm
61
+
62
+ assert self.qk_activation in ['silu', 'relu', 'elu', 'identity']
63
+ assert self.qk_norm in ['l2', 'sum']
64
+
65
+ if d_model is not None:
66
+ hidden_size = d_model
67
+ self.hidden_size = hidden_size
68
+ self.expand_k = expand_k
69
+ self.expand_v = expand_v
70
+ self.num_heads = num_heads
71
+ self.use_gate = use_gate
72
+ self.use_output_norm = use_output_norm
73
+ self.use_short_conv = use_short_conv
74
+ self.conv_size = conv_size
75
+ self.conv_bias = conv_bias
76
+
77
+ self.key_dim = int(hidden_size * expand_k)
78
+ self.value_dim = int(hidden_size * expand_v)
79
+ self.head_qk_dim = self.key_dim // num_heads
80
+ self.head_v_dim = self.value_dim // num_heads
81
+ self.norm_first = norm_first
82
+ self.layer_idx = layer_idx
83
+
84
+ self.silu = nn.SiLU()
85
+
86
+ assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
87
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
88
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
89
+
90
+ if norm_first:
91
+ self.norm = RMSNorm(self.hidden_size, eps=norm_eps)
92
+
93
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
94
+ self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
95
+ self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
96
+
97
+ self.use_beta = use_beta
98
+ self.use_elu = use_elu
99
+ if self.use_beta:
100
+ self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
101
+ if use_short_conv:
102
+ self.conv_size = conv_size
103
+ self.q_conv1d = ShortConvolution(
104
+ hidden_size=self.key_dim,
105
+ kernel_size=conv_size,
106
+ activation='silu' if qk_activation == 'silu' else None
107
+ )
108
+ self.k_conv1d = ShortConvolution(
109
+ hidden_size=self.key_dim,
110
+ kernel_size=conv_size,
111
+ activation='silu' if qk_activation == 'silu' else None
112
+ )
113
+ self.v_conv1d = ShortConvolution(
114
+ hidden_size=self.value_dim,
115
+ kernel_size=conv_size,
116
+ activation='silu'
117
+ )
118
+ else:
119
+ raise UserWarning(
120
+ "ShortConvolution is crucial to the performance. "
121
+ "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing."
122
+ )
123
+ if use_gate:
124
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
125
+ self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps)
126
+ else:
127
+ self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps)
128
+
129
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
130
+
131
+ self.apply(self._initialize_weights)
132
+
133
+ def _initialize_weights(self, module: nn.Module):
134
+ if getattr(module, "_is_hf_initialized", False):
135
+ return
136
+ if isinstance(module, nn.Linear):
137
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
138
+ if module.bias is not None:
139
+ nn.init.zeros_(module.bias)
140
+ module._is_hf_initialized = True
141
+
142
+ def forward(
143
+ self,
144
+ hidden_states: torch.Tensor,
145
+ attention_mask: Optional[torch.Tensor] = None,
146
+ past_key_values: Optional[Cache] = None,
147
+ use_cache: Optional[bool] = False,
148
+ output_attentions: Optional[bool] = False,
149
+ **kwargs
150
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
151
+ if attention_mask is not None:
152
+ assert len(attention_mask.shape) == 2, (
153
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
154
+ "for padding purposes (0 indicating padding). "
155
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
156
+ )
157
+
158
+ # change to inference mode.
159
+ mode = 'fused_recurrent' if hidden_states.shape[1] < 64 else self.mode
160
+
161
+ if self.norm_first:
162
+ hidden_states = self.norm(hidden_states)
163
+
164
+ last_state = None
165
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
166
+ last_state = past_key_values[self.layer_idx]
167
+
168
+ if self.use_short_conv:
169
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
170
+ if last_state is not None:
171
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
172
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
173
+ q, conv_state_q = self.q_conv1d(x=self.q_proj(hidden_states),
174
+ mask=conv_mask,
175
+ cache=conv_state_q,
176
+ output_final_state=use_cache)
177
+ k, conv_state_k = self.k_conv1d(x=self.k_proj(hidden_states),
178
+ mask=conv_mask,
179
+ cache=conv_state_k,
180
+ output_final_state=use_cache)
181
+ v, conv_state_v = self.v_conv1d(x=self.v_proj(hidden_states),
182
+ mask=conv_mask,
183
+ cache=conv_state_v,
184
+ output_final_state=use_cache)
185
+ else:
186
+ q = self.q_proj(hidden_states)
187
+ k = self.k_proj(hidden_states)
188
+ v = self.silu(self.v_proj(hidden_states))
189
+
190
+ q, k, v = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', h=self.num_heads), (q, k, v))
191
+ if self.qk_activation != 'silu':
192
+ if self.qk_activation == 'relu':
193
+ q, k = q.relu(), k.relu()
194
+ elif self.qk_activation == 'elu':
195
+ q, k = elu_p1(q), elu_p1(k)
196
+ elif self.qk_activation == 'identity':
197
+ pass
198
+ else:
199
+ raise NotImplementedError
200
+
201
+ if self.qk_norm is not None:
202
+ if self.qk_norm == 'l2':
203
+ q = l2_norm(q)
204
+ k = l2_norm(k)
205
+ elif self.qk_norm == 'sum':
206
+ q = sum_norm(q).to(q)
207
+ k = sum_norm(k).to(k)
208
+
209
+ if self.use_beta:
210
+ beta = self.b_proj(hidden_states).sigmoid()
211
+ else:
212
+ beta = q.new_ones(q.shape[0], q.shape[1], q.shape[2])
213
+
214
+ # dealing with padding
215
+ if attention_mask is not None:
216
+ beta = beta.mul(attention_mask[:, -beta.shape[-2]:, None])
217
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
218
+ if mode == 'fused_recurrent':
219
+ o, recurrent_state = fused_recurrent_delta_rule(
220
+ q=q,
221
+ k=k,
222
+ v=v,
223
+ beta=beta,
224
+ initial_state=recurrent_state,
225
+ output_final_state=use_cache,
226
+ head_first=False
227
+ )
228
+ elif mode == 'fused_chunk':
229
+ o, recurrent_state = fused_chunk_delta_rule(
230
+ q=q,
231
+ k=k,
232
+ v=v,
233
+ beta=beta,
234
+ initial_state=recurrent_state,
235
+ output_final_state=use_cache,
236
+ head_first=False
237
+ )
238
+ elif mode == 'chunk':
239
+ o, recurrent_state = chunk_delta_rule(
240
+ q=q,
241
+ k=k,
242
+ v=v,
243
+ beta=beta,
244
+ initial_state=recurrent_state,
245
+ output_final_state=use_cache,
246
+ head_first=False
247
+ )
248
+ else:
249
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
250
+
251
+ if past_key_values is not None:
252
+ past_key_values.update(
253
+ recurrent_state=recurrent_state,
254
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
255
+ layer_idx=self.layer_idx,
256
+ offset=q.shape[2]
257
+ )
258
+
259
+ if self.use_gate:
260
+ g = rearrange(self.g_proj(hidden_states), 'b t (h d) -> b t h d', h=self.num_heads)
261
+ o = self.o_norm(o, g)
262
+ else:
263
+ o = self.o_norm(o)
264
+ o = rearrange(o, 'b t h d -> b t (h d)')
265
+ o = self.o_proj(o)
266
+
267
+ return o, None, past_key_values
fla/layers/gla.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+
5
+ from __future__ import annotations
6
+
7
+ from typing import TYPE_CHECKING, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from einops import rearrange, repeat
13
+
14
+ from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
15
+ from fla.modules.activations import ACT2FN
16
+ from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
17
+
18
+ if TYPE_CHECKING:
19
+ from fla.models.utils import Cache
20
+
21
+
22
+ class GatedLinearAttention(nn.Module):
23
+ r"""
24
+ The layer implementaion for [Gated Linear Attention Transformers with Hardware-Efficient Training](https://arxiv.org/abs/2312.06635). # noqa
25
+
26
+ Args:
27
+ mode (str, Optional):
28
+ Which GLA kernel to use.
29
+ Currently available: `chunk`, `fused_recurrent`, and `fused_chunk`.
30
+ Default: `chunk`.
31
+ hidden_size (int, Optional):
32
+ The hidden size of the input. Default: 1024.
33
+ expand_k (float, Optional):
34
+ The expansion ratio for the key dim. Default: 0.5.
35
+ expand_v (float, Optional):
36
+ The expansion ratio for the value dim. Default: 1.0.
37
+ num_heads (int, Optional):
38
+ The number of heads. Default: 4.
39
+ num_kv_heads (int, Optional):
40
+ The number of key/value heads, used for MQA. Default: None.
41
+ feature_map (str, Optional):
42
+ Feature map function applied to queries/keys. Default: None.
43
+ use_short_conv (bool, Optional):
44
+ Whether to use short convolutions. Default: `False`.
45
+ conv_size (int, Optional):
46
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
47
+ conv_bias (bool, Optional):
48
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
49
+ use_output_gate (bool, Optional):
50
+ Whether to use output gate. Default: `True`.
51
+ gate_fn (str, Optional):
52
+ The activation function for the output gate. Default: `swish`.
53
+ elementwise_affine (bool, Optional):
54
+ If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
55
+ norm_eps (float, Optional):
56
+ The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
57
+ gate_logit_normalizer (int, Optional):
58
+ The normalizer for the gate logits, appied after `logsigmoid`. Default: 16.
59
+ gate_low_rank_dim (int, Optional):
60
+ The low rank dim for the gate projection. Default: 16.
61
+ clamp_min (float, Optional):
62
+ The minimum value for the gate logits. Default: None.
63
+ fuse_norm (bool, Optional):
64
+ Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
65
+ layer_idx (int, Optional):
66
+ The index of the layer. Default: None.
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ mode: str = 'chunk',
72
+ hidden_size: int = 1024,
73
+ expand_k: float = 0.5,
74
+ expand_v: float = 1.0,
75
+ num_heads: int = 4,
76
+ num_kv_heads: Optional[int] = None,
77
+ feature_map: Optional[str] = None,
78
+ use_short_conv: bool = False,
79
+ conv_size: int = 4,
80
+ conv_bias: bool = False,
81
+ use_output_gate: bool = True,
82
+ gate_fn: str = 'swish',
83
+ elementwise_affine: Optional[bool] = True,
84
+ norm_eps: float = 1e-5,
85
+ gate_logit_normalizer: int = 16,
86
+ gate_low_rank_dim: int = 16,
87
+ clamp_min: Optional[float] = None,
88
+ fuse_norm: bool = True,
89
+ layer_idx: int = None,
90
+ ) -> GatedLinearAttention:
91
+ super().__init__()
92
+
93
+ self.mode = mode
94
+ self.hidden_size = hidden_size
95
+ self.expand_k = expand_k
96
+ self.expand_v = expand_v
97
+ self.num_heads = num_heads
98
+ self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
99
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
100
+ self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None
101
+
102
+ self.use_short_conv = use_short_conv
103
+ self.conv_size = conv_size
104
+ self.conv_bias = conv_bias
105
+ self.use_output_gate = use_output_gate
106
+
107
+ self.key_dim = int(hidden_size * expand_k)
108
+ self.value_dim = int(hidden_size * expand_v)
109
+ self.key_dim_per_group = self.key_dim // self.num_kv_groups
110
+ self.value_dim_per_group = self.value_dim // self.num_kv_groups
111
+ self.clamp_min = clamp_min
112
+ self.layer_idx = layer_idx
113
+
114
+ assert mode in ['chunk', 'fused_recurrent', 'fused_chunk'], f"Not suppoerted mode `{mode}`."
115
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
116
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
117
+
118
+ self.head_qk_dim = self.key_dim // num_heads
119
+ self.head_v_dim = self.value_dim // num_heads
120
+
121
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
122
+ self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
123
+ self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
124
+ if self.use_output_gate:
125
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
126
+
127
+ if use_short_conv:
128
+ self.conv_size = conv_size
129
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
130
+ self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
131
+ self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
132
+
133
+ self.gk_proj = nn.Sequential(nn.Linear(hidden_size, gate_low_rank_dim, bias=False),
134
+ nn.Linear(gate_low_rank_dim, self.key_dim_per_group, bias=True))
135
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
136
+
137
+ if gate_fn == 'swish' and fuse_norm and use_output_gate:
138
+ self.g_norm_swish_gate = FusedRMSNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps)
139
+ self.fuse_norm_and_gate = True
140
+ else:
141
+ self.fuse_norm_and_gate = False
142
+ self.g_norm = RMSNorm(hidden_size=self.head_v_dim, elementwise_affine=elementwise_affine, eps=norm_eps)
143
+ self.gate_fn = ACT2FN[gate_fn]
144
+
145
+ self.gate_logit_normalizer = gate_logit_normalizer
146
+
147
+ self.apply(self._initialize_weights)
148
+
149
+ def _initialize_weights(self, module: nn.Module):
150
+ if getattr(module, "_is_hf_initialized", False):
151
+ return
152
+ if isinstance(module, nn.Linear):
153
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
154
+ if module.bias is not None:
155
+ nn.init.zeros_(module.bias)
156
+ module._is_hf_initialized = True
157
+
158
+ def forward(
159
+ self,
160
+ hidden_states: torch.Tensor,
161
+ attention_mask: Optional[torch.Tensor] = None,
162
+ past_key_values: Optional[Cache] = None,
163
+ use_cache: Optional[bool] = False,
164
+ output_attentions: Optional[bool] = False,
165
+ **kwargs
166
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
167
+ if attention_mask is not None:
168
+ assert len(attention_mask.shape) == 2, (
169
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
170
+ "for padding purposes (0 indicating padding). "
171
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
172
+ )
173
+
174
+ # launching the triton kernel for just one token will actually be slower
175
+ mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode
176
+
177
+ last_state = None
178
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
179
+ last_state = past_key_values[self.layer_idx]
180
+ if self.use_short_conv:
181
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
182
+ if last_state is not None:
183
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
184
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
185
+ q, conv_state_q = self.q_conv1d(x=self.q_proj(hidden_states),
186
+ mask=conv_mask,
187
+ cache=conv_state_q,
188
+ output_final_state=use_cache)
189
+ k, conv_state_k = self.k_conv1d(x=self.k_proj(hidden_states),
190
+ mask=conv_mask,
191
+ cache=conv_state_k,
192
+ output_final_state=use_cache)
193
+ v, conv_state_v = self.v_conv1d(x=self.v_proj(hidden_states),
194
+ mask=conv_mask,
195
+ cache=conv_state_v,
196
+ output_final_state=use_cache)
197
+ else:
198
+ q = self.q_proj(hidden_states)
199
+ k = self.k_proj(hidden_states)
200
+ v = self.v_proj(hidden_states)
201
+ gk = self.gk_proj(hidden_states)
202
+
203
+ if self.feature_map_fn is not None:
204
+ q, k = map(self.feature_map_fn, (q, k))
205
+ # dealing with left-padding
206
+ if attention_mask is not None:
207
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
208
+ q = rearrange(q, 'b t (h d) -> b t h d', h=self.num_heads)
209
+ if self.num_kv_groups > 1:
210
+ k, v, gk = (repeat(x, 'b t (h d) -> b t (h g) d', h=self.num_kv_heads, g=self.num_kv_groups) for x in (k, v, gk))
211
+ else:
212
+ k, v, gk = (rearrange(x, 'b t (h d) -> b t h d', h=self.num_kv_heads) for x in (k, v, gk))
213
+ gk = F.logsigmoid(gk) / self.gate_logit_normalizer
214
+
215
+ if self.clamp_min is not None:
216
+ gk = torch.clamp_min(gk, self.clamp_min)
217
+
218
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
219
+ if mode == 'fused_recurrent':
220
+ o, recurrent_state = fused_recurrent_gla(
221
+ q=q,
222
+ k=k,
223
+ v=v,
224
+ gk=gk,
225
+ initial_state=recurrent_state,
226
+ output_final_state=use_cache,
227
+ head_first=False
228
+ )
229
+ elif mode == 'fused_chunk':
230
+ o, recurrent_state = fused_chunk_gla(
231
+ q=q,
232
+ k=k,
233
+ v=v,
234
+ g=gk,
235
+ initial_state=recurrent_state,
236
+ output_final_state=use_cache,
237
+ head_first=False
238
+ )
239
+ elif mode == 'chunk':
240
+ o, recurrent_state = chunk_gla(
241
+ q=q,
242
+ k=k,
243
+ v=v,
244
+ g=gk,
245
+ initial_state=recurrent_state,
246
+ output_final_state=use_cache,
247
+ head_first=False
248
+ )
249
+ else:
250
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
251
+
252
+ if past_key_values is not None:
253
+ past_key_values.update(
254
+ recurrent_state=recurrent_state,
255
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
256
+ layer_idx=self.layer_idx,
257
+ offset=q.shape[2]
258
+ )
259
+
260
+ if self.use_output_gate:
261
+ g = self.g_proj(hidden_states)
262
+ if self.fuse_norm_and_gate:
263
+ g = rearrange(g, 'b t (h d) -> b t h d', h=self.num_heads)
264
+ o = self.g_norm_swish_gate(o, g)
265
+ o = rearrange(o, 'b t h d -> b t (h d)')
266
+ else:
267
+ o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
268
+ o = o * self.gate_fn(g)
269
+ else:
270
+ o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
271
+ o = self.o_proj(o)
272
+
273
+ return o, None, past_key_values
274
+
275
+ def state_size(self, **kwargs) -> int:
276
+ state_size = self.key_dim * self.head_v_dim
277
+ for module in self.children():
278
+ if isinstance(module, ShortConvolution):
279
+ state_size += module.state_size
280
+ return state_size
fla/layers/gsa.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from einops import rearrange
13
+
14
+ from fla.modules import RMSNorm, ShortConvolution
15
+ from fla.modules.activations import swish
16
+ from fla.modules.feature_map import (ReLUFeatureMap, SwishFeatureMap,
17
+ T2RFeatureMap)
18
+ from fla.modules.layernorm import rms_norm_linear
19
+ from fla.ops.gsa import chunk_gsa, fused_recurrent_gsa
20
+
21
+ if TYPE_CHECKING:
22
+ from fla.models.utils import Cache
23
+
24
+
25
+ class GatedSlotAttention(nn.Module):
26
+
27
+ def __init__(
28
+ self,
29
+ mode: str = 'chunk',
30
+ hidden_size: int = 1024,
31
+ expand_k: float = 1.,
32
+ expand_v: float = 1.,
33
+ num_heads: int = 4,
34
+ num_kv_heads: Optional[int] = None,
35
+ use_short_conv: bool = False,
36
+ conv_size: int = 4,
37
+ conv_bias: bool = False,
38
+ num_slots: Optional[int] = None,
39
+ elementwise_affine: Optional[bool] = True,
40
+ norm_first: bool = True,
41
+ norm_eps: float = 1e-5,
42
+ gate_logit_normalizer: int = 8,
43
+ feature_map: str = 'swish',
44
+ use_output_gate: bool = False,
45
+ use_norm: bool = True,
46
+ layer_idx: Optional[int] = None,
47
+ scale: Optional[float] = 1.,
48
+ **kwargs
49
+ ) -> GatedSlotAttention:
50
+ super().__init__()
51
+
52
+ self.mode = mode
53
+ self.hidden_size = hidden_size
54
+ self.expand_k = expand_k
55
+ self.expand_v = expand_v
56
+ self.num_heads = num_heads
57
+ self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
58
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
59
+ self.key_dim = int(hidden_size * expand_k)
60
+ self.value_dim = int(hidden_size * expand_v)
61
+ self.key_dim_per_group = self.key_dim // self.num_kv_groups
62
+ self.value_dim_per_group = self.value_dim // self.num_kv_groups
63
+ self.head_k_dim = self.key_dim // self.num_heads
64
+ self.head_v_dim = self.value_dim // self.num_heads
65
+
66
+ self.use_short_conv = use_short_conv
67
+ self.conv_size = conv_size
68
+ self.conv_bias = conv_bias
69
+
70
+ self.gate_logit_normalizer = gate_logit_normalizer
71
+
72
+ self.use_output_gate = use_output_gate
73
+ self.use_norm = use_norm
74
+ self.scale = scale
75
+
76
+ if num_slots is None:
77
+ num_slots = self.head_k_dim
78
+ self.num_slots = num_slots
79
+ self.norm_first = norm_first
80
+
81
+ self.layer_idx = layer_idx
82
+
83
+ if layer_idx is None:
84
+ warnings.warn(
85
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
86
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
87
+ "when creating this class."
88
+ )
89
+
90
+ if norm_first:
91
+ self.norm = RMSNorm(self.hidden_size, eps=norm_eps)
92
+ self.register_module('feature_map', None)
93
+ if feature_map == 'swish':
94
+ self.feature_map = SwishFeatureMap()
95
+ elif feature_map == 'relu':
96
+ self.feature_map = ReLUFeatureMap()
97
+ elif feature_map == 't2r':
98
+ self.feature_map = T2RFeatureMap(self.head_k_dim, self.head_k_dim)
99
+ else:
100
+ raise NotImplementedError(f"Feature map `{feature_map}` is not supported now.")
101
+
102
+ self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
103
+ self.k_proj = nn.Linear(self.hidden_size, self.key_dim_per_group, bias=False)
104
+ self.v_proj = nn.Linear(self.hidden_size, self.value_dim_per_group, bias=False)
105
+ self.f_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.num_slots, bias=False)
106
+
107
+ if use_short_conv:
108
+ self.conv_size = conv_size
109
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
110
+ self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
111
+ self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
112
+
113
+ self.g_norm = RMSNorm(self.hidden_size, elementwise_affine, eps=norm_eps)
114
+ self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
115
+
116
+ self.apply(self._initialize_weights)
117
+
118
+ def _initialize_weights(self, module: nn.Module):
119
+ if getattr(module, "_is_hf_initialized", False):
120
+ return
121
+ if isinstance(module, nn.Linear):
122
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
123
+ if module.bias is not None:
124
+ nn.init.zeros_(module.bias)
125
+ module._is_hf_initialized = True
126
+
127
+ def forward(
128
+ self,
129
+ hidden_states: torch.Tensor,
130
+ attention_mask: Optional[torch.Tensor] = None,
131
+ past_key_values: Optional[Cache] = None,
132
+ use_cache: Optional[bool] = False,
133
+ output_attentions: Optional[bool] = False,
134
+ **kwargs
135
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
136
+ if attention_mask is not None:
137
+ assert len(attention_mask.shape) == 2, (
138
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
139
+ "for padding purposes (0 indicating padding). "
140
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
141
+ )
142
+
143
+ # launching the triton kernel for just one token will actually be slower
144
+ mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode
145
+
146
+ if self.norm_first:
147
+ hidden_states = self.norm(hidden_states)
148
+
149
+ last_state = None
150
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
151
+ last_state = past_key_values[self.layer_idx]
152
+
153
+ if self.use_short_conv:
154
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
155
+ if last_state is not None:
156
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
157
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
158
+ q, conv_state_q = self.q_conv1d(x=self.q_proj(hidden_states),
159
+ mask=conv_mask,
160
+ cache=conv_state_q,
161
+ output_final_state=use_cache)
162
+ k, conv_state_k = self.k_conv1d(x=self.k_proj(hidden_states),
163
+ mask=conv_mask,
164
+ cache=conv_state_k,
165
+ output_final_state=use_cache)
166
+ v, conv_state_v = self.v_conv1d(x=self.v_proj(hidden_states),
167
+ mask=conv_mask,
168
+ cache=conv_state_v,
169
+ output_final_state=use_cache)
170
+ else:
171
+ q = self.q_proj(hidden_states)
172
+ k = self.k_proj(hidden_states)
173
+ v = self.v_proj(hidden_states)
174
+ f = self.f_proj(hidden_states)
175
+
176
+ q = rearrange(q, 'b t (h d) -> b t h d', h=self.num_heads)
177
+ k = rearrange(k, 'b t (h d) -> b t h d', h=self.num_kv_heads)
178
+ v = rearrange(v, 'b t (h d) -> b t h d', h=self.num_kv_heads)
179
+ f = rearrange(f, 'b t (h m) -> b t h m', h=self.num_kv_heads)
180
+
181
+ if self.feature_map is not None:
182
+ q, k = map(lambda x: self.feature_map(x), (q, k))
183
+ v = swish(v)
184
+
185
+ f = F.logsigmoid(f) / self.gate_logit_normalizer
186
+ s = (1 - f.exp()).to(f.dtype)
187
+ # dealing with left-padding
188
+ if attention_mask is not None:
189
+ s = s.mul_(attention_mask[:, -s.shape[1]:, None, None])
190
+ v = v.mul_(attention_mask[:, -v.shape[1]:, None, None])
191
+
192
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
193
+ if mode == 'fused_recurrent':
194
+ o, recurrent_state = fused_recurrent_gsa(
195
+ q=q,
196
+ k=k,
197
+ v=v,
198
+ s=s,
199
+ g=f,
200
+ initial_state=recurrent_state,
201
+ output_final_state=use_cache,
202
+ scale=self.scale,
203
+ head_first=False
204
+ )
205
+ elif mode == 'chunk':
206
+ o, recurrent_state = chunk_gsa(
207
+ q=q,
208
+ k=k,
209
+ v=v,
210
+ s=s,
211
+ g=f,
212
+ initial_state=recurrent_state,
213
+ output_final_state=use_cache,
214
+ scale=self.scale,
215
+ head_first=False
216
+ )
217
+ else:
218
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
219
+
220
+ if past_key_values is not None:
221
+ past_key_values.update(
222
+ recurrent_state=recurrent_state,
223
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
224
+ layer_idx=self.layer_idx,
225
+ offset=q.shape[2]
226
+ )
227
+
228
+ o = rearrange(o, 'b t h d -> b t (h d)')
229
+ o = rms_norm_linear(swish(o), self.g_norm.weight, self.g_norm.bias, self.o_proj.weight, self.o_proj.bias)
230
+ return o, None, past_key_values
231
+
232
+ def state_size(self, *args, **kwargs) -> int:
233
+ return 2 * self.num_slots * self.hidden_size
fla/layers/hgrn.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ # "Hierarchically Gated Recurrent Neural Network for Sequence Modeling" [https://arxiv.org/abs/2311.04823]
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from fla.modules import FusedRMSNormSwishGate, ShortConvolution
15
+ from fla.modules.activations import swiglu
16
+ from fla.ops.hgrn import chunk_hgrn, fused_recurrent_hgrn
17
+
18
+ if TYPE_CHECKING:
19
+ from fla.models.utils import Cache
20
+
21
+
22
+ class HGRNAttention(nn.Module):
23
+
24
+ def __init__(
25
+ self,
26
+ mode: str = 'chunk',
27
+ hidden_size: int = 1024,
28
+ expand_ratio: Optional[int] = 1,
29
+ use_short_conv: bool = False,
30
+ conv_size: int = 4,
31
+ conv_bias: bool = False,
32
+ elementwise_affine: Optional[bool] = True,
33
+ norm_eps: float = 1e-5,
34
+ layer_idx: int = None
35
+ ) -> HGRNAttention:
36
+ super().__init__()
37
+
38
+ self.mode = mode
39
+ self.hidden_size = hidden_size
40
+ self.expand_ratio = expand_ratio
41
+ self.input_dim = int(hidden_size * expand_ratio)
42
+
43
+ self.use_short_conv = use_short_conv
44
+ self.conv_size = conv_size
45
+ self.conv_bias = conv_bias
46
+
47
+ self.layer_idx = layer_idx
48
+
49
+ assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
50
+
51
+ self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
52
+ self.f_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
53
+ self.g_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
54
+
55
+ if use_short_conv:
56
+ self.conv_size = conv_size
57
+ self.q_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
58
+ self.f_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
59
+ self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
60
+
61
+ self.g_norm = FusedRMSNormSwishGate(self.input_dim, elementwise_affine, norm_eps)
62
+ self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False)
63
+
64
+ self.apply(self._initialize_weights)
65
+
66
+ def _initialize_weights(self, module: nn.Module):
67
+ if getattr(module, "_is_hf_initialized", False):
68
+ return
69
+ if isinstance(module, nn.Linear):
70
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
71
+ if module.bias is not None:
72
+ nn.init.zeros_(module.bias)
73
+ module._is_hf_initialized = True
74
+
75
+ def forward(
76
+ self,
77
+ hidden_states: torch.Tensor,
78
+ attention_mask: Optional[torch.Tensor] = None,
79
+ past_key_values: Optional[Cache] = None,
80
+ use_cache: Optional[bool] = False,
81
+ output_attentions: Optional[bool] = False,
82
+ lower_bound: Optional[torch.Tensor] = None,
83
+ **kwargs
84
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
85
+ if attention_mask is not None:
86
+ assert len(attention_mask.shape) == 2, (
87
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
88
+ "for padding purposes (0 indicating padding). "
89
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
90
+ )
91
+
92
+ # launching the triton kernel for just one token will actually be slower
93
+ mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode
94
+
95
+ last_state = None
96
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
97
+ last_state = past_key_values[self.layer_idx]
98
+
99
+ if self.use_short_conv:
100
+ conv_state_i, conv_state_f = None, None
101
+ if last_state is not None:
102
+ conv_state_i, conv_state_f = last_state['conv_state']
103
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
104
+ i, conv_state_i = self.i_conv1d(x=self.i_proj(hidden_states),
105
+ mask=conv_mask,
106
+ cache=conv_state_i,
107
+ output_final_state=use_cache)
108
+ f, conv_state_f = self.f_conv1d(x=self.f_proj(hidden_states),
109
+ mask=conv_mask,
110
+ cache=conv_state_f,
111
+ output_final_state=use_cache)
112
+ else:
113
+ i = self.i_proj(hidden_states)
114
+ f = self.f_proj(hidden_states)
115
+
116
+ # the lower bound for the first layer is zero
117
+ if lower_bound is None or self.layer_idx == 0:
118
+ i, f = swiglu(i, 1 - f.sigmoid()), F.logsigmoid(f)
119
+ else:
120
+ g = lower_bound + (1 - lower_bound) * f.sigmoid()
121
+ i, f = swiglu(i, 1 - g), g.log()
122
+
123
+ # dealing with left-padding
124
+ if attention_mask is not None:
125
+ i = i.mul_(attention_mask[:, -i.shape[-2]:, None])
126
+
127
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
128
+ if mode == 'chunk':
129
+ o, recurrent_state = chunk_hgrn(i, f, recurrent_state, use_cache)
130
+ elif mode == 'fused_recurrent':
131
+ o, recurrent_state = fused_recurrent_hgrn(i, f, recurrent_state, use_cache)
132
+ else:
133
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
134
+
135
+ if past_key_values is not None:
136
+ past_key_values.update(
137
+ recurrent_state=recurrent_state,
138
+ conv_state=(conv_state_i, conv_state_f) if self.use_short_conv else None,
139
+ layer_idx=self.layer_idx,
140
+ offset=i.shape[2]
141
+ )
142
+
143
+ o = self.g_norm(o, self.g_proj(hidden_states))
144
+ o = self.o_proj(o)
145
+
146
+ return o, None, past_key_values
147
+
148
+ def state_size(self, **kwargs) -> int:
149
+ state_size = self.hidden_size
150
+ for module in self.children():
151
+ if isinstance(module, ShortConvolution):
152
+ state_size += module.state_size
153
+ return state_size
fla/layers/hgrn2.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ # "HGRN2: Gated Linear RNNs with State Expansion"[https://arxiv.org/abs/2404.07904]
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from einops import rearrange
14
+
15
+ from fla.modules import RMSNorm, ShortConvolution
16
+ from fla.modules.activations import swish
17
+ from fla.modules.layernorm import rms_norm_linear
18
+ from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
19
+
20
+ if TYPE_CHECKING:
21
+ from fla.models.utils import Cache
22
+
23
+
24
+ class HGRN2Attention(nn.Module):
25
+
26
+ def __init__(
27
+ self,
28
+ mode: str = 'chunk',
29
+ hidden_size: int = 1024,
30
+ num_heads: Optional[int] = None,
31
+ expand_ratio: Optional[int] = 128,
32
+ use_short_conv: bool = False,
33
+ conv_size: int = 4,
34
+ conv_bias: bool = False,
35
+ elementwise_affine: Optional[bool] = True,
36
+ norm_eps: float = 1e-5,
37
+ layer_idx: int = None
38
+ ) -> HGRN2Attention:
39
+ super().__init__()
40
+
41
+ self.mode = mode
42
+ self.hidden_size = hidden_size
43
+
44
+ if expand_ratio is None and num_heads is not None:
45
+ expand_ratio = hidden_size // num_heads
46
+ elif expand_ratio is not None and num_heads is None:
47
+ num_heads = hidden_size // expand_ratio
48
+ elif expand_ratio is None and num_heads is None:
49
+ raise RuntimeError("One of `expand_ratio` or `num_heads` should be provided.")
50
+ self.num_heads = num_heads
51
+ self.expand_ratio = expand_ratio
52
+
53
+ self.use_short_conv = use_short_conv
54
+ self.conv_size = conv_size
55
+ self.conv_bias = conv_bias
56
+
57
+ self.forget_dim = int(self.num_heads * self.expand_ratio)
58
+ self.input_dim = hidden_size
59
+ self.layer_idx = layer_idx
60
+
61
+ assert mode in ['chunk', 'fused_recurrent', 'fused_chunk'], f"Not suppoerted mode `{mode}`."
62
+ assert self.forget_dim % num_heads == 0, f"forget dim must be divisible by num_heads of {num_heads}"
63
+ assert self.input_dim % num_heads == 0, f"input dim must be divisible by num_heads of {num_heads}"
64
+
65
+ self.head_f_dim = self.expand_ratio
66
+ self.head_i_dim = self.hidden_size // num_heads
67
+
68
+ self.q_proj = nn.Linear(hidden_size, self.forget_dim, bias=False)
69
+ self.f_proj = nn.Linear(hidden_size, self.forget_dim, bias=False)
70
+ self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
71
+
72
+ if use_short_conv:
73
+ self.conv_size = conv_size
74
+ self.q_conv1d = ShortConvolution(self.forget_dim, conv_size, activation=None)
75
+ self.f_conv1d = ShortConvolution(self.forget_dim, conv_size, activation=None)
76
+ self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
77
+
78
+ self.g_norm = RMSNorm(hidden_size=self.hidden_size, elementwise_affine=elementwise_affine, eps=norm_eps)
79
+ self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False)
80
+
81
+ self.apply(self._initialize_weights)
82
+
83
+ def _initialize_weights(self, module: nn.Module):
84
+ if getattr(module, "_is_hf_initialized", False):
85
+ return
86
+ if isinstance(module, nn.Linear):
87
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
88
+ if module.bias is not None:
89
+ nn.init.zeros_(module.bias)
90
+ module._is_hf_initialized = True
91
+
92
+ def forward(
93
+ self,
94
+ hidden_states: torch.Tensor,
95
+ attention_mask: Optional[torch.Tensor] = None,
96
+ past_key_values: Optional[Cache] = None,
97
+ use_cache: Optional[bool] = False,
98
+ output_attentions: Optional[bool] = False,
99
+ lower_bound: Optional[torch.Tensor] = None,
100
+ **kwargs
101
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
102
+ if attention_mask is not None:
103
+ assert len(attention_mask.shape) == 2, (
104
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
105
+ "for padding purposes (0 indicating padding). "
106
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
107
+ )
108
+
109
+ # launching the triton kernel for just one token will actually be slower
110
+ mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode
111
+
112
+ last_state = None
113
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
114
+ last_state = past_key_values[self.layer_idx]
115
+
116
+ if self.use_short_conv:
117
+ conv_state_q, conv_state_f, conv_state_i = None, None, None
118
+ if last_state is not None:
119
+ conv_state_q, conv_state_f, conv_state_i = last_state['conv_state']
120
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
121
+ q, conv_state_q = self.q_conv1d(x=self.q_proj(hidden_states),
122
+ mask=conv_mask,
123
+ cache=conv_state_q,
124
+ output_final_state=use_cache)
125
+ f, conv_state_f = self.f_conv1d(x=self.f_proj(hidden_states),
126
+ mask=conv_mask,
127
+ cache=conv_state_f,
128
+ output_final_state=use_cache)
129
+ i, conv_state_i = self.i_conv1d(x=self.i_proj(hidden_states),
130
+ mask=conv_mask,
131
+ cache=conv_state_i,
132
+ output_final_state=use_cache)
133
+ else:
134
+ q = self.q_proj(hidden_states)
135
+ f = self.f_proj(hidden_states)
136
+ i = self.i_proj(hidden_states)
137
+
138
+ # dealing with left-padding
139
+ if attention_mask is not None:
140
+ i = i.mul_(attention_mask[:, -i.shape[-2]:, None])
141
+
142
+ q = swish(q)
143
+
144
+ # improve precision
145
+ f = f.float()
146
+
147
+ # the lower bound for the first layer is zero
148
+ if lower_bound is None or self.layer_idx == 0:
149
+ k, g = 1 - f.sigmoid(), F.logsigmoid(f)
150
+ else:
151
+ g = lower_bound + (1 - lower_bound) * f.sigmoid()
152
+ k, g = 1 - g, g.log()
153
+
154
+ q, k, i, g = map(lambda x: rearrange(x, '... (h d) -> ... h d', h=self.num_heads), (q, k.to(i), i, g))
155
+
156
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
157
+ if mode == 'fused_recurrent':
158
+ o, recurrent_state = fused_recurrent_gla(
159
+ q=q,
160
+ k=k,
161
+ v=i,
162
+ gk=g,
163
+ initial_state=recurrent_state,
164
+ output_final_state=use_cache,
165
+ head_first=False
166
+ )
167
+ elif mode == 'fused_chunk':
168
+ o, recurrent_state = fused_chunk_gla(
169
+ q=q,
170
+ k=k,
171
+ v=i,
172
+ g=g,
173
+ initial_state=recurrent_state,
174
+ output_final_state=use_cache,
175
+ head_first=False
176
+ )
177
+ elif mode == 'chunk':
178
+ o, recurrent_state = chunk_gla(
179
+ q=q,
180
+ k=k,
181
+ v=i,
182
+ g=g,
183
+ initial_state=recurrent_state,
184
+ output_final_state=use_cache,
185
+ head_first=False
186
+ )
187
+ else:
188
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
189
+
190
+ if past_key_values is not None:
191
+ past_key_values.update(
192
+ recurrent_state=recurrent_state,
193
+ conv_state=(conv_state_q, conv_state_f, conv_state_i) if self.use_short_conv else None,
194
+ layer_idx=self.layer_idx,
195
+ offset=q.shape[2]
196
+ )
197
+
198
+ o = rearrange(o, '... h d -> ... (h d)')
199
+ o = rms_norm_linear(o, self.g_norm.weight, self.g_norm.bias, self.o_proj.weight, self.o_proj.bias)
200
+ return o, None, past_key_values
201
+
202
+ def state_size(self, **kwargs) -> int:
203
+ state_size = self.forget_dim * self.head_i_dim
204
+ for module in self.children():
205
+ if isinstance(module, ShortConvolution):
206
+ state_size += module.state_size
207
+ return state_size
fla/layers/linear_attn.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange, repeat
9
+
10
+ from fla.modules import RMSNorm
11
+ from fla.modules.feature_map import (DPFPFeatureMap, HadamardFeatureMap,
12
+ HedgehogFeatureMap, T2RFeatureMap)
13
+ from fla.ops.linear_attn import (chunk_linear_attn, fused_chunk_linear_attn,
14
+ fused_recurrent_linear_attn)
15
+
16
+
17
+ class LinearAttention(nn.Module):
18
+ def __init__(
19
+ self,
20
+ mode: str = 'chunk',
21
+ hidden_size: str = 1024,
22
+ expand_k: int = 1.0,
23
+ expand_v: int = 1.0,
24
+ num_heads: int = 8,
25
+ num_kv_heads: Optional[int] = None,
26
+ feature_map: str = 'elementwise_product',
27
+ tie_feature_map_qk: bool = False,
28
+ output_norm: str = 'rmsnorm',
29
+ norm_q: bool = False,
30
+ norm_k: bool = False,
31
+ # standard linear attention normalization
32
+ do_feature_map_norm: bool = False,
33
+ elementwise_affine: bool = True,
34
+ norm_eps: float = 1e-5,
35
+ **kwargs
36
+ ):
37
+ super().__init__()
38
+
39
+ self.hidden_size = hidden_size
40
+ self.mode = mode
41
+ self.num_heads = num_heads
42
+ self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
43
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
44
+ self.key_dim = int(hidden_size * expand_k)
45
+ self.value_dim = int(hidden_size * expand_v)
46
+ self.key_dim_per_group = self.key_dim // self.num_kv_groups
47
+ self.value_dim_per_group = self.value_dim // self.num_kv_groups
48
+
49
+ assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
50
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
51
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
52
+
53
+ self.head_qk_dim = self.key_dim // num_heads
54
+ self.head_v_dim = self.value_dim // num_heads
55
+ self.do_feature_map_norm = do_feature_map_norm
56
+
57
+ if feature_map == 'hedgehog':
58
+ if tie_feature_map_qk:
59
+ self.feature_map_q = self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_qk_dim)
60
+ else:
61
+ self.feature_map_q = HedgehogFeatureMap(head_dim=self.head_qk_dim)
62
+ self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_qk_dim)
63
+
64
+ elif feature_map == 't2r':
65
+ if tie_feature_map_qk:
66
+ self.feature_map_q = self.feature_map_k = T2RFeatureMap(head_dim=self.head_qk_dim)
67
+ else:
68
+ self.feature_map_q = T2RFeatureMap(head_dim=self.head_qk_dim)
69
+ self.feature_map_k = T2RFeatureMap(head_dim=self.head_qk_dim)
70
+
71
+ elif feature_map == 'elementwise_product':
72
+ if tie_feature_map_qk:
73
+ self.feature_map_q = self.feature_map_k = HadamardFeatureMap(head_dim=self.head_qk_dim)
74
+ else:
75
+ self.feature_map_q = HadamardFeatureMap(head_dim=self.head_qk_dim)
76
+ self.feature_map_k = HadamardFeatureMap(head_dim=self.head_qk_dim)
77
+
78
+ elif feature_map == 'dpfp':
79
+ self.feature_map_q = DPFPFeatureMap(head_dim=self.head_qk_dim)
80
+ self.feature_map_k = DPFPFeatureMap(head_dim=self.head_qk_dim)
81
+
82
+ elif feature_map == 'elu':
83
+ def elu(x):
84
+ return F.elu(x) + 1
85
+ self.feature_map_q = elu
86
+ self.feature_map_k = elu
87
+
88
+ elif feature_map == 'relu':
89
+ self.feature_map_q = nn.ReLU()
90
+ self.feature_map_k = nn.ReLU()
91
+
92
+ elif feature_map == 'identity':
93
+ self.feature_map_q = nn.Identity()
94
+ self.feature_map_k = nn.Identity()
95
+ else:
96
+ raise NotImplementedError(f"Not supported feature map `{feature_map}`.")
97
+
98
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
99
+ self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
100
+ self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
101
+
102
+ if output_norm == 'rmsnorm':
103
+ self.norm = RMSNorm(hidden_size=self.head_v_dim, elementwise_affine=elementwise_affine, eps=norm_eps)
104
+ elif output_norm == 'identity':
105
+ self.norm = nn.Identity()
106
+ else:
107
+ raise NotImplementedError(f"Not supported output norm `{output_norm}`.")
108
+
109
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
110
+
111
+ self.norm_q = norm_q
112
+ self.norm_k = norm_k
113
+
114
+ self.apply(self._initialize_weights)
115
+
116
+ def _initialize_weights(self, module: nn.Module):
117
+ if getattr(module, "_is_hf_initialized", False):
118
+ return
119
+ if isinstance(module, nn.Linear):
120
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
121
+ if module.bias is not None:
122
+ nn.init.zeros_(module.bias)
123
+ module._is_hf_initialized = True
124
+
125
+ def forward(self, x):
126
+ mode = self.mode
127
+ q = self.q_proj(x)
128
+ k = self.k_proj(x)
129
+ v = self.v_proj(x)
130
+
131
+ q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads)
132
+ if self.num_kv_groups > 1:
133
+ k, v = (repeat(x, '... (h d) -> ... (h g) d', h=self.num_kv_heads, g=self.num_kv_groups) for x in (k, v))
134
+ else:
135
+ k, v = (rearrange(x, '... (h d) -> ... h d', h=self.num_kv_heads) for x in (k, v))
136
+
137
+ q = self.feature_map_q(q)
138
+ k = self.feature_map_k(k)
139
+
140
+ if self.norm_q:
141
+ q = q / (q.sum(-1, True) + 1e-4)
142
+ if self.norm_k:
143
+ k = k / (k.sum(-1, True) + 1e-4)
144
+
145
+ if mode == 'chunk':
146
+ o, final_state = chunk_linear_attn(
147
+ q=q,
148
+ k=k,
149
+ v=v,
150
+ normalize=self.do_feature_map_norm,
151
+ head_first=False
152
+ )
153
+ elif mode == 'fused_chunk':
154
+ o, final_state = fused_chunk_linear_attn(
155
+ q=q,
156
+ k=k,
157
+ v=v,
158
+ normalize=self.do_feature_map_norm,
159
+ )
160
+ elif mode == 'fused_recurrent':
161
+ o, final_state = fused_recurrent_linear_attn(
162
+ q=q,
163
+ k=k,
164
+ v=v,
165
+ normalize=self.do_feature_map_norm,
166
+ )
167
+ else:
168
+ raise NotImplementedError
169
+ o = self.norm(o)
170
+ o = self.o_proj(o)
171
+ return o
fla/layers/multiscale_retention.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops import rearrange, repeat
11
+ from transformers.activations import ACT2FN
12
+
13
+ from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
14
+ from fla.modules.rotary import RotaryEmbedding
15
+ from fla.ops.retention import (chunk_retention, fused_chunk_retention,
16
+ fused_recurrent_retention, parallel_retention)
17
+
18
+ if TYPE_CHECKING:
19
+ from fla.models.utils import Cache
20
+
21
+
22
+ class MultiScaleRetention(nn.Module):
23
+ r"""
24
+ The layer implementaion for [Retentive Network: A Successor to Transformer for Large Language Models](https://arxiv.org/pdf/2307.08621.pdf). # noqa
25
+
26
+ Args:
27
+ mode (str, Optional):
28
+ Which Retention kernel to use.
29
+ Currently available: `chunk`, `fused_recurrent`, `parallel`, and `fused_chunk`.
30
+ Default: `fused_chunk`.
31
+ hidden_size (int, Optional):
32
+ The hidden size of the input. Default: 1024.
33
+ expand_k (float, Optional):
34
+ The expansion ratio for the key dim. Default: 1.0.
35
+ expand_v (float, Optional):
36
+ The expansion ratio for the value dim. Default: 2.0.
37
+ num_heads (int, Optional):
38
+ The number of heads. Default: 8.
39
+ num_kv_heads (int, Optional):
40
+ The number of key/value heads, used for MQA. Default: None.
41
+ feature_map (str, Optional):
42
+ Feature map function applied to queries/keys. Default: None.
43
+ use_short_conv (bool, Optional):
44
+ Whether to use short convolutions. Default: `False`.
45
+ conv_size (int, Optional):
46
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
47
+ conv_bias (bool, Optional):
48
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
49
+ use_output_gate (bool, Optional):
50
+ Whether to use output gate. Default: `True`.
51
+ gate_fn (str, Optional):
52
+ The activation function for the output gate. Default: `swish`.
53
+ elementwise_affine (bool, Optional):
54
+ If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
55
+ norm_eps (float, Optional):
56
+ The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
57
+ fuse_norm (bool, Optional):
58
+ Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
59
+ layer_idx (int, Optional):
60
+ The index of the layer. Default: None.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ mode: str = 'chunk',
66
+ hidden_size: int = 1024,
67
+ expand_k: float = 1.0,
68
+ expand_v: float = 2.0,
69
+ num_heads: int = 8,
70
+ num_kv_heads: Optional[int] = None,
71
+ feature_map: Optional[str] = None,
72
+ use_short_conv: bool = False,
73
+ conv_size: int = 4,
74
+ conv_bias: bool = False,
75
+ use_output_gate: bool = True,
76
+ gate_fn: str = 'swish',
77
+ elementwise_affine: Optional[bool] = True,
78
+ norm_eps: float = 1e-5,
79
+ fuse_norm: bool = True,
80
+ layer_idx: int = None,
81
+ **kwargs
82
+ ) -> MultiScaleRetention:
83
+ super().__init__()
84
+
85
+ self.mode = mode
86
+ self.hidden_size = hidden_size
87
+ self.expand_k = expand_k
88
+ self.expand_v = expand_v
89
+ self.num_heads = num_heads
90
+ self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
91
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
92
+ self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None
93
+
94
+ self.use_short_conv = use_short_conv
95
+ self.conv_size = conv_size
96
+ self.conv_bias = conv_bias
97
+ self.use_output_gate = use_output_gate
98
+
99
+ self.key_dim = int(hidden_size * expand_k)
100
+ self.value_dim = int(hidden_size * expand_v)
101
+ self.key_dim_per_group = self.key_dim // self.num_kv_groups
102
+ self.value_dim_per_group = self.value_dim // self.num_kv_groups
103
+ self.layer_idx = layer_idx
104
+
105
+ assert mode in ['chunk', 'fused_chunk', 'parallel', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
106
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
107
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
108
+
109
+ self.head_qk_dim = self.key_dim // num_heads
110
+ self.head_v_dim = self.value_dim // num_heads
111
+
112
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
113
+ self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
114
+ self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
115
+ if self.use_output_gate:
116
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
117
+
118
+ if use_short_conv:
119
+ self.conv_size = conv_size
120
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
121
+ self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
122
+ self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
123
+
124
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
125
+
126
+ if gate_fn == 'swish' and fuse_norm and use_output_gate:
127
+ self.g_norm_swish_gate = FusedRMSNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps)
128
+ self.fuse_norm_and_gate = True
129
+ else:
130
+ self.fuse_norm_and_gate = False
131
+ self.g_norm = RMSNorm(hidden_size=self.head_v_dim, elementwise_affine=elementwise_affine, eps=norm_eps)
132
+ self.gate_fn = ACT2FN[gate_fn]
133
+
134
+ # TODO: fix this issue
135
+ # https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/rotary.py#L180
136
+ # Ideally, we would want to support arbitrary d_head_qk
137
+ assert self.head_qk_dim <= 256, "head_qk_dim must be less than or equal to 256"
138
+ self.rotary = RotaryEmbedding(dim=self.head_qk_dim)
139
+
140
+ self.apply(self._initialize_weights)
141
+
142
+ def _initialize_weights(self, module: nn.Module):
143
+ if getattr(module, "_is_hf_initialized", False):
144
+ return
145
+ if isinstance(module, nn.Linear):
146
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
147
+ if module.bias is not None:
148
+ nn.init.zeros_(module.bias)
149
+ module._is_hf_initialized = True
150
+
151
+ def forward(
152
+ self,
153
+ hidden_states: torch.Tensor,
154
+ attention_mask: Optional[torch.Tensor] = None,
155
+ past_key_values: Optional[Cache] = None,
156
+ use_cache: Optional[bool] = False,
157
+ output_attentions: Optional[bool] = False,
158
+ **kwargs
159
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
160
+ if attention_mask is not None:
161
+ assert len(attention_mask.shape) == 2, (
162
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
163
+ "for padding purposes (0 indicating padding). "
164
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
165
+ )
166
+
167
+ # launching the triton kernel for just one token will actually be slower
168
+ mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode
169
+
170
+ last_state = None
171
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
172
+ last_state = past_key_values[self.layer_idx]
173
+
174
+ if self.use_short_conv:
175
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
176
+ if last_state is not None:
177
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
178
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
179
+ q, conv_state_q = self.q_conv1d(x=self.q_proj(hidden_states),
180
+ mask=conv_mask,
181
+ cache=conv_state_q,
182
+ output_final_state=use_cache)
183
+ k, conv_state_k = self.k_conv1d(x=self.k_proj(hidden_states),
184
+ mask=conv_mask,
185
+ cache=conv_state_k,
186
+ output_final_state=use_cache)
187
+ v, conv_state_v = self.v_conv1d(x=self.v_proj(hidden_states),
188
+ mask=conv_mask,
189
+ cache=conv_state_v,
190
+ output_final_state=use_cache)
191
+ else:
192
+ q = self.q_proj(hidden_states)
193
+ k = self.k_proj(hidden_states)
194
+ v = self.v_proj(hidden_states)
195
+
196
+ # dealing with left-padding
197
+ if attention_mask is not None:
198
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
199
+ q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads)
200
+ k = rearrange(k, '... (h d) -> ... h d', h=self.num_kv_heads)
201
+ if self.feature_map_fn is not None:
202
+ q, k = map(self.feature_map_fn, (q, k))
203
+
204
+ seqlen_offset, max_seqlen = 0, q.shape[1]
205
+ if past_key_values is not None:
206
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
207
+ max_seqlen = q.shape[1] + seqlen_offset
208
+
209
+ if attention_mask is not None:
210
+ # to deliminate the offsets of padding tokens
211
+ seqlen_offset = (seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]).clamp(min=0)
212
+ max_seqlen = q.shape[1] + max(seqlen_offset)
213
+
214
+ q, k = self.rotary(q, k, seqlen_offset, max_seqlen)
215
+ if self.num_kv_groups > 1:
216
+ k = repeat(k, 'b t h d -> b t (h g) d', h=self.num_kv_heads, g=self.num_kv_groups)
217
+ v = repeat(v, 'b t (h d) -> b t (h g) d', h=self.num_kv_heads, g=self.num_kv_groups)
218
+ else:
219
+ k, v = rearrange(k, 'b t h d -> b t h d'), rearrange(v, 'b t (h d) -> b t h d', h=self.num_kv_heads)
220
+
221
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
222
+ if mode == 'chunk':
223
+ o, recurrent_state = chunk_retention(
224
+ q=q,
225
+ k=k,
226
+ v=v,
227
+ initial_state=recurrent_state,
228
+ output_final_state=use_cache,
229
+ head_first=False
230
+ )
231
+ elif mode == 'fused_chunk':
232
+ o, recurrent_state = fused_chunk_retention(
233
+ q=q,
234
+ k=k,
235
+ v=v,
236
+ initial_state=recurrent_state,
237
+ output_final_state=use_cache,
238
+ head_first=False
239
+ )
240
+ elif mode == 'parallel':
241
+ o, recurrent_state = parallel_retention(q, k, v, head_first=False)
242
+ elif mode == 'fused_recurrent':
243
+ o, recurrent_state = fused_recurrent_retention(
244
+ q=q,
245
+ k=k,
246
+ v=v,
247
+ initial_state=recurrent_state,
248
+ output_final_state=use_cache,
249
+ head_first=False
250
+ )
251
+ else:
252
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
253
+
254
+ if past_key_values is not None:
255
+ past_key_values.update(
256
+ recurrent_state=recurrent_state,
257
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
258
+ layer_idx=self.layer_idx,
259
+ offset=q.shape[2]
260
+ )
261
+
262
+ if self.use_output_gate:
263
+ g = self.g_proj(hidden_states)
264
+ if self.fuse_norm_and_gate:
265
+ g = rearrange(g, 'b t (h d) -> b t h d', h=self.num_heads)
266
+ o = self.g_norm_swish_gate(o, g)
267
+ o = rearrange(o, 'b t h d -> b t (h d)')
268
+ else:
269
+ o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
270
+ o = o * self.gate_fn(g)
271
+ else:
272
+ o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
273
+ o = self.o_proj(o)
274
+
275
+ return o, None, past_key_values
276
+
277
+ def state_size(self, **kwargs) -> int:
278
+ state_size = self.key_dim * self.head_v_dim
279
+ for module in self.children():
280
+ if isinstance(module, ShortConvolution):
281
+ state_size += module.state_size
282
+ return state_size
fla/layers/rebased.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ """
5
+ https://github.com/corl-team/rebased/blob/main/flash_linear_attention/fla/layers/rebased_fast.py
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Optional
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from einops import rearrange
15
+
16
+ from fla.modules.feature_map import RebasedFeatureMap
17
+ from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn
18
+ from fla.ops.rebased import parallel_rebased
19
+
20
+
21
+ class ReBasedLinearAttention(nn.Module):
22
+ def __init__(
23
+ self,
24
+ hidden_size: int,
25
+ l_max: int = 2048,
26
+ feature_dim: int = 16,
27
+ num_key_value_heads: int = 16,
28
+ num_heads: int = 16,
29
+ use_gamma: Optional[bool] = True,
30
+ use_beta: Optional[bool] = True,
31
+ normalize: Optional[bool] = True,
32
+ causal: bool = True,
33
+ eps: float = 1e-5,
34
+ mode: str = "parallel",
35
+ layer_idx: Optional[int] = None,
36
+ **kwargs
37
+ ) -> ReBasedLinearAttention:
38
+ super().__init__()
39
+ self.hidden_size = hidden_size
40
+ self.l_max = l_max
41
+ self.mode = mode
42
+ assert self.mode in ["fused_chunk", "parallel", 'chunk']
43
+
44
+ # linear attention
45
+ self.feature_dim = feature_dim
46
+ self.num_key_value_heads = num_key_value_heads
47
+ self.num_heads = num_heads
48
+ self.head_dim = self.hidden_size // self.num_key_value_heads
49
+ self.use_gamma = use_gamma
50
+ self.use_beta = use_beta
51
+ self.normalize = normalize
52
+ self.causal = causal
53
+
54
+ self.feature_map = RebasedFeatureMap(self.feature_dim, use_gamma, use_beta, normalize)
55
+ self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
56
+ self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
57
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
58
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
59
+ self.dropout = nn.Identity()
60
+ self.eps = eps
61
+
62
+ self.apply(self._initialize_weights)
63
+
64
+ def _initialize_weights(self, module: nn.Module):
65
+ if getattr(module, "_is_hf_initialized", False):
66
+ return
67
+ if isinstance(module, nn.Linear):
68
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
69
+ if module.bias is not None:
70
+ nn.init.zeros_(module.bias)
71
+ module._is_hf_initialized = True
72
+
73
+ def forward(self, hidden_states: torch.Tensor, **kwargs):
74
+ mode = self.mode
75
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
76
+ q, k, v = map(lambda x: rearrange(x, "... (h d) -> ... h d", h=self.num_heads), [q, k, v])
77
+ q, k = self.feature_map(q, flatten=(mode != 'parallel')), self.feature_map(k, flatten=(mode != 'parallel'))
78
+ if mode == "fused_chunk":
79
+ o = fused_chunk_linear_attn(
80
+ q=q,
81
+ k=k,
82
+ v=v,
83
+ normalize=True,
84
+ scale=1,
85
+ head_first=False
86
+ )
87
+ elif mode == 'chunk':
88
+ o = chunk_linear_attn(
89
+ q=q,
90
+ k=k,
91
+ v=v,
92
+ normalize=True,
93
+ scale=1,
94
+ head_first=False
95
+ )
96
+ elif mode == 'parallel':
97
+ assert q.shape[-1] <= 128
98
+ o = parallel_rebased(
99
+ q=q,
100
+ k=k,
101
+ v=v,
102
+ eps=self.eps,
103
+ use_scale=True,
104
+ use_normalize=True,
105
+ head_first=False
106
+ )
107
+ o = self.o_proj(o)
108
+ o = self.dropout(o)
109
+ return o
110
+
111
+ # https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119
112
+ def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor = None, *args, **kwargs):
113
+ """
114
+ x (torch.Tensor): tensor of shape (b, d, t)
115
+ y (torch.Tensor): tensor of shape (b, d, t)
116
+ """
117
+ b, t, _ = hidden_states.size()
118
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
119
+
120
+ q = q.view(b, t, self.num_heads, self.feature_dim).transpose(1, 2)
121
+ k = k.view(b, t, self.num_key_value_heads, self.feature_dim).transpose(1, 2)
122
+ v = v.view(b, t, self.num_key_value_heads, self.head_dim).transpose(1, 2)
123
+
124
+ # Linear attention
125
+ q, k = self.feature_map(q), self.feature_map(k)
126
+ q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1)
127
+
128
+ # Compute attention
129
+ if self.causal:
130
+ y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps))
131
+ else:
132
+ y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps))
133
+ y = rearrange(y, 'b h t d -> b t (h d)')
134
+ y = self.o_proj(y.to(hidden_states.dtype))
135
+ y = self.dropout(y)
136
+ return y.to(hidden_states.dtype)
fla/layers/rwkv6.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ # "Eagle and Finch: RWKV with Matrix-Valued States and Dynamic Recurrence"[https://arxiv.org/abs/2404.05892]
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from einops import rearrange
13
+
14
+ from fla.modules import GroupNorm
15
+ from fla.modules.activations import ACT2FN
16
+ from fla.ops.rwkv6 import chunk_rwkv6, fused_recurrent_rwkv6
17
+
18
+ if TYPE_CHECKING:
19
+ from fla.models.utils import Cache
20
+
21
+
22
+ class RWKV6Attention(nn.Module):
23
+
24
+ def __init__(
25
+ self,
26
+ mode: str = 'chunk',
27
+ hidden_size: int = 1024,
28
+ expand_k: float = 0.5,
29
+ expand_v: float = 1.0,
30
+ num_heads: int = 4,
31
+ gate_fn: str = 'swish',
32
+ proj_low_rank_dim: int = 32,
33
+ gate_low_rank_dim: int = 64,
34
+ fuse_norm: bool = True,
35
+ elementwise_affine: Optional[bool] = True,
36
+ norm_eps: float = 1e-5,
37
+ layer_idx: int = None,
38
+ **kwargs
39
+ ) -> RWKV6Attention:
40
+ super().__init__()
41
+
42
+ self.mode = mode
43
+ self.hidden_size = hidden_size
44
+ self.expand_k = expand_k
45
+ self.expand_v = expand_v
46
+ self.num_heads = num_heads
47
+ self.proj_low_rank_dim = proj_low_rank_dim
48
+ self.gate_low_rank_dim = gate_low_rank_dim
49
+
50
+ self.key_dim = int(hidden_size * expand_k)
51
+ self.value_dim = int(hidden_size * expand_v)
52
+ self.layer_idx = layer_idx
53
+
54
+ assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
55
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
56
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
57
+
58
+ self.head_qk_dim = self.key_dim // num_heads
59
+ self.head_v_dim = self.value_dim // num_heads
60
+
61
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
62
+ self.x_proj = nn.Sequential(
63
+ LerpLinear(hidden_size, proj_low_rank_dim * 5),
64
+ nn.Tanh(),
65
+ nn.Linear(proj_low_rank_dim * 5, hidden_size, bias=False)
66
+ )
67
+ self.x_bias = nn.Parameter(torch.zeros(5, hidden_size))
68
+
69
+ self.r_proj = DDLerpLinear(hidden_size, self.key_dim)
70
+ self.w_proj = DDLerpLinear(hidden_size, self.key_dim, low_rank_dim=gate_low_rank_dim)
71
+ self.k_proj = DDLerpLinear(hidden_size, self.key_dim)
72
+ self.v_proj = DDLerpLinear(hidden_size, self.value_dim)
73
+ self.g_proj = DDLerpLinear(hidden_size, self.value_dim)
74
+ self.bonus = nn.Parameter(torch.zeros(num_heads, self.head_qk_dim))
75
+
76
+ # TODO: fuse GroupNorm and output gate
77
+ self.g_norm = GroupNorm(self.num_heads, self.value_dim, elementwise_affine=elementwise_affine, bias=True, eps=norm_eps)
78
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
79
+ self.gate_fn = ACT2FN[gate_fn]
80
+
81
+ self.apply(self._initialize_weights)
82
+
83
+ def _initialize_weights(self, module: nn.Module):
84
+ if getattr(module, "_is_hf_initialized", False):
85
+ return
86
+ if isinstance(module, nn.Linear):
87
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
88
+ if module.bias is not None:
89
+ nn.init.zeros_(module.bias)
90
+ if isinstance(module, nn.Parameter):
91
+ nn.init.xavier_uniform_(module, gain=2 ** -2.5)
92
+ module._is_hf_initialized = True
93
+
94
+ def forward(
95
+ self,
96
+ hidden_states: torch.Tensor,
97
+ attention_mask: Optional[torch.Tensor] = None,
98
+ past_key_values: Optional[Cache] = None,
99
+ use_cache: Optional[bool] = False,
100
+ output_attentions: Optional[bool] = False,
101
+ **kwargs
102
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
103
+ if attention_mask is not None:
104
+ assert len(attention_mask.shape) == 2, (
105
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
106
+ "for padding purposes (0 indicating padding). "
107
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
108
+ )
109
+
110
+ batch_size, seq_len, hidden_size = hidden_states.shape
111
+ # launching the triton kernel for just one token will actually be slower
112
+ mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode
113
+
114
+ last_state = None
115
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
116
+ last_state = past_key_values[self.layer_idx]
117
+
118
+ if attention_mask is not None:
119
+ hidden_states = hidden_states.mul_(attention_mask[:, -hidden_states.shape[-2]:, None])
120
+ if hidden_states.shape[1] == 1 and last_state is not None:
121
+ shifted = last_state['conv_state'].unsqueeze(1)
122
+ else:
123
+ shifted = self.time_shift(hidden_states)
124
+ if last_state is not None:
125
+ shifted[:, 0] = last_state['conv_state'][0]
126
+
127
+ delta = shifted - hidden_states
128
+ x = self.x_proj[0](hidden_states, delta).view(batch_size, seq_len, -1, self.proj_low_rank_dim)
129
+ x = torch.einsum('b t n r, h n r-> b t n h', self.x_proj[1](x), self.x_proj[2].weight.view(hidden_size, 5, -1))
130
+
131
+ r, w, k, v, g = x.add_(self.x_bias).unbind(-2)
132
+ r = self.r_proj(hidden_states, r, delta)
133
+ w = self.w_proj(hidden_states, w, delta)
134
+ k = self.k_proj(hidden_states, k, delta)
135
+ v = self.v_proj(hidden_states, v, delta)
136
+ g = self.g_proj(hidden_states, g, delta)
137
+
138
+ # dealing with left-padding
139
+ if attention_mask is not None:
140
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
141
+ r, w, k, v = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', h=self.num_heads), (r, w, k, v))
142
+ w = -torch.exp(w)
143
+ u = self.bonus
144
+
145
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
146
+ if mode == 'fused_recurrent':
147
+ o, recurrent_state = fused_recurrent_rwkv6(
148
+ r=r,
149
+ k=k,
150
+ v=v,
151
+ w=w,
152
+ u=u,
153
+ scale=1.,
154
+ initial_state=recurrent_state,
155
+ output_final_state=use_cache,
156
+ head_first=False
157
+ )
158
+ elif mode == 'chunk':
159
+ o, recurrent_state = chunk_rwkv6(
160
+ q=r,
161
+ k=k,
162
+ v=v,
163
+ g=w,
164
+ u=u,
165
+ scale=1.,
166
+ initial_state=recurrent_state,
167
+ output_final_state=use_cache,
168
+ head_first=False
169
+ )
170
+ else:
171
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
172
+
173
+ if past_key_values is not None:
174
+ past_key_values.update(
175
+ recurrent_state=recurrent_state,
176
+ conv_state=hidden_states[:, -1],
177
+ layer_idx=self.layer_idx,
178
+ offset=r.shape[2]
179
+ )
180
+
181
+ o = self.g_norm(rearrange(o, '... h d -> ... (h d)')) * self.gate_fn(g)
182
+ o = self.o_proj(o)
183
+
184
+ return o, None, past_key_values
185
+
186
+
187
+ class LoRA(nn.Module):
188
+
189
+ def __init__(
190
+ self,
191
+ input_dim: int,
192
+ output_dim: int,
193
+ low_rank_dim: int,
194
+ bias: Optional[bool] = True
195
+ ):
196
+ super().__init__()
197
+
198
+ self.input_dim = input_dim
199
+ self.output_dim = output_dim
200
+ self.low_rank_dim = low_rank_dim
201
+ self.bias = bias
202
+
203
+ self.lora = nn.Sequential(
204
+ nn.Linear(input_dim, low_rank_dim, bias=False),
205
+ nn.Tanh(),
206
+ nn.Linear(low_rank_dim, output_dim, bias=bias)
207
+ )
208
+
209
+ def __repr__(self) -> str:
210
+ s = f"{self.__class__.__name__}("
211
+ s += f"input_dim={self.input_dim}, low_rank_dim={self.low_rank_dim}, output_dim={self.output_dim}"
212
+ if not self.bias:
213
+ s += f", bias={self.bias}"
214
+ s += ")"
215
+ return s
216
+
217
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
218
+ return self.lora(x)
219
+
220
+
221
+ class LerpLinear(nn.Module):
222
+
223
+ def __init__(
224
+ self,
225
+ input_dim: int,
226
+ output_dim: int,
227
+ low_rank_dim: Optional[int] = None
228
+ ):
229
+ super().__init__()
230
+
231
+ self.input_dim = input_dim
232
+ self.output_dim = output_dim
233
+ self.low_rank_dim = low_rank_dim
234
+
235
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
236
+ if low_rank_dim is None:
237
+ self.linear = nn.Linear(input_dim, output_dim, bias=False)
238
+ else:
239
+ self.linear = LoRA(input_dim, output_dim, low_rank_dim)
240
+ self.mu = nn.Parameter(torch.zeros(input_dim))
241
+
242
+ def __repr__(self) -> str:
243
+ s = f"{self.__class__.__name__}({self.input_dim}, {self.output_dim}"
244
+ if self.low_rank_dim is not None:
245
+ s += f", low_rank_dim={self.low_rank_dim}"
246
+ s += ")"
247
+ return s
248
+
249
+ def forward(self, x: torch.Tensor, delta: Optional[torch.Tensor] = None) -> torch.Tensor:
250
+ if delta is None:
251
+ shifted = self.time_shift(x)
252
+ if len(shifted.shape) == 2:
253
+ shifted = shifted.unsqueeze(1)
254
+ delta = shifted - x
255
+ return self.linear(x + delta * self.mu)
256
+
257
+
258
+ class DDLerpLinear(nn.Module):
259
+
260
+ def __init__(
261
+ self,
262
+ input_dim: int,
263
+ output_dim: int,
264
+ low_rank_dim: Optional[int] = None
265
+ ):
266
+ super().__init__()
267
+
268
+ self.input_dim = input_dim
269
+ self.output_dim = output_dim
270
+ self.low_rank_dim = low_rank_dim
271
+
272
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
273
+ if low_rank_dim is None:
274
+ self.linear = nn.Linear(input_dim, output_dim, bias=False)
275
+ else:
276
+ self.linear = LoRA(input_dim, output_dim, low_rank_dim)
277
+
278
+ def __repr__(self) -> str:
279
+ s = f"{self.__class__.__name__}({self.input_dim}, {self.output_dim}"
280
+ if self.low_rank_dim is not None:
281
+ s += f", low_rank_dim={self.low_rank_dim}"
282
+ s += ")"
283
+ return s
284
+
285
+ def forward(self, x: torch.Tensor, mu: torch.Tensor, delta: Optional[torch.Tensor] = None) -> torch.Tensor:
286
+ if delta is None:
287
+ shifted = self.time_shift(x)
288
+ if len(shifted.shape) == 2:
289
+ shifted = shifted.unsqueeze(1)
290
+ delta = shifted - x
291
+ return self.linear(x + delta * mu)
fla/layers/scan.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from einops import rearrange
13
+
14
+ from fla.modules import RMSNorm
15
+ from fla.modules.activations import swish, sigmoid
16
+ from fla.modules.layernorm import rms_norm_linear
17
+ from fla.ops.scan import parallel_scan, naive_recurrent_scan
18
+
19
+ if TYPE_CHECKING:
20
+ from fla.models.utils import Cache
21
+
22
+ def build_alibi_tensor_scan(head_num, seq_len, window_len, state_size):
23
+ slopes = torch.tensor([2 ** (-8.0 * i / head_num) for i in range(head_num)])
24
+ alibi = torch.zeros((head_num, seq_len, window_len))
25
+ for i in range(seq_len):
26
+ for j in range(window_len):
27
+ if i < window_len:
28
+ alibi[:, i, j] = slopes * (j - window_len + 1) if i > (window_len - j - 2) else 0
29
+ else:
30
+ alibi[:, i, j] = alibi[:, window_len-1, j]
31
+ # Now concat a zeros tensor of size (head_num, seq_len, state_size) to the left of the above square tensor
32
+ alibi = torch.cat((torch.zeros(head_num, seq_len, state_size), alibi), dim=2)
33
+ return alibi # shape: (head_num, seq_len, state_size + window_size) or (H, T, S + W)
34
+
35
+ def scores_mask(T, W, S):
36
+ # create lower right triangle mask (W, W)
37
+ mask = torch.tril(torch.ones(W, W)).flip(1)
38
+ # concat ones with size (T-W, W) in 0th dim
39
+ mask = torch.cat((mask, torch.ones(T-W, W)), dim=0)
40
+ # concat ones with size (T, S) in 1st dim
41
+ mask = torch.cat((torch.ones(T, S), mask), dim=1)
42
+ return mask # shape: (T, S + W)
43
+
44
+ class SemiCompressedAttention(nn.Module):
45
+
46
+ def __init__(
47
+ self,
48
+ mode: str = 'parallel',
49
+ hidden_size: int = 1024,
50
+ window_size: int = 512,
51
+ state_size: int = 64,
52
+ gate_act: str = 'softmax',
53
+ max_position_embeddings: Optional[int] = 2048,
54
+ expand_k: float = 1.,
55
+ expand_v: float = 1.,
56
+ num_heads: int = 4,
57
+ num_kv_heads: Optional[int] = None,
58
+ elementwise_affine: Optional[bool] = True,
59
+ norm_first: bool = True,
60
+ norm_eps: float = 1e-5,
61
+ gate_logit_normalizer: int = 8,
62
+ use_output_gate: bool = False,
63
+ use_norm: bool = True,
64
+ layer_idx: Optional[int] = None,
65
+ scale: Optional[float] = 1.,
66
+ **kwargs
67
+ ) -> SemiCompressedAttention:
68
+ super().__init__()
69
+
70
+ self.mode = mode
71
+ self.hidden_size = hidden_size
72
+ self.window_size = window_size
73
+ self.state_size = state_size
74
+ self.gate_act = gate_act
75
+ self.max_position_embeddings = max_position_embeddings
76
+ self.expand_k = expand_k
77
+ self.expand_v = expand_v
78
+ self.num_heads = num_heads
79
+ self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
80
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
81
+ self.key_dim = int(hidden_size * expand_k)
82
+ self.value_dim = int(hidden_size * expand_v)
83
+ self.key_dim_per_group = self.key_dim // self.num_kv_groups
84
+ self.value_dim_per_group = self.value_dim // self.num_kv_groups
85
+ self.head_k_dim = self.key_dim // self.num_heads
86
+ self.head_v_dim = self.value_dim // self.num_heads
87
+
88
+ self.gate_logit_normalizer = gate_logit_normalizer
89
+
90
+ self.use_output_gate = use_output_gate
91
+ self.use_norm = use_norm
92
+ self.scale = scale
93
+
94
+ self.norm_first = norm_first
95
+
96
+ self.layer_idx = layer_idx
97
+
98
+ if layer_idx is None:
99
+ warnings.warn(
100
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
101
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
102
+ "when creating this class."
103
+ )
104
+
105
+ if norm_first:
106
+ self.norm = RMSNorm(self.hidden_size, eps=norm_eps)
107
+
108
+ self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
109
+ self.k_proj = nn.Linear(self.hidden_size, self.key_dim_per_group, bias=False)
110
+ self.v_proj = nn.Linear(self.hidden_size, self.value_dim_per_group, bias=False)
111
+ self.s_proj = nn.Linear(self.hidden_size, self.key_dim_per_group, bias=False)
112
+ self.g_proj = nn.Linear(self.hidden_size, self.num_heads * self.state_size, bias=False)
113
+
114
+ self.norm = RMSNorm(self.hidden_size, elementwise_affine, eps=norm_eps)
115
+ self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
116
+
117
+ self.apply(self._initialize_weights)
118
+
119
+ self.register_buffer('alibi', build_alibi_tensor_scan(self.num_heads, self.max_position_embeddings, self.window_size, self.state_size))
120
+ self.register_buffer('mask', scores_mask(self.max_position_embeddings, self.window_size, self.state_size))
121
+
122
+ def _initialize_weights(self, module: nn.Module):
123
+ if getattr(module, "_is_hf_initialized", False):
124
+ return
125
+ if isinstance(module, nn.Linear):
126
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
127
+ if module.bias is not None:
128
+ nn.init.zeros_(module.bias)
129
+ module._is_hf_initialized = True
130
+
131
+ def forward(
132
+ self,
133
+ hidden_states: torch.Tensor,
134
+ attention_mask: Optional[torch.Tensor] = None,
135
+ past_key_values: Optional[Cache] = None,
136
+ use_cache: Optional[bool] = False,
137
+ output_attentions: Optional[bool] = False,
138
+ **kwargs
139
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
140
+ if attention_mask is not None:
141
+ assert len(attention_mask.shape) == 2, (
142
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
143
+ "for padding purposes (0 indicating padding). "
144
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
145
+ )
146
+
147
+ # launching the triton kernel for just one token will actually be slower
148
+ mode = 'naive' if past_key_values is not None else self.mode
149
+
150
+ if self.norm_first:
151
+ hidden_states = self.norm(hidden_states)
152
+
153
+ last_state = None
154
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
155
+ last_state = past_key_values[self.layer_idx]
156
+
157
+ q = self.q_proj(hidden_states)
158
+ k = self.k_proj(hidden_states)
159
+ v = self.v_proj(hidden_states)
160
+ s = self.s_proj(hidden_states)
161
+ g = self.g_proj(hidden_states)
162
+
163
+ if self.gate_act == 'softmax':
164
+ g = F.softmax(g, dim=-1)
165
+ elif self.gate_act == 'sigmoid':
166
+ g = sigmoid(g)
167
+ else:
168
+ raise NotImplementedError(f"Gate activation `{self.gate_act}` is not supported.")
169
+
170
+ # KV cache is updated before going into SCAN
171
+ if past_key_values is not None:
172
+ k, v = past_key_values.update(
173
+ attn_state=(k, v),
174
+ layer_idx=self.layer_idx,
175
+ offset=q.shape[2],
176
+ # We actually don't want to crop to window for the initial prompt, only for subsequent autoregressive tokens
177
+ cache_kwargs=dict(window_size=self.window_size) if q.shape[-2] == 1 else dict()
178
+ )['attn_state']
179
+
180
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
181
+ if mode == 'parallel':
182
+ # Split heads (but merge with batch dimension because kernels receive (B T C) shape)
183
+ q = rearrange(q, 'b t (h c) -> (b h) t c', h=self.num_heads)
184
+ k = rearrange(k, 'b t (h c) -> (b h) t c', h=self.num_kv_heads)
185
+ v = rearrange(v, 'b t (h c) -> (b h) t c', h=self.num_kv_heads)
186
+ s = rearrange(s, 'b t (h c) -> (b h) t c', h=self.num_kv_heads)
187
+ g = rearrange(g, 'b t (h s) -> (b h) t s', h=self.num_kv_heads)
188
+ o, recurrent_state = parallel_scan(
189
+ q=q,
190
+ k=k,
191
+ v=v,
192
+ s=s,
193
+ g=g,
194
+ window_size=self.window_size,
195
+ num_heads=self.num_heads,
196
+ alibi=self.alibi.to(q.device),
197
+ mask=self.mask.to(q.device),
198
+ initial_state=recurrent_state,
199
+ output_final_state=use_cache,
200
+ scale=self.scale,
201
+ head_first=False
202
+ )
203
+ o = rearrange(o, '(b h) t c -> b t (h c)', h=self.num_heads)
204
+ elif mode == 'naive':
205
+ # TODO: Implement naive recurrent SCAN for inference
206
+ q = rearrange(q, 'b t (h c) -> b h t c', h=self.num_heads)
207
+ k = rearrange(k, 'b t (h c) -> b h t c', h=self.num_kv_heads)
208
+ v = rearrange(v, 'b t (h c) -> b h t c', h=self.num_kv_heads)
209
+ s = rearrange(s, 'b t (h c) -> b h t c', h=self.num_kv_heads)
210
+ g = rearrange(g, 'b t (h s) -> b h t s', h=self.num_kv_heads)
211
+ o, recurrent_state = naive_recurrent_scan(
212
+ q=q,
213
+ k=k,
214
+ v=v,
215
+ s=s,
216
+ g=g,
217
+ window_size=self.window_size,
218
+ alibi=self.alibi.to(q.device),
219
+ mask=self.mask.to(q.device),
220
+ initial_state=recurrent_state,
221
+ output_final_state=use_cache,
222
+ scale=self.scale,
223
+ head_first=False
224
+ )
225
+ o = rearrange(o, 'b h t c -> b t (h c)', h=self.num_heads)
226
+ else:
227
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
228
+
229
+ # Update the recurrent state after SCAN
230
+ if past_key_values is not None:
231
+ past_key_values.update(
232
+ recurrent_state=recurrent_state,
233
+ layer_idx=self.layer_idx
234
+ )
235
+
236
+ o = rms_norm_linear(swish(o), self.norm.weight, self.norm.bias, self.o_proj.weight, self.o_proj.bias)
237
+ return o, None, past_key_values
fla/layers/simple_gla.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from einops import rearrange, repeat
12
+
13
+ from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
14
+ from fla.modules.activations import ACT2FN
15
+ from fla.ops.simple_gla import chunk_simple_gla, fused_recurrent_simple_gla
16
+
17
+ if TYPE_CHECKING:
18
+ from fla.models.utils import Cache
19
+
20
+
21
+ class SimpleGatedLinearAttention(nn.Module):
22
+ r"""
23
+ The layer implementaion for [Gated Linear Attention Transformers with Hardware-Efficient Training](https://arxiv.org/abs/2312.06635). # noqa
24
+ This layer calls the simplified GLA kernel in which the gating is head-wise instead of elementwise.
25
+
26
+ Args:
27
+ mode (str, Optional):
28
+ Which GLA kernel to use.
29
+ Currently available: `chunk`.
30
+ Default: `chunk`.
31
+ hidden_size (int, Optional):
32
+ The hidden size of the input. Default: 1024.
33
+ expand_k (float, Optional):
34
+ The expansion ratio for the key dim. Default: 1.0.
35
+ expand_v (float, Optional):
36
+ The expansion ratio for the value dim. Default: 1.0.
37
+ num_heads (int, Optional):
38
+ The number of heads. Default: 4.
39
+ num_kv_heads (int, Optional):
40
+ The number of key/value heads, used for MQA. Default: None.
41
+ feature_map (str, Optional):
42
+ Feature map function applied to queries/keys. Default: None.
43
+ use_short_conv (bool, Optional):
44
+ Whether to use short convolutions. Default: `False`.
45
+ conv_size (int, Optional):
46
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
47
+ conv_bias (bool, Optional):
48
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
49
+ gate_fn (str, Optional):
50
+ The activation function for the output gate. Default: `swish`.
51
+ elementwise_affine (bool, Optional):
52
+ If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
53
+ norm_eps (float, Optional):
54
+ The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
55
+ gate_logit_normalizer (int, Optional):
56
+ The normalizer for the gate logits, appied after `logsigmoid`. Default: 16.
57
+ fuse_norm (bool, Optional):
58
+ Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
59
+ layer_idx (int, Optional):
60
+ The index of the layer. Default: None.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ mode: str = 'chunk',
66
+ hidden_size: int = 1024,
67
+ expand_k: float = 1.,
68
+ expand_v: float = 1.,
69
+ num_heads: int = 4,
70
+ num_kv_heads: Optional[int] = None,
71
+ feature_map: Optional[str] = None,
72
+ use_short_conv: bool = True,
73
+ conv_size: int = 4,
74
+ conv_bias: bool = False,
75
+ gate_fn: str = 'swish',
76
+ elementwise_affine: Optional[bool] = True,
77
+ norm_eps: float = 1e-5,
78
+ gate_logit_normalizer: int = 16,
79
+ fuse_norm: bool = True,
80
+ layer_idx: int = None,
81
+ ) -> SimpleGatedLinearAttention:
82
+ super().__init__()
83
+
84
+ self.mode = mode
85
+ self.hidden_size = hidden_size
86
+ self.expand_k = expand_k
87
+ self.expand_v = expand_v
88
+ self.num_heads = num_heads
89
+ self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
90
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
91
+ self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None
92
+
93
+ self.use_short_conv = use_short_conv
94
+ self.conv_size = conv_size
95
+ self.conv_bias = conv_bias
96
+
97
+ self.key_dim = int(hidden_size * expand_k)
98
+ self.value_dim = int(hidden_size * expand_v)
99
+ self.key_dim_per_group = self.key_dim // self.num_kv_groups
100
+ self.value_dim_per_group = self.value_dim // self.num_kv_groups
101
+ self.layer_idx = layer_idx
102
+
103
+ assert mode in ['chunk', "fused_recurrent"], f"Not suppoerted mode `{mode}`."
104
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
105
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
106
+
107
+ self.head_qk_dim = self.key_dim // num_heads
108
+ self.head_v_dim = self.value_dim // num_heads
109
+
110
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
111
+ self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
112
+ self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
113
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
114
+
115
+ if use_short_conv:
116
+ self.conv_size = conv_size
117
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
118
+ self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
119
+ self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
120
+
121
+ self.gk_proj = nn.Linear(hidden_size, self.num_heads)
122
+
123
+ if gate_fn == 'swish' and fuse_norm:
124
+ self.g_norm_swish_gate = FusedRMSNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps)
125
+ self.fuse_norm_and_gate = True
126
+ else:
127
+ self.fuse_norm_and_gate = False
128
+ self.g_norm = RMSNorm(hidden_size=self.head_v_dim, elementwise_affine=elementwise_affine, eps=norm_eps)
129
+ self.gate_fn = ACT2FN[gate_fn]
130
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
131
+
132
+ self.gate_logit_normalizer = gate_logit_normalizer
133
+
134
+ self.apply(self._initialize_weights)
135
+
136
+ def _initialize_weights(self, module: nn.Module):
137
+ if getattr(module, "_is_hf_initialized", False):
138
+ return
139
+ if isinstance(module, nn.Linear):
140
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
141
+ if module.bias is not None:
142
+ nn.init.zeros_(module.bias)
143
+ module._is_hf_initialized = True
144
+
145
+ def forward(
146
+ self,
147
+ hidden_states: torch.Tensor,
148
+ attention_mask: Optional[torch.Tensor] = None,
149
+ past_key_values: Optional[Cache] = None,
150
+ use_cache: Optional[bool] = False,
151
+ output_attentions: Optional[bool] = False,
152
+ **kwargs
153
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
154
+ if attention_mask is not None:
155
+ assert len(attention_mask.shape) == 2, (
156
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
157
+ "for padding purposes (0 indicating padding). "
158
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
159
+ )
160
+
161
+ # launching the triton kernel for just one token will actually be slower
162
+ mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode
163
+
164
+ last_state = None
165
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
166
+ last_state = past_key_values[self.layer_idx]
167
+
168
+ if self.use_short_conv:
169
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
170
+ if last_state is not None:
171
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
172
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
173
+ q, conv_state_q = self.q_conv1d(x=self.q_proj(hidden_states),
174
+ mask=conv_mask,
175
+ cache=conv_state_q,
176
+ output_final_state=use_cache)
177
+ k, conv_state_k = self.k_conv1d(x=self.k_proj(hidden_states),
178
+ mask=conv_mask,
179
+ cache=conv_state_k,
180
+ output_final_state=use_cache)
181
+ v, conv_state_v = self.v_conv1d(x=self.v_proj(hidden_states),
182
+ mask=conv_mask,
183
+ cache=conv_state_v,
184
+ output_final_state=use_cache)
185
+ else:
186
+ q = self.q_proj(hidden_states)
187
+ k = self.k_proj(hidden_states)
188
+ v = self.v_proj(hidden_states)
189
+ gk = self.gk_proj(hidden_states)
190
+
191
+ if self.feature_map_fn is not None:
192
+ q, k = map(self.feature_map_fn, (q, k))
193
+ # dealing with left-padding
194
+ if attention_mask is not None:
195
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
196
+ q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads)
197
+ if self.num_kv_groups > 1:
198
+ k, v = (repeat(x, '... (h d) -> ... (h g) d', h=self.num_kv_heads, g=self.num_kv_groups) for x in (k, v))
199
+ else:
200
+ k, v = (rearrange(x, '... (h d) -> ... h d', h=self.num_kv_heads) for x in (k, v))
201
+ gk = F.logsigmoid(gk) / self.gate_logit_normalizer
202
+
203
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
204
+ if mode == 'chunk':
205
+ o, recurrent_state = chunk_simple_gla(
206
+ q=q,
207
+ k=k,
208
+ v=v,
209
+ gk=gk,
210
+ initial_state=recurrent_state,
211
+ output_final_state=use_cache,
212
+ head_first=False
213
+ )
214
+ elif mode == 'fused_recurrent':
215
+ o, recurrent_state = fused_recurrent_simple_gla(
216
+ q=q,
217
+ k=k,
218
+ v=v,
219
+ gk=gk,
220
+ initial_state=recurrent_state,
221
+ output_final_state=use_cache,
222
+ head_first=False
223
+ )
224
+ else:
225
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
226
+
227
+ if past_key_values is not None:
228
+ past_key_values.update(
229
+ recurrent_state=recurrent_state,
230
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
231
+ layer_idx=self.layer_idx,
232
+ offset=q.shape[2]
233
+ )
234
+
235
+ g = self.g_proj(hidden_states)
236
+ if self.fuse_norm_and_gate:
237
+ g = rearrange(g, 'b t (h d) -> b t h d', h=self.num_heads)
238
+ o = self.g_norm_swish_gate(o, g)
239
+ o = rearrange(o, 'b t h d -> b t (h d)')
240
+ else:
241
+ o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
242
+ o = o * self.gate_fn(g)
243
+ o = self.o_proj(o)
244
+
245
+ return o, None, past_key_values
246
+
247
+ def state_size(self, **kwargs) -> int:
248
+ state_size = self.key_dim * self.head_v_dim
249
+ for module in self.children():
250
+ if isinstance(module, ShortConvolution):
251
+ state_size += module.state_size
252
+ return state_size
fla/models/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from fla.models.abc import ABCConfig, ABCForCausalLM, ABCModel
4
+ from fla.models.bitnet import BitNetConfig, BitNetForCausalLM, BitNetModel
5
+ from fla.models.delta_net import (DeltaNetConfig, DeltaNetForCausalLM,
6
+ DeltaNetModel)
7
+ from fla.models.gla import GLAConfig, GLAForCausalLM, GLAModel
8
+ from fla.models.gsa import GSAConfig, GSAForCausalLM, GSAModel
9
+ from fla.models.hgrn import HGRNConfig, HGRNForCausalLM, HGRNModel
10
+ from fla.models.hgrn2 import HGRN2Config, HGRN2ForCausalLM, HGRN2Model
11
+ from fla.models.linear_attn import (LinearAttentionConfig,
12
+ LinearAttentionForCausalLM,
13
+ LinearAttentionModel)
14
+ from fla.models.mamba import MambaConfig, MambaForCausalLM, MambaModel
15
+ from fla.models.mamba2 import Mamba2Config, Mamba2ForCausalLM, Mamba2Model
16
+ from fla.models.retnet import RetNetConfig, RetNetForCausalLM, RetNetModel
17
+ from fla.models.rwkv6 import RWKV6Config, RWKV6ForCausalLM, RWKV6Model
18
+ from fla.models.samba import SambaConfig, SambaForCausalLM, SambaModel
19
+ from fla.models.scan import SCANConfig, SCANForCausalLM, SCANModel
20
+ from fla.models.transformer import (TransformerConfig, TransformerForCausalLM,
21
+ TransformerModel)
22
+
23
+ __all__ = [
24
+ 'ABCConfig', 'ABCForCausalLM', 'ABCModel',
25
+ 'BitNetConfig', 'BitNetForCausalLM', 'BitNetModel',
26
+ 'DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel',
27
+ 'GLAConfig', 'GLAForCausalLM', 'GLAModel',
28
+ 'GSAConfig', 'GSAForCausalLM', 'GSAModel',
29
+ 'HGRNConfig', 'HGRNForCausalLM', 'HGRNModel',
30
+ 'HGRN2Config', 'HGRN2ForCausalLM', 'HGRN2Model',
31
+ 'LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel',
32
+ 'MambaConfig', 'MambaForCausalLM', 'MambaModel',
33
+ 'Mamba2Config', 'Mamba2ForCausalLM', 'Mamba2Model',
34
+ 'RetNetConfig', 'RetNetForCausalLM', 'RetNetModel',
35
+ 'RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model',
36
+ 'SambaConfig', 'SambaForCausalLM', 'SambaModel',
37
+ 'SCANConfig', 'SCANForCausalLM', 'SCANModel',
38
+ 'TransformerConfig', 'TransformerForCausalLM', 'TransformerModel'
39
+ ]
fla/models/abc/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.abc.configuration_abc import ABCConfig
6
+ from fla.models.abc.modeling_abc import ABCForCausalLM, ABCModel
7
+
8
+ AutoConfig.register(ABCConfig.model_type, ABCConfig)
9
+ AutoModel.register(ABCConfig, ABCModel)
10
+ AutoModelForCausalLM.register(ABCConfig, ABCForCausalLM)
11
+
12
+
13
+ __all__ = ['ABCConfig', 'ABCForCausalLM', 'ABCModel']
fla/models/abc/configuration_abc.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class ABCConfig(PretrainedConfig):
9
+
10
+ model_type = 'abc'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ gate_low_rank_dim: int = 16,
17
+ clamp_min: float = -32,
18
+ clamp_max: float = 32,
19
+ hidden_ratio: Optional[int] = 4,
20
+ intermediate_size: Optional[int] = None,
21
+ num_hidden_layers: int = 24,
22
+ num_heads: int = 4,
23
+ num_slots: Optional[int] = 64,
24
+ use_short_conv: bool = False,
25
+ conv_size: int = 4,
26
+ exapnd_k: float = 0.5,
27
+ exapnd_v: float = 1,
28
+ hidden_act: str = "swish",
29
+ max_position_embeddings: int = 2048,
30
+ elementwise_affine: Optional[bool] = True,
31
+ norm_eps: float = 1e-6,
32
+ attn: Optional[Dict] = None,
33
+ use_cache: bool = True,
34
+ pad_token_id: int = None,
35
+ bos_token_id: int = 1,
36
+ eos_token_id: int = 2,
37
+ tie_word_embeddings: bool = False,
38
+ initializer_range: float = 0.02,
39
+ fuse_norm: bool = True,
40
+ fuse_cross_entropy: bool = True,
41
+ vocab_size: int = 32000,
42
+ **kwargs
43
+ ):
44
+ self.hidden_size = hidden_size
45
+ self.gate_low_rank_dim = gate_low_rank_dim
46
+ self.clamp_min = clamp_min
47
+ self.clamp_max = clamp_max
48
+ self.hidden_ratio = hidden_ratio
49
+ self.intermediate_size = intermediate_size
50
+ self.num_hidden_layers = num_hidden_layers
51
+ self.num_heads = num_heads
52
+ self.num_slots = num_slots
53
+ self.use_short_conv = use_short_conv
54
+ self.conv_size = conv_size
55
+ self.expand_k = exapnd_k
56
+ self.expand_v = exapnd_v
57
+ self.hidden_act = hidden_act
58
+ self.max_position_embeddings = max_position_embeddings
59
+ self.elementwise_affine = elementwise_affine
60
+ self.norm_eps = norm_eps
61
+ self.attn = attn
62
+ self.use_cache = use_cache
63
+ self.initializer_range = initializer_range
64
+ self.fuse_norm = fuse_norm
65
+ self.fuse_cross_entropy = fuse_cross_entropy
66
+ self.vocab_size = vocab_size
67
+
68
+ if attn is not None:
69
+ if not isinstance(attn, Dict):
70
+ raise ValueError("attn must be a dictionary")
71
+ if 'layers' not in attn:
72
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
73
+ if 'num_heads' not in attn:
74
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
75
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
76
+ attn['window_size'] = attn.get('window_size', None)
77
+
78
+ super().__init__(
79
+ pad_token_id=pad_token_id,
80
+ bos_token_id=bos_token_id,
81
+ eos_token_id=eos_token_id,
82
+ tie_word_embeddings=tie_word_embeddings,
83
+ **kwargs,
84
+ )
fla/models/abc/modeling_abc.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.activations import ACT2FN
13
+ from transformers.generation import GenerationMixin
14
+ from transformers.modeling_outputs import (BaseModelOutputWithPast,
15
+ CausalLMOutputWithPast)
16
+ from transformers.modeling_utils import PreTrainedModel
17
+ from transformers.utils import logging
18
+
19
+ from fla.layers.abc import ABCAttention
20
+ from fla.layers.attn import Attention
21
+ from fla.models.abc.configuration_abc import ABCConfig
22
+ from fla.models.utils import Cache
23
+ from fla.modules import (FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss,
24
+ RMSNorm)
25
+ from fla.modules.activations import swiglu_linear
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ class ABCMLP(nn.Module):
31
+
32
+ def __init__(
33
+ self,
34
+ hidden_size: int,
35
+ hidden_ratio: Optional[int] = None,
36
+ intermediate_size: Optional[int] = None,
37
+ hidden_act: str = 'swish'
38
+ ) -> ABCMLP:
39
+ super().__init__()
40
+
41
+ self.hidden_size = hidden_size
42
+ # the final number of params is `hidden_ratio * hidden_size^2`
43
+ # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
44
+ if hidden_ratio is None:
45
+ hidden_ratio = 4
46
+ if intermediate_size is None:
47
+ intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
48
+ intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
49
+ self.hidden_ratio = hidden_ratio
50
+ self.intermediate_size = intermediate_size
51
+
52
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
53
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
54
+ self.act_fn = ACT2FN[hidden_act]
55
+
56
+ def forward(self, x):
57
+ y = self.gate_proj(x)
58
+ gate, y = y.chunk(2, -1)
59
+ return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
60
+
61
+
62
+ class ABCBlock(nn.Module):
63
+ def __init__(self, config: ABCConfig, layer_idx: int):
64
+ super().__init__()
65
+ self.hidden_size = config.hidden_size
66
+
67
+ self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
68
+ if config.attn is not None and layer_idx in config.attn['layers']:
69
+ self.attn = Attention(
70
+ hidden_size=config.hidden_size,
71
+ num_heads=config.attn['num_heads'],
72
+ num_kv_heads=config.attn['num_kv_heads'],
73
+ window_size=config.attn['window_size'],
74
+ max_position_embeddings=config.max_position_embeddings,
75
+ layer_idx=layer_idx
76
+ )
77
+ else:
78
+ self.attn = ABCAttention(
79
+ hidden_size=config.hidden_size,
80
+ expand_k=config.expand_k,
81
+ expand_v=config.expand_v,
82
+ num_heads=config.num_heads,
83
+ num_slots=config.num_slots,
84
+ use_short_conv=config.use_short_conv,
85
+ conv_size=config.conv_size,
86
+ gate_fn=config.hidden_act,
87
+ elementwise_affine=config.elementwise_affine,
88
+ norm_eps=config.norm_eps,
89
+ clamp_min=config.clamp_min,
90
+ clamp_max=config.clamp_max,
91
+ fuse_norm=config.fuse_norm,
92
+ layer_idx=layer_idx
93
+ )
94
+ self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
95
+ self.mlp = ABCMLP(
96
+ hidden_size=config.hidden_size,
97
+ hidden_ratio=config.hidden_ratio,
98
+ intermediate_size=config.intermediate_size,
99
+ hidden_act=config.hidden_act
100
+ )
101
+
102
+ def forward(
103
+ self,
104
+ hidden_states: torch.Tensor,
105
+ attention_mask: Optional[torch.Tensor] = None,
106
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
107
+ use_cache: Optional[bool] = False,
108
+ output_attentions: Optional[bool] = False,
109
+ **kwargs,
110
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
111
+
112
+ residual = hidden_states
113
+
114
+ hidden_states = self.attn_norm(hidden_states)
115
+ hidden_states, attentions, past_key_values = self.attn(
116
+ hidden_states=hidden_states,
117
+ attention_mask=attention_mask,
118
+ past_key_values=past_key_values,
119
+ use_cache=use_cache,
120
+ output_attentions=output_attentions
121
+ )
122
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
123
+ hidden_states = self.mlp(hidden_states)
124
+ hidden_states = residual + hidden_states
125
+
126
+ outputs = (hidden_states, attentions, past_key_values)
127
+
128
+ return outputs
129
+
130
+
131
+ class ABCPreTrainedModel(PreTrainedModel):
132
+
133
+ config_class = ABCConfig
134
+ supports_gradient_checkpointing = True
135
+ _no_split_modules = ['ABCBlock']
136
+
137
+ def __init__(self, *inputs, **kwargs):
138
+ super().__init__(*inputs, **kwargs)
139
+
140
+ def _init_weights(
141
+ self,
142
+ module: nn.Module,
143
+ rescale_prenorm_residual: bool = True,
144
+ num_residuals_per_layer: int = 2,
145
+ ):
146
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
147
+ # Slightly different from the TF version which uses truncated_normal for initialization
148
+ # cf https://github.com/pytorch/pytorch/pull/5617
149
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
150
+ if module.bias is not None:
151
+ nn.init.zeros_(module.bias)
152
+ elif isinstance(module, nn.Embedding):
153
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
154
+ if module.padding_idx is not None:
155
+ module.weight.data[module.padding_idx].zero_()
156
+
157
+ if rescale_prenorm_residual:
158
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
159
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
160
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
161
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
162
+ #
163
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
164
+ for name, p in module.named_parameters():
165
+ if name in ["o_proj.weight", "down_proj.weight"]:
166
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
167
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
168
+ # We need to reinit p since this code could be called multiple times
169
+ # Having just p *= scale would repeatedly scale it down
170
+ with torch.no_grad():
171
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
172
+
173
+
174
+ class ABCModel(ABCPreTrainedModel):
175
+
176
+ def __init__(self, config: ABCConfig):
177
+ super().__init__(config)
178
+ self.padding_idx = config.pad_token_id
179
+ self.vocab_size = config.vocab_size
180
+
181
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
182
+ self.layers = nn.ModuleList([ABCBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
183
+ self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
184
+
185
+ self.gradient_checkpointing = False
186
+
187
+ self.post_init()
188
+
189
+ def get_input_embeddings(self):
190
+ return self.embeddings
191
+
192
+ def set_input_embeddings(self, value):
193
+ self.embeddings = value
194
+
195
+ def forward(
196
+ self,
197
+ input_ids: Optional[torch.LongTensor] = None,
198
+ attention_mask: Optional[torch.Tensor] = None, # noqa
199
+ inputs_embeds: Optional[torch.FloatTensor] = None,
200
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
201
+ use_cache: Optional[bool] = None,
202
+ output_attentions: Optional[bool] = None,
203
+ output_hidden_states: Optional[bool] = None,
204
+ return_dict: Optional[bool] = None
205
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
206
+ if output_attentions:
207
+ warnings.warn("`ABCModel` does not `output_attentions` now, setting it to `False`.")
208
+ output_attentions = False
209
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
210
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
211
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
212
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
213
+
214
+ # retrieve input_ids and inputs_embeds
215
+ if input_ids is not None and inputs_embeds is not None:
216
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
217
+ if input_ids is None and inputs_embeds is None:
218
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
219
+
220
+ if inputs_embeds is None:
221
+ inputs_embeds = self.embeddings(input_ids)
222
+ hidden_states = inputs_embeds
223
+
224
+ if use_cache and not isinstance(past_key_values, Cache):
225
+ past_key_values = Cache.from_legacy_cache(past_key_values)
226
+
227
+ if self.gradient_checkpointing and self.training and use_cache:
228
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
229
+ use_cache = False
230
+
231
+ all_hidden_states = () if output_hidden_states else None
232
+ all_attns = () if output_attentions else None
233
+ for layer in self.layers:
234
+ if output_hidden_states:
235
+ all_hidden_states += (hidden_states,)
236
+
237
+ if self.gradient_checkpointing and self.training:
238
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
239
+ layer.__call__,
240
+ hidden_states,
241
+ attention_mask,
242
+ past_key_values,
243
+ use_cache,
244
+ output_attentions
245
+ )
246
+ else:
247
+ hidden_states, attentions, past_key_values = layer(
248
+ hidden_states,
249
+ attention_mask,
250
+ past_key_values=past_key_values,
251
+ use_cache=use_cache,
252
+ output_attentions=output_attentions
253
+ )
254
+
255
+ if output_attentions:
256
+ all_attns += (attentions,)
257
+
258
+ hidden_states = self.norm(hidden_states)
259
+
260
+ # add hidden states from the last decoder layer
261
+ if output_hidden_states:
262
+ all_hidden_states += (hidden_states,)
263
+
264
+ if not return_dict:
265
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
266
+ return BaseModelOutputWithPast(
267
+ last_hidden_state=hidden_states,
268
+ past_key_values=past_key_values,
269
+ hidden_states=all_hidden_states,
270
+ attentions=all_attns
271
+ )
272
+
273
+
274
+ class ABCForCausalLM(ABCPreTrainedModel, GenerationMixin):
275
+
276
+ _tied_weights_keys = ["lm_head.weight"]
277
+
278
+ def __init__(self, config):
279
+ super().__init__(config)
280
+ self.model = ABCModel(config)
281
+ self.vocab_size = config.vocab_size
282
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
283
+
284
+ # Initialize weights and apply final processing
285
+ self.post_init()
286
+
287
+ def get_input_embeddings(self):
288
+ return self.model.embeddings
289
+
290
+ def set_input_embeddings(self, value):
291
+ self.model.embeddings = value
292
+
293
+ def get_output_embeddings(self):
294
+ return self.lm_head
295
+
296
+ def set_output_embeddings(self, new_embeddings):
297
+ self.lm_head = new_embeddings
298
+
299
+ def set_decoder(self, decoder):
300
+ self.model = decoder
301
+
302
+ def get_decoder(self):
303
+ return self.model
304
+
305
+ def generate(self, *args, **kwargs):
306
+ try:
307
+ return super().generate(*args, **kwargs)
308
+ except AttributeError as exception:
309
+ if 'past_key_values' in str(exception):
310
+ raise AttributeError(
311
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
312
+ f"which is not supported for {self.__class__.__name__}. "
313
+ f"Try another generation strategy instead. "
314
+ f"For the available generation strategies, check this doc: "
315
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
316
+ )
317
+ else:
318
+ raise exception
319
+
320
+ def prepare_inputs_for_generation(
321
+ self,
322
+ input_ids: torch.LongTensor = None,
323
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
324
+ inputs_embeds: Optional[torch.FloatTensor] = None,
325
+ **kwargs
326
+ ):
327
+ # only last token for `inputs_ids` if the `past_key_values` is passed along.
328
+ if past_key_values is not None:
329
+ input_ids = input_ids[:, -1:]
330
+
331
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
332
+ if inputs_embeds is not None and past_key_values is None:
333
+ model_inputs = {'inputs_embeds': inputs_embeds}
334
+ else:
335
+ model_inputs = {'input_ids': input_ids}
336
+ model_inputs['past_key_values'] = past_key_values
337
+ return model_inputs
338
+
339
+ def forward(
340
+ self,
341
+ input_ids: torch.LongTensor = None,
342
+ attention_mask: Optional[torch.Tensor] = None,
343
+ inputs_embeds: Optional[torch.Tensor] = None,
344
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
345
+ labels: Optional[torch.LongTensor] = None,
346
+ use_cache: Optional[bool] = None,
347
+ output_attentions: Optional[bool] = None,
348
+ output_hidden_states: Optional[bool] = None,
349
+ return_dict: Optional[bool] = None,
350
+ num_logits_to_keep: Optional[int] = 0
351
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
352
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
353
+ output_hidden_states = (
354
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
355
+ )
356
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
357
+
358
+ outputs = self.model(
359
+ input_ids=input_ids,
360
+ attention_mask=attention_mask,
361
+ inputs_embeds=inputs_embeds,
362
+ past_key_values=past_key_values,
363
+ use_cache=use_cache,
364
+ output_attentions=output_attentions,
365
+ output_hidden_states=output_hidden_states,
366
+ return_dict=return_dict
367
+ )
368
+
369
+ hidden_states = outputs[0]
370
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
371
+ logits = None if fuse_linear_and_cross_entropy else self.lm_head(hidden_states[:, -num_logits_to_keep:])
372
+
373
+ loss = None
374
+ if labels is not None:
375
+ if self.config.fuse_cross_entropy:
376
+ if fuse_linear_and_cross_entropy:
377
+ loss_fct = FusedLinearCrossEntropyLoss()
378
+ else:
379
+ loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
380
+ else:
381
+ loss_fct = nn.CrossEntropyLoss()
382
+ # Enable model parallelism
383
+ labels = labels.to(hidden_states.device)
384
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1)
385
+ if fuse_linear_and_cross_entropy:
386
+ loss = loss_fct(hidden_states.view(-1, self.config.hidden_size),
387
+ labels.view(-1),
388
+ self.lm_head.weight,
389
+ self.lm_head.bias)
390
+ else:
391
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
392
+
393
+ if not return_dict:
394
+ output = (logits,) + outputs[1:]
395
+ return (loss,) + output if loss is not None else output
396
+
397
+ return CausalLMOutputWithPast(
398
+ loss=loss,
399
+ logits=logits,
400
+ past_key_values=outputs.past_key_values,
401
+ hidden_states=outputs.hidden_states,
402
+ attentions=outputs.attentions,
403
+ )
fla/models/bitnet/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.bitnet.configuration_bitnet import BitNetConfig
6
+ from fla.models.bitnet.modeling_bitnet import BitNetForCausalLM, BitNetModel
7
+
8
+ AutoConfig.register(BitNetConfig.model_type, BitNetConfig)
9
+ AutoModel.register(BitNetConfig, BitNetModel)
10
+ AutoModelForCausalLM.register(BitNetConfig, BitNetForCausalLM)
11
+
12
+
13
+ __all__ = ['BitNetConfig', 'BitNetForCausalLM', 'BitNetModel']
fla/models/bitnet/configuration_bitnet.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class BitNetConfig(PretrainedConfig):
9
+
10
+ model_type = 'bitnet'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ vocab_size: int = 32000,
16
+ hidden_size: int = 2048,
17
+ num_hidden_layers: int = 24,
18
+ num_heads: int = 32,
19
+ num_kv_heads: int = None,
20
+ window_size: Optional[int] = None,
21
+ rope_theta: Optional[float] = 10000.,
22
+ max_position_embeddings: int = 2048,
23
+ hidden_ratio: Optional[int] = 4,
24
+ intermediate_size: Optional[int] = None,
25
+ hidden_act: str = "swish",
26
+ initializer_range: float = 0.02,
27
+ elementwise_affine: Optional[bool] = True,
28
+ norm_first: bool = False,
29
+ norm_eps: float = 1e-6,
30
+ use_cache: bool = True,
31
+ pad_token_id: int = None,
32
+ bos_token_id: int = 1,
33
+ eos_token_id: int = 2,
34
+ tie_word_embeddings: bool = False,
35
+ attention_bias: bool = False,
36
+ fuse_norm: bool = True,
37
+ fuse_cross_entropy: bool = True,
38
+ **kwargs,
39
+ ):
40
+ self.vocab_size = vocab_size
41
+ self.hidden_size = hidden_size
42
+ self.num_hidden_layers = num_hidden_layers
43
+ self.num_heads = num_heads
44
+ self.num_kv_heads = num_kv_heads
45
+ self.window_size = window_size
46
+ self.rope_theta = rope_theta
47
+ self.max_position_embeddings = max_position_embeddings
48
+
49
+ self.hidden_ratio = hidden_ratio
50
+ self.intermediate_size = intermediate_size
51
+ self.hidden_act = hidden_act
52
+
53
+ self.initializer_range = initializer_range
54
+ self.elementwise_affine = elementwise_affine
55
+ self.norm_first = norm_first
56
+ self.norm_eps = norm_eps
57
+ self.use_cache = use_cache
58
+ self.attention_bias = attention_bias
59
+ self.fuse_cross_entropy = fuse_cross_entropy
60
+ self.fuse_norm = fuse_norm
61
+
62
+ super().__init__(
63
+ pad_token_id=pad_token_id,
64
+ bos_token_id=bos_token_id,
65
+ eos_token_id=eos_token_id,
66
+ tie_word_embeddings=tie_word_embeddings,
67
+ **kwargs,
68
+ )
fla/models/bitnet/modeling_bitnet.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.activations import ACT2FN
13
+ from transformers.generation import GenerationMixin
14
+ from transformers.modeling_outputs import (BaseModelOutputWithPast,
15
+ CausalLMOutputWithPast)
16
+ from transformers.modeling_utils import PreTrainedModel
17
+ from transformers.utils import logging
18
+
19
+ from fla.layers.bitattn import BitAttention
20
+ from fla.models.bitnet.configuration_bitnet import BitNetConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import (FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss,
23
+ RMSNorm)
24
+ from fla.modules.activations import swiglu_bitlinear
25
+ from fla.modules.fused_bitlinear import BitLinear, rms_norm_linear_quant
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ class BitNetMLP(nn.Module):
31
+
32
+ def __init__(
33
+ self,
34
+ hidden_size: int,
35
+ hidden_ratio: Optional[int] = None,
36
+ intermediate_size: Optional[int] = None,
37
+ hidden_act: str = 'swish',
38
+ norm_first: bool = True,
39
+ norm_eps: float = 1e-5
40
+ ) -> BitNetMLP:
41
+ super().__init__()
42
+
43
+ self.hidden_size = hidden_size
44
+ # the final number of params is `hidden_ratio * hidden_size^2`
45
+ # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
46
+ if hidden_ratio is None:
47
+ hidden_ratio = 4
48
+ if intermediate_size is None:
49
+ intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
50
+ intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
51
+ self.hidden_ratio = hidden_ratio
52
+ self.intermediate_size = intermediate_size
53
+ self.norm_first = norm_first
54
+
55
+ if norm_first:
56
+ self.norm = RMSNorm(hidden_size=hidden_size, eps=norm_eps)
57
+
58
+ self.gate_proj = BitLinear(self.hidden_size, self.intermediate_size * 2, bias=False)
59
+ self.down_proj = BitLinear(self.intermediate_size, self.hidden_size, bias=False)
60
+ self.act_fn = ACT2FN[hidden_act]
61
+
62
+ def forward(self, x):
63
+ if self.norm_first:
64
+ x = rms_norm_linear_quant(x, self.norm.weight, self.norm.bias, self.gate_proj.weight, self.gate_proj.bias)
65
+ else:
66
+ x = self.gate_proj(x)
67
+ gate, y = x.chunk(2, -1)
68
+ return swiglu_bitlinear(gate, y, self.down_proj.weight, self.down_proj.bias)
69
+
70
+
71
+ class BitNetBlock(nn.Module):
72
+
73
+ def __init__(self, config: BitNetConfig, layer_idx: int):
74
+ super().__init__()
75
+
76
+ self.hidden_size = config.hidden_size
77
+
78
+ if not config.norm_first:
79
+ self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
80
+ self.attn = BitAttention(
81
+ hidden_size=config.hidden_size,
82
+ num_heads=config.num_heads,
83
+ num_kv_heads=config.num_kv_heads,
84
+ window_size=config.window_size,
85
+ rope_theta=config.rope_theta,
86
+ max_position_embeddings=config.max_position_embeddings,
87
+ norm_first=config.norm_first,
88
+ norm_eps=config.norm_eps,
89
+ layer_idx=layer_idx
90
+ )
91
+ if not config.norm_first:
92
+ self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
93
+ self.mlp = BitNetMLP(
94
+ hidden_size=config.hidden_size,
95
+ hidden_ratio=config.hidden_ratio,
96
+ intermediate_size=config.intermediate_size,
97
+ hidden_act=config.hidden_act,
98
+ norm_first=config.norm_first,
99
+ norm_eps=config.norm_eps
100
+ )
101
+
102
+ def forward(
103
+ self,
104
+ hidden_states: torch.Tensor,
105
+ attention_mask: Optional[torch.Tensor] = None,
106
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
107
+ output_attentions: Optional[bool] = False,
108
+ use_cache: Optional[bool] = False,
109
+ **kwargs,
110
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
111
+
112
+ residual = hidden_states
113
+ if hasattr(self, 'attn_norm'):
114
+ hidden_states = self.attn_norm(hidden_states)
115
+ hidden_states, attentions, past_key_values = self.attn(
116
+ hidden_states=hidden_states,
117
+ attention_mask=attention_mask,
118
+ past_key_values=past_key_values,
119
+ use_cache=use_cache,
120
+ output_attentions=output_attentions
121
+ )
122
+ if hasattr(self, 'mlp_norm'):
123
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
124
+ else:
125
+ hidden_states = residual + hidden_states
126
+ residual = hidden_states
127
+ hidden_states = self.mlp(hidden_states)
128
+ hidden_states = residual + hidden_states
129
+
130
+ outputs = (hidden_states,)
131
+
132
+ if output_attentions:
133
+ outputs += (attentions,)
134
+
135
+ if use_cache:
136
+ outputs += (past_key_values,)
137
+
138
+ return outputs
139
+
140
+
141
+ class BitNetPreTrainedModel(PreTrainedModel):
142
+
143
+ config_class = BitNetConfig
144
+ supports_gradient_checkpointing = True
145
+ _no_split_modules = ['BitNetBlock']
146
+
147
+ def __init__(self, *inputs, **kwargs):
148
+ super().__init__(*inputs, **kwargs)
149
+
150
+ def _init_weights(
151
+ self,
152
+ module: nn.Module,
153
+ rescale_prenorm_residual: bool = False,
154
+ num_residuals_per_layer: int = 2,
155
+ ):
156
+ if isinstance(module, (BitLinear, nn.Conv1d)):
157
+ # Slightly different from the TF version which uses truncated_normal for initialization
158
+ # cf https://github.com/pytorch/pytorch/pull/5617
159
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
160
+ if module.bias is not None:
161
+ nn.init.zeros_(module.bias)
162
+ elif isinstance(module, nn.Embedding):
163
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
164
+ if module.padding_idx is not None:
165
+ module.weight.data[module.padding_idx].zero_()
166
+
167
+ if rescale_prenorm_residual:
168
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
169
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
170
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
171
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
172
+ #
173
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
174
+ for name, p in module.named_parameters():
175
+ if name in ["o_proj.weight", "down_proj.weight"]:
176
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
177
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
178
+ # We need to reinit p since this code could be called multiple times
179
+ # Having just p *= scale would repeatedly scale it down
180
+ with torch.no_grad():
181
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
182
+
183
+
184
+ class BitNetModel(BitNetPreTrainedModel):
185
+
186
+ def __init__(self, config: BitNetConfig):
187
+ super().__init__(config)
188
+ self.padding_idx = config.pad_token_id
189
+ self.vocab_size = config.vocab_size
190
+
191
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
192
+ self.layers = nn.ModuleList([BitNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
193
+ self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
194
+
195
+ self.gradient_checkpointing = False
196
+
197
+ self.post_init()
198
+
199
+ def get_input_embeddings(self):
200
+ return self.embeddings
201
+
202
+ def set_input_embeddings(self, value):
203
+ self.embeddings = value
204
+
205
+ def forward(
206
+ self,
207
+ input_ids: Optional[torch.LongTensor] = None,
208
+ attention_mask: Optional[torch.Tensor] = None,
209
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
210
+ inputs_embeds: Optional[torch.FloatTensor] = None,
211
+ use_cache: Optional[bool] = None,
212
+ output_attentions: Optional[bool] = None,
213
+ output_hidden_states: Optional[bool] = None,
214
+ return_dict: Optional[bool] = None
215
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
216
+ if output_attentions:
217
+ warnings.warn(
218
+ "`BitNetModel` does not support output attention weights now, so `output_attentions` is set to `False`."
219
+ )
220
+ output_attentions = False
221
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
222
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
223
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
224
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
225
+
226
+ # retrieve input_ids and inputs_embeds
227
+ if input_ids is not None and inputs_embeds is not None:
228
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
229
+ elif input_ids is None and inputs_embeds is None:
230
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
231
+
232
+ if use_cache and not isinstance(past_key_values, Cache):
233
+ past_key_values = Cache.from_legacy_cache(past_key_values)
234
+
235
+ if inputs_embeds is None:
236
+ inputs_embeds = self.embeddings(input_ids)
237
+
238
+ # embed positions
239
+ hidden_states = inputs_embeds
240
+
241
+ if self.gradient_checkpointing and self.training:
242
+ if use_cache:
243
+ logger.warning_once(
244
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
245
+ )
246
+ use_cache = False
247
+
248
+ all_hidden_states = () if output_hidden_states else None
249
+ all_attns = () if output_attentions else None
250
+ next_cache = None
251
+
252
+ for layer in self.layers:
253
+ if output_hidden_states:
254
+ all_hidden_states += (hidden_states,)
255
+
256
+ if self.gradient_checkpointing and self.training:
257
+ layer_outputs = self._gradient_checkpointing_func(
258
+ layer.__call__,
259
+ hidden_states,
260
+ attention_mask,
261
+ past_key_values,
262
+ output_attentions,
263
+ use_cache
264
+ )
265
+ else:
266
+ layer_outputs = layer(
267
+ hidden_states,
268
+ attention_mask=attention_mask,
269
+ past_key_values=past_key_values,
270
+ output_attentions=output_attentions,
271
+ use_cache=use_cache
272
+ )
273
+
274
+ hidden_states = layer_outputs[0]
275
+
276
+ if use_cache:
277
+ next_cache = layer_outputs[2 if output_attentions else 1]
278
+
279
+ if output_attentions:
280
+ all_attns += (layer_outputs[1],)
281
+
282
+ hidden_states = self.norm(hidden_states)
283
+
284
+ # add hidden states from the last decoder layer
285
+ if output_hidden_states:
286
+ all_hidden_states += (hidden_states,)
287
+
288
+ if not return_dict:
289
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attns] if v is not None)
290
+
291
+ return BaseModelOutputWithPast(
292
+ last_hidden_state=hidden_states,
293
+ past_key_values=next_cache,
294
+ hidden_states=all_hidden_states,
295
+ attentions=all_attns
296
+ )
297
+
298
+
299
+ class BitNetForCausalLM(BitNetPreTrainedModel, GenerationMixin):
300
+
301
+ _tied_weights_keys = ["lm_head.weight"]
302
+
303
+ def __init__(self, config):
304
+ super().__init__(config)
305
+ self.model = BitNetModel(config)
306
+ self.vocab_size = config.vocab_size
307
+ self.lm_head = BitLinear(config.hidden_size, config.vocab_size, bias=False)
308
+
309
+ # Initialize weights and apply final processing
310
+ self.post_init()
311
+
312
+ def get_input_embeddings(self):
313
+ return self.model.embeddings
314
+
315
+ def set_input_embeddings(self, value):
316
+ self.model.embeddings = value
317
+
318
+ def get_output_embeddings(self):
319
+ return self.lm_head
320
+
321
+ def set_output_embeddings(self, new_embeddings):
322
+ self.lm_head = new_embeddings
323
+
324
+ def set_decoder(self, decoder):
325
+ self.model = decoder
326
+
327
+ def get_decoder(self):
328
+ return self.model
329
+
330
+ def prepare_inputs_for_generation(
331
+ self,
332
+ input_ids: torch.LongTensor = None,
333
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
334
+ attention_mask: Optional[torch.Tensor] = None,
335
+ inputs_embeds: Optional[torch.Tensor] = None,
336
+ use_cache: bool = True,
337
+ num_logits_to_keep: Optional[int] = None,
338
+ **kwargs
339
+ ):
340
+ # only last token for `inputs_ids` if the `past_key_values` is passed along.
341
+ if past_key_values is not None:
342
+ input_ids = input_ids[:, -1:]
343
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
344
+ if inputs_embeds is not None and past_key_values is None:
345
+ model_inputs = {'inputs_embeds': inputs_embeds}
346
+ else:
347
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
348
+ # recompiles graphs as the stride of the inputs is a guard.
349
+ # Ref: https://github.com/huggingface/transformers/pull/29114
350
+ # TODO: use `next_tokens` directly instead.
351
+ model_inputs = {'input_ids': input_ids.contiguous()}
352
+
353
+ if num_logits_to_keep is not None:
354
+ model_inputs['num_logits_to_keep'] = num_logits_to_keep
355
+
356
+ model_inputs.update({
357
+ 'past_key_values': past_key_values,
358
+ 'use_cache': use_cache,
359
+ 'attention_mask': attention_mask,
360
+ 'num_logits_to_keep': num_logits_to_keep,
361
+ })
362
+ return model_inputs
363
+
364
+ def forward(
365
+ self,
366
+ input_ids: torch.LongTensor = None,
367
+ attention_mask: Optional[torch.Tensor] = None,
368
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
369
+ inputs_embeds: Optional[torch.FloatTensor] = None,
370
+ labels: Optional[torch.LongTensor] = None,
371
+ use_cache: Optional[bool] = None,
372
+ output_attentions: Optional[bool] = None,
373
+ output_hidden_states: Optional[bool] = None,
374
+ return_dict: Optional[bool] = None,
375
+ num_logits_to_keep: Optional[int] = 0
376
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
377
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
378
+ output_hidden_states = (
379
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
380
+ )
381
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
382
+
383
+ outputs = self.model(
384
+ input_ids=input_ids,
385
+ attention_mask=attention_mask,
386
+ past_key_values=past_key_values,
387
+ inputs_embeds=inputs_embeds,
388
+ use_cache=use_cache,
389
+ output_attentions=output_attentions,
390
+ output_hidden_states=output_hidden_states,
391
+ return_dict=return_dict
392
+ )
393
+
394
+ hidden_states = outputs[0]
395
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
396
+ logits = None if fuse_linear_and_cross_entropy else self.lm_head(hidden_states[:, -num_logits_to_keep:])
397
+
398
+ loss = None
399
+ if labels is not None:
400
+ if self.config.fuse_cross_entropy:
401
+ if fuse_linear_and_cross_entropy:
402
+ loss_fct = FusedLinearCrossEntropyLoss()
403
+ else:
404
+ loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
405
+ else:
406
+ loss_fct = nn.CrossEntropyLoss()
407
+ # Enable model parallelism
408
+ labels = labels.to(hidden_states.device)
409
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1)
410
+ if fuse_linear_and_cross_entropy:
411
+ loss = loss_fct(hidden_states.view(-1, self.config.hidden_size),
412
+ labels.view(-1),
413
+ self.lm_head.weight,
414
+ self.lm_head.bias)
415
+ else:
416
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
417
+
418
+ if not return_dict:
419
+ output = (logits,) + outputs[1:]
420
+ return (loss,) + output if loss is not None else output
421
+
422
+ return CausalLMOutputWithPast(
423
+ loss=loss,
424
+ logits=logits,
425
+ past_key_values=outputs.past_key_values,
426
+ hidden_states=outputs.hidden_states,
427
+ attentions=outputs.attentions,
428
+ )
fla/models/delta_net/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.delta_net.configuration_delta_net import DeltaNetConfig
6
+ from fla.models.delta_net.modeling_delta_net import (DeltaNetForCausalLM,
7
+ DeltaNetModel)
8
+
9
+ AutoConfig.register(DeltaNetConfig.model_type, DeltaNetConfig)
10
+ AutoModel.register(DeltaNetConfig, DeltaNetModel)
11
+ AutoModelForCausalLM.register(DeltaNetConfig, DeltaNetForCausalLM)
12
+
13
+ __all__ = ['DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel']
fla/models/delta_net/configuration_delta_net.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class DeltaNetConfig(PretrainedConfig):
9
+
10
+ model_type = 'delta_net'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ attn_mode: str = "chunk",
16
+ hidden_size: int = 2048,
17
+ expand_k: int = 1,
18
+ expand_v: int = 1,
19
+ use_gate: bool = False,
20
+ use_short_conv: bool = True,
21
+ conv_size: int = 4,
22
+ use_beta: bool = True,
23
+ use_output_norm: bool = True,
24
+ num_heads: int = 16,
25
+ qk_norm: str = 'l2',
26
+ qk_activation: str = 'silu',
27
+ max_position_embeddings: int = 2048,
28
+ hidden_ratio: Optional[int] = 4,
29
+ intermediate_size: Optional[int] = None,
30
+ hidden_act: str = "swish",
31
+ num_hidden_layers: int = 24,
32
+ norm_first: bool = False,
33
+ norm_eps: float = 1e-6,
34
+ attn: Optional[Dict] = None,
35
+ use_cache: bool = True,
36
+ pad_token_id: int = None,
37
+ bos_token_id: int = 1,
38
+ eos_token_id: int = 2,
39
+ tie_word_embeddings: bool = False,
40
+ initializer_range: float = 0.02,
41
+ fuse_cross_entropy: bool = True,
42
+ vocab_size: int = 32000,
43
+ **kwargs
44
+ ):
45
+ self.attn_mode = attn_mode
46
+ self.hidden_size = hidden_size
47
+ self.expand_k = expand_k
48
+ self.expand_v = expand_v
49
+ self.use_gate = use_gate
50
+ self.use_short_conv = use_short_conv
51
+ self.conv_size = conv_size
52
+ self.use_beta = use_beta
53
+ self.use_output_norm = use_output_norm
54
+ self.num_heads = num_heads
55
+ self.qk_norm = qk_norm
56
+ self.qk_activation = qk_activation
57
+ self.max_position_embeddings = max_position_embeddings
58
+
59
+ self.hidden_ratio = hidden_ratio
60
+ self.intermediate_size = intermediate_size
61
+ self.hidden_act = hidden_act
62
+ self.num_hidden_layers = num_hidden_layers
63
+ self.norm_first = norm_first
64
+ self.norm_eps = norm_eps
65
+ self.attn = attn
66
+ self.use_cache = use_cache
67
+ self.initializer_range = initializer_range
68
+ self.fuse_cross_entropy = fuse_cross_entropy
69
+ self.vocab_size = vocab_size
70
+
71
+ if attn is not None:
72
+ if not isinstance(attn, Dict):
73
+ raise ValueError("attn must be a dictionary")
74
+ if 'layers' not in attn:
75
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
76
+ if 'num_heads' not in attn:
77
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
78
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
79
+ attn['window_size'] = attn.get('window_size', None)
80
+
81
+ super().__init__(
82
+ pad_token_id=pad_token_id,
83
+ bos_token_id=bos_token_id,
84
+ eos_token_id=eos_token_id,
85
+ tie_word_embeddings=tie_word_embeddings,
86
+ **kwargs,
87
+ )
fla/models/delta_net/modeling_delta_net.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.activations import ACT2FN
13
+ from transformers.generation import GenerationMixin
14
+ from transformers.modeling_outputs import (BaseModelOutputWithPast,
15
+ CausalLMOutputWithPast)
16
+ from transformers.modeling_utils import PreTrainedModel
17
+ from transformers.utils import logging
18
+
19
+ from fla.layers.attn import Attention
20
+ from fla.layers.delta_net import DeltaNet
21
+ from fla.models.delta_net.configuration_delta_net import DeltaNetConfig
22
+ from fla.models.utils import Cache
23
+ from fla.modules import (FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss,
24
+ RMSNorm)
25
+ from fla.modules.activations import swiglu_linear
26
+ from fla.modules.layernorm import rms_norm_linear
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ class DeltaNetMLP(nn.Module):
32
+
33
+ def __init__(
34
+ self,
35
+ hidden_size: int,
36
+ hidden_ratio: Optional[int] = None,
37
+ intermediate_size: Optional[int] = None,
38
+ hidden_act: str = 'swish',
39
+ norm_first: bool = True,
40
+ norm_eps: float = 1e-5
41
+ ) -> DeltaNetMLP:
42
+ super().__init__()
43
+
44
+ self.hidden_size = hidden_size
45
+ # the final number of params is `hidden_ratio * hidden_size^2`
46
+ # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
47
+ if hidden_ratio is None:
48
+ hidden_ratio = 4
49
+ if intermediate_size is None:
50
+ intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
51
+ intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
52
+ self.hidden_ratio = hidden_ratio
53
+ self.intermediate_size = intermediate_size
54
+ self.norm_first = norm_first
55
+
56
+ if norm_first:
57
+ self.norm = RMSNorm(hidden_size=hidden_size, eps=norm_eps)
58
+
59
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
60
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
61
+ self.act_fn = ACT2FN[hidden_act]
62
+
63
+ def forward(self, x):
64
+ if self.norm_first:
65
+ x = rms_norm_linear(x, self.norm.weight, self.norm.bias, self.gate_proj.weight, self.gate_proj.bias)
66
+ else:
67
+ x = self.gate_proj(x)
68
+ gate, y = x.chunk(2, -1)
69
+ return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
70
+
71
+
72
+ class DeltaNetBlock(nn.Module):
73
+ def __init__(self, config: DeltaNetConfig, layer_idx: int):
74
+ super().__init__()
75
+ self.hidden_size = config.hidden_size
76
+
77
+ if not config.norm_first:
78
+ self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
79
+ if config.attn is not None and layer_idx in config.attn['layers']:
80
+ self.attn = Attention(
81
+ hidden_size=config.hidden_size,
82
+ num_heads=config.attn['num_heads'],
83
+ num_kv_heads=config.attn['num_kv_heads'],
84
+ window_size=config.attn['window_size'],
85
+ max_position_embeddings=config.max_position_embeddings,
86
+ layer_idx=layer_idx
87
+ )
88
+ else:
89
+ self.attn = DeltaNet(
90
+ mode=config.attn_mode,
91
+ hidden_size=config.hidden_size,
92
+ expand_k=config.expand_k,
93
+ expand_v=config.expand_v,
94
+ num_heads=config.num_heads,
95
+ use_gate=config.use_gate,
96
+ use_beta=config.use_beta,
97
+ use_short_conv=config.use_short_conv,
98
+ use_output_norm=config.use_output_norm,
99
+ conv_size=config.conv_size,
100
+ qk_norm=config.qk_norm,
101
+ qk_activation=config.qk_activation,
102
+ norm_first=config.norm_first,
103
+ norm_eps=config.norm_eps,
104
+ layer_idx=layer_idx
105
+ )
106
+ if not config.norm_first:
107
+ self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
108
+ self.mlp = DeltaNetMLP(
109
+ hidden_size=config.hidden_size,
110
+ hidden_ratio=config.hidden_ratio,
111
+ intermediate_size=config.intermediate_size,
112
+ hidden_act=config.hidden_act,
113
+ norm_first=config.norm_first,
114
+ norm_eps=config.norm_eps
115
+ )
116
+
117
+ def forward(
118
+ self,
119
+ hidden_states: torch.Tensor,
120
+ attention_mask: Optional[torch.Tensor] = None,
121
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
122
+ use_cache: Optional[bool] = False,
123
+ output_attentions: Optional[bool] = False,
124
+ **kwargs
125
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
126
+
127
+ residual = hidden_states
128
+ if hasattr(self, 'attn_norm'):
129
+ hidden_states = self.attn_norm(hidden_states)
130
+ hidden_states, attentions, past_key_values = self.attn(
131
+ hidden_states=hidden_states,
132
+ attention_mask=attention_mask,
133
+ past_key_values=past_key_values,
134
+ use_cache=use_cache,
135
+ output_attentions=output_attentions
136
+ )
137
+ if hasattr(self, 'mlp_norm'):
138
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
139
+ else:
140
+ hidden_states = residual + hidden_states
141
+ residual = hidden_states
142
+ hidden_states = self.mlp(hidden_states)
143
+ hidden_states = residual + hidden_states
144
+
145
+ outputs = (hidden_states, attentions, past_key_values)
146
+
147
+ return outputs
148
+
149
+
150
+ class DeltaNetPreTrainedModel(PreTrainedModel):
151
+
152
+ config_class = DeltaNetConfig
153
+ supports_gradient_checkpointing = True
154
+ _no_split_modules = ['DeltaNetBlock']
155
+
156
+ def __init__(self, *inputs, **kwargs):
157
+ super().__init__(*inputs, **kwargs)
158
+
159
+ def _init_weights(
160
+ self,
161
+ module: nn.Module,
162
+ rescale_prenorm_residual: bool = True,
163
+ num_residuals_per_layer: int = 2,
164
+ ):
165
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
166
+ # Slightly different from the TF version which uses truncated_normal for initialization
167
+ # cf https://github.com/pytorch/pytorch/pull/5617
168
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
169
+ if module.bias is not None:
170
+ nn.init.zeros_(module.bias)
171
+ elif isinstance(module, nn.Embedding):
172
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
173
+ if module.padding_idx is not None:
174
+ module.weight.data[module.padding_idx].zero_()
175
+
176
+ if rescale_prenorm_residual:
177
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
178
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
179
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
180
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
181
+ #
182
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
183
+ for name, p in module.named_parameters():
184
+ if name in ["o_proj.weight", "down_proj.weight"]:
185
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
186
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
187
+ # We need to reinit p since this code could be called multiple times
188
+ # Having just p *= scale would repeatedly scale it down
189
+ with torch.no_grad():
190
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
191
+
192
+
193
+ class DeltaNetModel(DeltaNetPreTrainedModel):
194
+
195
+ def __init__(self, config: DeltaNetConfig):
196
+ super().__init__(config)
197
+ self.padding_idx = config.pad_token_id
198
+ self.vocab_size = config.vocab_size
199
+
200
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
201
+ self.layers = nn.ModuleList([DeltaNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
202
+ self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
203
+
204
+ self.gradient_checkpointing = False
205
+
206
+ self.post_init()
207
+
208
+ def get_input_embeddings(self):
209
+ return self.embeddings
210
+
211
+ def set_input_embeddings(self, value):
212
+ self.embeddings = value
213
+
214
+ def forward(
215
+ self,
216
+ input_ids: Optional[torch.LongTensor] = None,
217
+ attention_mask: Optional[torch.Tensor] = None, # noqa
218
+ inputs_embeds: Optional[torch.FloatTensor] = None,
219
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
220
+ use_cache: Optional[bool] = None,
221
+ output_attentions: Optional[bool] = None,
222
+ output_hidden_states: Optional[bool] = None,
223
+ return_dict: Optional[bool] = None
224
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
225
+ if output_attentions:
226
+ warnings.warn("`DeltaNetModel` does not `output_attentions` now, setting it to `False`.")
227
+ output_attentions = False
228
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
229
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
230
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
231
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
232
+
233
+ # retrieve input_ids and inputs_embeds
234
+ if input_ids is not None and inputs_embeds is not None:
235
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
236
+ if input_ids is None and inputs_embeds is None:
237
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
238
+
239
+ if inputs_embeds is None:
240
+ inputs_embeds = self.embeddings(input_ids)
241
+ hidden_states = inputs_embeds
242
+
243
+ if use_cache and not isinstance(past_key_values, Cache):
244
+ past_key_values = Cache.from_legacy_cache(past_key_values)
245
+
246
+ if self.gradient_checkpointing and self.training and use_cache:
247
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
248
+ use_cache = False
249
+
250
+ all_hidden_states = () if output_hidden_states else None
251
+ all_attns = () if output_attentions else None
252
+ for layer in self.layers:
253
+ if output_hidden_states:
254
+ all_hidden_states += (hidden_states,)
255
+
256
+ if self.gradient_checkpointing and self.training:
257
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
258
+ layer.__call__,
259
+ hidden_states,
260
+ attention_mask,
261
+ past_key_values,
262
+ use_cache,
263
+ output_attentions
264
+ )
265
+ else:
266
+ hidden_states, attentions, past_key_values = layer(
267
+ hidden_states,
268
+ attention_mask=attention_mask,
269
+ past_key_values=past_key_values,
270
+ use_cache=use_cache,
271
+ output_attentions=output_attentions
272
+ )
273
+
274
+ if output_attentions:
275
+ all_attns += (attentions,)
276
+
277
+ hidden_states = self.norm(hidden_states)
278
+
279
+ # add hidden states from the last decoder layer
280
+ if output_hidden_states:
281
+ all_hidden_states += (hidden_states,)
282
+
283
+ next_cache = past_key_values
284
+ if not return_dict:
285
+ return tuple(x for x in [hidden_states, next_cache, all_hidden_states, all_attns] if x is not None)
286
+ return BaseModelOutputWithPast(
287
+ last_hidden_state=hidden_states,
288
+ past_key_values=next_cache,
289
+ hidden_states=all_hidden_states,
290
+ attentions=all_attns
291
+ )
292
+
293
+
294
+ class DeltaNetForCausalLM(DeltaNetPreTrainedModel, GenerationMixin):
295
+
296
+ _tied_weights_keys = ["lm_head.weight"]
297
+
298
+ def __init__(self, config):
299
+ super().__init__(config)
300
+ self.model = DeltaNetModel(config)
301
+ self.vocab_size = config.vocab_size
302
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
303
+
304
+ # Initialize weights and apply final processing
305
+ self.post_init()
306
+
307
+ def get_input_embeddings(self):
308
+ return self.model.embeddings
309
+
310
+ def set_input_embeddings(self, value):
311
+ self.model.embeddings = value
312
+
313
+ def get_output_embeddings(self):
314
+ return self.lm_head
315
+
316
+ def set_output_embeddings(self, new_embeddings):
317
+ self.lm_head = new_embeddings
318
+
319
+ def set_decoder(self, decoder):
320
+ self.model = decoder
321
+
322
+ def get_decoder(self):
323
+ return self.model
324
+
325
+ def generate(self, *args, **kwargs):
326
+ try:
327
+ return super().generate(*args, **kwargs)
328
+ except AttributeError as exception:
329
+ if 'past_key_values' in str(exception):
330
+ raise AttributeError(
331
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
332
+ f"which is not supported for {self.__class__.__name__}. "
333
+ f"Try another generation strategy instead. "
334
+ f"For the available generation strategies, check this doc: "
335
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
336
+ )
337
+ else:
338
+ raise exception
339
+
340
+ def prepare_inputs_for_generation(
341
+ self,
342
+ input_ids: torch.LongTensor = None,
343
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
344
+ attention_mask: Optional[torch.Tensor] = None,
345
+ inputs_embeds: Optional[torch.Tensor] = None,
346
+ use_cache: bool = True,
347
+ num_logits_to_keep: Optional[int] = None,
348
+ **kwargs
349
+ ):
350
+ # only last token for `inputs_ids` if the `past_key_values` is passed along.
351
+ if past_key_values is not None:
352
+ input_ids = input_ids[:, -1:]
353
+
354
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
355
+ if inputs_embeds is not None and past_key_values is None:
356
+ model_inputs = {'inputs_embeds': inputs_embeds}
357
+ else:
358
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
359
+ # recompiles graphs as the stride of the inputs is a guard.
360
+ # Ref: https://github.com/huggingface/transformers/pull/29114
361
+ # TODO: use `next_tokens` directly instead.
362
+ model_inputs = {'input_ids': input_ids.contiguous()}
363
+
364
+ if num_logits_to_keep is not None:
365
+ model_inputs['num_logits_to_keep'] = num_logits_to_keep
366
+
367
+ model_inputs.update({
368
+ 'past_key_values': past_key_values,
369
+ 'use_cache': use_cache,
370
+ 'attention_mask': attention_mask,
371
+ 'num_logits_to_keep': num_logits_to_keep,
372
+ })
373
+ return model_inputs
374
+
375
+ def forward(
376
+ self,
377
+ input_ids: torch.LongTensor = None,
378
+ attention_mask: Optional[torch.Tensor] = None,
379
+ inputs_embeds: Optional[torch.Tensor] = None,
380
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
381
+ labels: Optional[torch.LongTensor] = None,
382
+ use_cache: Optional[bool] = None,
383
+ output_attentions: Optional[bool] = None,
384
+ output_hidden_states: Optional[bool] = None,
385
+ return_dict: Optional[bool] = None,
386
+ num_logits_to_keep: Optional[int] = 0
387
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
388
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
389
+ output_hidden_states = (
390
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
391
+ )
392
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
393
+
394
+ outputs = self.model(
395
+ input_ids=input_ids,
396
+ attention_mask=attention_mask,
397
+ inputs_embeds=inputs_embeds,
398
+ past_key_values=past_key_values,
399
+ use_cache=use_cache,
400
+ output_attentions=output_attentions,
401
+ output_hidden_states=output_hidden_states,
402
+ return_dict=return_dict
403
+ )
404
+
405
+ hidden_states = outputs[0]
406
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
407
+ logits = None if fuse_linear_and_cross_entropy else self.lm_head(hidden_states[:, -num_logits_to_keep:])
408
+
409
+ loss = None
410
+ if labels is not None:
411
+ if self.config.fuse_cross_entropy:
412
+ if fuse_linear_and_cross_entropy:
413
+ loss_fct = FusedLinearCrossEntropyLoss()
414
+ else:
415
+ loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
416
+ else:
417
+ loss_fct = nn.CrossEntropyLoss()
418
+ # Enable model parallelism
419
+ labels = labels.to(hidden_states.device)
420
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1)
421
+ if fuse_linear_and_cross_entropy:
422
+ loss = loss_fct(hidden_states.view(-1, self.config.hidden_size),
423
+ labels.view(-1),
424
+ self.lm_head.weight,
425
+ self.lm_head.bias)
426
+ else:
427
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
428
+
429
+ if not return_dict:
430
+ output = (logits,) + outputs[1:]
431
+ return (loss,) + output if loss is not None else output
432
+
433
+ return CausalLMOutputWithPast(
434
+ loss=loss,
435
+ logits=logits,
436
+ past_key_values=outputs.past_key_values,
437
+ hidden_states=outputs.hidden_states,
438
+ attentions=outputs.attentions,
439
+ )
fla/models/gla/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.gla.configuration_gla import GLAConfig
6
+ from fla.models.gla.modeling_gla import GLAForCausalLM, GLAModel
7
+
8
+ AutoConfig.register(GLAConfig.model_type, GLAConfig)
9
+ AutoModel.register(GLAConfig, GLAModel)
10
+ AutoModelForCausalLM.register(GLAConfig, GLAForCausalLM)
11
+
12
+
13
+ __all__ = ['GLAConfig', 'GLAForCausalLM', 'GLAModel']
fla/models/gla/configuration_gla.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class GLAConfig(PretrainedConfig):
9
+
10
+ model_type = 'gla'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ expand_k: int = 0.5,
17
+ expand_v: int = 1,
18
+ hidden_ratio: Optional[int] = 4,
19
+ intermediate_size: Optional[int] = None,
20
+ num_hidden_layers: int = 24,
21
+ num_heads: int = 4,
22
+ num_kv_heads: Optional[int] = None,
23
+ feature_map: Optional[str] = None,
24
+ attn_mode: str = "chunk",
25
+ use_short_conv: bool = False,
26
+ conv_size: int = 4,
27
+ use_output_gate: bool = True,
28
+ clamp_min: Optional[float] = None,
29
+ hidden_act: str = "swish",
30
+ max_position_embeddings: int = 2048,
31
+ elementwise_affine: Optional[bool] = True,
32
+ norm_eps: float = 1e-6,
33
+ use_gk: bool = True,
34
+ use_gv: bool = False,
35
+ attn: Optional[Dict] = None,
36
+ use_cache: bool = True,
37
+ pad_token_id: int = None,
38
+ bos_token_id: int = 1,
39
+ eos_token_id: int = 2,
40
+ tie_word_embeddings: bool = False,
41
+ initializer_range: float = 0.02,
42
+ fuse_norm: bool = True,
43
+ fuse_cross_entropy: bool = True,
44
+ vocab_size: int = 32000,
45
+ **kwargs
46
+ ):
47
+ self.hidden_size = hidden_size
48
+ self.expand_k = expand_k
49
+ self.expand_v = expand_v
50
+ self.hidden_ratio = hidden_ratio
51
+ self.intermediate_size = intermediate_size
52
+ self.num_hidden_layers = num_hidden_layers
53
+ self.num_heads = num_heads
54
+ self.num_kv_heads = num_kv_heads
55
+ self.feature_map = feature_map
56
+ self.attn_mode = attn_mode
57
+ self.use_short_conv = use_short_conv
58
+ self.conv_size = conv_size
59
+ self.use_output_gate = use_output_gate
60
+ self.clamp_min = clamp_min
61
+ self.hidden_act = hidden_act
62
+ self.max_position_embeddings = max_position_embeddings
63
+ self.elementwise_affine = elementwise_affine
64
+ self.norm_eps = norm_eps
65
+ self.use_gk = use_gk
66
+ self.use_gv = use_gv
67
+ self.attn = attn
68
+ self.use_cache = use_cache
69
+ self.initializer_range = initializer_range
70
+ self.fuse_norm = fuse_norm
71
+ self.fuse_cross_entropy = fuse_cross_entropy
72
+ self.vocab_size = vocab_size
73
+
74
+ if attn is not None:
75
+ if not isinstance(attn, Dict):
76
+ raise ValueError("attn must be a dictionary")
77
+ if 'layers' not in attn:
78
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
79
+ if 'num_heads' not in attn:
80
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
81
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
82
+ attn['window_size'] = attn.get('window_size', None)
83
+
84
+ super().__init__(
85
+ pad_token_id=pad_token_id,
86
+ bos_token_id=bos_token_id,
87
+ eos_token_id=eos_token_id,
88
+ tie_word_embeddings=tie_word_embeddings,
89
+ **kwargs,
90
+ )
fla/models/gla/modeling_gla.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.activations import ACT2FN
13
+ from transformers.generation import GenerationMixin
14
+ from transformers.modeling_outputs import (BaseModelOutputWithPast,
15
+ CausalLMOutputWithPast)
16
+ from transformers.modeling_utils import PreTrainedModel
17
+ from transformers.utils import logging
18
+
19
+ from fla.layers.attn import Attention
20
+ from fla.layers.gla import GatedLinearAttention
21
+ from fla.models.gla.configuration_gla import GLAConfig
22
+ from fla.models.utils import Cache
23
+ from fla.modules import (FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss,
24
+ RMSNorm)
25
+ from fla.modules.activations import swiglu_linear
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ class GLAMLP(nn.Module):
31
+
32
+ def __init__(
33
+ self,
34
+ hidden_size: int,
35
+ hidden_ratio: Optional[int] = None,
36
+ intermediate_size: Optional[int] = None,
37
+ hidden_act: str = 'swish'
38
+ ) -> GLAMLP:
39
+ super().__init__()
40
+
41
+ self.hidden_size = hidden_size
42
+ # the final number of params is `hidden_ratio * hidden_size^2`
43
+ # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
44
+ if hidden_ratio is None:
45
+ hidden_ratio = 4
46
+ if intermediate_size is None:
47
+ intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
48
+ intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
49
+ self.hidden_ratio = hidden_ratio
50
+ self.intermediate_size = intermediate_size
51
+
52
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
53
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
54
+ self.act_fn = ACT2FN[hidden_act]
55
+
56
+ def forward(self, x):
57
+ y = self.gate_proj(x)
58
+ gate, y = y.chunk(2, -1)
59
+ return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
60
+
61
+
62
+ class GLABlock(nn.Module):
63
+ def __init__(self, config: GLAConfig, layer_idx: int):
64
+ super().__init__()
65
+ self.hidden_size = config.hidden_size
66
+
67
+ self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
68
+ if config.attn is not None and layer_idx in config.attn['layers']:
69
+ self.attn = Attention(
70
+ hidden_size=config.hidden_size,
71
+ num_heads=config.attn['num_heads'],
72
+ num_kv_heads=config.attn['num_kv_heads'],
73
+ window_size=config.attn['window_size'],
74
+ max_position_embeddings=config.max_position_embeddings,
75
+ layer_idx=layer_idx
76
+ )
77
+ else:
78
+ self.attn = GatedLinearAttention(
79
+ mode=config.attn_mode,
80
+ hidden_size=config.hidden_size,
81
+ expand_k=config.expand_k,
82
+ expand_v=config.expand_v,
83
+ num_heads=config.num_heads,
84
+ num_kv_heads=config.num_kv_heads,
85
+ feature_map=config.feature_map,
86
+ use_short_conv=config.use_short_conv,
87
+ conv_size=config.conv_size,
88
+ use_output_gate=config.use_output_gate,
89
+ gate_fn=config.hidden_act,
90
+ elementwise_affine=config.elementwise_affine,
91
+ norm_eps=config.norm_eps,
92
+ clamp_min=config.clamp_min,
93
+ fuse_norm=config.fuse_norm,
94
+ layer_idx=layer_idx
95
+ )
96
+ self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
97
+ self.mlp = GLAMLP(
98
+ hidden_size=config.hidden_size,
99
+ hidden_ratio=config.hidden_ratio,
100
+ intermediate_size=config.intermediate_size,
101
+ hidden_act=config.hidden_act
102
+ )
103
+
104
+ def forward(
105
+ self,
106
+ hidden_states: torch.Tensor,
107
+ attention_mask: Optional[torch.Tensor] = None,
108
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
109
+ use_cache: Optional[bool] = False,
110
+ output_attentions: Optional[bool] = False,
111
+ **kwargs,
112
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
113
+ residual = hidden_states
114
+ hidden_states = self.attn_norm(hidden_states)
115
+ hidden_states, attentions, past_key_values = self.attn(
116
+ hidden_states=hidden_states,
117
+ attention_mask=attention_mask,
118
+ past_key_values=past_key_values,
119
+ use_cache=use_cache,
120
+ output_attentions=output_attentions
121
+ )
122
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
123
+ hidden_states = self.mlp(hidden_states)
124
+ hidden_states = residual + hidden_states
125
+
126
+ outputs = (hidden_states, attentions, past_key_values)
127
+
128
+ return outputs
129
+
130
+
131
+ class GLAPreTrainedModel(PreTrainedModel):
132
+
133
+ config_class = GLAConfig
134
+ supports_gradient_checkpointing = True
135
+ _no_split_modules = ['GLABlock']
136
+
137
+ def __init__(self, *inputs, **kwargs):
138
+ super().__init__(*inputs, **kwargs)
139
+
140
+ def _init_weights(
141
+ self,
142
+ module: nn.Module,
143
+ rescale_prenorm_residual: bool = True,
144
+ num_residuals_per_layer: int = 2,
145
+ ):
146
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
147
+ # Slightly different from the TF version which uses truncated_normal for initialization
148
+ # cf https://github.com/pytorch/pytorch/pull/5617
149
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
150
+ if module.bias is not None:
151
+ nn.init.zeros_(module.bias)
152
+ elif isinstance(module, nn.Embedding):
153
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
154
+ if module.padding_idx is not None:
155
+ module.weight.data[module.padding_idx].zero_()
156
+
157
+ if rescale_prenorm_residual:
158
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
159
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
160
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
161
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
162
+ #
163
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
164
+ for name, p in module.named_parameters():
165
+ if name in ["o_proj.weight", "down_proj.weight"]:
166
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
167
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
168
+ # We need to reinit p since this code could be called multiple times
169
+ # Having just p *= scale would repeatedly scale it down
170
+ with torch.no_grad():
171
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
172
+
173
+
174
+ class GLAModel(GLAPreTrainedModel):
175
+
176
+ def __init__(self, config: GLAConfig):
177
+ super().__init__(config)
178
+ self.padding_idx = config.pad_token_id
179
+ self.vocab_size = config.vocab_size
180
+
181
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
182
+ self.layers = nn.ModuleList([GLABlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
183
+ self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
184
+
185
+ self.gradient_checkpointing = False
186
+
187
+ self.post_init()
188
+
189
+ def get_input_embeddings(self):
190
+ return self.embeddings
191
+
192
+ def set_input_embeddings(self, value):
193
+ self.embeddings = value
194
+
195
+ def forward(
196
+ self,
197
+ input_ids: Optional[torch.LongTensor] = None,
198
+ attention_mask: Optional[torch.Tensor] = None, # noqa
199
+ inputs_embeds: Optional[torch.FloatTensor] = None,
200
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
201
+ use_cache: Optional[bool] = None,
202
+ output_attentions: Optional[bool] = None,
203
+ output_hidden_states: Optional[bool] = None,
204
+ return_dict: Optional[bool] = None
205
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
206
+ if output_attentions:
207
+ warnings.warn("`GLAModel` does not `output_attentions` now, setting it to `False`.")
208
+ output_attentions = False
209
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
210
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
211
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
212
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
213
+
214
+ # retrieve input_ids and inputs_embeds
215
+ if input_ids is not None and inputs_embeds is not None:
216
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
217
+ if input_ids is None and inputs_embeds is None:
218
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
219
+
220
+ if inputs_embeds is None:
221
+ inputs_embeds = self.embeddings(input_ids)
222
+ hidden_states = inputs_embeds
223
+
224
+ if use_cache and not isinstance(past_key_values, Cache):
225
+ past_key_values = Cache.from_legacy_cache(past_key_values)
226
+
227
+ if self.gradient_checkpointing and self.training and use_cache:
228
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
229
+ use_cache = False
230
+
231
+ all_hidden_states = () if output_hidden_states else None
232
+ all_attns = () if output_attentions else None
233
+ for layer in self.layers:
234
+ if output_hidden_states:
235
+ all_hidden_states += (hidden_states,)
236
+
237
+ if self.gradient_checkpointing and self.training:
238
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
239
+ layer.__call__,
240
+ hidden_states,
241
+ attention_mask,
242
+ past_key_values,
243
+ use_cache,
244
+ output_attentions
245
+ )
246
+ else:
247
+ hidden_states, attentions, past_key_values = layer(
248
+ hidden_states,
249
+ attention_mask=attention_mask,
250
+ past_key_values=past_key_values,
251
+ use_cache=use_cache,
252
+ output_attentions=output_attentions
253
+ )
254
+
255
+ if output_attentions:
256
+ all_attns += (attentions,)
257
+
258
+ hidden_states = self.norm(hidden_states)
259
+
260
+ # add hidden states from the last decoder layer
261
+ if output_hidden_states:
262
+ all_hidden_states += (hidden_states,)
263
+
264
+ if not return_dict:
265
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
266
+ return BaseModelOutputWithPast(
267
+ last_hidden_state=hidden_states,
268
+ past_key_values=past_key_values,
269
+ hidden_states=all_hidden_states,
270
+ attentions=all_attns
271
+ )
272
+
273
+
274
+ class GLAForCausalLM(GLAPreTrainedModel, GenerationMixin):
275
+
276
+ _tied_weights_keys = ["lm_head.weight"]
277
+
278
+ def __init__(self, config):
279
+ super().__init__(config)
280
+ self.model = GLAModel(config)
281
+ self.vocab_size = config.vocab_size
282
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
283
+
284
+ # Initialize weights and apply final processing
285
+ self.post_init()
286
+
287
+ def get_input_embeddings(self):
288
+ return self.model.embeddings
289
+
290
+ def set_input_embeddings(self, value):
291
+ self.model.embeddings = value
292
+
293
+ def get_output_embeddings(self):
294
+ return self.lm_head
295
+
296
+ def set_output_embeddings(self, new_embeddings):
297
+ self.lm_head = new_embeddings
298
+
299
+ def set_decoder(self, decoder):
300
+ self.model = decoder
301
+
302
+ def get_decoder(self):
303
+ return self.model
304
+
305
+ def generate(self, *args, **kwargs):
306
+ try:
307
+ return super().generate(*args, **kwargs)
308
+ except AttributeError as exception:
309
+ if 'past_key_values' in str(exception):
310
+ raise AttributeError(
311
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
312
+ f"which is not supported for {self.__class__.__name__}. "
313
+ f"Try another generation strategy instead. "
314
+ f"For the available generation strategies, check this doc: "
315
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
316
+ )
317
+ else:
318
+ raise exception
319
+
320
+ def prepare_inputs_for_generation(
321
+ self,
322
+ input_ids: torch.LongTensor = None,
323
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
324
+ attention_mask: Optional[torch.Tensor] = None,
325
+ inputs_embeds: Optional[torch.Tensor] = None,
326
+ use_cache: bool = True,
327
+ num_logits_to_keep: Optional[int] = None,
328
+ **kwargs
329
+ ):
330
+ # only last token for `inputs_ids` if the `past_key_values` is passed along.
331
+ if past_key_values is not None:
332
+ input_ids = input_ids[:, -1:]
333
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
334
+ if inputs_embeds is not None and past_key_values is None:
335
+ model_inputs = {'inputs_embeds': inputs_embeds}
336
+ else:
337
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
338
+ # recompiles graphs as the stride of the inputs is a guard.
339
+ # Ref: https://github.com/huggingface/transformers/pull/29114
340
+ # TODO: use `next_tokens` directly instead.
341
+ model_inputs = {'input_ids': input_ids.contiguous()}
342
+
343
+ if num_logits_to_keep is not None:
344
+ model_inputs['num_logits_to_keep'] = num_logits_to_keep
345
+
346
+ model_inputs.update({
347
+ 'past_key_values': past_key_values,
348
+ 'use_cache': use_cache,
349
+ 'attention_mask': attention_mask,
350
+ 'num_logits_to_keep': num_logits_to_keep,
351
+ })
352
+ return model_inputs
353
+
354
+ def forward(
355
+ self,
356
+ input_ids: torch.LongTensor = None,
357
+ attention_mask: Optional[torch.Tensor] = None,
358
+ inputs_embeds: Optional[torch.Tensor] = None,
359
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
360
+ labels: Optional[torch.LongTensor] = None,
361
+ use_cache: Optional[bool] = None,
362
+ output_attentions: Optional[bool] = None,
363
+ output_hidden_states: Optional[bool] = None,
364
+ return_dict: Optional[bool] = None,
365
+ num_logits_to_keep: Optional[int] = 0
366
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
367
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
368
+ output_hidden_states = (
369
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
370
+ )
371
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
372
+
373
+ outputs = self.model(
374
+ input_ids=input_ids,
375
+ attention_mask=attention_mask,
376
+ inputs_embeds=inputs_embeds,
377
+ past_key_values=past_key_values,
378
+ use_cache=use_cache,
379
+ output_attentions=output_attentions,
380
+ output_hidden_states=output_hidden_states,
381
+ return_dict=return_dict
382
+ )
383
+
384
+ hidden_states = outputs[0]
385
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
386
+ logits = None if fuse_linear_and_cross_entropy else self.lm_head(hidden_states[:, -num_logits_to_keep:])
387
+
388
+ loss = None
389
+ if labels is not None:
390
+ if self.config.fuse_cross_entropy:
391
+ if fuse_linear_and_cross_entropy:
392
+ loss_fct = FusedLinearCrossEntropyLoss()
393
+ else:
394
+ loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
395
+ else:
396
+ loss_fct = nn.CrossEntropyLoss()
397
+ # Enable model parallelism
398
+ labels = labels.to(hidden_states.device)
399
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1)
400
+ if fuse_linear_and_cross_entropy:
401
+ loss = loss_fct(hidden_states.view(-1, self.config.hidden_size),
402
+ labels.view(-1),
403
+ self.lm_head.weight,
404
+ self.lm_head.bias)
405
+ else:
406
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
407
+
408
+ if not return_dict:
409
+ output = (logits,) + outputs[1:]
410
+ return (loss,) + output if loss is not None else output
411
+
412
+ return CausalLMOutputWithPast(
413
+ loss=loss,
414
+ logits=logits,
415
+ past_key_values=outputs.past_key_values,
416
+ hidden_states=outputs.hidden_states,
417
+ attentions=outputs.attentions,
418
+ )
fla/models/gsa/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.gsa.configuration_gsa import GSAConfig
6
+ from fla.models.gsa.modeling_gsa import GSAForCausalLM, GSAModel
7
+
8
+ AutoConfig.register(GSAConfig.model_type, GSAConfig)
9
+ AutoModel.register(GSAConfig, GSAModel)
10
+ AutoModelForCausalLM.register(GSAConfig, GSAForCausalLM)
11
+
12
+
13
+ __all__ = ['GSAConfig', 'GSAForCausalLM', 'GSAModel']
fla/models/gsa/configuration_gsa.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class GSAConfig(PretrainedConfig):
9
+
10
+ model_type = 'gsa'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ gate_logit_normalizer: Optional[int] = 8,
17
+ clamp_min: Optional[float] = None,
18
+ clamp_max: Optional[float] = None,
19
+ hidden_ratio: Optional[int] = 4,
20
+ intermediate_size: Optional[int] = None,
21
+ num_hidden_layers: int = 24,
22
+ num_heads: int = 4,
23
+ num_kv_heads: Optional[int] = None,
24
+ num_slots: Optional[int] = 64,
25
+ use_short_conv: bool = False,
26
+ conv_size: int = 4,
27
+ exapnd_k: float = 1,
28
+ exapnd_v: float = 1,
29
+ feature_map: str = 'swish',
30
+ use_output_gate: bool = False,
31
+ use_norm: bool = True,
32
+ max_position_embeddings: int = 2048,
33
+ hidden_act: str = "swish",
34
+ elementwise_affine: Optional[bool] = True,
35
+ norm_first: bool = True,
36
+ norm_eps: float = 1e-6,
37
+ attn: Optional[Dict] = None,
38
+ use_cache: bool = True,
39
+ pad_token_id: int = None,
40
+ bos_token_id: int = 1,
41
+ eos_token_id: int = 2,
42
+ initializer_range: float = 0.02,
43
+ tie_word_embeddings: bool = False,
44
+ fuse_norm: bool = True,
45
+ fuse_cross_entropy: bool = True,
46
+ vocab_size: int = 32000,
47
+ **kwargs
48
+ ):
49
+ self.hidden_size = hidden_size
50
+ self.gate_logit_normalizer = gate_logit_normalizer
51
+ self.clamp_min = clamp_min
52
+ self.clamp_max = clamp_max
53
+ self.hidden_ratio = hidden_ratio
54
+ self.intermediate_size = intermediate_size
55
+ self.num_hidden_layers = num_hidden_layers
56
+ self.num_heads = num_heads
57
+ self.num_kv_heads = num_kv_heads
58
+ self.num_slots = num_slots
59
+ self.use_short_conv = use_short_conv
60
+ self.conv_size = conv_size
61
+ self.expand_k = exapnd_k
62
+ self.expand_v = exapnd_v
63
+ self.feature_map = feature_map
64
+ self.use_output_gate = use_output_gate
65
+ self.use_norm = use_norm
66
+ self.max_position_embeddings = max_position_embeddings
67
+ self.hidden_act = hidden_act
68
+ self.elementwise_affine = elementwise_affine
69
+ self.norm_first = norm_first
70
+ self.norm_eps = norm_eps
71
+ self.attn = attn
72
+ self.use_cache = use_cache
73
+ self.initializer_range = initializer_range
74
+ self.fuse_cross_entropy = fuse_cross_entropy
75
+ self.fuse_norm = fuse_norm
76
+ self.vocab_size = vocab_size
77
+
78
+ if attn is not None:
79
+ if not isinstance(attn, Dict):
80
+ raise ValueError("attn must be a dictionary")
81
+ if 'layers' not in attn:
82
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
83
+ if 'num_heads' not in attn:
84
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
85
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
86
+ attn['window_size'] = attn.get('window_size', None)
87
+
88
+ super().__init__(
89
+ pad_token_id=pad_token_id,
90
+ bos_token_id=bos_token_id,
91
+ eos_token_id=eos_token_id,
92
+ tie_word_embeddings=tie_word_embeddings,
93
+ **kwargs,
94
+ )
fla/models/gsa/modeling_gsa.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.activations import ACT2FN
13
+ from transformers.generation import GenerationMixin
14
+ from transformers.modeling_outputs import (BaseModelOutputWithPast,
15
+ CausalLMOutputWithPast)
16
+ from transformers.modeling_utils import PreTrainedModel
17
+ from transformers.utils import logging
18
+
19
+ from fla.layers.attn import Attention
20
+ from fla.layers.gsa import GatedSlotAttention
21
+ from fla.models.gsa.configuration_gsa import GSAConfig
22
+ from fla.models.utils import Cache
23
+ from fla.modules import (FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss,
24
+ RMSNorm)
25
+ from fla.modules.activations import swiglu_linear
26
+ from fla.modules.layernorm import rms_norm_linear
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ class GSAMLP(nn.Module):
32
+
33
+ def __init__(
34
+ self,
35
+ hidden_size: int,
36
+ hidden_ratio: Optional[int] = None,
37
+ intermediate_size: Optional[int] = None,
38
+ hidden_act: str = 'swish',
39
+ norm_first: bool = True,
40
+ norm_eps: float = 1e-5
41
+ ) -> GSAMLP:
42
+ super().__init__()
43
+
44
+ self.hidden_size = hidden_size
45
+ # the final number of params is `hidden_ratio * hidden_size^2`
46
+ # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
47
+ if hidden_ratio is None:
48
+ hidden_ratio = 4
49
+ if intermediate_size is None:
50
+ intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
51
+ intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
52
+ self.hidden_ratio = hidden_ratio
53
+ self.intermediate_size = intermediate_size
54
+ self.norm_first = norm_first
55
+
56
+ if norm_first:
57
+ self.norm = RMSNorm(hidden_size=hidden_size, eps=norm_eps)
58
+
59
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
60
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
61
+ self.act_fn = ACT2FN[hidden_act]
62
+
63
+ def forward(self, x):
64
+ if self.norm_first:
65
+ x = rms_norm_linear(x, self.norm.weight, self.norm.bias, self.gate_proj.weight, self.gate_proj.bias)
66
+ else:
67
+ x = self.gate_proj(x)
68
+ gate, y = x.chunk(2, -1)
69
+ return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
70
+
71
+
72
+ class GSABlock(nn.Module):
73
+ def __init__(self, config: GSAConfig, layer_idx: int):
74
+ super().__init__()
75
+ self.hidden_size = config.hidden_size
76
+
77
+ if not config.norm_first:
78
+ self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
79
+ if config.attn is not None and layer_idx in config.attn['layers']:
80
+ self.attn = Attention(
81
+ hidden_size=config.hidden_size,
82
+ num_heads=config.attn['num_heads'],
83
+ num_kv_heads=config.attn['num_kv_heads'],
84
+ window_size=config.attn['window_size'],
85
+ max_position_embeddings=config.max_position_embeddings,
86
+ layer_idx=layer_idx
87
+ )
88
+ else:
89
+ self.attn = GatedSlotAttention(
90
+ hidden_size=config.hidden_size,
91
+ expand_k=config.expand_k,
92
+ expand_v=config.expand_v,
93
+ num_heads=config.num_heads,
94
+ num_kv_heads=config.num_kv_heads,
95
+ num_slots=config.num_slots,
96
+ use_short_conv=config.use_short_conv,
97
+ conv_size=config.conv_size,
98
+ feature_map=config.feature_map,
99
+ use_output_gate=config.use_output_gate,
100
+ use_norm=config.use_norm,
101
+ gate_fn=config.hidden_act,
102
+ gate_logit_normalizer=config.gate_logit_normalizer,
103
+ elementwise_affine=config.elementwise_affine,
104
+ norm_first=config.norm_first,
105
+ norm_eps=config.norm_eps,
106
+ fuse_norm=config.fuse_norm,
107
+ layer_idx=layer_idx
108
+ )
109
+ if not config.norm_first:
110
+ self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
111
+ self.mlp = GSAMLP(
112
+ hidden_size=config.hidden_size,
113
+ hidden_ratio=config.hidden_ratio,
114
+ intermediate_size=config.intermediate_size,
115
+ hidden_act=config.hidden_act,
116
+ norm_first=config.norm_first,
117
+ norm_eps=config.norm_eps
118
+ )
119
+
120
+ def forward(
121
+ self,
122
+ hidden_states: torch.Tensor,
123
+ attention_mask: Optional[torch.Tensor] = None,
124
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
125
+ use_cache: Optional[bool] = False,
126
+ output_attentions: Optional[bool] = False,
127
+ **kwargs
128
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
129
+
130
+ residual = hidden_states
131
+ if hasattr(self, 'attn_norm'):
132
+ hidden_states = self.attn_norm(hidden_states)
133
+ hidden_states, attentions, past_key_values = self.attn(
134
+ hidden_states=hidden_states,
135
+ attention_mask=attention_mask,
136
+ past_key_values=past_key_values,
137
+ use_cache=use_cache,
138
+ output_attentions=output_attentions
139
+ )
140
+ if hasattr(self, 'mlp_norm'):
141
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
142
+ else:
143
+ hidden_states = residual + hidden_states
144
+ residual = hidden_states
145
+ hidden_states = self.mlp(hidden_states)
146
+ hidden_states = residual + hidden_states
147
+
148
+ outputs = (hidden_states, attentions, past_key_values)
149
+
150
+ return outputs
151
+
152
+
153
+ class GSAPreTrainedModel(PreTrainedModel):
154
+
155
+ config_class = GSAConfig
156
+ supports_gradient_checkpointing = True
157
+ _no_split_modules = ['GSABlock']
158
+
159
+ def __init__(self, *inputs, **kwargs):
160
+ super().__init__(*inputs, **kwargs)
161
+
162
+ def _init_weights(
163
+ self,
164
+ module: nn.Module,
165
+ rescale_prenorm_residual: bool = True,
166
+ num_residuals_per_layer: int = 2,
167
+ ):
168
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
169
+ # Slightly different from the TF version which uses truncated_normal for initialization
170
+ # cf https://github.com/pytorch/pytorch/pull/5617
171
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
172
+ if module.bias is not None:
173
+ nn.init.zeros_(module.bias)
174
+ elif isinstance(module, nn.Embedding):
175
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
176
+ if module.padding_idx is not None:
177
+ module.weight.data[module.padding_idx].zero_()
178
+
179
+ if rescale_prenorm_residual:
180
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
181
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
182
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
183
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
184
+ #
185
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
186
+ for name, p in module.named_parameters():
187
+ if name in ["o_proj.weight", "down_proj.weight"]:
188
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
189
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
190
+ # We need to reinit p since this code could be called multiple times
191
+ # Having just p *= scale would repeatedly scale it down
192
+ with torch.no_grad():
193
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
194
+
195
+
196
+ class GSAModel(GSAPreTrainedModel):
197
+
198
+ def __init__(self, config: GSAConfig):
199
+ super().__init__(config)
200
+ self.padding_idx = config.pad_token_id
201
+ self.vocab_size = config.vocab_size
202
+
203
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
204
+ self.layers = nn.ModuleList([GSABlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
205
+ self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
206
+
207
+ self.gradient_checkpointing = False
208
+
209
+ self.post_init()
210
+
211
+ def get_input_embeddings(self):
212
+ return self.embeddings
213
+
214
+ def set_input_embeddings(self, value):
215
+ self.embeddings = value
216
+
217
+ def forward(
218
+ self,
219
+ input_ids: Optional[torch.LongTensor] = None,
220
+ attention_mask: Optional[torch.Tensor] = None, # noqa
221
+ inputs_embeds: Optional[torch.FloatTensor] = None,
222
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
223
+ use_cache: Optional[bool] = None,
224
+ output_attentions: Optional[bool] = None,
225
+ output_hidden_states: Optional[bool] = None,
226
+ return_dict: Optional[bool] = None
227
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
228
+ if output_attentions:
229
+ warnings.warn("`GSAModel` does not `output_attentions` now, setting it to `False`.")
230
+ output_attentions = False
231
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
232
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
233
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
234
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
235
+
236
+ # retrieve input_ids and inputs_embeds
237
+ if input_ids is not None and inputs_embeds is not None:
238
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
239
+ if input_ids is None and inputs_embeds is None:
240
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
241
+
242
+ if inputs_embeds is None:
243
+ inputs_embeds = self.embeddings(input_ids)
244
+ hidden_states = inputs_embeds
245
+
246
+ if use_cache and not isinstance(past_key_values, Cache):
247
+ past_key_values = Cache.from_legacy_cache(past_key_values)
248
+
249
+ if self.gradient_checkpointing and self.training and use_cache:
250
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
251
+ use_cache = False
252
+
253
+ all_hidden_states = () if output_hidden_states else None
254
+ all_attns = () if output_attentions else None
255
+
256
+ for i, layer in enumerate(self.layers):
257
+ if output_hidden_states:
258
+ all_hidden_states += (hidden_states,)
259
+
260
+ if self.gradient_checkpointing and self.training:
261
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
262
+ layer.__call__,
263
+ hidden_states,
264
+ attention_mask,
265
+ past_key_values,
266
+ use_cache,
267
+ output_attentions,
268
+ )
269
+ else:
270
+ hidden_states, attentions, past_key_values = layer(
271
+ hidden_states,
272
+ attention_mask=attention_mask,
273
+ past_key_values=past_key_values,
274
+ use_cache=use_cache,
275
+ output_attentions=output_attentions
276
+ )
277
+
278
+ if output_attentions:
279
+ all_attns += (attentions,)
280
+
281
+ hidden_states = self.norm(hidden_states)
282
+
283
+ # add hidden states from the last decoder layer
284
+ if output_hidden_states:
285
+ all_hidden_states += (hidden_states,)
286
+
287
+ if not return_dict:
288
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
289
+ return BaseModelOutputWithPast(
290
+ last_hidden_state=hidden_states,
291
+ past_key_values=past_key_values,
292
+ hidden_states=all_hidden_states,
293
+ attentions=all_attns
294
+ )
295
+
296
+
297
+ class GSAForCausalLM(GSAPreTrainedModel, GenerationMixin):
298
+
299
+ _tied_weights_keys = ["lm_head.weight"]
300
+
301
+ def __init__(self, config):
302
+
303
+ super().__init__(config)
304
+ self.model = GSAModel(config)
305
+ self.vocab_size = config.vocab_size
306
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
307
+
308
+ # Initialize weights and apply final processing
309
+ self.post_init()
310
+
311
+ def get_input_embeddings(self):
312
+ return self.model.embeddings
313
+
314
+ def set_input_embeddings(self, value):
315
+ self.model.embeddings = value
316
+
317
+ def get_output_embeddings(self):
318
+ return self.lm_head
319
+
320
+ def set_output_embeddings(self, new_embeddings):
321
+ self.lm_head = new_embeddings
322
+
323
+ def set_decoder(self, decoder):
324
+ self.model = decoder
325
+
326
+ def get_decoder(self):
327
+ return self.model
328
+
329
+ def generate(self, *args, **kwargs):
330
+ try:
331
+ return super().generate(*args, **kwargs)
332
+ except AttributeError as exception:
333
+ if 'past_key_values' in str(exception):
334
+ raise AttributeError(
335
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
336
+ f"which is not supported for {self.__class__.__name__}. "
337
+ f"Try another generation strategy instead. "
338
+ f"For the available generation strategies, check this doc: "
339
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
340
+ )
341
+ else:
342
+ raise exception
343
+
344
+ def prepare_inputs_for_generation(
345
+ self,
346
+ input_ids: torch.LongTensor = None,
347
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
348
+ attention_mask: Optional[torch.Tensor] = None,
349
+ inputs_embeds: Optional[torch.Tensor] = None,
350
+ use_cache: bool = True,
351
+ num_logits_to_keep: Optional[int] = None,
352
+ **kwargs
353
+ ):
354
+ # only last token for `inputs_ids` if the `past_key_values` is passed along.
355
+ if past_key_values is not None:
356
+ input_ids = input_ids[:, -1:]
357
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
358
+ if inputs_embeds is not None and past_key_values is None:
359
+ model_inputs = {'inputs_embeds': inputs_embeds}
360
+ else:
361
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
362
+ # recompiles graphs as the stride of the inputs is a guard.
363
+ # Ref: https://github.com/huggingface/transformers/pull/29114
364
+ # TODO: use `next_tokens` directly instead.
365
+ model_inputs = {'input_ids': input_ids.contiguous()}
366
+
367
+ if num_logits_to_keep is not None:
368
+ model_inputs['num_logits_to_keep'] = num_logits_to_keep
369
+
370
+ model_inputs.update({
371
+ 'past_key_values': past_key_values,
372
+ 'use_cache': use_cache,
373
+ 'attention_mask': attention_mask,
374
+ 'num_logits_to_keep': num_logits_to_keep,
375
+ })
376
+ return model_inputs
377
+
378
+ def forward(
379
+ self,
380
+ input_ids: torch.LongTensor = None,
381
+ attention_mask: Optional[torch.Tensor] = None,
382
+ inputs_embeds: Optional[torch.Tensor] = None,
383
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
384
+ labels: Optional[torch.LongTensor] = None,
385
+ use_cache: Optional[bool] = None,
386
+ output_attentions: Optional[bool] = None,
387
+ output_hidden_states: Optional[bool] = None,
388
+ return_dict: Optional[bool] = None,
389
+ num_logits_to_keep: Optional[int] = 0
390
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
391
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
392
+ output_hidden_states = (
393
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
394
+ )
395
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
396
+
397
+ outputs = self.model(
398
+ input_ids=input_ids,
399
+ attention_mask=attention_mask,
400
+ inputs_embeds=inputs_embeds,
401
+ past_key_values=past_key_values,
402
+ use_cache=use_cache,
403
+ output_attentions=output_attentions,
404
+ output_hidden_states=output_hidden_states,
405
+ return_dict=return_dict
406
+ )
407
+
408
+ hidden_states = outputs[0]
409
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
410
+ logits = None if fuse_linear_and_cross_entropy else self.lm_head(hidden_states[:, -num_logits_to_keep:])
411
+
412
+ loss = None
413
+ if labels is not None:
414
+ if self.config.fuse_cross_entropy:
415
+ if fuse_linear_and_cross_entropy:
416
+ loss_fct = FusedLinearCrossEntropyLoss()
417
+ else:
418
+ loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
419
+ else:
420
+ loss_fct = nn.CrossEntropyLoss()
421
+ # Enable model parallelism
422
+ labels = labels.to(hidden_states.device)
423
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1)
424
+ if fuse_linear_and_cross_entropy:
425
+ loss = loss_fct(hidden_states.view(-1, self.config.hidden_size),
426
+ labels.view(-1),
427
+ self.lm_head.weight,
428
+ self.lm_head.bias)
429
+ else:
430
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
431
+
432
+ if not return_dict:
433
+ output = (logits,) + outputs[1:]
434
+ return (loss,) + output if loss is not None else output
435
+
436
+ return CausalLMOutputWithPast(
437
+ loss=loss,
438
+ logits=logits,
439
+ past_key_values=outputs.past_key_values,
440
+ hidden_states=outputs.hidden_states,
441
+ attentions=outputs.attentions,
442
+ )
fla/models/hgrn/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.hgrn.configuration_hgrn import HGRNConfig
6
+ from fla.models.hgrn.modeling_hgrn import HGRNForCausalLM, HGRNModel
7
+
8
+ AutoConfig.register(HGRNConfig.model_type, HGRNConfig)
9
+ AutoModel.register(HGRNConfig, HGRNModel)
10
+ AutoModelForCausalLM.register(HGRNConfig, HGRNForCausalLM)
11
+
12
+
13
+ __all__ = ['HGRNConfig', 'HGRNForCausalLM', 'HGRNModel']
fla/models/hgrn/configuration_hgrn.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class HGRNConfig(PretrainedConfig):
9
+
10
+ model_type = 'hgrn'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ attn_mode: str = "chunk",
16
+ hidden_size: int = 2048,
17
+ num_hidden_layers: int = 24,
18
+ expand_ratio: Optional[int] = 1,
19
+ use_short_conv: bool = False,
20
+ conv_size: int = 4,
21
+ use_lower_bound: bool = True,
22
+ max_position_embeddings: int = 2048,
23
+ hidden_ratio: Optional[int] = 4,
24
+ intermediate_size: Optional[int] = None,
25
+ hidden_act: str = "swish",
26
+ elementwise_affine: Optional[bool] = True,
27
+ norm_eps: float = 1e-6,
28
+ attn: Optional[Dict] = None,
29
+ use_cache: bool = True,
30
+ pad_token_id: int = None,
31
+ bos_token_id: int = 1,
32
+ eos_token_id: int = 2,
33
+ tie_word_embeddings: bool = False,
34
+ initializer_range: float = 0.02,
35
+ fuse_cross_entropy: bool = True,
36
+ vocab_size: int = 32000,
37
+ **kwargs
38
+ ):
39
+ self.attn_mode = attn_mode
40
+ self.hidden_size = hidden_size
41
+ self.num_hidden_layers = num_hidden_layers
42
+ self.expand_ratio = expand_ratio
43
+ self.use_short_conv = use_short_conv
44
+ self.conv_size = conv_size
45
+ self.use_lower_bound = use_lower_bound
46
+ self.max_position_embeddings = max_position_embeddings
47
+ self.hidden_ratio = hidden_ratio
48
+ self.intermediate_size = intermediate_size
49
+ self.elementwise_affine = elementwise_affine
50
+ self.attn = attn
51
+ self.norm_eps = norm_eps
52
+ self.hidden_act = hidden_act
53
+ self.use_cache = use_cache
54
+ self.initializer_range = initializer_range
55
+ self.fuse_cross_entropy = fuse_cross_entropy
56
+ self.vocab_size = vocab_size
57
+
58
+ if attn is not None:
59
+ if not isinstance(attn, Dict):
60
+ raise ValueError("attn must be a dictionary")
61
+ if 'layers' not in attn:
62
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
63
+ if 'num_heads' not in attn:
64
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
65
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
66
+ attn['window_size'] = attn.get('window_size', None)
67
+
68
+ super().__init__(
69
+ pad_token_id=pad_token_id,
70
+ bos_token_id=bos_token_id,
71
+ eos_token_id=eos_token_id,
72
+ tie_word_embeddings=tie_word_embeddings,
73
+ **kwargs,
74
+ )