suriyagunasekar commited on
Commit
0f4ae0e
1 Parent(s): 47c069f

Upload MixFormerSequentialForCausalLM

Browse files
config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "phi-1",
3
+ "activation_function": "gelu_new",
4
+ "architecture": {
5
+ "block_cls": "parallel",
6
+ "mixer": {},
7
+ "mlp": {
8
+ "mlp_cls": "mlp"
9
+ }
10
+ },
11
+ "architectures": [
12
+ "MixFormerSequentialForCausalLM"
13
+ ],
14
+ "auto_map": {
15
+ "AutoConfig": "configuration_mixformer_sequential.MixFormerSequentialConfig",
16
+ "AutoModelForCausalLM": "modeling_mixformer_sequential.MixFormerSequentialForCausalLM"
17
+ },
18
+ "embd_layer": "default",
19
+ "embd_pdrop": 0.0,
20
+ "initializer_range": 0.02,
21
+ "layer_norm_epsilon": 1e-05,
22
+ "model_type": "mixformer-sequential",
23
+ "n_embd": 2048,
24
+ "n_head": 32,
25
+ "n_inner": null,
26
+ "n_layer": 24,
27
+ "n_positions": 2048,
28
+ "phyagi_version": "0.0.4.dev",
29
+ "resid_pdrop": 0.0,
30
+ "rotary_dim": 32,
31
+ "tie_word_embeddings": false,
32
+ "torch_dtype": "float32",
33
+ "transformers_version": "4.32.1",
34
+ "vocab_size": 51200
35
+ }
configuration_mixformer_sequential.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT license.
3
+
4
+ import math
5
+ from typing import Any, Dict, List, Optional, Union
6
+
7
+ from transformers import PretrainedConfig
8
+
9
+
10
+ class MixFormerSequentialConfig(PretrainedConfig):
11
+ """MixFormer (sequential for DeepSpeed) configuration."""
12
+
13
+ model_type = "mixformer-sequential"
14
+
15
+ attribute_map = {
16
+ "max_position_embeddings": "n_positions",
17
+ "hidden_size": "n_embd",
18
+ "num_attention_heads": "n_head",
19
+ "num_hidden_layers": "n_layer",
20
+ "input_emb_layer": "embd_layer", # `input_emb_layer` key is for backward compatibility
21
+ "blocks": "architecture", # `blocks` key is for backward compatibility
22
+ }
23
+
24
+ def __init__(
25
+ self,
26
+ vocab_size: Optional[int] = 50304,
27
+ n_positions: Optional[int] = 2048,
28
+ n_embd: Optional[int] = 1024,
29
+ n_layer: Optional[int] = 20,
30
+ n_inner: Optional[int] = None,
31
+ n_head: Optional[int] = 16,
32
+ rotary_dim: Optional[int] = 32,
33
+ activation_function: Optional[str] = "gelu_new",
34
+ embd_layer: Optional[str] = "default",
35
+ architecture: Union[Dict[str, Any], List[Dict[str, Any]]] = None,
36
+ embd_pdrop: Optional[float] = 0.0,
37
+ resid_pdrop: Optional[float] = 0.0,
38
+ layer_norm_epsilon: Optional[float] = 1e-5,
39
+ initializer_range: Optional[float] = 0.02,
40
+ tie_word_embeddings: Optional[bool] = False,
41
+ pad_vocab_size_multiple: Optional[int] = 64,
42
+ **kwargs
43
+ ) -> None:
44
+ self.vocab_size = int(math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
45
+ self.n_positions = n_positions
46
+ self.n_embd = n_embd
47
+ self.n_layer = n_layer
48
+ self.n_inner = n_inner
49
+ self.n_head = n_head
50
+ self.rotary_dim = min(rotary_dim, n_embd // n_head)
51
+ self.activation_function = activation_function
52
+ self.embd_layer = embd_layer
53
+ self.architecture = architecture
54
+ self.embd_pdrop = embd_pdrop
55
+ self.resid_pdrop = resid_pdrop
56
+ self.layer_norm_epsilon = layer_norm_epsilon
57
+ self.initializer_range = initializer_range
58
+
59
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.32.1"
4
+ }
modeling_mixformer_sequential.py ADDED
@@ -0,0 +1,742 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT license.
3
+
4
+ from __future__ import annotations
5
+
6
+ import math
7
+ import copy
8
+ from typing import Any, Dict, Optional, Tuple
9
+ from dataclasses import dataclass, field
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from einops import rearrange
15
+ from transformers.activations import ACT2FN
16
+ from transformers import PretrainedConfig, PreTrainedModel
17
+ from transformers.modeling_outputs import CausalLMOutputWithPast
18
+
19
+ from .configuration_mixformer_sequential import MixFormerSequentialConfig
20
+
21
+ @dataclass
22
+ class InferenceParams:
23
+ """Inference parameters that are passed to the main model in order
24
+ to efficienly calculate and store the context during inference."""
25
+ max_sequence_len: int
26
+ max_batch_size: int
27
+ sequence_len_offset: int = 0
28
+ batch_size_offset: int = 0
29
+ key_value_memory_dict: dict = field(default_factory=dict)
30
+ fused_ft_kernel: bool = False
31
+ lengths_per_sample: Optional[torch.Tensor] = None
32
+
33
+
34
+ class Embedding(nn.Module):
35
+ """Token embedding with dropout."""
36
+
37
+ def __init__(self, config: PretrainedConfig) -> None:
38
+ super().__init__()
39
+
40
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
41
+ self.drop = nn.Dropout(config.embd_pdrop)
42
+
43
+ def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
44
+ input_shape = input_ids.size()
45
+ input_ids = input_ids.view(-1, input_shape[-1])
46
+
47
+ hidden_states = self.wte(input_ids)
48
+ hidden_states = self.drop(hidden_states)
49
+
50
+ return hidden_states
51
+
52
+ class RotaryEmbedding(nn.Module):
53
+ """PyTorch implementation of `flash-attn` RotaryEmbedding layer."""
54
+
55
+ def __init__(
56
+ self,
57
+ dim: int,
58
+ base: Optional[int] = 10000,
59
+ scale_base: Optional[float] = None,
60
+ device: Optional[str] = None,
61
+ **kwargs,
62
+ ) -> None:
63
+ super().__init__()
64
+
65
+ if scale_base is not None:
66
+ raise NotImplementedError
67
+
68
+ # Generate and save the inverse frequency buffer (non-trainable)
69
+ self.dim = dim
70
+ self.base = base
71
+ self.scale_base = scale_base
72
+ self.device = device
73
+
74
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
75
+ self.register_buffer("inv_freq", inv_freq)
76
+
77
+ scale = (
78
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
79
+ if scale_base is not None
80
+ else None
81
+ )
82
+ self.register_buffer("scale", scale)
83
+
84
+ self._seq_len_cached = 0
85
+ self._cos_cached = None
86
+ self._sin_cached = None
87
+ self._cos_k_cached = None
88
+ self._sin_k_cached = None
89
+
90
+ def _update_cos_sin_cache(self, x: torch.FloatTensor, seqlen_offset: Optional[int] = 0) -> None:
91
+ # Reset the tables if the sequence length has changed,
92
+ # or if we're on a new device (possibly due to tracing for instance)
93
+ seqlen = x.shape[1] + seqlen_offset
94
+
95
+ # Re-generate the inverse frequency buffer if it's not fp32
96
+ # (for instance if model.half() was called)
97
+ if self.inv_freq.dtype != "torch.float32":
98
+ self.inv_freq = 1.0 / (
99
+ self.base ** (torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32) / self.dim)
100
+ )
101
+
102
+ if seqlen > self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype:
103
+ self._seq_len_cached = seqlen
104
+ t = torch.arange(seqlen, device=x.device, dtype=torch.float32)
105
+
106
+ # Don't do einsum, it converts fp32 to fp16
107
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
108
+ freqs = torch.outer(t, self.inv_freq.to(device=t.device, dtype=torch.float32))
109
+ if self.scale is None:
110
+ self._cos_cached = torch.cos(freqs).to(x.dtype)
111
+ self._sin_cached = torch.sin(freqs).to(x.dtype)
112
+ else:
113
+ power = (
114
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
115
+ ) / self.scale_base
116
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
117
+
118
+ # We want the multiplication by scale to happen in fp32
119
+ self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype)
120
+ self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
121
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
122
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
123
+
124
+ def apply_rotary_emb_qkv(
125
+ self,
126
+ qkv: torch.FloatTensor,
127
+ sin: torch.FloatTensor,
128
+ cos: torch.FloatTensor,
129
+ sin_k: Optional[torch.FloatTensor] = None,
130
+ cos_k: Optional[torch.FloatTensor] = None,
131
+ ) -> torch.FloatTensor:
132
+ _, seqlen, three, _, headdim = qkv.shape
133
+ assert three == 3
134
+
135
+ rotary_seqlen, rotary_dim = cos.shape
136
+ rotary_dim *= 2
137
+ assert rotary_dim <= headdim
138
+ assert seqlen <= rotary_seqlen
139
+
140
+ cos_k = cos if cos_k is None else cos_k
141
+ sin_k = sin if sin_k is None else sin_k
142
+ assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
143
+
144
+ q_rot = qkv[:, :, 0, :, :rotary_dim]
145
+ q_pass = qkv[:, :, 0, :, rotary_dim:]
146
+
147
+ k_rot = qkv[:, :, 1, :, :rotary_dim]
148
+ k_pass = qkv[:, :, 1, :, rotary_dim:]
149
+
150
+ # Splits the queries and keys in half
151
+ q1, q2 = q_rot.chunk(2, dim=-1)
152
+ k1, k2 = k_rot.chunk(2, dim=-1)
153
+ c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
154
+
155
+ # Casts to fp32 are necessary to prevent fp16 overflow issues
156
+ q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]]
157
+
158
+ # Computes the new keys and queries, recasting to original dtype
159
+ q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
160
+
161
+ k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
162
+
163
+ return torch.cat(
164
+ [
165
+ torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
166
+ torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
167
+ qkv[:, :, 2:3, :, :],
168
+ ],
169
+ axis=2,
170
+ )
171
+
172
+ def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
173
+ """Perform the forward pass.
174
+
175
+ Args:
176
+ qkv: Query, key and value tensors of shape (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim).
177
+ seqlen_offset: Used in generation where the passed `qkv` is only the last token in the batch.
178
+
179
+ Returns:
180
+ New `qkv` and the cached sinusoids.
181
+
182
+ """
183
+
184
+ self._update_cos_sin_cache(qkv, seqlen_offset)
185
+
186
+ return self.apply_rotary_emb_qkv(qkv, self._sin_cached[seqlen_offset:], self._cos_cached[seqlen_offset:])
187
+
188
+ def _update_kv_cache(kv, inference_params, layer_idx):
189
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
190
+ """
191
+ # Pre-allocate memory for key-values for inference.
192
+ num_heads, head_dim = kv.shape[-2:]
193
+ if layer_idx not in inference_params.key_value_memory_dict:
194
+ kv_cache = torch.empty(
195
+ inference_params.max_batch_size, inference_params.max_sequence_len, 2,
196
+ num_heads, head_dim, dtype=kv.dtype, device=kv.device
197
+ )
198
+ inference_params.key_value_memory_dict[layer_idx] = kv_cache
199
+ else:
200
+ kv_cache = inference_params.key_value_memory_dict[layer_idx]
201
+
202
+ # Adjust key and value for inference
203
+ batch_start = inference_params.batch_size_offset
204
+ batch_end = batch_start + kv.shape[0]
205
+ sequence_start = inference_params.sequence_len_offset
206
+ sequence_end = sequence_start + kv.shape[1]
207
+ assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
208
+ assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
209
+
210
+ assert kv_cache is not None
211
+ kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
212
+ kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
213
+ return kv
214
+
215
+
216
+ class MLP(nn.Module):
217
+ """Multi-Layer Perceptron.
218
+
219
+ Reference:
220
+ Attention Is All You Need.
221
+ https://arxiv.org/pdf/1706.03762.pdf.
222
+
223
+ """
224
+
225
+ def __init__(self, config: PretrainedConfig, n_inner: Optional[int] = None, act_fn: Optional[str] = None) -> None:
226
+ super().__init__()
227
+
228
+ act_fn = config.activation_function if act_fn is None else act_fn
229
+ assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}."
230
+
231
+ n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
232
+ n_inner = n_inner if n_inner is not None else 4 * config.n_embd
233
+
234
+ self.fc1 = nn.Linear(config.n_embd, n_inner)
235
+ self.fc2 = nn.Linear(n_inner, config.n_embd)
236
+ self.act = ACT2FN[act_fn]
237
+
238
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
239
+ old_keys = [prefix + "fc_in.weight", prefix + "fc_out.weight", prefix + "fc_in.bias", prefix + "fc_out.bias"]
240
+ new_keys = [prefix + "fc1.weight", prefix + "fc2.weight", prefix + "fc1.bias", prefix + "fc2.bias"]
241
+
242
+ if all(k in state_dict for k in old_keys) and not all(k in state_dict for k in new_keys):
243
+ # Older version of `MLP` saved with different key names.
244
+ for old_key, new_key in zip(old_keys, new_keys):
245
+ state_dict[new_key] = state_dict.pop(old_key)
246
+
247
+ return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
248
+
249
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
250
+ hidden_states = self.fc1(hidden_states)
251
+ hidden_states = self.act(hidden_states)
252
+ hidden_states = self.fc2(hidden_states)
253
+
254
+ return hidden_states
255
+
256
+
257
+ class FusedMLP(nn.Module):
258
+ """Fused Multi-Layer Perceptron from `flash-attn`.
259
+
260
+ Reference:
261
+ https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/ops/fused_dense.py.
262
+
263
+ """
264
+ def __init__(self, config: PretrainedConfig, n_inner: Optional[int] = None, act_fn: Optional[str] = None,
265
+ raise_on_missing: bool = False) -> None:
266
+ super().__init__()
267
+
268
+ act_fn = config.activation_function if act_fn is None else act_fn
269
+ assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}."
270
+
271
+ n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
272
+ n_inner = n_inner if n_inner is not None else 4 * config.n_embd
273
+
274
+ gelu_activations = ["gelu_new", "gelu_fast", "gelu_approx"]
275
+ activation = "gelu_approx" if act_fn in gelu_activations else "relu"
276
+
277
+ self.mlp = MLP(config, n_inner=n_inner, act_fn=act_fn)
278
+
279
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
280
+ return self.mlp(hidden_states)
281
+
282
+ class SelfAttention(nn.Module):
283
+ """Implement the scaled dot product attention with softmax.
284
+ Arguments
285
+ ---------
286
+ softmax_scale: The temperature to use for the softmax attention.
287
+ (default: 1/sqrt(d_keys) where d_keys is computed at
288
+ runtime)
289
+ attention_dropout: The dropout rate to apply to the attention
290
+ (default: 0.0)
291
+ """
292
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
293
+ super().__init__()
294
+ self.causal = causal
295
+ self.softmax_scale = softmax_scale
296
+ self.drop = nn.Dropout(attention_dropout)
297
+
298
+ def forward(self, qkv, causal=None, key_padding_mask=None):
299
+ """Implements the multihead softmax attention.
300
+ Arguments
301
+ ---------
302
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
303
+ causal: if passed, will override self.causal
304
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
305
+ False means to mask out. (B, S)
306
+ """
307
+ batch_size, seqlen = qkv.shape[0], qkv.shape[1]
308
+ causal = self.causal if causal is None else causal
309
+ q, k, v = qkv.unbind(dim=2)
310
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
311
+ scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
312
+ if key_padding_mask is not None:
313
+ padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype,
314
+ device=scores.device)
315
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
316
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
317
+ scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
318
+ if causal:
319
+ # "triu_tril_cuda_template" not implemented for 'BFloat16'
320
+ # So we have to construct the mask in float
321
+ causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
322
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
323
+ scores = scores + causal_mask.to(dtype=scores.dtype)
324
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
325
+ attention_drop = self.drop(attention)
326
+ output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
327
+ return output
328
+
329
+
330
+ class CrossAttention(nn.Module):
331
+ """Implement the scaled dot product attention with softmax.
332
+ Arguments
333
+ ---------
334
+ softmax_scale: The temperature to use for the softmax attention.
335
+ (default: 1/sqrt(d_keys) where d_keys is computed at
336
+ runtime)
337
+ attention_dropout: The dropout rate to apply to the attention
338
+ (default: 0.0)
339
+ """
340
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
341
+ super().__init__()
342
+ self.causal = causal
343
+ self.softmax_scale = softmax_scale
344
+ self.drop = nn.Dropout(attention_dropout)
345
+
346
+ def forward(self, q, kv, causal=None, key_padding_mask=None):
347
+ """Implements the multihead softmax attention.
348
+ Arguments
349
+ ---------
350
+ q: The tensor containing the query. (B, Sq, H, D)
351
+ kv: The tensor containing the key and value. (B, Sk, 2, H, D)
352
+ causal: if passed, will override self.causal
353
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
354
+ False means to mask out. (B, Sk)
355
+ """
356
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
357
+ causal = self.causal if causal is None else causal
358
+ seqlen_k = kv.shape[1]
359
+ assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
360
+ k, v = kv.unbind(dim=2)
361
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
362
+ scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
363
+ if key_padding_mask is not None:
364
+ padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype,
365
+ device=scores.device)
366
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
367
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
368
+ scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
369
+ if causal:
370
+ # "triu_tril_cuda_template" not implemented for 'BFloat16'
371
+ # So we have to construct the mask in float
372
+ causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0,
373
+ device=scores.device), 1)
374
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
375
+ scores = scores + causal_mask.to(dtype=scores.dtype)
376
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
377
+ attention_drop = self.drop(attention)
378
+ output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
379
+ return output
380
+
381
+ def find_mha_dims(
382
+ config: PretrainedConfig, n_head: Optional[int] = None, head_dim: Optional[int] = None
383
+ ) -> Tuple[int, int]:
384
+ """Validate and return the number of heads and head dimension for multi-head attention.
385
+
386
+ Args:
387
+ config: Model configuration.
388
+ n_head: Number of heads.
389
+ head_dim: Head dimension.
390
+
391
+ Returns:
392
+ Number of heads and head dimension.
393
+
394
+ """
395
+
396
+ assert all(
397
+ hasattr(config, attr) for attr in ["n_embd", "n_head"]
398
+ ), "`config` must have `n_embd` and `n_head` attributes."
399
+
400
+ if head_dim is None:
401
+ assert (
402
+ config.n_embd % config.n_head == 0
403
+ ), f"Hidden size ({config.n_embd}) must be divisible by the number of heads ({config.n_head})."
404
+
405
+ if n_head is None and head_dim is None:
406
+ head_dim = config.n_embd // config.n_head
407
+ n_head = config.n_head
408
+ elif n_head is None or head_dim is None:
409
+ raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
410
+
411
+ return n_head, head_dim
412
+
413
+
414
+ class MHA(nn.Module):
415
+ """Multi-head attention layer."""
416
+
417
+ def __init__(
418
+ self,
419
+ config: PretrainedConfig,
420
+ rotary_dim: Optional[int] = None,
421
+ n_head: Optional[int] = None,
422
+ head_dim: Optional[int] = None,
423
+ bias: Optional[bool] = True,
424
+ dropout: Optional[float] = 0.0,
425
+ softmax_scale: Optional[float] = None,
426
+ causal: Optional[bool] = True,
427
+ layer_idx: Optional[int] = None,
428
+ rotary_emb_scale_base: Optional[float] = None,
429
+ return_residual: Optional[bool] = False,
430
+ checkpointing: Optional[bool] = False,
431
+ device: Optional[str] = None,
432
+ dtype: Optional[torch.dtype] = None,
433
+ fused_dense: Optional[bool] = True,
434
+ flash_attn: Optional[bool] = True,
435
+ cutlass_attn: Optional[bool] = False,
436
+ flash_rotary: Optional[bool] = True,
437
+ raise_on_missing: Optional[bool] = False
438
+ ) -> None:
439
+ super().__init__()
440
+
441
+ factory_kwargs = {"device": device, "dtype": dtype}
442
+ n_head, head_dim = find_mha_dims(config, n_head, head_dim)
443
+
444
+ self.hidden_size = config.n_embd
445
+ self.n_head = n_head
446
+ self.head_dim = head_dim
447
+ self.op_size = n_head * head_dim
448
+
449
+ self.causal = causal
450
+ self.layer_idx = layer_idx
451
+ self.rotary_emb_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
452
+ self.fused_dense = fused_dense
453
+ self.flash_attn = flash_attn
454
+ self.cutlass_attn = cutlass_attn
455
+ self.flash_rotary = flash_rotary
456
+ self.return_residual = return_residual
457
+ self.checkpointing = checkpointing
458
+
459
+ if self.rotary_emb_dim > 0:
460
+ rotary_kwargs = {"device": device}
461
+ if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
462
+ rotary_kwargs["scale_base"] = rotary_emb_scale_base
463
+
464
+ self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs)
465
+ else:
466
+ pass
467
+
468
+ self.Wqkv = nn.Linear(self.hidden_size, 3 * self.op_size, bias=bias, **factory_kwargs)
469
+ self.out_proj = nn.Linear(self.op_size, self.hidden_size, bias=bias, **factory_kwargs)
470
+
471
+ self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
472
+ self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
473
+
474
+ def _update_kv_cache(self, kv: torch.FloatTensor, inference_params: InferenceParams) -> None:
475
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
476
+
477
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
478
+
479
+ return _update_kv_cache(kv, inference_params, self.layer_idx)
480
+
481
+ def forward(
482
+ self,
483
+ x: torch.FloatTensor,
484
+ x_kv: Optional[torch.FloatTensor] = None,
485
+ key_padding_mask: Optional[torch.BoolTensor] = None,
486
+ cu_seqlens: Optional[torch.LongTensor] = None,
487
+ max_seqlen: Optional[int] = None,
488
+ mixer_subset: Optional[torch.LongTensor] = None,
489
+ past_cache: Optional[InferenceParams] = None,
490
+ **kwargs
491
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
492
+ """Perform the forward pass.
493
+
494
+ Args:
495
+ x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
496
+ cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
497
+ is the is the sum of the sequence lengths in the batch.
498
+ x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
499
+ key_padding_mask: boolean mask, True means to keep, False means to mask out.
500
+ (batch, seqlen). Only applicable when not using FlashAttention.
501
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
502
+ of the sequences in the batch, used to index into x. Only applicable when using
503
+ FlashAttention.
504
+ max_seqlen: int. Maximum sequence length in the batch.
505
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
506
+ before applying the query projection. Useful for e.g., ViT where we only care
507
+ about the CLS token in the last layer.
508
+ past_cache: For generation only.
509
+
510
+ Returns:
511
+ (batch, seqlen, hidden_dim) if cu_seqlens is None and max_seqlen is None,
512
+ else (total, hidden_dim) where total is the is the sum of the sequence lengths
513
+ in the batch.
514
+
515
+ """
516
+
517
+ if cu_seqlens is not None:
518
+ assert max_seqlen is not None
519
+ assert key_padding_mask is None
520
+ assert self.flash_attn
521
+ assert self.rotary_emb_dim == 0
522
+
523
+ if key_padding_mask is not None:
524
+ assert cu_seqlens is None
525
+ assert max_seqlen is None
526
+ assert not self.flash_attn
527
+
528
+ if past_cache is not None:
529
+ assert key_padding_mask is None
530
+ assert cu_seqlens is None and max_seqlen is None
531
+
532
+ attn_kwargs = {"key_padding_mask": key_padding_mask}
533
+
534
+ assert x_kv is None and mixer_subset is None
535
+
536
+ qkv = self.Wqkv(x)
537
+ qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
538
+
539
+ if past_cache is None:
540
+ if self.rotary_emb_dim > 0:
541
+ qkv = self.rotary_emb(qkv)
542
+ context = self.inner_attn(qkv, **attn_kwargs)
543
+
544
+ else:
545
+ if self.rotary_emb_dim > 0:
546
+ qkv = self.rotary_emb(qkv, seqlen_offset=past_cache.sequence_len_offset)
547
+ q = qkv[:, :, 0]
548
+ kv = self._update_kv_cache(qkv[:, :, 1:], past_cache)
549
+ # If we're processing the prompt, causal=None (use self.causal).
550
+ # If we're decoding, then causal=False.
551
+ causal = None if past_cache.sequence_len_offset == 0 else False
552
+ context = self.inner_cross_attn(q, kv, causal=causal)
553
+
554
+ out = rearrange(context, "... h d -> ... (h d)")
555
+ out = self.out_proj(out)
556
+
557
+ return out if not self.return_residual else (out, x)
558
+
559
+ class ParallelBlock(nn.Module):
560
+ """Parallel block.
561
+
562
+ This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen).
563
+
564
+ """
565
+
566
+ def __init__(
567
+ self,
568
+ config: PretrainedConfig,
569
+ mixer: Optional[Dict[str, Any]] = None,
570
+ mlp: Optional[Dict[str, Any]] = None,
571
+ block_idx: Optional[int] = None,
572
+ ) -> None:
573
+ super().__init__()
574
+
575
+ self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
576
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
577
+ self.block_idx = block_idx
578
+
579
+ self.mixer = MHA(config=config, **mixer, layer_idx=block_idx)
580
+ mlp_cls = mlp.pop('mlp_cls')
581
+ if mlp_cls == 'fused_mlp':
582
+ self.mlp = FusedMLP(config=config, **mlp)
583
+ else:
584
+ self.mlp = MLP(config=config, **mlp)
585
+
586
+ def forward(self, hidden_states: torch.FloatTensor,
587
+ past_cache: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
588
+ residual = hidden_states
589
+ hidden_states = self.ln(hidden_states)
590
+
591
+ attn_outputs = self.mixer(hidden_states, past_cache=past_cache)
592
+ if isinstance(attn_outputs, tuple):
593
+ attn_outputs = attn_outputs[0]
594
+
595
+ attn_outputs = self.resid_dropout(attn_outputs)
596
+ feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
597
+
598
+ hidden_states = attn_outputs + feed_forward_hidden_states + residual
599
+
600
+ return hidden_states
601
+
602
+ class CausalLMHead(nn.Module):
603
+ """Causal Language Modeling head.
604
+
605
+ Reference:
606
+ Improving Language Understanding by Generative Pre-Training.
607
+ https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
608
+
609
+ """
610
+
611
+ def __init__(self, config: PretrainedConfig) -> None:
612
+ super().__init__()
613
+
614
+ self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
615
+ self.linear = nn.Linear(config.n_embd, config.vocab_size)
616
+
617
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
618
+ hidden_states = self.ln(hidden_states)
619
+ logits = self.linear(hidden_states).to(torch.float32)
620
+
621
+ return logits
622
+
623
+
624
+ class CausalLMLoss(nn.Module):
625
+ """Causal Language Modeling loss.
626
+
627
+ Reference:
628
+ Improving Language Understanding by Generative Pre-Training.
629
+ https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
630
+
631
+ """
632
+
633
+ def __init__(self, shift_labels: Optional[bool] = True) -> None:
634
+ super().__init__()
635
+
636
+ self.shift_labels = shift_labels
637
+ self.loss_fct = nn.CrossEntropyLoss()
638
+
639
+ def forward(self, logits: torch.FloatTensor, labels: torch.LongTensor) -> torch.FloatTensor:
640
+ if self.shift_labels:
641
+ logits = logits[..., :-1, :].contiguous()
642
+ labels = labels[..., 1:].contiguous()
643
+
644
+ loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
645
+
646
+ return loss
647
+
648
+ class MixFormerSequentialPreTrainedModel(PreTrainedModel):
649
+ """MixFormer (sequential for DeepSpeed) pre-trained model."""
650
+
651
+ config_class = MixFormerSequentialConfig
652
+ base_model_prefix = "transformer"
653
+ supports_gradient_checkpointing = True
654
+
655
+ def __init__(self, *inputs, **kwargs) -> None:
656
+ super().__init__(*inputs, **kwargs)
657
+
658
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs) -> Dict[str, Any]:
659
+ if "use_cache" in kwargs and not kwargs["use_cache"]:
660
+ return {"input_ids": input_ids}
661
+
662
+ if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
663
+ past_key_values = InferenceParams(
664
+ max_batch_size=input_ids.shape[0],
665
+ max_sequence_len=self.config.n_positions,
666
+ sequence_len_offset=0,
667
+ batch_size_offset=0,
668
+ fused_ft_kernel=False,
669
+ key_value_memory_dict={},
670
+ )
671
+ else:
672
+ # assume past_key_values has cached all but last token in input_ids
673
+ past_key_values.sequence_len_offset = len(input_ids[0]) - 1
674
+ input_ids = input_ids[:, -1].unsqueeze(-1)
675
+
676
+ return {"input_ids": input_ids, "past_key_values": past_key_values, **kwargs}
677
+
678
+
679
+ class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
680
+ """MixFormer (sequential for DeepSpeed) for Causal Language Modeling."""
681
+
682
+ _keys_to_ignore_on_load_missing = [""]
683
+ _keys_to_ignore_on_load_unexpected = [r"layers\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
684
+
685
+ def __init__(self, config: MixFormerSequentialConfig) -> None:
686
+ super().__init__(config)
687
+
688
+ modules = [Embedding(config)]
689
+ block_config = config.architecture
690
+
691
+ if not isinstance(block_config, list):
692
+ block_config = [block_config for _ in range(config.n_layer)]
693
+
694
+ if config.n_layer != len(block_config):
695
+ config.n_layer = len(block_config)
696
+
697
+ for block_idx, block in enumerate(block_config):
698
+ # `block_cls` with `legacy` value is for backward compatibility
699
+ # `path` key is for backward compatibility
700
+ block = copy.deepcopy(block) or {"block_cls": "parallel"}
701
+ block_cls = block.pop("path", None) or block.pop("block_cls", None)
702
+
703
+ block["block_idx"] = block_idx
704
+ modules.append(ParallelBlock(config, **block))
705
+
706
+ modules.append(CausalLMHead(config))
707
+
708
+ self.layers = nn.Sequential(*modules)
709
+ self.loss = CausalLMLoss()
710
+
711
+ self.post_init()
712
+
713
+ def get_input_embeddings(self) -> nn.Embedding:
714
+ return self.layers[0].wte
715
+
716
+ def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
717
+ self.layers[0].wte = new_embeddings
718
+
719
+ def get_output_embeddings(self) -> nn.Linear:
720
+ return self.layers[-1].linear
721
+
722
+ def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
723
+ self.layers[-1].linear = new_embeddings
724
+
725
+ def forward(
726
+ self, input_ids: torch.LongTensor, labels: Optional[torch.LongTensor] = None,
727
+ past_key_values: Optional[torch.FloatTensor] = None, **kwargs
728
+ ) -> CausalLMOutputWithPast:
729
+
730
+ if not past_key_values:
731
+ lm_logits = self.layers(input_ids)
732
+ else:
733
+ hidden_layer = self.layers[0](input_ids)
734
+ for module in self.layers[1:-1]:
735
+ hidden_layer = module(hidden_layer, past_cache=past_key_values)
736
+ lm_logits = self.layers[-1](hidden_layer)
737
+
738
+ loss = None
739
+ if labels is not None:
740
+ loss = self.loss(lm_logits, labels)
741
+
742
+ return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values)
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0bd82d6ec4fc74e2800a71d20c6ef8ea85724fdb41e34711858377ef5b46b4e2
3
+ size 5673167489