mqyqlx commited on
Commit
29a399e
1 Parent(s): 97c2057

upload model and code

Browse files
config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "DCFormer"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_dcformer.DCFormerConfig",
7
+ "AutoModelForCausalLM": "modeling_dcformer.DCFormer"
8
+ },
9
+ "block_size": 2048,
10
+ "bos_token_id": 1,
11
+ "dim": 2560,
12
+ "eos_token_id": 2,
13
+ "head_dim": 80,
14
+ "intermediate_size": 6912,
15
+ "is_training": false,
16
+ "model_type": "dcformer",
17
+ "n_head": 32,
18
+ "n_layer": 32,
19
+ "n_local_heads": 32,
20
+ "norm_eps": 1e-06,
21
+ "q_chunk_size": 128,
22
+ "query_wise": false,
23
+ "rope_base": 10000,
24
+ "tie_word_embeddings": false,
25
+ "torch_dtype": "float16",
26
+ "transformers_version": "4.33.2",
27
+ "use_dcmha": true,
28
+ "use_gradient_checkpointing": false,
29
+ "use_qk_norm": true,
30
+ "vocab_size": 50257,
31
+ "window_size": 256,
32
+ "window_type": null
33
+ }
configuration_dcformer.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ from transformers.utils import logging
3
+ from typing import Optional,Tuple,List
4
+
5
+
6
+ def find_multiple(n: int, k: int) -> int:
7
+ if n % k == 0:
8
+ return n
9
+ return n + k - (n % k)
10
+
11
+ class DCFormerConfig(PretrainedConfig):
12
+ model_type = "dcformer"
13
+
14
+ '''
15
+ DCFormerConfig is a config class for DCFormer, which is adpated from https://github.com/pytorch-labs/gpt-fast/blob/main/model.py#L21
16
+ '''
17
+ def __init__(
18
+ self,
19
+ block_size: int = 2048,
20
+ vocab_size: int = 32000,
21
+ n_layer: int = 32,
22
+ n_head: int = 32,
23
+ dim: int = 2560,
24
+ intermediate_size: int = None,
25
+ n_local_heads: int = -1,
26
+ head_dim: int = 64,
27
+ rope_base: float = 10000,
28
+ norm_eps: float = 1e-5,
29
+ use_gradient_checkpointing: bool = False,
30
+ is_training: bool = False,
31
+ q_chunk_size: int = 128,
32
+ use_dcmha: bool = True,
33
+ use_qk_norm: bool = False ,
34
+ window_size: Optional[int] = 256,
35
+ window_type: Optional[str] = None,
36
+ query_wise: bool = False,
37
+ pad_token_id: Optional[int]= None,
38
+ bos_token_id: int =1,
39
+ eos_token_id: int =2,
40
+ tie_word_embeddings: bool =False,
41
+ **kwargs,
42
+ ):
43
+ self.block_size=block_size
44
+ self.vocab_size=vocab_size
45
+ self.n_layer=n_layer
46
+ self.n_head=n_head
47
+ self.dim=dim
48
+ self.intermediate_size=intermediate_size
49
+ self.n_local_heads=n_local_heads
50
+ self.head_dim=head_dim
51
+ self.rope_base=rope_base
52
+ self.norm_eps=norm_eps
53
+ self.use_gradient_checkpointing=use_gradient_checkpointing
54
+ self.is_training=is_training
55
+ self.q_chunk_size=q_chunk_size
56
+ self.use_dcmha=use_dcmha
57
+ self.use_qk_norm=use_qk_norm
58
+ self.window_size=window_size
59
+ self.window_type=window_type
60
+ self.query_wise=query_wise
61
+ # post init
62
+ if self.n_local_heads == -1:
63
+ self.n_local_heads = self.n_head
64
+ if self.intermediate_size is None:
65
+ hidden_dim = 4 * self.dim
66
+ n_hidden = int(2 * hidden_dim / 3)
67
+ self.intermediate_size = find_multiple(n_hidden, 256)
68
+ self.head_dim = self.dim // self.n_head
69
+
70
+ super().__init__(
71
+ pad_token_id=pad_token_id,
72
+ bos_token_id=bos_token_id,
73
+ eos_token_id=eos_token_id,
74
+ tie_word_embeddings=tie_word_embeddings,
75
+ **kwargs,
76
+ )
generation_demo.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+
4
+ import os
5
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
6
+
7
+ tokenizer = AutoTokenizer.from_pretrained("Caiyun-AI/DCFormer-2.8B")
8
+ model = AutoModelForCausalLM.from_pretrained("Caiyun-AI/DCFormer-2.8B", trust_remote_code=True)
9
+
10
+ device = torch.device('cuda')
11
+ MAX_BATCH_SIZE = 1
12
+ MAX_SEQ_LENGTH = 2048
13
+ NUM_TOKENS_TO_GENERATE = 100
14
+ COMPILE = True
15
+
16
+ _ = model.to(device=device,dtype=torch.float16)
17
+ with torch.device(device):
18
+ model.setup_caches(max_batch_size=MAX_BATCH_SIZE, max_seq_length=MAX_SEQ_LENGTH, set_kv_cache=True)
19
+
20
+ def decode_one_token(model, cur_token, input_pos):
21
+ logits = model(cur_token, input_pos=input_pos, return_tensor=True)
22
+ new_token = torch.argmax(logits[:, -1], dim=-1)[:,None]
23
+ return new_token
24
+
25
+ prompt = "Beijing is the capital of China. London is the capital of"
26
+ input_ids = tokenizer.encode(prompt, return_tensors='pt')
27
+
28
+ compiled_decode_one_token = torch.compile(decode_one_token,mode="reduce-overhead", fullgraph=True) if COMPILE else None
29
+
30
+ with torch.no_grad():
31
+ generated_ids = model.generate(input_ids.to(device),num_tokens_to_generate=NUM_TOKENS_TO_GENERATE, compiled_decode_one_token=compiled_decode_one_token)
32
+ text = tokenizer.decode(generated_ids[0])
33
+ print('generated text:', text)
modeling_dcformer.py ADDED
@@ -0,0 +1,627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional,Tuple,List
3
+ from collections import namedtuple
4
+
5
+ import math
6
+ import time
7
+ import json
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch import Tensor
11
+ from torch.nn import functional as F
12
+ from torch.utils.checkpoint import checkpoint
13
+
14
+ try:
15
+ from .configuration_dcformer import DCFormerConfig
16
+ except:
17
+ from configuration_dcformer import DCFormerConfig
18
+
19
+ from transformers.modeling_utils import PreTrainedModel
20
+
21
+
22
+ class KVKWCache(nn.Module):
23
+ def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, window_size=2048, dtype=torch.float16, use_kw_cache=True):
24
+ super().__init__()
25
+ self.head_dim = head_dim
26
+ self.kw_dim = 2 * n_heads
27
+ self.n_heads = n_heads
28
+ self.window_size = window_size
29
+ self.use_kw_cache = use_kw_cache
30
+ if window_size is None:
31
+ self.seq_length = max_seq_length
32
+ else:
33
+ self.seq_length = min(window_size, max_seq_length)
34
+ cache_shape = (max_batch_size, n_heads, self.seq_length, head_dim)
35
+ kw_cache_shape = (max_batch_size, self.seq_length, 2, n_heads, n_heads)
36
+ self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
37
+ self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
38
+ if self.use_kw_cache:
39
+ self.register_buffer('kw_cache', torch.zeros(kw_cache_shape, dtype=dtype))
40
+
41
+ def update(self, input_pos, k_val, v_val, kw_val=None): # kw_val B,N,S,2,N B2NSD
42
+ # input_pos: [S], k_val: [B, H, S, D]
43
+ assert input_pos.shape[-1] == k_val.shape[2]
44
+ B,N,S,D = v_val.shape
45
+ k_out = self.k_cache
46
+ v_out = self.v_cache
47
+ if self.use_kw_cache:
48
+ kw_out = self.kw_cache
49
+ else:
50
+ kw_out = None
51
+
52
+ if self.window_size is None:
53
+ k_out[:, :, input_pos] = k_val
54
+ v_out[:, :, input_pos] = v_val
55
+ if self.use_kw_cache and kw_val is not None:
56
+ kw_out[:,input_pos] = kw_val
57
+ elif S == 1:
58
+ input_pos = input_pos % self.seq_length
59
+ v_out[:, :, input_pos] = v_val
60
+ k_out[:, :, input_pos] = k_val
61
+ if self.use_kw_cache and kw_val is not None:
62
+ kw_out[:,input_pos] = kw_val
63
+ else: # prefill
64
+ start = max(0, input_pos[-1]-self.seq_length+1)
65
+ input_pos = input_pos[start:] % self.seq_length
66
+ v_out[:, :, input_pos] = v_val[:,:,start:]
67
+ k_out[:, :, input_pos] = k_val[:,:,start:]
68
+ if self.use_kw_cache and kw_val is not None:
69
+ kw_out[:, input_pos] = kw_val[:,start:]
70
+ return k_out, v_out, kw_out
71
+
72
+ class DCFormer(PreTrainedModel):
73
+ '''
74
+ DCFormer's implementation is adapted from https://github.com/pytorch-labs/gpt-fast/blob/main/model.py#L89
75
+ '''
76
+
77
+ def __init__(self, config: DCFormerConfig) -> None:
78
+ super().__init__(config)
79
+ self.config = config
80
+
81
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
82
+ self.layers = nn.ModuleList(DCFormerBlock(config, lidx) for lidx in range(config.n_layer))
83
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
84
+ self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
85
+ self.use_gradient_checkpointing = config.use_gradient_checkpointing
86
+ self.is_training = config.is_training
87
+
88
+ self.freqs_cis: Optional[Tensor] = None
89
+ self.mask_cache: Optional[Tensor] = None
90
+ self.window_size = config.window_size
91
+ self.max_batch_size = -1
92
+ self.max_seq_length = -1
93
+
94
+ def setup_caches(self, max_batch_size, max_seq_length, set_kv_cache=True):
95
+ if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
96
+ return
97
+ head_dim = self.config.dim // self.config.n_head
98
+ max_seq_length = find_multiple(max_seq_length, 8)
99
+ self.max_seq_length = max_seq_length
100
+ self.max_batch_size = max_batch_size
101
+ if not self.is_training:
102
+ for b in self.layers:
103
+ if set_kv_cache:
104
+ use_kw_cache = False if b.attention.query_wise else True
105
+ b.attention.kv_cache = KVKWCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, window_size=b.attention.window_size, use_kw_cache=use_kw_cache)
106
+ b.attention.dyn_w_proj.merge_weights()
107
+ if not b.attention.use_sw:
108
+ dtype = b.attention.wo.weight.dtype
109
+ device = b.attention.wo.weight.device
110
+ b.attention.dyn_w_proj.sw = b.attention.dyn_w_proj.sw.to(device=device, dtype=dtype)
111
+ b.attention.dyn_w_proj.pre_proj.w = b.attention.dyn_w_proj.pre_proj.w.to(device=device, dtype=dtype)
112
+ b.attention.dyn_w_proj.post_proj.w = b.attention.dyn_w_proj.post_proj.w.to(device=device, dtype=dtype)
113
+
114
+ self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base).to(self.tok_embeddings.weight.device)
115
+ if self.is_training:
116
+ self.causal_mask = torch.tril(torch.ones(self.config.block_size, self.config.block_size, dtype=torch.bool, device=self.tok_embeddings.weight.device))
117
+ elif self.window_size is None:
118
+ self.causal_mask = torch.tril(torch.ones(max_seq_length, max_seq_length, dtype=torch.bool, device=self.tok_embeddings.weight.device))
119
+ else:
120
+ self.causal_mask = torch.stack([make_window_mask(max_seq_length, self.config.window_size), torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))]) # LG
121
+
122
+ def generate(self, input_ids, num_tokens_to_generate=10, compiled_decode_one_token=None):
123
+ batch_size, seq_length = input_ids.shape
124
+ input_pos = torch.arange(seq_length, device=self.device)
125
+ generated_ids = torch.zeros(batch_size, seq_length + num_tokens_to_generate + 1, dtype=torch.int, device=self.device)
126
+ generated_ids[:, :seq_length] = input_ids.to(self.device).to(torch.int)
127
+ logits = self.forward(input_ids, input_pos=input_pos,return_tensor=True)
128
+ _next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
129
+ next_token = torch.zeros(self.max_batch_size, 1, device=self.device, dtype=torch.int)
130
+ next_token[:batch_size] = _next_token
131
+ generated_ids[:, seq_length] = next_token[:batch_size, 0]
132
+ input_pos = torch.tensor([seq_length], device=self.device)
133
+ for _ in range(1, num_tokens_to_generate):
134
+ if compiled_decode_one_token is not None:
135
+ next_token = compiled_decode_one_token(self, next_token.clone(), input_pos)
136
+ else:
137
+ next_token = self.decode_one_token(next_token.clone(), input_pos)
138
+ generated_ids[:, input_pos+1] = next_token.int()[:batch_size]
139
+ input_pos += 1
140
+ return generated_ids
141
+
142
+ def decode_one_token(self, cur_token, input_pos):
143
+ logits = self.forward(
144
+ cur_token,
145
+ input_pos=input_pos,
146
+ return_tensor=True
147
+ )
148
+ new_token = torch.argmax(logits[:, -1], dim=-1)[:,None]
149
+ return new_token
150
+
151
+ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None, return_tensor=False) -> Tensor:
152
+ assert self.freqs_cis is not None, "Caches must be initialized first"
153
+ if input_pos is None:
154
+ input_pos = torch.arange(idx.shape[-1], device=idx.device, dtype=torch.int)
155
+ if self.window_size is None or self.is_training:
156
+ mask = self.causal_mask[None, None, input_pos]
157
+ else:
158
+ mask = self.causal_mask[None, None,:,input_pos]
159
+ freqs_cis = self.freqs_cis[input_pos][:idx.shape[-1]]
160
+ x = self.tok_embeddings(idx)
161
+ for i, layer in enumerate(self.layers):
162
+ if self.is_training or self.window_size is None :
163
+ layer_mask = mask
164
+ elif self.window_size is not None:
165
+ layer_mask = mask[:,:,1] if layer.attention.window_size is None else mask[:,:,0]
166
+ if self.use_gradient_checkpointing:
167
+ x = checkpoint(layer, x, input_pos, freqs_cis, layer_mask)
168
+ else:
169
+ x = layer(x, input_pos, freqs_cis, layer_mask)
170
+ x = self.norm(x)
171
+ logits = self.output(x)
172
+ if return_tensor:
173
+ return logits
174
+ else:
175
+ CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
176
+ return CausalLMOutput(logits=logits)
177
+
178
+ class DCFormerBlock(nn.Module):
179
+ def __init__(self, config: DCFormerConfig, lidx) -> None:
180
+ super().__init__()
181
+ self.lidx = lidx
182
+ self.attention = DCMHAttention(config, lidx)
183
+ self.feed_forward = FeedForward(config)
184
+ self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
185
+ self.attention_norm = RMSNorm(config.dim, config.norm_eps)
186
+
187
+ def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
188
+ h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos, fast_infer=True)
189
+ out = h + self.feed_forward(self.ffn_norm(h))
190
+ return out
191
+
192
+ class DynamicWeightProjection(nn.Module):
193
+
194
+ def __init__(self, num_heads=32, num_groups=1, residual=True, query_input_dim=4096, dynamic_squeeze_ratio=16, dynamic_w_hidden_dim=128,dtype=torch.float16,use_sw=False):
195
+ super().__init__()
196
+ self.num_heads = num_heads
197
+ self.num_groups = num_groups
198
+ self.query_input_dim = query_input_dim
199
+ self.dynamic_squeeze_ratio = dynamic_squeeze_ratio
200
+ self.dynamic_w_hidden_dim = dynamic_w_hidden_dim
201
+ self.dw_hidden_activation = nn.GELU()
202
+ self.num_heads_per_group = self.num_heads // self.num_groups
203
+ self.dw_activation = nn.Tanh()
204
+ self.dw1_norm = RMSnormNoscale(dim=-1)
205
+ self.use_sw = use_sw
206
+ self.pre_proj = CrossHeadProjection('pre', num_heads=self.num_heads, use_sw=use_sw)
207
+ self.post_proj = CrossHeadProjection('post', num_heads=self.num_heads, use_sw=use_sw)
208
+
209
+ dynamic_hidden_dim = self.num_heads_per_group // self.dynamic_squeeze_ratio
210
+ self.dynamic_hidden_dim = dynamic_hidden_dim
211
+ self.dw1 = nn.parameter.Parameter(torch.zeros(self.query_input_dim, self.num_groups, 4, self.dynamic_w_hidden_dim, dtype=dtype)) #(4096, 1, 4, 128)
212
+ G, K, M = self.num_groups, self.dynamic_w_hidden_dim, self.num_heads_per_group
213
+ I = dynamic_hidden_dim * 2
214
+ self.qkw = nn.parameter.Parameter(torch.zeros([G, 4, K, I, M], dtype=dtype)) # (1, 4, 128, 4, 32)
215
+ self.dd = nn.parameter.Parameter(torch.zeros(self.query_input_dim, self.num_groups, self.num_heads_per_group * 4, dtype=dtype)) # (4096, 1, 128)
216
+
217
+ self.merge_weights()
218
+
219
+ def merge_weights(self):
220
+ self.dw_m = nn.parameter.Parameter(torch.cat([self.dw1.reshape(self.query_input_dim, -1), self.dd.squeeze(1)], dim=-1)).to(self.dw1.device) # E,(4*K + K) K=2*N*I
221
+ self.qkw_m = nn.parameter.Parameter(self.qkw.permute(0,1,2,3,4).reshape(4,self.dynamic_w_hidden_dim,-1)).to(self.dw1.device) #(4,K,I*M)
222
+ if self.use_sw:
223
+ self.sw = nn.parameter.Parameter(torch.stack([self.pre_proj.w, self.post_proj.w]).squeeze(1) + torch.eye(self.num_heads) ).to(self.dw1.device) # (2,N,N) sw + identity matrix
224
+ else:
225
+ self.sw = (torch.eye(self.num_heads).expand(2,self.num_heads,self.num_heads)).to(self.dw1.device) # identity matrix (2,N,N)
226
+
227
+ def forward(self,query_vec,KW:Optional[torch.Tensor]=None, gen_cache:Optional[bool]=True):
228
+ dw_hidden = torch.einsum('BTD,DGCK->BTGCK', query_vec, self.dw1) # C=4 [pre,post]*[query,key]
229
+ dw_hidden = self.dw_hidden_activation(dw_hidden) #BTGCK
230
+ w1, w2 = torch.split(torch.einsum('BTGCK,GCKIM->BTGCIM', dw_hidden, self.qkw), self.qkw.shape[-2]//2, dim=-2) #BTGC(2I)M -> [BTGCIM] * 2
231
+ w1 = self.dw1_norm(w1) # BTGCIM
232
+ pre_qw1, pre_kw1, post_qw1, post_kw1 = unbind(w1, 4, dim=3) # BTG4IM->[BTGIM]*4
233
+ pre_qw2, pre_kw2, post_qw2, post_kw2 = unbind(w2, 4, dim=3)
234
+ dd = torch.einsum('BTD,DGM->BTGM', query_vec, self.dd) # BTG(4M)
235
+ dd = self.dw_activation(dd)
236
+ pre_qdd, pre_kdd, post_qdd, post_kdd = torch.split(dd, dd.shape[-1] // 4, dim=-1) # BTG(4N)->[BTGN]*4
237
+ pre_dw_args = (pre_qw1, pre_qw2, pre_kw1, pre_kw2, pre_qdd, pre_kdd)
238
+ post_dw_args = (post_qw1, post_qw2, post_kw1, post_kw2, post_qdd, post_kdd)
239
+ if gen_cache: # generate KW cache
240
+ pre_kw = torch.einsum('BSGIM, BSGIN->BSMN', pre_kw1, pre_kw2) + torch.diag_embed(pre_kdd.squeeze(2)) # merge kw and kdd
241
+ post_kw = torch.einsum('BSGIM, BSGIN->BSMN', post_kw1, post_kw2) + torch.diag_embed(post_kdd.squeeze(2))
242
+ KW = torch.stack((pre_kw, post_kw), dim=-3) # BSMN,BSMN->BS2MN
243
+ return pre_dw_args, post_dw_args, KW
244
+
245
+
246
+ class RMSnormNoscale(nn.Module):
247
+
248
+ def __init__(self, epsilon=1e-6, dim=-1):
249
+ super().__init__()
250
+ self.dim = dim
251
+ self.epsilon = epsilon
252
+
253
+ def forward(self, inputs):
254
+ var = inputs.pow(2).mean(dim=self.dim, keepdim=True)
255
+ normed_inputs = inputs * torch.rsqrt(var + self.epsilon)
256
+ return normed_inputs
257
+
258
+
259
+ class RMSnorm(nn.Module):
260
+
261
+ def __init__(self, hid_dim=128, epsilon=1e-6, dim=-1):
262
+ super().__init__()
263
+ self.dim = dim
264
+ self.hid_dim = hid_dim
265
+ self.epsilon = epsilon
266
+ self.scale = nn.parameter.Parameter(data=torch.ones(self.hid_dim))
267
+
268
+ def forward(self, inputs):
269
+ var = inputs.pow(2).mean(dim=self.dim, keepdim=True)
270
+ normed_inputs = inputs * torch.rsqrt(var + self.epsilon)
271
+ normed_inputs = normed_inputs * self.scale
272
+ return normed_inputs
273
+
274
+
275
+ class CrossHeadProjection(nn.Module):
276
+
277
+ def __init__(self, mode, num_heads=16, num_groups=1, dtype=torch.float16, use_sw=False):
278
+ super().__init__()
279
+ self.mode = mode
280
+ self.use_sw = use_sw
281
+ self.num_heads = num_heads
282
+ self.num_groups = num_groups
283
+ self.num_heads_per_group = self.num_heads // self.num_groups
284
+ if self.use_sw:
285
+ self.w = nn.parameter.Parameter(data=torch.zeros(self.num_groups, self.num_heads_per_group, self.num_heads_per_group, dtype=dtype))
286
+ else:
287
+ self.register_buffer('w', torch.eye(self.num_heads_per_group, dtype=dtype).expand(self.num_groups, self.num_heads_per_group, self.num_heads_per_group))
288
+
289
+ def forward(self, inputs,
290
+ dws:Optional[Tuple[Tensor,Tensor, Tensor,Tensor, Tensor,Tensor]]=None,
291
+ query_vec=None, key_vec=None,
292
+ proj_w:Optional[Tensor]=None,
293
+ fast_infer=True):
294
+ if proj_w is not None:
295
+ ret = torch.einsum('BNTS,BSNM->BMTS', inputs, proj_w)
296
+ else:
297
+ assert dws is not None
298
+ qw1, qw2, kw1, kw2, qdd, kdd = dws
299
+ inputs = inputs.unsqueeze(1) #BNTS->BGNTS
300
+ # apply sw
301
+ ret = torch.einsum('BGMTS,GMN->BGNTS', inputs, self.w) if self.use_sw else inputs
302
+ if fast_infer:
303
+ inputs_label = 'BGMTS'
304
+ hidden_sym = 'I'; hidden_label = inputs_label.replace('M', 'I') # BGITS
305
+ # apply qw and kw
306
+ for sym, (w1, w2) in zip(['T', 'S'], [(qw1, qw2), (kw1, kw2)]):
307
+ dw_label = f'B{sym}G{hidden_sym}M' # w1: BTGIM, dw_label:BTGIM
308
+ dynamic_hidden_dim = w1.shape[dw_label.index(hidden_sym)]
309
+ eqn1 = f'{inputs_label},{dw_label}->{hidden_label}' # 'BGMTS,BTGMI->BGITS'
310
+ eqn2 = f'{hidden_label},{dw_label}->{inputs_label}' # 'BGITS,BTGMI->BGMTS'
311
+ for i in range(dynamic_hidden_dim):
312
+ hidden = torch.einsum(eqn1.replace(hidden_sym, ''), inputs, w1[..., i, :]) # BGMTS,BTG(I)M->BGTS
313
+ out = torch.einsum(eqn2.replace(hidden_sym, ''), hidden, w2[..., i, :]) # 'BG(I)TS,BTG(I)M->BGMTS'
314
+ ret = ret + out
315
+ # apply qdd and kdd
316
+ for sym, dd in zip(['T', 'S'], [qdd, kdd]):
317
+ dd_label = f'B{sym}GM'
318
+ dout = torch.einsum(f'{inputs_label},{dd_label}->{inputs_label}', inputs, dd) # BGMTS,B(T/S)GM->BGMTS
319
+ ret = ret + dout
320
+ else:
321
+ # apply qw and kw (BTGIN)
322
+ x_inter = torch.einsum('BGNTS, BTGIN->BGTSI', inputs, qw1)
323
+ qw_out = torch.einsum('BGTSI, BTGIN->BGNTS', x_inter, qw2)
324
+ ret = ret + qw_out
325
+ x_inter = torch.einsum('BGNTS, BSGIN->BGTSI', inputs, kw1)
326
+ kw_out = torch.einsum('BGTSI, BSGIN->BGNTS', x_inter, kw2)
327
+ ret = ret + kw_out
328
+
329
+ # apply qdd(BTGN) and kdd(BSGN)
330
+ ret = ret + torch.einsum('BGNTS, BTGN->BGNTS', inputs, qdd)
331
+ ret = ret + torch.einsum('BGNTS, BSGN->BGNTS', inputs, kdd)
332
+ ret = ret.squeeze(1) # BGNTS->BNTS
333
+ return ret
334
+
335
+
336
+ class DCMHAttention(nn.Module):
337
+ def __init__(self, config: DCFormerConfig, lidx, use_sw=False):
338
+ super().__init__()
339
+ assert config.dim % config.n_head == 0
340
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
341
+ # key, query, value projections for all heads, but in a batch
342
+ self.lidx = lidx
343
+ self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
344
+ self.wo = nn.Linear(config.dim, config.dim, bias=False)
345
+ self.kv_cache = None
346
+
347
+ self.n_head = config.n_head
348
+ self.head_dim = config.head_dim
349
+ self.n_local_heads = config.n_local_heads
350
+ self.is_training = config.is_training
351
+ self.dim = config.dim
352
+ self.use_dcmha = config.use_dcmha
353
+ self.scale_factor = 1 / math.sqrt(self.head_dim)
354
+ self.q_chunk_size = config.q_chunk_size
355
+ self.use_sw = use_sw
356
+ self.dyn_w_proj = DynamicWeightProjection(num_heads=self.n_head, query_input_dim=config.dim, dynamic_squeeze_ratio=self.n_head//2, dynamic_w_hidden_dim=self.n_head*4, use_sw=use_sw)
357
+ self.use_qk_norm = config.use_qk_norm
358
+ if self.use_qk_norm:
359
+ self.q_norm = RMSnorm(hid_dim=self.head_dim)
360
+ self.k_norm = RMSnorm(hid_dim=self.head_dim)
361
+
362
+ self.window_types = {
363
+ "LG":[256, None],
364
+ "LGLL":[256, None, 256, 256],
365
+ "LGL6":[256, None, 256, 256, 256, 256, 256, 256],
366
+ }
367
+
368
+ self.query_wise = config.query_wise
369
+ if config.window_type is None: # LG
370
+ self.window_size = None if self.lidx % 2 == 1 else config.window_size
371
+ else:
372
+ window_l = self.window_types[config.window_type]
373
+ self.window_size = window_l[self.lidx % len(window_l)]
374
+
375
+ if not self.is_training:
376
+ self._register_load_state_dict_pre_hook(self.load_hook)
377
+
378
+ def load_hook(self, state_dict, prefix, *args):
379
+ if prefix + "wq.weight" in state_dict:
380
+ wq = state_dict.pop(prefix + "wq.weight")
381
+ wk = state_dict.pop(prefix + "wk.weight")
382
+ wv = state_dict.pop(prefix + "wv.weight")
383
+ state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
384
+
385
+ def _generate_fast(self, x, input_pos, q, k, v, k_mask):
386
+ B,T,D = x.shape
387
+ N,I = self.n_head, self.dyn_w_proj.dynamic_hidden_dim # 32, 2
388
+ dw_hidden, dd = (x @ self.dyn_w_proj.dw_m).split([2*2*N*(2*I), 2*2*N*1], -1) # BTD, D(4K+4N) -> BT(4K+4N) -> BT(4K), BT(4N)
389
+ dw_hidden = dw_hidden.view((B,T,4,-1,1)) # BT(4K) -> BT4K1
390
+ dw = (self.dyn_w_proj.dw_hidden_activation(dw_hidden) * self.dyn_w_proj.qkw_m).sum(-2) # gelu, BT4K1, 4K(IM)->BT4K(IM)->BT4(IM)
391
+ w1, w2 = dw.view((B,T,2,2,-1,N)).split(I,-2) # BT4(IM)->BT{pre/post}{q/k}IM->[BT22IM] * 2
392
+ w1 = self.dyn_w_proj.dw1_norm(w1) # BT22IN
393
+ qkdd = self.dyn_w_proj.dw_activation(dd.view((B,T,2,2,N))) # BT2{2}N1->BT2{2}N tanh
394
+ qkw = torch.einsum('BTKJIN,BTKJIM->BTKJNM', w1, w2) + torch.diag_embed(qkdd) # j=k=2, BT2{2}NM q/k, pre/post
395
+ if self.query_wise: # TODO: do not generate kw and kdd
396
+ qw, _ = qkw.unbind(3) # BS2NM
397
+ kw_new = None
398
+ qw = qw + self.dyn_w_proj.sw
399
+ else:
400
+ qw, kw_new = qkw.unbind(3) # BS{pre/post}{q/k}NM -> BS{pre/post}NM * 2
401
+ kw_new = kw_new + self.dyn_w_proj.sw # BS2NM + 2NM-> BS2NM
402
+ if self.kv_cache is not None:
403
+ k, v, kw_out = self.kv_cache.update(input_pos, k, v, kw_val=kw_new) #BNT2M
404
+ logits = q @ k.transpose(-2, -1) * self.scale_factor
405
+ if self.query_wise:
406
+ w = qw # B12NM
407
+ else:
408
+ w = qw + kw_out # B12NM,BS2NM -> BS2NM
409
+ wl, w = w.permute(0,2,3,4,1).unbind(1) # BS2NM->B2NMS->[BNMS]*2
410
+ logits = (logits * wl).sum(1).unsqueeze(2) # BN1S, BNMS -> BNMS-> BMS-> BM1S
411
+ min_value = torch.finfo(torch.float16).min
412
+ logits = torch.where(k_mask, logits, min_value)
413
+ probs = logits.softmax(-1)
414
+ probs = (probs * w).sum(1).unsqueeze(2)
415
+ y = probs @ v
416
+ return y
417
+
418
+ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None, fast_infer=True) -> Tensor:
419
+ bsz, seqlen, _ = x.shape
420
+
421
+ kv_size = self.n_local_heads * self.head_dim
422
+ q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
423
+
424
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim) # BSND
425
+ k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
426
+ v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
427
+
428
+ if self.use_qk_norm:
429
+ q, k = self.q_norm(q), self.k_norm(k)
430
+
431
+ q = apply_rotary_emb(q, freqs_cis)
432
+ k = apply_rotary_emb(k, freqs_cis)
433
+
434
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) # BNSD
435
+
436
+ if self.is_training:
437
+ N, D, I = self.n_head, self.head_dim, self.dyn_w_proj.dynamic_hidden_dim; # 6.7B
438
+ B,T,E = x.shape
439
+ if self.use_dcmha:
440
+ project_logits = True
441
+ project_probs = True
442
+ if project_probs:
443
+ dw_hidden, dd = (x @ self.dyn_w_proj.dw_m).split([2*2*N*(2*I), 2*2*N*1], -1)
444
+ dw_hidden = self.dyn_w_proj.dw_hidden_activation(dw_hidden)
445
+ dw_hidden = dw_hidden.view(dw_hidden.shape[:2]+(4,-1)) #B T (4 K) -> B T 4 K # reshape
446
+ dw = torch.einsum('B T C K, C K D -> B T C D', dw_hidden, self.dyn_w_proj.qkw_m) # BT4K,4K(MI)->BT4(MI)
447
+ shape = (B,T,2*2,-1,N)# if project_logits else (B,T,2,N,-1) # BT(pre/post)(q/k)IN
448
+ w1, w2 = dw.view(shape).split(I,-2)
449
+ w1 = self.dyn_w_proj.dw1_norm(w1) # BT22IN
450
+ if self.use_sw:
451
+ pre_sw, post_sw = self.dyn_w_proj.sw.unbind(0)
452
+ else:
453
+ pre_sw, post_sw = None, None
454
+ pre_qw1, pre_kw1, post_qw1, post_kw1 = w1.unbind(2) # BT(2{*2})IN->[BTIN]*4
455
+ pre_qw2, pre_kw2, post_qw2, post_kw2 = w2.unbind(2)
456
+ qkdd = F.tanh(dd).squeeze(-1).view(shape[:-2] + (N,)) # BT(2{*2})N1->BT(2{*2})N
457
+ pre_qdd, pre_kdd, post_qdd, post_kdd = qkdd.unbind(2) # BT(2{*2})N->[BTN]*4
458
+
459
+ y = torch.zeros(B, N, T, D).to(q.device, dtype=torch.float16)
460
+ for i in range(T // self.q_chunk_size):
461
+ start, stop = i * self.q_chunk_size, (i + 1) * self.q_chunk_size
462
+ kv_start = max(0, stop - self.q_chunk_size -self.window_size)
463
+ _q = q[:, :, start : stop, :]
464
+ _k, _v = k[:, :, kv_start : stop, :], v[:, :, kv_start : stop, :]
465
+ _atten_mask = mask[:, :, start : stop, kv_start : stop]
466
+ _pre_proj_dw_args = slice_dw(pre_sw, pre_qw1, pre_qw2, pre_kw1, pre_kw2, pre_qdd, pre_kdd, start, stop, kv_start) \
467
+ if project_logits else None
468
+ _post_proj_dw_args = slice_dw(post_sw, post_qw1, post_qw2, post_kw1, post_kw2, post_qdd, post_kdd, start,stop,kv_start) \
469
+ if project_probs else None
470
+ _o = _atten_context(_q, _k, _v, _atten_mask, _pre_proj_dw_args, _post_proj_dw_args)
471
+ y[:,:,start:stop] = _o
472
+ else:
473
+ y = torch.zeros(B, N, T, D).to(q.device, dtype=torch.float16)
474
+ for i in range(T // self.q_chunk_size):
475
+ start, stop = i * self.q_chunk_size, (i + 1) * self.q_chunk_size
476
+ kv_start = max(0, stop - self.q_chunk_size -self.window_size)
477
+ _q = q[:, :, start : stop, :]
478
+ _k, _v = k[:, :, kv_start : stop, :], v[:, :, kv_start : stop, :]
479
+ _atten_mask = mask[:, :, start : stop, kv_start : stop]
480
+ _pre_proj_dw_args, _post_proj_dw_args = None, None
481
+ _o = _atten_context(_q, _k, _v, _atten_mask, _pre_proj_dw_args, _post_proj_dw_args)
482
+ y[:,:,start:stop] = _o
483
+ else: # inference
484
+ if seqlen == 1: # one-token generation
485
+ k_mask = mask if self.window_size is None else mask[:,:,:,:self.kv_cache.seq_length]
486
+ if fast_infer:
487
+ y = self._generate_fast(x, input_pos, q, k, v, k_mask)
488
+ else:
489
+ assert not self.query_wise
490
+ # generate dw from hidden_state
491
+ pre_proj_dw_args, post_proj_dw_args, kw_new = self.dyn_w_proj(x, gen_cache=True)
492
+
493
+ # update kvkw cache
494
+ kw_new = kw_new + self.dyn_w_proj.sw # absorb residual or sw into kw cache
495
+ if self.kv_cache is not None:
496
+ k, v, kw_out = self.kv_cache.update(input_pos, k, v, kw_val=kw_new) # BNSD, BNSD, BS2NN
497
+
498
+ logits = q @ k.transpose(-2, -1) * self.scale_factor
499
+ # merge pre_w and apply it
500
+ pre_qw1, pre_qw2, pre_kw1, pre_kw2, pre_qdd, pre_kdd = pre_proj_dw_args
501
+ pre_qw = torch.einsum('BTGIN, BTGIM->BTNM',pre_qw1, pre_qw2) + torch.diag_embed(pre_qdd.squeeze(2))
502
+ pre_w = pre_qw + kw_out[:,:,0] # B1NM, BSNM -> BSNM
503
+ logits = self.dyn_w_proj.pre_proj(logits, proj_w=pre_w.squeeze(1))
504
+
505
+ logits = torch.where(k_mask, logits, torch.finfo(torch.float16).min)
506
+ probs = logits.softmax(-1)
507
+
508
+ # merge post_w and apply it
509
+ post_qw1, post_qw2, post_kw1, post_kw2, post_qdd, post_kdd = post_proj_dw_args
510
+ post_qw = torch.einsum('BTGIN, BTGIM->BTNM', post_qw1, post_qw2) + torch.diag_embed(post_qdd.squeeze(2))
511
+ post_w = post_qw + kw_out[:,:,1]
512
+ probs = self.dyn_w_proj.post_proj(probs, proj_w=post_w.squeeze(1))
513
+
514
+ y = probs @ v
515
+ else: # prefill
516
+ k_mask = mask[:,:,:,:k.shape[-2]]
517
+ pre_proj_dw_args, post_proj_dw_args,kw_new = self.dyn_w_proj(x, gen_cache=True)
518
+ kw_new = kw_new + self.dyn_w_proj.sw # absorb residual or sw into kw cache
519
+ if self.kv_cache is not None:
520
+ self.kv_cache.update(input_pos, k, v, kw_val=kw_new) # BNSD, BNSD, BS2NN
521
+ logits = q @ k.transpose(-2, -1) * self.scale_factor
522
+ logits = self.dyn_w_proj.pre_proj(logits, dws=pre_proj_dw_args, query_vec=x, key_vec=x, fast_infer=True) # XD BN1S
523
+ logits = torch.where(k_mask, logits, torch.finfo(torch.float16).min)
524
+ probs = logits.softmax(-1)
525
+ probs = self.dyn_w_proj.post_proj(probs, dws=post_proj_dw_args, query_vec=x, key_vec=x, fast_infer=True) # BN1S
526
+ y = probs @ v
527
+
528
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
529
+ y = self.wo(y)
530
+ return y
531
+
532
+
533
+ class FeedForward(nn.Module):
534
+ def __init__(self, config: DCFormerConfig) -> None:
535
+ super().__init__()
536
+ self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
537
+ self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
538
+ self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
539
+
540
+ def forward(self, x: Tensor) -> Tensor:
541
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
542
+
543
+
544
+ class RMSNorm(nn.Module):
545
+ def __init__(self, dim: int, eps: float = 1e-5):
546
+ super().__init__()
547
+ self.eps = eps
548
+ self.weight = nn.Parameter(torch.ones(dim))
549
+
550
+ def _norm(self, x):
551
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
552
+
553
+ def forward(self, x: Tensor) -> Tensor:
554
+ output = self._norm(x.float()).type_as(x)
555
+ return output * self.weight
556
+
557
+
558
+ def _atten_context(query, key, value, atten_mask, pre_proj_dw_args, post_proj_dw_args):
559
+ logits = query @ key.transpose(-2, -1)
560
+ if pre_proj_dw_args is not None: logits = _cross_head_proj(logits, *pre_proj_dw_args)
561
+ logits = torch.where(atten_mask, logits, torch.finfo(torch.float16).min)
562
+ probs = logits.softmax(-1)
563
+ if post_proj_dw_args is not None: probs = _cross_head_proj(probs, *post_proj_dw_args)
564
+ o = probs @ value # BNTS,BNSD->BNTD
565
+ return o
566
+
567
+ def _cross_head_proj(inputs, sw, qw1, qw2, kw1, kw2, qdd, kdd, loop_over_dynamic_hd=False):
568
+ out = inputs + torch.einsum('BNTS,NM->BMTS', inputs, sw) if sw is not None else inputs
569
+ for i in range(2): # qw1.shape[-2]):
570
+ qhidden = (inputs * qw1[..., i, :].transpose(-2, -1).unsqueeze(-1)).sum(1) # BNTS,(BTN->BNT->BNT1)->BNTS->BTS
571
+ qout = qhidden.unsqueeze(1) * qw2[..., i, :].transpose(-2, -1).unsqueeze(-1) # (BTS->B1TS),(BTN->BNT->BNT1)->BNTS
572
+ out = out + qout
573
+ khidden = (inputs * kw1[..., i, :].transpose(-2, -1).unsqueeze(-2)).sum(1) # BNTS,(BSN->BNS->BN1S)->BNTS->BTS
574
+ kout = khidden.unsqueeze(1) * kw2[..., i, :].transpose(-2, -1).unsqueeze(-2) # (BTS->B1TS),(BSN->BNS->BNS1)->BNTS
575
+ out = out + kout
576
+ qdout = inputs * qdd.transpose(-2, -1).unsqueeze(-1); out = out + qdout # BNTS,(BTN->BNT->BNT1)->BNTS
577
+ kdout = inputs * kdd.transpose(-2, -1).unsqueeze(-2); out = out + kdout # BNTS,(BSN->BNS->BN1S)->BNTS
578
+ return out
579
+
580
+ def find_multiple(n: int, k: int) -> int:
581
+ if n % k == 0:
582
+ return n
583
+ return n + k - (n % k)
584
+
585
+ def make_window_mask(t, window_size):
586
+ col_idx = torch.tile(torch.arange(t).unsqueeze(0), [t, 1])
587
+ row_idx = torch.tile(torch.arange(t).unsqueeze(1), [1, t])
588
+ bias_mask = (col_idx + window_size >= row_idx).tril().view(t, t)
589
+ return bias_mask
590
+
591
+ def slice_dw(sw, qw1, qw2, kw1, kw2, qdd, kdd, start, stop, kv_start):
592
+ return (sw,
593
+ qw1[:, start : stop] if qw1 is not None else None,
594
+ qw2[:, start : stop] if qw2 is not None else None,
595
+ kw1[:, kv_start : stop] if kw1 is not None else None,
596
+ kw2[:, kv_start : stop] if kw2 is not None else None,
597
+ qdd[:, start : stop] if qdd is not None else None,
598
+ kdd[:, kv_start : stop] if kdd is not None else None)
599
+
600
+ def precompute_freqs_cis(
601
+ seq_len: int, n_elem: int, base: int = 10000
602
+ ) -> Tensor:
603
+ freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
604
+ t = torch.arange(seq_len, device=freqs.device)
605
+ freqs = torch.outer(t, freqs)
606
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
607
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
608
+ return cache.to(dtype=torch.float16)
609
+
610
+ def unbind(ary, n, dim=0):
611
+ return [torch.squeeze(a, dim=dim) for a in torch.split(ary, ary.shape[dim] // n, dim=dim)]
612
+
613
+ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor, mode='half') -> Tensor:
614
+ if mode == 'half':
615
+ xshaped = x.float().reshape(*x.shape[:-1], 2,-1).transpose(-1,-2)
616
+ elif mode == 'alternative':
617
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
618
+ freqs_cis = freqs_cis.view(-1, xshaped.size(1), 1, xshaped.size(3), 2)
619
+ x_out2 = torch.stack(
620
+ [
621
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
622
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
623
+ ],
624
+ -1,
625
+ )
626
+ x_out2 = x_out2.flatten(3)
627
+ return x_out2.type_as(x)
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32d8ef97eb234f95509026b1064ade0d6d33a4f83a72dd025604cf7d130bce27
3
+ size 5808488649
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": "<|endoftext|>",
4
+ "clean_up_tokenization_spaces": true,
5
+ "eos_token": "<|endoftext|>",
6
+ "model_max_length": 1000000000000000019884624838656,
7
+ "tokenizer_class": "GPTNeoXTokenizer",
8
+ "unk_token": "<|endoftext|>"
9
+ }