amazingvince commited on
Commit
78dbe53
1 Parent(s): 1d6d730

Upload folder using huggingface_hub

Browse files
.ipynb_checkpoints/configuration_custom_seq2seq_llm-checkpoint.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, PretrainedConfig
2
+
3
+
4
+ class Seq2SeqConfig(PretrainedConfig):
5
+ def __init__(
6
+ self,
7
+ vocab_size=30522,
8
+ hidden_size=768,
9
+ num_encoder_layers=6,
10
+ num_decoder_layers=12,
11
+ num_attention_heads=12,
12
+ num_key_value_heads=4,
13
+ intermediate_size=3072,
14
+ hidden_act="silu",
15
+ hidden_dropout_prob=0.0,
16
+ attention_probs_dropout_prob=0.0,
17
+ max_position_embeddings=512,
18
+ initializer_range=0.02,
19
+ layer_norm_eps=1e-12,
20
+ pad_token_id=0,
21
+ bos_token_id=1,
22
+ eos_token_id=2,
23
+ use_cache=True,
24
+ rotary_emb_dim=0,
25
+ rotary_emb_base=10000.0,
26
+ rotary_emb_scale_base=None,
27
+ rotary_emb_interleaved=False,
28
+ **kwargs
29
+ ):
30
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
31
+ self.vocab_size = vocab_size
32
+ self.hidden_size = hidden_size
33
+ self.num_encoder_layers = num_encoder_layers
34
+ self.num_decoder_layers = num_decoder_layers
35
+ self.num_attention_heads = num_attention_heads
36
+ self.num_key_value_heads = num_key_value_heads
37
+ self.hidden_act = hidden_act
38
+ self.intermediate_size = intermediate_size
39
+ self.hidden_dropout_prob = hidden_dropout_prob
40
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
41
+ self.max_position_embeddings = max_position_embeddings
42
+ self.initializer_range = initializer_range
43
+ self.layer_norm_eps = layer_norm_eps
44
+ self.use_cache = use_cache
45
+ self.rotary_emb_base = rotary_emb_base
46
+ self.rotary_emb_scale_base = rotary_emb_scale_base
47
+ self.rotary_emb_interleaved = rotary_emb_interleaved
48
+
49
+ # Calculate head_dim and set rotary_emb_dim
50
+ self.head_dim = self.hidden_size // self.num_attention_heads
51
+ self.rotary_emb_dim = kwargs.get('rotary_emb_dim', self.head_dim // 2)
52
+
53
+ # Ensure rotary_emb_dim is not larger than head_dim
54
+ if self.rotary_emb_dim > self.head_dim:
55
+ print(f"Warning: rotary_emb_dim ({self.rotary_emb_dim}) is larger than head_dim ({self.head_dim}). Setting rotary_emb_dim to head_dim.")
56
+ self.rotary_emb_dim = self.head_dim
.ipynb_checkpoints/flash_atten-checkpoint.py ADDED
File without changes
.ipynb_checkpoints/modeling_custom_seq2seq_llm-checkpoint.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import CrossEntropyLoss
4
+ from transformers.modeling_outputs import Seq2SeqLMOutput
5
+ from transformers.activations import ACT2FN
6
+ from .flash_atten import MHA # Import the MHA class from the provided implementation
7
+ from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
8
+ from liger_kernel.transformers.rms_norm import LigerRMSNorm
9
+ from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
10
+
11
+ from .configuration_custom_seq2seq_llm import Seq2SeqConfig
12
+
13
+
14
+ class RMSNorm(nn.Module):
15
+ def __init__(self, hidden_size, eps=1e-6):
16
+ super().__init__()
17
+ self.weight = nn.Parameter(torch.ones(hidden_size))
18
+ self.eps = eps
19
+
20
+ def forward(self, hidden_states):
21
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
22
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
23
+ return self.weight * hidden_states.to(self.weight.dtype)
24
+
25
+
26
+ class CustomSeq2SeqLLM(PreTrainedModel):
27
+ config_class = Seq2SeqConfig
28
+ base_model_prefix = "custom_seq2seq_llm"
29
+
30
+ def __init__(self, config):
31
+ super().__init__(config)
32
+ self.config = config
33
+ self.shared = nn.Embedding(config.vocab_size, config.hidden_size)
34
+ self.encoder = CustomEncoder(config)
35
+ self.decoder = CustomDecoder(config)
36
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
37
+ self.loss_fn = LigerCrossEntropyLoss()
38
+ self.init_weights()
39
+
40
+ def get_encoder(self):
41
+ return self.encoder
42
+
43
+ def get_decoder(self):
44
+ return self.decoder
45
+
46
+ def get_input_embeddings(self):
47
+ return self.shared
48
+
49
+ def set_input_embeddings(self, value):
50
+ self.shared = value
51
+
52
+ def get_output_embeddings(self):
53
+ return self.lm_head
54
+
55
+ def forward(
56
+ self,
57
+ input_ids=None,
58
+ attention_mask=None,
59
+ decoder_input_ids=None,
60
+ decoder_attention_mask=None,
61
+ encoder_outputs=None,
62
+ past_key_values=None,
63
+ labels=None,
64
+ use_cache=None,
65
+ output_attentions=None,
66
+ output_hidden_states=None,
67
+ return_dict=None,
68
+ position_ids=None,
69
+ ):
70
+ if position_ids is None and input_ids is not None:
71
+ position_ids = torch.arange(input_ids.size(1), dtype=torch.long, device=input_ids.device)
72
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
73
+
74
+ if encoder_outputs is None and input_ids is not None:
75
+ encoder_outputs = self.encoder(
76
+ self.shared(input_ids),
77
+ attention_mask=attention_mask,
78
+ position_ids=position_ids,
79
+ )
80
+
81
+ if decoder_input_ids is None:
82
+ if labels is not None:
83
+ decoder_input_ids = self._shift_right(labels)
84
+ elif input_ids is not None:
85
+ decoder_input_ids = input_ids
86
+ else:
87
+ raise ValueError("Either decoder_input_ids, labels, or input_ids must be provided.")
88
+
89
+ decoder_outputs = self.decoder(
90
+ self.shared(decoder_input_ids),
91
+ encoder_outputs,
92
+ attention_mask=decoder_attention_mask,
93
+ position_ids=position_ids,
94
+ )
95
+
96
+ lm_logits = self.lm_head(decoder_outputs)
97
+
98
+ loss = None
99
+ if labels is not None:
100
+ loss = self.loss_fn(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
101
+
102
+ return Seq2SeqLMOutput(
103
+ loss=loss,
104
+ logits=lm_logits,
105
+ encoder_last_hidden_state=encoder_outputs,
106
+ decoder_hidden_states=decoder_outputs,
107
+ )
108
+
109
+ def _shift_right(self, input_ids):
110
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
111
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
112
+ shifted_input_ids[..., 0] = self.config.pad_token_id
113
+ return shifted_input_ids
114
+
115
+ class CustomEncoder(nn.Module):
116
+ def __init__(self, config):
117
+ super().__init__()
118
+ self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.num_encoder_layers)])
119
+ self.layer_norm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
120
+
121
+ def forward(self, hidden_states, attention_mask=None, position_ids=None):
122
+ for layer in self.layers:
123
+ hidden_states = layer(hidden_states, attention_mask, position_ids)
124
+ return self.layer_norm(hidden_states)
125
+
126
+ class EncoderLayer(nn.Module):
127
+ def __init__(self, config):
128
+ super().__init__()
129
+ self.self_attn = MHA(config.hidden_size, config.num_attention_heads,
130
+ num_heads_kv=config.num_key_value_heads,
131
+ dropout=config.attention_probs_dropout_prob,
132
+ causal=False,
133
+ rotary_emb_dim=config.rotary_emb_dim,
134
+ rotary_emb_base=config.rotary_emb_base,
135
+ rotary_emb_scale_base=config.rotary_emb_scale_base,
136
+ rotary_emb_interleaved=config.rotary_emb_interleaved)
137
+ self.feed_forward = LigerSwiGLUMLP(config)
138
+ self.layer_norm1 = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
139
+ self.layer_norm2 = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
140
+
141
+ def forward(self, hidden_states, attention_mask=None, position_ids=None):
142
+ normed_hidden_states = self.layer_norm1(hidden_states)
143
+ attention_output = self.self_attn(normed_hidden_states, key_padding_mask=attention_mask)
144
+ hidden_states = hidden_states + attention_output
145
+
146
+ normed_hidden_states = self.layer_norm2(hidden_states)
147
+ feed_forward_output = self.feed_forward(normed_hidden_states)
148
+ hidden_states = hidden_states + feed_forward_output
149
+
150
+ return hidden_states
151
+
152
+ class CustomDecoder(nn.Module):
153
+ def __init__(self, config):
154
+ super().__init__()
155
+ self.layers = nn.ModuleList([
156
+ DecoderLayer(config, use_cross_attention=self._should_use_cross_attention(i, config.num_decoder_layers))
157
+ for i in range(config.num_decoder_layers)
158
+ ])
159
+ self.layer_norm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
160
+
161
+ def _should_use_cross_attention(self, layer_idx, total_layers):
162
+ return layer_idx == 0 or layer_idx == total_layers - 1 or layer_idx % 2 == 0
163
+
164
+ def forward(self, hidden_states, encoder_hidden_states, attention_mask=None, position_ids=None):
165
+ for layer in self.layers:
166
+ hidden_states = layer(hidden_states, encoder_hidden_states, attention_mask, position_ids)
167
+ return self.layer_norm(hidden_states)
168
+
169
+ class DecoderLayer(nn.Module):
170
+ def __init__(self, config, use_cross_attention=True):
171
+ super().__init__()
172
+ self.use_cross_attention = use_cross_attention
173
+ self.self_attn = MHA(config.hidden_size, config.num_attention_heads,
174
+ num_heads_kv=config.num_key_value_heads,
175
+ dropout=config.attention_probs_dropout_prob,
176
+ causal=True,
177
+ rotary_emb_dim=config.rotary_emb_dim,
178
+ rotary_emb_base=config.rotary_emb_base,
179
+ rotary_emb_scale_base=config.rotary_emb_scale_base,
180
+ rotary_emb_interleaved=config.rotary_emb_interleaved)
181
+ if use_cross_attention:
182
+ self.cross_attn = MHA(config.hidden_size, config.num_attention_heads,
183
+ num_heads_kv=config.num_key_value_heads,
184
+ dropout=config.attention_probs_dropout_prob,
185
+ causal=False,
186
+ cross_attn=True)
187
+ self.layer_norm2 = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
188
+ self.feed_forward = LigerSwiGLUMLP(config)
189
+ self.layer_norm1 = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
190
+ self.layer_norm3 = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
191
+
192
+ def forward(self, hidden_states, encoder_hidden_states, attention_mask=None, position_ids=None):
193
+ normed_hidden_states = self.layer_norm1(hidden_states)
194
+ self_attention_output = self.self_attn(normed_hidden_states, key_padding_mask=attention_mask)
195
+ hidden_states = hidden_states + self_attention_output
196
+
197
+ if self.use_cross_attention:
198
+ normed_hidden_states = self.layer_norm2(hidden_states)
199
+ cross_attention_output = self.cross_attn(normed_hidden_states, x_kv=encoder_hidden_states, key_padding_mask=attention_mask)
200
+ hidden_states = hidden_states + cross_attention_output
201
+
202
+ normed_hidden_states = self.layer_norm3(hidden_states)
203
+ feed_forward_output = self.feed_forward(normed_hidden_states)
204
+ hidden_states = hidden_states + feed_forward_output
205
+
206
+ return hidden_states
207
+
208
+ class FeedForward(nn.Module):
209
+ def __init__(self, config):
210
+ super().__init__()
211
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
212
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
213
+ self.act = ACT2FN[config.hidden_act]
214
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
215
+
216
+ def forward(self, x):
217
+ x = self.fc1(x)
218
+ x = self.act(x)
219
+ x = self.dropout(x)
220
+ x = self.fc2(x)
221
+ x = self.dropout(x)
222
+ return x
config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_probs_dropout_prob": 0.0,
3
+ "bos_token_id": 1,
4
+ "decoder_start_token_id": 3,
5
+ "eos_token_id": 2,
6
+ "head_dim": 64,
7
+ "hidden_act": "silu",
8
+ "hidden_dropout_prob": 0.0,
9
+ "hidden_size": 1024,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 3072,
12
+ "layer_norm_eps": 1e-12,
13
+ "max_position_embeddings": 1024,
14
+ "num_attention_heads": 16,
15
+ "num_decoder_layers": 32,
16
+ "num_encoder_layers": 16,
17
+ "num_key_value_heads": 4,
18
+ "pad_token_id": 3,
19
+ "rotary_emb_base": 10000.0,
20
+ "rotary_emb_dim": 32,
21
+ "rotary_emb_interleaved": false,
22
+ "rotary_emb_scale_base": null,
23
+ "transformers_version": "4.44.2",
24
+ "use_cache": true,
25
+ "vocab_size": 48256
26
+ }
configuration_custom_seq2seq_llm.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, PretrainedConfig
2
+
3
+
4
+ class Seq2SeqConfig(PretrainedConfig):
5
+ def __init__(
6
+ self,
7
+ vocab_size=30522,
8
+ hidden_size=768,
9
+ num_encoder_layers=6,
10
+ num_decoder_layers=12,
11
+ num_attention_heads=12,
12
+ num_key_value_heads=4,
13
+ intermediate_size=3072,
14
+ hidden_act="silu",
15
+ hidden_dropout_prob=0.0,
16
+ attention_probs_dropout_prob=0.0,
17
+ max_position_embeddings=512,
18
+ initializer_range=0.02,
19
+ layer_norm_eps=1e-12,
20
+ pad_token_id=0,
21
+ bos_token_id=1,
22
+ eos_token_id=2,
23
+ use_cache=True,
24
+ rotary_emb_dim=0,
25
+ rotary_emb_base=10000.0,
26
+ rotary_emb_scale_base=None,
27
+ rotary_emb_interleaved=False,
28
+ **kwargs
29
+ ):
30
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
31
+ self.vocab_size = vocab_size
32
+ self.hidden_size = hidden_size
33
+ self.num_encoder_layers = num_encoder_layers
34
+ self.num_decoder_layers = num_decoder_layers
35
+ self.num_attention_heads = num_attention_heads
36
+ self.num_key_value_heads = num_key_value_heads
37
+ self.hidden_act = hidden_act
38
+ self.intermediate_size = intermediate_size
39
+ self.hidden_dropout_prob = hidden_dropout_prob
40
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
41
+ self.max_position_embeddings = max_position_embeddings
42
+ self.initializer_range = initializer_range
43
+ self.layer_norm_eps = layer_norm_eps
44
+ self.use_cache = use_cache
45
+ self.rotary_emb_base = rotary_emb_base
46
+ self.rotary_emb_scale_base = rotary_emb_scale_base
47
+ self.rotary_emb_interleaved = rotary_emb_interleaved
48
+
49
+ # Calculate head_dim and set rotary_emb_dim
50
+ self.head_dim = self.hidden_size // self.num_attention_heads
51
+ self.rotary_emb_dim = kwargs.get('rotary_emb_dim', self.head_dim // 2)
52
+
53
+ # Ensure rotary_emb_dim is not larger than head_dim
54
+ if self.rotary_emb_dim > self.head_dim:
55
+ print(f"Warning: rotary_emb_dim ({self.rotary_emb_dim}) is larger than head_dim ({self.head_dim}). Setting rotary_emb_dim to head_dim.")
56
+ self.rotary_emb_dim = self.head_dim
flash_atten.py ADDED
@@ -0,0 +1,1020 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ import math
4
+ from functools import partial
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange, repeat
9
+
10
+ from flash_attn.utils.distributed import get_dim_for_local_rank
11
+
12
+ try:
13
+ from flash_attn import (
14
+ flash_attn_kvpacked_func,
15
+ flash_attn_qkvpacked_func,
16
+ flash_attn_varlen_kvpacked_func,
17
+ flash_attn_varlen_qkvpacked_func,
18
+ flash_attn_with_kvcache,
19
+ )
20
+ except ImportError:
21
+ flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
22
+ flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
23
+ flash_attn_with_kvcache = None
24
+
25
+ try:
26
+ from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
27
+ except ImportError:
28
+ FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
29
+
30
+ try:
31
+ from flash_attn.layers.rotary import RotaryEmbedding
32
+ except ImportError:
33
+ RotaryEmbedding = None
34
+
35
+
36
+ # From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
37
+ def get_alibi_slopes(nheads):
38
+ def get_slopes_power_of_2(nheads):
39
+ start = 2 ** (-(2 ** -(math.log2(nheads) - 3)))
40
+ ratio = start
41
+ return [start * ratio**i for i in range(nheads)]
42
+
43
+ if math.log2(nheads).is_integer():
44
+ return get_slopes_power_of_2(nheads)
45
+ else:
46
+ closest_power_of_2 = 2 ** math.floor(math.log2(nheads))
47
+ return (
48
+ get_slopes_power_of_2(closest_power_of_2)
49
+ + get_alibi_slopes(2 * closest_power_of_2)[0::2][: nheads - closest_power_of_2]
50
+ )
51
+
52
+
53
+ class FlashSelfAttention(nn.Module):
54
+ """Implement the scaled dot product attention with softmax.
55
+ Arguments
56
+ ---------
57
+ softmax_scale: The temperature to use for the softmax attention.
58
+ (default: 1/sqrt(d_keys) where d_keys is computed at
59
+ runtime)
60
+ attention_dropout: The dropout rate to apply to the attention
61
+ (default: 0.0)
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ causal=False,
67
+ softmax_scale=None,
68
+ attention_dropout=0.0,
69
+ window_size=(-1, -1),
70
+ alibi_slopes=None,
71
+ deterministic=False,
72
+ ):
73
+ super().__init__()
74
+ assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
75
+ assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
76
+ self.causal = causal
77
+ self.softmax_scale = softmax_scale
78
+ self.drop = nn.Dropout(attention_dropout)
79
+ self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
80
+ self.window_size = window_size
81
+ self.deterministic = deterministic
82
+
83
+ def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
84
+ """Implements the multihead softmax attention.
85
+ Arguments
86
+ ---------
87
+ qkv: The tensor containing the query, key, and value.
88
+ If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
89
+ If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
90
+ (total, 3, H, D), where total is the sum of the sequence lengths in the batch.
91
+ causal: if passed, will override self.causal
92
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
93
+ of the sequences in the batch, used to index into qkv.
94
+ max_seqlen: int. Maximum sequence length in the batch.
95
+ Returns:
96
+ --------
97
+ out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
98
+ else (B, S, H, D).
99
+ """
100
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
101
+ assert qkv.is_cuda
102
+ causal = self.causal if causal is None else causal
103
+ unpadded = cu_seqlens is not None
104
+ if self.alibi_slopes is not None:
105
+ self.alibi_slopes = self.alibi_slopes.to(torch.float32)
106
+ if unpadded:
107
+ assert cu_seqlens.dtype == torch.int32
108
+ assert max_seqlen is not None
109
+ assert isinstance(max_seqlen, int)
110
+ return flash_attn_varlen_qkvpacked_func(
111
+ qkv,
112
+ cu_seqlens,
113
+ max_seqlen,
114
+ self.drop.p if self.training else 0.0,
115
+ softmax_scale=self.softmax_scale,
116
+ causal=causal,
117
+ alibi_slopes=self.alibi_slopes,
118
+ window_size=self.window_size,
119
+ deterministic=self.deterministic,
120
+ )
121
+ else:
122
+ return flash_attn_qkvpacked_func(
123
+ qkv,
124
+ self.drop.p if self.training else 0.0,
125
+ softmax_scale=self.softmax_scale,
126
+ causal=causal,
127
+ alibi_slopes=self.alibi_slopes,
128
+ window_size=self.window_size,
129
+ deterministic=self.deterministic,
130
+ )
131
+
132
+
133
+ class FlashCrossAttention(nn.Module):
134
+ """Implement the scaled dot product attention with softmax.
135
+ Arguments
136
+ ---------
137
+ softmax_scale: The temperature to use for the softmax attention.
138
+ (default: 1/sqrt(d_keys) where d_keys is computed at
139
+ runtime)
140
+ attention_dropout: The dropout rate to apply to the attention
141
+ (default: 0.0)
142
+ """
143
+
144
+ def __init__(
145
+ self,
146
+ causal=False,
147
+ softmax_scale=None,
148
+ attention_dropout=0.0,
149
+ alibi_slopes=None,
150
+ window_size=(-1, -1),
151
+ deterministic=False,
152
+ ):
153
+ super().__init__()
154
+ assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
155
+ assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
156
+ self.causal = causal
157
+ self.softmax_scale = softmax_scale
158
+ self.drop = nn.Dropout(attention_dropout)
159
+ self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
160
+ self.window_size = window_size
161
+ self.deterministic = deterministic
162
+
163
+ def forward(
164
+ self,
165
+ q,
166
+ kv,
167
+ causal=None,
168
+ cu_seqlens=None,
169
+ max_seqlen=None,
170
+ cu_seqlens_k=None,
171
+ max_seqlen_k=None,
172
+ ):
173
+ """Implements the multihead softmax attention.
174
+ Arguments
175
+ ---------
176
+ q: The tensor containing the query. (B, Sq, H, D)
177
+ kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
178
+ causal: if passed, will override self.causal
179
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
180
+ of the sequences in the batch, used to index into q.
181
+ max_seqlen: int. Maximum sequence length in the batch of q.
182
+ cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
183
+ of the sequences in the batch, used to index into kv.
184
+ max_seqlen_k: int. Maximum sequence length in the batch of k and v.
185
+ """
186
+ assert q.dtype in [torch.float16, torch.bfloat16]
187
+ assert q.is_cuda and kv.is_cuda
188
+ causal = self.causal if causal is None else causal
189
+ unpadded = cu_seqlens is not None
190
+ if self.alibi_slopes is not None:
191
+ self.alibi_slopes = self.alibi_slopes.to(torch.float32)
192
+ if unpadded:
193
+ assert cu_seqlens.dtype == torch.int32
194
+ assert max_seqlen is not None
195
+ assert isinstance(max_seqlen, int)
196
+ assert cu_seqlens_k is not None
197
+ assert cu_seqlens_k.dtype == torch.int32
198
+ assert max_seqlen_k is not None
199
+ assert isinstance(max_seqlen_k, int)
200
+ return flash_attn_varlen_kvpacked_func(
201
+ q,
202
+ kv,
203
+ cu_seqlens,
204
+ cu_seqlens_k,
205
+ max_seqlen,
206
+ max_seqlen_k,
207
+ self.drop.p if self.training else 0.0,
208
+ softmax_scale=self.softmax_scale,
209
+ causal=causal,
210
+ alibi_slopes=self.alibi_slopes,
211
+ window_size=self.window_size,
212
+ deterministic=self.deterministic,
213
+ )
214
+ else:
215
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
216
+ seqlen_k = kv.shape[1]
217
+ assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
218
+ return flash_attn_kvpacked_func(
219
+ q,
220
+ kv,
221
+ self.drop.p if self.training else 0.0,
222
+ causal=causal,
223
+ softmax_scale=self.softmax_scale,
224
+ alibi_slopes=self.alibi_slopes,
225
+ window_size=self.window_size,
226
+ deterministic=self.deterministic,
227
+ )
228
+
229
+
230
+ class SelfAttention(nn.Module):
231
+ """Implement the scaled dot product attention with softmax.
232
+ Arguments
233
+ ---------
234
+ softmax_scale: The temperature to use for the softmax attention.
235
+ (default: 1/sqrt(d_keys) where d_keys is computed at
236
+ runtime)
237
+ attention_dropout: The dropout rate to apply to the attention
238
+ (default: 0.0)
239
+ """
240
+
241
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
242
+ super().__init__()
243
+ self.causal = causal
244
+ self.softmax_scale = softmax_scale
245
+ self.drop = nn.Dropout(attention_dropout)
246
+
247
+ def forward(self, qkv, causal=None, key_padding_mask=None):
248
+ """Implements the multihead softmax attention.
249
+ Arguments
250
+ ---------
251
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
252
+ causal: if passed, will override self.causal
253
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
254
+ False means to mask out. (B, S)
255
+ """
256
+ batch_size, seqlen = qkv.shape[0], qkv.shape[1]
257
+ causal = self.causal if causal is None else causal
258
+ q, k, v = qkv.unbind(dim=2)
259
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
260
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
261
+ if key_padding_mask is not None:
262
+ padding_mask = torch.full(
263
+ (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
264
+ )
265
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
266
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
267
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
268
+ if causal:
269
+ # "triu_tril_cuda_template" not implemented for 'BFloat16'
270
+ # So we have to construct the mask in float
271
+ causal_mask = torch.triu(
272
+ torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
273
+ )
274
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
275
+ scores = scores + causal_mask.to(dtype=scores.dtype)
276
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
277
+ attention_drop = self.drop(attention)
278
+ output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
279
+ return output
280
+
281
+
282
+ class CrossAttention(nn.Module):
283
+ """Implement the scaled dot product attention with softmax.
284
+ Arguments
285
+ ---------
286
+ softmax_scale: The temperature to use for the softmax attention.
287
+ (default: 1/sqrt(d_keys) where d_keys is computed at
288
+ runtime)
289
+ attention_dropout: The dropout rate to apply to the attention
290
+ (default: 0.0)
291
+ """
292
+
293
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
294
+ super().__init__()
295
+ self.causal = causal
296
+ self.softmax_scale = softmax_scale
297
+ self.drop = nn.Dropout(attention_dropout)
298
+
299
+ def forward(self, q, kv, causal=None, key_padding_mask=None):
300
+ """Implements the multihead softmax attention.
301
+ Arguments
302
+ ---------
303
+ q: The tensor containing the query. (B, Sq, H, D)
304
+ kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
305
+ causal: if passed, will override self.causal
306
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
307
+ False means to mask out. (B, Sk)
308
+ """
309
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
310
+ causal = self.causal if causal is None else causal
311
+ seqlen_k = kv.shape[1]
312
+ assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
313
+ if kv.shape[3] != q.shape[2]: # MQA/GQA
314
+ kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
315
+ k, v = kv.unbind(dim=2)
316
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
317
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
318
+ if key_padding_mask is not None:
319
+ padding_mask = torch.full(
320
+ (batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device
321
+ )
322
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
323
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
324
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
325
+ if causal:
326
+ # causal mask needs to take into account the difference between seqlen_q and seqlen_k
327
+ row_idx = rearrange(
328
+ torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
329
+ )
330
+ col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long)
331
+ sk = (
332
+ seqlen_k
333
+ if key_padding_mask is None
334
+ else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
335
+ )
336
+ causal_mask = col_idx > row_idx + sk - seqlen_q
337
+ scores = scores.masked_fill(causal_mask, -10000.0)
338
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
339
+ attention_drop = self.drop(attention)
340
+ output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
341
+ return output
342
+
343
+
344
+ class LinearResidual(nn.Linear):
345
+ """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
346
+
347
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
348
+ return super().forward(input), input
349
+
350
+
351
+ def _update_kv_cache(kv, inference_params, layer_idx):
352
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
353
+ # Pre-allocate memory for key-values for inference.
354
+ num_heads, head_dim = kv.shape[-2:]
355
+ if layer_idx not in inference_params.key_value_memory_dict:
356
+ kv_cache = torch.empty(
357
+ inference_params.max_batch_size,
358
+ inference_params.max_seqlen,
359
+ 2,
360
+ num_heads,
361
+ head_dim,
362
+ dtype=kv.dtype,
363
+ device=kv.device,
364
+ )
365
+ inference_params.key_value_memory_dict[layer_idx] = kv_cache
366
+ else:
367
+ kv_cache = inference_params.key_value_memory_dict[layer_idx]
368
+ # Adjust key and value for inference
369
+ batch_start = inference_params.batch_size_offset
370
+ batch_end = batch_start + kv.shape[0]
371
+ sequence_start = inference_params.seqlen_offset
372
+ sequence_end = sequence_start + kv.shape[1]
373
+ assert batch_end <= kv_cache.shape[0]
374
+ assert sequence_end <= kv_cache.shape[1]
375
+ assert kv_cache is not None
376
+ kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
377
+ return kv_cache[batch_start:batch_end, :sequence_end, ...]
378
+
379
+
380
+ class MHA(nn.Module):
381
+ """Multi-head self-attention and cross-attention"""
382
+
383
+ def __init__(
384
+ self,
385
+ embed_dim,
386
+ num_heads,
387
+ num_heads_kv=None,
388
+ cross_attn=False,
389
+ qkv_proj_bias=True,
390
+ out_proj_bias=True,
391
+ dropout=0.0,
392
+ softmax_scale=None,
393
+ causal=False,
394
+ layer_idx=None,
395
+ dwconv=False,
396
+ rotary_emb_dim=0,
397
+ rotary_emb_base=10000.0,
398
+ rotary_emb_scale_base=None,
399
+ rotary_emb_interleaved=False,
400
+ use_alibi=False,
401
+ window_size=(-1, -1),
402
+ fused_bias_fc=False,
403
+ use_flash_attn=False,
404
+ return_residual=False,
405
+ checkpointing=False,
406
+ device=None,
407
+ dtype=None,
408
+ ) -> None:
409
+ """
410
+ num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
411
+ return_residual: whether to return the input x along with the output. This is for
412
+ performance reason: for post-norm architecture, returning the input allows us
413
+ to fuse the backward of nn.Linear with the residual connection.
414
+ """
415
+ factory_kwargs = {"device": device, "dtype": dtype}
416
+ super().__init__()
417
+ self.embed_dim = embed_dim
418
+ self.cross_attn = cross_attn
419
+ self.causal = causal
420
+ self.layer_idx = layer_idx
421
+ self.dwconv = dwconv
422
+ self.rotary_emb_dim = rotary_emb_dim
423
+ self.use_flash_attn = use_flash_attn
424
+ self.return_residual = return_residual
425
+ self.checkpointing = checkpointing
426
+ if use_alibi:
427
+ assert use_flash_attn, "ALiBi code path requires flash_attn"
428
+ alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
429
+ else:
430
+ alibi_slopes = None
431
+ if window_size != (-1, -1):
432
+ assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
433
+
434
+ self.num_heads = num_heads
435
+ self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
436
+ assert (
437
+ self.num_heads % self.num_heads_kv == 0
438
+ ), "num_heads must be divisible by num_heads_kv"
439
+ assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
440
+ self.head_dim = self.embed_dim // num_heads
441
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
442
+ kv_dim = 2 * self.head_dim * self.num_heads_kv
443
+
444
+ if self.rotary_emb_dim > 0:
445
+ assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet"
446
+ assert RotaryEmbedding is not None, "rotary_emb is not installed"
447
+ self.rotary_emb = RotaryEmbedding(
448
+ self.rotary_emb_dim,
449
+ base=rotary_emb_base,
450
+ scale_base=rotary_emb_scale_base,
451
+ interleaved=rotary_emb_interleaved,
452
+ device=device,
453
+ )
454
+
455
+ if fused_bias_fc and FusedDense is None:
456
+ raise ImportError("fused_dense is not installed")
457
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
458
+ linear_resid_cls = (
459
+ LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
460
+ )
461
+ wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
462
+ inner_attn_cls = (
463
+ partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
464
+ if use_flash_attn
465
+ else SelfAttention
466
+ )
467
+ inner_cross_attn_cls = (
468
+ partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
469
+ if use_flash_attn
470
+ else CrossAttention
471
+ )
472
+ if not self.cross_attn:
473
+ self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
474
+ else:
475
+ self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
476
+ self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
477
+ if self.dwconv:
478
+ if self.num_heads_kv == self.num_heads:
479
+ self.dwconv_qkv = nn.Conv1d(
480
+ qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim
481
+ )
482
+ else:
483
+ self.dwconv_q = nn.Conv1d(
484
+ embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
485
+ )
486
+ self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
487
+ self.inner_attn = inner_attn_cls(
488
+ causal=causal,
489
+ softmax_scale=softmax_scale,
490
+ attention_dropout=dropout,
491
+ )
492
+ self.inner_cross_attn = inner_cross_attn_cls(
493
+ causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
494
+ )
495
+ self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
496
+
497
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
498
+ dtype = self.out_proj.weight.dtype if dtype is None else dtype
499
+ device = self.out_proj.weight.device
500
+ return torch.empty(
501
+ batch_size,
502
+ max_seqlen,
503
+ 2,
504
+ self.num_heads_kv,
505
+ self.head_dim,
506
+ dtype=dtype,
507
+ device=device,
508
+ )
509
+
510
+ def _update_kv_cache(self, kv, inference_params):
511
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
512
+ assert not self.dwconv, "Generation does not support dwconv yet"
513
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
514
+ return _update_kv_cache(kv, inference_params, self.layer_idx)
515
+
516
+ def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
517
+ """
518
+ Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
519
+ q: (batch_size, seqlen_q, nheads, head_dim)
520
+ kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
521
+ """
522
+ assert inference_params is not None and inference_params.seqlen_offset > 0
523
+ assert self.use_flash_attn
524
+ if self.rotary_emb_dim > 0:
525
+ assert self.rotary_emb.scale is None, "This code path does not support xPos"
526
+ self.rotary_emb._update_cos_sin_cache(
527
+ inference_params.max_seqlen, device=q.device, dtype=q.dtype
528
+ )
529
+ rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
530
+ else:
531
+ rotary_cos, rotary_sin = None, None
532
+ batch = q.shape[0]
533
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
534
+ cache_seqlens = (
535
+ inference_params.lengths_per_sample[:batch]
536
+ if inference_params.lengths_per_sample is not None
537
+ else inference_params.seqlen_offset
538
+ )
539
+ alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
540
+ context = flash_attn_with_kvcache(
541
+ q,
542
+ kv_cache[:, :, 0],
543
+ kv_cache[:, :, 1],
544
+ kv[:, :, 0],
545
+ kv[:, :, 1],
546
+ rotary_cos=rotary_cos,
547
+ rotary_sin=rotary_sin,
548
+ cache_seqlens=cache_seqlens,
549
+ softmax_scale=self.inner_cross_attn.softmax_scale,
550
+ causal=self.inner_cross_attn.causal,
551
+ rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
552
+ alibi_slopes=alibi_slopes,
553
+ )
554
+ return context
555
+
556
+ def _update_kvcache_attention(self, q, kv, inference_params):
557
+ """Write kv to inference_params, then do attention"""
558
+ if (
559
+ inference_params.seqlen_offset == 0
560
+ or flash_attn_with_kvcache is None
561
+ or not self.use_flash_attn
562
+ ):
563
+ # TODO: this only uses seqlen_offset and not lengths_per_sample.
564
+ kv = self._update_kv_cache(kv, inference_params)
565
+ return self.inner_cross_attn(q, kv)
566
+ else:
567
+ batch = q.shape[0]
568
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
569
+ cache_seqlens = (
570
+ inference_params.lengths_per_sample[:batch]
571
+ if inference_params.lengths_per_sample is not None
572
+ else inference_params.seqlen_offset
573
+ )
574
+ alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
575
+ return flash_attn_with_kvcache(
576
+ q,
577
+ kv_cache[:, :, 0],
578
+ kv_cache[:, :, 1],
579
+ kv[:, :, 0],
580
+ kv[:, :, 1],
581
+ cache_seqlens=cache_seqlens,
582
+ softmax_scale=self.inner_cross_attn.softmax_scale,
583
+ causal=self.inner_cross_attn.causal,
584
+ alibi_slopes=alibi_slopes,
585
+ )
586
+
587
+ def forward(
588
+ self,
589
+ x,
590
+ x_kv=None,
591
+ key_padding_mask=None,
592
+ cu_seqlens=None,
593
+ max_seqlen=None,
594
+ mixer_subset=None,
595
+ inference_params=None,
596
+ **kwargs,
597
+ ):
598
+ """
599
+ Arguments:
600
+ x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
601
+ cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
602
+ is the is the sum of the sequence lengths in the batch.
603
+ x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
604
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
605
+ of the sequences in the batch, used to index into x. Only applicable when using
606
+ FlashAttention.
607
+ max_seqlen: int. Maximum sequence length in the batch.
608
+ key_padding_mask: boolean mask, True means to keep, False means to mask out.
609
+ (batch, seqlen). Only applicable when not using FlashAttention.
610
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
611
+ before applying the query projection. Useful for e.g., ViT where we only care
612
+ about the CLS token in the last layer.
613
+ inference_params: for generation. Adapted from Megatron-LM (and Apex)
614
+ https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
615
+ """
616
+ if cu_seqlens is not None:
617
+ assert max_seqlen is not None
618
+ assert key_padding_mask is None
619
+ assert self.use_flash_attn
620
+ assert not self.dwconv
621
+ assert self.rotary_emb_dim == 0
622
+ if key_padding_mask is not None:
623
+ assert cu_seqlens is None
624
+ assert max_seqlen is None
625
+ assert not self.use_flash_attn
626
+ if inference_params is not None:
627
+ assert key_padding_mask is None
628
+ assert cu_seqlens is None and max_seqlen is None
629
+ assert not self.dwconv
630
+
631
+ kwargs = (
632
+ {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
633
+ if self.use_flash_attn
634
+ else {"key_padding_mask": key_padding_mask, **kwargs}
635
+ )
636
+ seqlen_offset = (
637
+ 0
638
+ if inference_params is None
639
+ else (
640
+ inference_params.lengths_per_sample
641
+ if inference_params.lengths_per_sample is not None
642
+ else inference_params.seqlen_offset
643
+ )
644
+ )
645
+ rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
646
+ batch, seqlen = x.shape[:2]
647
+ if not self.cross_attn and self.num_heads_kv == self.num_heads:
648
+ assert x_kv is None and mixer_subset is None
649
+ if not self.return_residual:
650
+ qkv = self.Wqkv(x)
651
+ else:
652
+ qkv, x = self.Wqkv(x)
653
+ if self.dwconv:
654
+ qkv = rearrange(
655
+ self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
656
+ ).contiguous()
657
+ qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
658
+ if (
659
+ inference_params is None
660
+ or inference_params.seqlen_offset == 0
661
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
662
+ or not self.use_flash_attn
663
+ ):
664
+ if self.rotary_emb_dim > 0:
665
+ qkv = self.rotary_emb(
666
+ qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
667
+ )
668
+ if inference_params is None:
669
+ if not self.checkpointing:
670
+ context = self.inner_attn(qkv, **kwargs)
671
+ else:
672
+ context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
673
+ else:
674
+ context = self._update_kvcache_attention(
675
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
676
+ )
677
+ else:
678
+ context = self._apply_rotary_update_kvcache_attention(
679
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
680
+ )
681
+ else:
682
+ if self.cross_attn:
683
+ if not self.return_residual:
684
+ q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
685
+ kv = self.Wkv(x_kv if x_kv is not None else x)
686
+ else:
687
+ if x_kv is not None:
688
+ kv, x_kv = self.Wkv(x_kv)
689
+ else:
690
+ kv, x = self.Wkv(x)
691
+ q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
692
+ else:
693
+ assert self.num_heads_kv != self.num_heads
694
+ if not self.return_residual:
695
+ qkv = self.Wqkv(x)
696
+ else:
697
+ qkv, x = self.Wqkv(x)
698
+ q = qkv[..., : self.num_heads * self.head_dim]
699
+ kv = qkv[..., self.num_heads * self.head_dim :]
700
+ q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
701
+ kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
702
+ if self.dwconv:
703
+ q = rearrange(
704
+ self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
705
+ ).contiguous()
706
+ kv = rearrange(
707
+ self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
708
+ ).contiguous()
709
+ if (
710
+ inference_params is None
711
+ or inference_params.seqlen_offset == 0
712
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
713
+ or not self.use_flash_attn
714
+ ):
715
+ if self.rotary_emb_dim > 0:
716
+ q, kv = self.rotary_emb(
717
+ q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
718
+ )
719
+ if inference_params is None:
720
+ if not self.checkpointing:
721
+ context = self.inner_cross_attn(q, kv, **kwargs)
722
+ else:
723
+ context = torch.utils.checkpoint.checkpoint(
724
+ self.inner_cross_attn, q, kv, **kwargs
725
+ )
726
+ else:
727
+ context = self._update_kvcache_attention(q, kv, inference_params)
728
+ else:
729
+ context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
730
+ out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
731
+ return out if not self.return_residual else (out, x)
732
+
733
+
734
+ class ParallelMHA(nn.Module):
735
+ """Multi-head self-attention and cross-attention"""
736
+
737
+ def __init__(
738
+ self,
739
+ embed_dim,
740
+ num_heads,
741
+ process_group,
742
+ num_heads_kv=None,
743
+ qkv_proj_bias=True,
744
+ out_proj_bias=True,
745
+ dropout=0.0,
746
+ softmax_scale=None,
747
+ causal=False,
748
+ layer_idx=None,
749
+ rotary_emb_dim=0,
750
+ rotary_emb_base=10000.0,
751
+ rotary_emb_scale_base=None,
752
+ rotary_emb_interleaved=False,
753
+ use_alibi=False,
754
+ window_size=(-1, -1),
755
+ use_flash_attn=False,
756
+ checkpointing=False,
757
+ sequence_parallel=True,
758
+ device=None,
759
+ dtype=None,
760
+ ) -> None:
761
+ factory_kwargs = {"device": device, "dtype": dtype}
762
+ super().__init__()
763
+ self.embed_dim = embed_dim
764
+ self.causal = causal
765
+ self.layer_idx = layer_idx
766
+ self.rotary_emb_dim = rotary_emb_dim
767
+ self.use_flash_attn = use_flash_attn
768
+ self.checkpointing = checkpointing
769
+ self.process_group = process_group
770
+ self.world_size = process_group.size()
771
+ self.local_rank = torch.distributed.get_rank(process_group)
772
+
773
+ self.num_heads = num_heads
774
+ assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
775
+
776
+ self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
777
+ assert (
778
+ self.num_heads % self.num_heads_kv == 0
779
+ ), "num_heads must be divisible by num_heads_kv"
780
+
781
+ self.num_heads_per_rank = get_dim_for_local_rank(
782
+ self.num_heads, self.world_size, self.local_rank
783
+ )
784
+ self.num_heads_kv_per_rank = get_dim_for_local_rank(
785
+ self.num_heads_kv, self.world_size, self.local_rank
786
+ )
787
+ self.head_dim = self.embed_dim // num_heads
788
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
789
+
790
+ if use_alibi:
791
+ assert use_flash_attn, "ALiBi code path requires flash_attn"
792
+ num_heads_local = math.ceil(self.num_heads / self.world_size)
793
+ alibi_slopes = torch.tensor(
794
+ get_alibi_slopes(num_heads)[
795
+ self.local_rank * num_heads_local : (self.local_rank + 1) * num_heads_local
796
+ ],
797
+ device=device,
798
+ )
799
+ else:
800
+ alibi_slopes = None
801
+ if window_size != (-1, -1):
802
+ assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
803
+
804
+ if self.rotary_emb_dim > 0:
805
+ assert RotaryEmbedding is not None, "rotary_emb is not installed"
806
+ self.rotary_emb = RotaryEmbedding(
807
+ self.rotary_emb_dim,
808
+ base=rotary_emb_base,
809
+ scale_base=rotary_emb_scale_base,
810
+ interleaved=rotary_emb_interleaved,
811
+ device=device,
812
+ )
813
+
814
+ if ColumnParallelLinear is None or RowParallelLinear is None:
815
+ raise ImportError("fused_dense is not installed")
816
+ self.Wqkv = ColumnParallelLinear(
817
+ embed_dim,
818
+ qkv_dim,
819
+ process_group,
820
+ bias=qkv_proj_bias,
821
+ sequence_parallel=sequence_parallel,
822
+ multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2),
823
+ **factory_kwargs,
824
+ )
825
+ inner_attn_cls = (
826
+ partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
827
+ if use_flash_attn
828
+ else SelfAttention
829
+ )
830
+ inner_cross_attn_cls = (
831
+ partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
832
+ if use_flash_attn
833
+ else CrossAttention
834
+ )
835
+ self.inner_attn = inner_attn_cls(
836
+ causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
837
+ )
838
+ self.inner_cross_attn = inner_cross_attn_cls(
839
+ causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
840
+ )
841
+ self.out_proj = RowParallelLinear(
842
+ embed_dim,
843
+ embed_dim,
844
+ process_group,
845
+ bias=out_proj_bias,
846
+ sequence_parallel=sequence_parallel,
847
+ multiple_of=self.head_dim,
848
+ **factory_kwargs,
849
+ )
850
+
851
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
852
+ dtype = self.out_proj.weight.dtype if dtype is None else dtype
853
+ device = self.out_proj.weight.device
854
+ return torch.empty(
855
+ batch_size,
856
+ max_seqlen,
857
+ 2,
858
+ self.num_heads_kv_per_rank,
859
+ self.head_dim,
860
+ dtype=dtype,
861
+ device=device,
862
+ )
863
+
864
+ def _update_kv_cache(self, kv, inference_params):
865
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
866
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
867
+ return _update_kv_cache(kv, inference_params, self.layer_idx)
868
+
869
+ def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
870
+ """
871
+ Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
872
+ q: (batch_size, seqlen_q, nheads, head_dim)
873
+ kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
874
+ """
875
+ assert inference_params is not None and inference_params.seqlen_offset > 0
876
+ assert self.use_flash_attn
877
+ if self.rotary_emb_dim > 0:
878
+ assert self.rotary_emb.scale is None, "This code path does not support xPos"
879
+ self.rotary_emb._update_cos_sin_cache(
880
+ inference_params.max_seqlen, device=q.device, dtype=q.dtype
881
+ )
882
+ rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
883
+ else:
884
+ rotary_cos, rotary_sin = None, None
885
+ batch = q.shape[0]
886
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
887
+ cache_seqlens = (
888
+ inference_params.lengths_per_sample[:batch]
889
+ if inference_params.lengths_per_sample is not None
890
+ else inference_params.seqlen_offset
891
+ )
892
+ alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
893
+ context = flash_attn_with_kvcache(
894
+ q,
895
+ kv_cache[:, :, 0],
896
+ kv_cache[:, :, 1],
897
+ kv[:, :, 0],
898
+ kv[:, :, 1],
899
+ rotary_cos=rotary_cos,
900
+ rotary_sin=rotary_sin,
901
+ cache_seqlens=cache_seqlens,
902
+ softmax_scale=self.inner_cross_attn.softmax_scale,
903
+ causal=self.inner_cross_attn.causal,
904
+ rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
905
+ alibi_slopes=alibi_slopes,
906
+ )
907
+ return context
908
+
909
+ def _update_kvcache_attention(self, q, kv, inference_params):
910
+ """Write kv to inference_params, then do attention"""
911
+ if inference_params.seqlen_offset == 0 or not self.use_flash_attn:
912
+ # TODO: this only uses seqlen_offset and not lengths_per_sample.
913
+ kv = self._update_kv_cache(kv, inference_params)
914
+ return self.inner_cross_attn(q, kv)
915
+ else:
916
+ batch = q.shape[0]
917
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
918
+ cache_seqlens = (
919
+ inference_params.lengths_per_sample[:batch]
920
+ if inference_params.lengths_per_sample is not None
921
+ else inference_params.seqlen_offset
922
+ )
923
+ alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
924
+ context = flash_attn_with_kvcache(
925
+ q,
926
+ kv_cache[:, :, 0],
927
+ kv_cache[:, :, 1],
928
+ kv[:, :, 0],
929
+ kv[:, :, 1],
930
+ cache_seqlens=cache_seqlens,
931
+ softmax_scale=self.inner_cross_attn.softmax_scale,
932
+ causal=self.inner_cross_attn.causal,
933
+ alibi_slopes=alibi_slopes,
934
+ )
935
+ return context
936
+
937
+ def forward(self, x, seqlen=None, inference_params=None, **kwargs):
938
+ """
939
+ Arguments:
940
+ x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
941
+ If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
942
+ split x during sequence parallel, we split the batch * seqlen dimension
943
+ (in case batch is small).
944
+ """
945
+ qkv = self.Wqkv(x)
946
+ if seqlen is not None:
947
+ qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
948
+ seqlen_offset = (
949
+ 0
950
+ if inference_params is None
951
+ else (
952
+ inference_params.lengths_per_sample
953
+ if inference_params.lengths_per_sample is not None
954
+ else inference_params.seqlen_offset
955
+ )
956
+ )
957
+ rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
958
+ if self.num_heads_kv == self.num_heads:
959
+ qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
960
+ if (
961
+ inference_params is None
962
+ or inference_params.seqlen_offset == 0
963
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
964
+ or not self.use_flash_attn
965
+ ):
966
+ if self.rotary_emb_dim > 0:
967
+ qkv = self.rotary_emb(
968
+ qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
969
+ )
970
+ if inference_params is None:
971
+ if not self.checkpointing:
972
+ context = self.inner_attn(qkv, **kwargs)
973
+ else:
974
+ context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
975
+ else:
976
+ context = self._update_kvcache_attention(
977
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
978
+ )
979
+ else:
980
+ context = self._apply_rotary_update_kvcache_attention(
981
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
982
+ )
983
+ else:
984
+ q = rearrange(
985
+ qkv[..., : self.num_heads_per_rank * self.head_dim],
986
+ "... (h d) -> ... h d",
987
+ d=self.head_dim,
988
+ )
989
+ kv = rearrange(
990
+ qkv[..., self.num_heads_per_rank * self.head_dim :],
991
+ "... (two hkv d) -> ... two hkv d",
992
+ two=2,
993
+ d=self.head_dim,
994
+ )
995
+ if (
996
+ inference_params is None
997
+ or inference_params.seqlen_offset == 0
998
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
999
+ or not self.use_flash_attn
1000
+ ):
1001
+ if self.rotary_emb_dim > 0:
1002
+ q, kv = self.rotary_emb(
1003
+ q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
1004
+ )
1005
+ if inference_params is None:
1006
+ if not self.checkpointing:
1007
+ context = self.inner_cross_attn(q, kv, **kwargs)
1008
+ else:
1009
+ context = torch.utils.checkpoint.checkpoint(
1010
+ self.inner_cross_attn, q, kv, **kwargs
1011
+ )
1012
+ else:
1013
+ context = self._update_kvcache_attention(q, kv, inference_params)
1014
+ else:
1015
+ context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
1016
+ context = rearrange(context, "b s h d -> b s (h d)")
1017
+ if seqlen is not None:
1018
+ context = rearrange(context, "b s d -> (b s) d")
1019
+ out = self.out_proj(context)
1020
+ return out
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7134d1e0be9b58fd378c0b681679ae41855daa9cbb0dc80b24e826ef59861ce
3
+ size 2692370584
modeling_custom_seq2seq_llm.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import CrossEntropyLoss
4
+ from transformers.modeling_outputs import Seq2SeqLMOutput
5
+ from transformers.activations import ACT2FN
6
+ from .flash_atten import MHA # Import the MHA class from the provided implementation
7
+ from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
8
+ from liger_kernel.transformers.rms_norm import LigerRMSNorm
9
+ from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
10
+
11
+ from .configuration_custom_seq2seq_llm import Seq2SeqConfig
12
+
13
+
14
+ class RMSNorm(nn.Module):
15
+ def __init__(self, hidden_size, eps=1e-6):
16
+ super().__init__()
17
+ self.weight = nn.Parameter(torch.ones(hidden_size))
18
+ self.eps = eps
19
+
20
+ def forward(self, hidden_states):
21
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
22
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
23
+ return self.weight * hidden_states.to(self.weight.dtype)
24
+
25
+
26
+ class CustomSeq2SeqLLM(PreTrainedModel):
27
+ config_class = Seq2SeqConfig
28
+ base_model_prefix = "custom_seq2seq_llm"
29
+
30
+ def __init__(self, config):
31
+ super().__init__(config)
32
+ self.config = config
33
+ self.shared = nn.Embedding(config.vocab_size, config.hidden_size)
34
+ self.encoder = CustomEncoder(config)
35
+ self.decoder = CustomDecoder(config)
36
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
37
+ self.loss_fn = LigerCrossEntropyLoss()
38
+ self.init_weights()
39
+
40
+ def get_encoder(self):
41
+ return self.encoder
42
+
43
+ def get_decoder(self):
44
+ return self.decoder
45
+
46
+ def get_input_embeddings(self):
47
+ return self.shared
48
+
49
+ def set_input_embeddings(self, value):
50
+ self.shared = value
51
+
52
+ def get_output_embeddings(self):
53
+ return self.lm_head
54
+
55
+ def forward(
56
+ self,
57
+ input_ids=None,
58
+ attention_mask=None,
59
+ decoder_input_ids=None,
60
+ decoder_attention_mask=None,
61
+ encoder_outputs=None,
62
+ past_key_values=None,
63
+ labels=None,
64
+ use_cache=None,
65
+ output_attentions=None,
66
+ output_hidden_states=None,
67
+ return_dict=None,
68
+ position_ids=None,
69
+ ):
70
+ if position_ids is None and input_ids is not None:
71
+ position_ids = torch.arange(input_ids.size(1), dtype=torch.long, device=input_ids.device)
72
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
73
+
74
+ if encoder_outputs is None and input_ids is not None:
75
+ encoder_outputs = self.encoder(
76
+ self.shared(input_ids),
77
+ attention_mask=attention_mask,
78
+ position_ids=position_ids,
79
+ )
80
+
81
+ if decoder_input_ids is None:
82
+ if labels is not None:
83
+ decoder_input_ids = self._shift_right(labels)
84
+ elif input_ids is not None:
85
+ decoder_input_ids = input_ids
86
+ else:
87
+ raise ValueError("Either decoder_input_ids, labels, or input_ids must be provided.")
88
+
89
+ decoder_outputs = self.decoder(
90
+ self.shared(decoder_input_ids),
91
+ encoder_outputs,
92
+ attention_mask=decoder_attention_mask,
93
+ position_ids=position_ids,
94
+ )
95
+
96
+ lm_logits = self.lm_head(decoder_outputs)
97
+
98
+ loss = None
99
+ if labels is not None:
100
+ loss = self.loss_fn(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
101
+
102
+ return Seq2SeqLMOutput(
103
+ loss=loss,
104
+ logits=lm_logits,
105
+ encoder_last_hidden_state=encoder_outputs,
106
+ decoder_hidden_states=decoder_outputs,
107
+ )
108
+
109
+ def _shift_right(self, input_ids):
110
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
111
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
112
+ shifted_input_ids[..., 0] = self.config.pad_token_id
113
+ return shifted_input_ids
114
+
115
+ class CustomEncoder(nn.Module):
116
+ def __init__(self, config):
117
+ super().__init__()
118
+ self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.num_encoder_layers)])
119
+ self.layer_norm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
120
+
121
+ def forward(self, hidden_states, attention_mask=None, position_ids=None):
122
+ for layer in self.layers:
123
+ hidden_states = layer(hidden_states, attention_mask, position_ids)
124
+ return self.layer_norm(hidden_states)
125
+
126
+ class EncoderLayer(nn.Module):
127
+ def __init__(self, config):
128
+ super().__init__()
129
+ self.self_attn = MHA(config.hidden_size, config.num_attention_heads,
130
+ num_heads_kv=config.num_key_value_heads,
131
+ dropout=config.attention_probs_dropout_prob,
132
+ causal=False,
133
+ rotary_emb_dim=config.rotary_emb_dim,
134
+ rotary_emb_base=config.rotary_emb_base,
135
+ rotary_emb_scale_base=config.rotary_emb_scale_base,
136
+ rotary_emb_interleaved=config.rotary_emb_interleaved)
137
+ self.feed_forward = LigerSwiGLUMLP(config)
138
+ self.layer_norm1 = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
139
+ self.layer_norm2 = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
140
+
141
+ def forward(self, hidden_states, attention_mask=None, position_ids=None):
142
+ normed_hidden_states = self.layer_norm1(hidden_states)
143
+ attention_output = self.self_attn(normed_hidden_states, key_padding_mask=attention_mask)
144
+ hidden_states = hidden_states + attention_output
145
+
146
+ normed_hidden_states = self.layer_norm2(hidden_states)
147
+ feed_forward_output = self.feed_forward(normed_hidden_states)
148
+ hidden_states = hidden_states + feed_forward_output
149
+
150
+ return hidden_states
151
+
152
+ class CustomDecoder(nn.Module):
153
+ def __init__(self, config):
154
+ super().__init__()
155
+ self.layers = nn.ModuleList([
156
+ DecoderLayer(config, use_cross_attention=self._should_use_cross_attention(i, config.num_decoder_layers))
157
+ for i in range(config.num_decoder_layers)
158
+ ])
159
+ self.layer_norm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
160
+
161
+ def _should_use_cross_attention(self, layer_idx, total_layers):
162
+ return layer_idx == 0 or layer_idx == total_layers - 1 or layer_idx % 2 == 0
163
+
164
+ def forward(self, hidden_states, encoder_hidden_states, attention_mask=None, position_ids=None):
165
+ for layer in self.layers:
166
+ hidden_states = layer(hidden_states, encoder_hidden_states, attention_mask, position_ids)
167
+ return self.layer_norm(hidden_states)
168
+
169
+ class DecoderLayer(nn.Module):
170
+ def __init__(self, config, use_cross_attention=True):
171
+ super().__init__()
172
+ self.use_cross_attention = use_cross_attention
173
+ self.self_attn = MHA(config.hidden_size, config.num_attention_heads,
174
+ num_heads_kv=config.num_key_value_heads,
175
+ dropout=config.attention_probs_dropout_prob,
176
+ causal=True,
177
+ rotary_emb_dim=config.rotary_emb_dim,
178
+ rotary_emb_base=config.rotary_emb_base,
179
+ rotary_emb_scale_base=config.rotary_emb_scale_base,
180
+ rotary_emb_interleaved=config.rotary_emb_interleaved)
181
+ if use_cross_attention:
182
+ self.cross_attn = MHA(config.hidden_size, config.num_attention_heads,
183
+ num_heads_kv=config.num_key_value_heads,
184
+ dropout=config.attention_probs_dropout_prob,
185
+ causal=False,
186
+ cross_attn=True)
187
+ self.layer_norm2 = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
188
+ self.feed_forward = LigerSwiGLUMLP(config)
189
+ self.layer_norm1 = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
190
+ self.layer_norm3 = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
191
+
192
+ def forward(self, hidden_states, encoder_hidden_states, attention_mask=None, position_ids=None):
193
+ normed_hidden_states = self.layer_norm1(hidden_states)
194
+ self_attention_output = self.self_attn(normed_hidden_states, key_padding_mask=attention_mask)
195
+ hidden_states = hidden_states + self_attention_output
196
+
197
+ if self.use_cross_attention:
198
+ normed_hidden_states = self.layer_norm2(hidden_states)
199
+ cross_attention_output = self.cross_attn(normed_hidden_states, x_kv=encoder_hidden_states, key_padding_mask=attention_mask)
200
+ hidden_states = hidden_states + cross_attention_output
201
+
202
+ normed_hidden_states = self.layer_norm3(hidden_states)
203
+ feed_forward_output = self.feed_forward(normed_hidden_states)
204
+ hidden_states = hidden_states + feed_forward_output
205
+
206
+ return hidden_states
207
+
208
+ class FeedForward(nn.Module):
209
+ def __init__(self, config):
210
+ super().__init__()
211
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
212
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
213
+ self.act = ACT2FN[config.hidden_act]
214
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
215
+
216
+ def forward(self, x):
217
+ x = self.fc1(x)
218
+ x = self.act(x)
219
+ x = self.dropout(x)
220
+ x = self.fc2(x)
221
+ x = self.dropout(x)
222
+ return x
random_states_0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50b14a8e8e4cb1f87530bb13452da585006a1a54e1fa02069afa73d0775f0736
3
+ size 14344