mqyqlx commited on
Commit
b3abc18
1 Parent(s): eb28db4

add model and code

Browse files
README.md CHANGED
@@ -1,3 +1,64 @@
1
  ---
 
 
 
 
 
 
 
2
  license: mit
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ language:
3
+ - en
4
+ tags:
5
+ - pytorch
6
+ - causal-lm
7
+ - dcformer
8
+ - dcmha
9
  license: mit
10
  ---
11
+ DCPythia-6.9B is a pretrained language model on the Pile with 300B tokens. With comparison of Pythia-6.9B, we validate the scaling performance of Dynamically
12
+ Composable Multi-Head Attention(DCMHA), a parameter and computation efficient attention architecture that tackles the shortcomings of MHA and increases the expressive power of the model
13
+ by dynamically composing attention heads. Please see downstrem evaluations and more details in the paper[(Improving Transformers with Dynamically Composable Multi-Head Attention)](). In addition, we open-source Jax training code on [(Github)](https://github.com/Caiyun-AI/DCFormer/).
14
+
15
+ We recommend <strong>compiled version</strong> of DCPythia with *torch.compile* for inference acceleration. Please refer to Generation section for compile implementation.
16
+
17
+ # Usage
18
+
19
+ ## Env
20
+
21
+ You need to upgrade transformers to avoid [(loading problems)](https://github.com/huggingface/transformers/pull/29175).
22
+
23
+ ```
24
+ pip install transformers>=4.40.2
25
+ ```
26
+
27
+
28
+ ## Generation
29
+
30
+ ```
31
+ import torch
32
+ from transformers import AutoTokenizer, AutoModelForCausalLM
33
+
34
+ import os
35
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
36
+
37
+ tokenizer = AutoTokenizer.from_pretrained("Caiyun-AI/DCPythia-6.9B")
38
+ model = AutoModelForCausalLM.from_pretrained("Caiyun-AI/DCPythia-6.9B", trust_remote_code=True)
39
+
40
+ device = torch.device('cuda')
41
+ MAX_BATCH_SIZE = 1
42
+ MAX_SEQ_LENGTH = 2048
43
+ NUM_TOKENS_TO_GENERATE = 100
44
+ COMPILE = True
45
+
46
+ _ = model.to(device=device,dtype=torch.float16)
47
+ with torch.device(device):
48
+ model.setup_caches(max_batch_size=MAX_BATCH_SIZE, max_seq_length=MAX_SEQ_LENGTH, set_kv_cache=True)
49
+
50
+ def decode_one_token(model, cur_token, input_pos):
51
+ logits = model(cur_token, input_pos=input_pos, return_tensor=True)
52
+ new_token = torch.argmax(logits[:, -1], dim=-1)[:,None]
53
+ return new_token
54
+
55
+ prompt = "Beijing is the capital of China. London is the capital of"
56
+ input_ids = tokenizer.encode(prompt, return_tensors='pt')
57
+
58
+ compiled_decode_one_token = torch.compile(decode_one_token,mode="reduce-overhead", fullgraph=True) if COMPILE else None
59
+
60
+ with torch.no_grad():
61
+ generated_ids = model.generate(input_ids.to(device),num_tokens_to_generate=NUM_TOKENS_TO_GENERATE, compiled_decode_one_token=compiled_decode_one_token)
62
+ text = tokenizer.decode(generated_ids[0])
63
+ print('generated text:', text)
64
+ ```
config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "DCPythia"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_dcpythia.DCPythiaConfig",
7
+ "AutoModelForCausalLM": "modeling_dcpythia.DCPythia"
8
+ },
9
+ "block_size": 2048,
10
+ "bos_token_id": 0,
11
+ "dim": 4096,
12
+ "eos_token_id": 0,
13
+ "head_dim": 128,
14
+ "intermediate_size": 16384,
15
+ "is_training": false,
16
+ "model_type": "dcpythia",
17
+ "n_head": 32,
18
+ "n_layer": 32,
19
+ "n_local_heads": 32,
20
+ "norm_eps": 1e-05,
21
+ "q_chunk_size": 128,
22
+ "query_wise": false,
23
+ "rope_base": 10000,
24
+ "rotary_pct": 0.25,
25
+ "tie_word_embeddings": false,
26
+ "torch_dtype": "float16",
27
+ "transformers_version": "4.33.2",
28
+ "use_dcmha": true,
29
+ "use_gradient_checkpointing": false,
30
+ "use_linear_bias": true,
31
+ "use_parallel_residual": true,
32
+ "use_qk_norm": true,
33
+ "vocab_size": 50257,
34
+ "window_size": 256,
35
+ "window_type": null
36
+ }
configuration_dcpythia.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ from transformers.utils import logging
3
+ from typing import Optional,Tuple,List
4
+
5
+
6
+ class DCPythiaConfig(PretrainedConfig):
7
+ model_type = "dcpythia"
8
+
9
+ '''
10
+ DCPythiaConfig is a config class for DCPythia, which is adpated from https://github.com/pytorch-labs/gpt-fast/blob/main/model.py#L21
11
+ '''
12
+ def __init__(
13
+ self,
14
+ block_size: int = 2048,
15
+ vocab_size: int = 32000,
16
+ n_layer: int = 32,
17
+ n_head: int = 32,
18
+ dim: int = 2560,
19
+ intermediate_size: int = None,
20
+ n_local_heads: int = -1,
21
+ head_dim: int = 64,
22
+ rope_base: float = 10000,
23
+ norm_eps: float = 1e-5,
24
+ use_gradient_checkpointing: bool = False,
25
+ is_training: bool = False,
26
+ q_chunk_size: int = 128,
27
+ use_dcmha: bool = True,
28
+ use_qk_norm: bool = False ,
29
+ window_size: Optional[int] = 256,
30
+ window_type: Optional[str] = None,
31
+ query_wise: bool = False,
32
+ pad_token_id: Optional[int]= None,
33
+ use_parallel_residual: bool =True,
34
+ use_linear_bias: bool = True,
35
+ rotary_pct: float = 0.25,
36
+ bos_token_id: int =1,
37
+ eos_token_id: int =2,
38
+ tie_word_embeddings: bool =False,
39
+ **kwargs,
40
+ ):
41
+ self.block_size=block_size
42
+ self.vocab_size=vocab_size
43
+ self.n_layer=n_layer
44
+ self.n_head=n_head
45
+ self.dim=dim
46
+ self.intermediate_size=intermediate_size
47
+ self.n_local_heads=n_local_heads
48
+ self.head_dim=head_dim
49
+ self.rope_base=rope_base
50
+ self.norm_eps=norm_eps
51
+ self.use_gradient_checkpointing=use_gradient_checkpointing
52
+ self.is_training=is_training
53
+ self.q_chunk_size=q_chunk_size
54
+ self.use_dcmha=use_dcmha
55
+ self.use_qk_norm=use_qk_norm
56
+ self.window_size=window_size
57
+ self.window_type=window_type
58
+ self.query_wise=query_wise
59
+ self.use_parallel_residual = use_parallel_residual
60
+ self.use_linear_bias = use_linear_bias
61
+ self.rotary_pct = rotary_pct
62
+ # post init
63
+ if self.n_local_heads == -1:
64
+ self.n_local_heads = self.n_head
65
+ if self.intermediate_size is None:
66
+ self.intermediate_size = 4 * self.dim
67
+ self.head_dim = self.dim // self.n_head
68
+
69
+ super().__init__(
70
+ pad_token_id=pad_token_id,
71
+ bos_token_id=bos_token_id,
72
+ eos_token_id=eos_token_id,
73
+ tie_word_embeddings=tie_word_embeddings,
74
+ **kwargs,
75
+ )
76
+
generation_demo.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+
4
+ import os
5
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
6
+
7
+ tokenizer = AutoTokenizer.from_pretrained("Caiyun-AI/DCPythia-6.9B")
8
+ model = AutoModelForCausalLM.from_pretrained("Caiyun-AI/DCPythia-6.9B", trust_remote_code=True)
9
+
10
+ device = torch.device('cuda')
11
+ MAX_BATCH_SIZE = 1
12
+ MAX_SEQ_LENGTH = 2048
13
+ NUM_TOKENS_TO_GENERATE = 100
14
+ COMPILE = True
15
+
16
+ _ = model.to(device=device,dtype=torch.float16)
17
+ with torch.device(device):
18
+ model.setup_caches(max_batch_size=MAX_BATCH_SIZE, max_seq_length=MAX_SEQ_LENGTH, set_kv_cache=True)
19
+
20
+ def decode_one_token(model, cur_token, input_pos):
21
+ logits = model(cur_token, input_pos=input_pos, return_tensor=True)
22
+ new_token = torch.argmax(logits[:, -1], dim=-1)[:,None]
23
+ return new_token
24
+
25
+ prompt = "Beijing is the capital of China. London is the capital of"
26
+ input_ids = tokenizer.encode(prompt, return_tensors='pt')
27
+
28
+ compiled_decode_one_token = torch.compile(decode_one_token,mode="reduce-overhead", fullgraph=True) if COMPILE else None
29
+
30
+ with torch.no_grad():
31
+ generated_ids = model.generate(input_ids.to(device),num_tokens_to_generate=NUM_TOKENS_TO_GENERATE, compiled_decode_one_token=compiled_decode_one_token)
32
+ text = tokenizer.decode(generated_ids[0])
33
+ print('generated text:', text)
modeling_dcpythia.py ADDED
@@ -0,0 +1,630 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional,Tuple,List
3
+ from collections import namedtuple
4
+
5
+ import math
6
+ import time
7
+ import json
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch import Tensor
11
+ from torch.nn import functional as F
12
+ from torch.utils.checkpoint import checkpoint
13
+
14
+ try:
15
+ from .configuration_dcpythia import DCPythiaConfig
16
+ except
17
+ from configuration_dcpythia import DCPythiaConfig
18
+ from transformers.modeling_utils import PreTrainedModel
19
+
20
+
21
+ class KVKWCache(nn.Module):
22
+ def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, window_size=2048, dtype=torch.float16, use_kw_cache=True):
23
+ super().__init__()
24
+ self.head_dim = head_dim
25
+ self.kw_dim = 2 * n_heads
26
+ self.n_heads = n_heads
27
+ self.window_size = window_size
28
+ self.use_kw_cache = use_kw_cache
29
+ if window_size is None:
30
+ self.seq_length = max_seq_length
31
+ else:
32
+ self.seq_length = min(window_size, max_seq_length)
33
+ cache_shape = (max_batch_size, n_heads, self.seq_length, head_dim)
34
+ kw_cache_shape = (max_batch_size, self.seq_length, 2, n_heads, n_heads)
35
+ self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
36
+ self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
37
+ if self.use_kw_cache:
38
+ self.register_buffer('kw_cache', torch.zeros(kw_cache_shape, dtype=dtype))
39
+
40
+ def update(self, input_pos, k_val, v_val, kw_val=None): # kw_val B,N,S,2,N B2NSD
41
+ # input_pos: [S], k_val: [B, H, S, D]
42
+ assert input_pos.shape[-1] == k_val.shape[2]
43
+ B,N,S,D = v_val.shape
44
+ k_out = self.k_cache
45
+ v_out = self.v_cache
46
+ if self.use_kw_cache:
47
+ kw_out = self.kw_cache
48
+ else:
49
+ kw_out = None
50
+
51
+ if self.window_size is None:
52
+ k_out[:, :, input_pos] = k_val
53
+ v_out[:, :, input_pos] = v_val
54
+ if self.use_kw_cache and kw_val is not None:
55
+ kw_out[:,input_pos] = kw_val
56
+ elif S == 1:
57
+ input_pos = input_pos % self.seq_length
58
+ v_out[:, :, input_pos] = v_val
59
+ k_out[:, :, input_pos] = k_val
60
+ if self.use_kw_cache and kw_val is not None:
61
+ kw_out[:,input_pos] = kw_val
62
+ else: # prefill
63
+ start = max(0, input_pos[-1]-self.seq_length+1)
64
+ input_pos = input_pos[start:] % self.seq_length
65
+ v_out[:, :, input_pos] = v_val[:,:,start:]
66
+ k_out[:, :, input_pos] = k_val[:,:,start:]
67
+ if self.use_kw_cache and kw_val is not None:
68
+ kw_out[:, input_pos] = kw_val[:,start:]
69
+ return k_out, v_out, kw_out
70
+
71
+ class DCPythia(PreTrainedModel):
72
+ config_class=DCPythiaConfig
73
+
74
+ def __init__(self, config: DCPythiaConfig) -> None:
75
+ super().__init__(config)
76
+ self.config = config
77
+
78
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
79
+ self.layers = nn.ModuleList(DCPythiaBlock(config, lidx) for lidx in range(config.n_layer))
80
+ self.norm = nn.LayerNorm(config.dim, eps=config.norm_eps)
81
+ self.output = nn.Linear(config.dim, config.vocab_size, bias=False) # no bias in pythia
82
+ self.use_gradient_checkpointing = config.use_gradient_checkpointing
83
+ self.is_training = config.is_training
84
+
85
+ self.freqs_cis: Optional[Tensor] = None
86
+ self.rotary_ndims = int(config.head_dim * config.rotary_pct)
87
+ self.mask_cache: Optional[Tensor] = None
88
+ self.window_size = config.window_size
89
+ self.max_batch_size = -1
90
+ self.max_seq_length = -1
91
+
92
+ def setup_caches(self, max_batch_size, max_seq_length, set_kv_cache=True):
93
+ if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
94
+ return
95
+ head_dim = self.config.dim // self.config.n_head
96
+ max_seq_length = find_multiple(max_seq_length, 8)
97
+ self.max_seq_length = max_seq_length
98
+ self.max_batch_size = max_batch_size
99
+ if not self.is_training:
100
+ for b in self.layers:
101
+ if set_kv_cache:
102
+ use_kw_cache = False if b.attention.query_wise else True
103
+ b.attention.kv_cache = KVKWCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, window_size=b.attention.window_size, use_kw_cache=use_kw_cache)
104
+ b.attention.dyn_w_proj.merge_weights()
105
+ if not b.attention.use_sw:
106
+ dtype = b.attention.wo.weight.dtype
107
+ device = b.attention.wo.weight.device
108
+ b.attention.dyn_w_proj.sw = b.attention.dyn_w_proj.sw.to(device=device, dtype=dtype)
109
+ b.attention.dyn_w_proj.pre_proj.w = b.attention.dyn_w_proj.pre_proj.w.to(device=device, dtype=dtype)
110
+ b.attention.dyn_w_proj.post_proj.w = b.attention.dyn_w_proj.post_proj.w.to(device=device, dtype=dtype)
111
+
112
+ self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.rotary_ndims, self.config.rope_base).to(self.tok_embeddings.weight.device)
113
+ if self.is_training:
114
+ self.causal_mask = torch.tril(torch.ones(self.config.block_size, self.config.block_size, dtype=torch.bool, device=self.tok_embeddings.weight.device))
115
+ elif self.window_size is None:
116
+ self.causal_mask = torch.tril(torch.ones(max_seq_length, max_seq_length, dtype=torch.bool, device=self.tok_embeddings.weight.device))
117
+ else:
118
+ self.causal_mask = torch.stack([make_window_mask(max_seq_length, self.config.window_size), torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))]) # LG
119
+
120
+ def generate(self, input_ids, num_tokens_to_generate=10, compiled_decode_one_token=None):
121
+ batch_size, seq_length = input_ids.shape
122
+ input_pos = torch.arange(seq_length, device=self.device)
123
+ generated_ids = torch.zeros(batch_size, seq_length + num_tokens_to_generate + 1, dtype=torch.int, device=self.device)
124
+ generated_ids[:, :seq_length] = input_ids.to(self.device).to(torch.int)
125
+ logits = self.forward(input_ids, input_pos=input_pos,return_tensor=True)
126
+ _next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
127
+ next_token = torch.zeros(self.max_batch_size, 1, device=self.device, dtype=torch.int)
128
+ next_token[:batch_size] = _next_token
129
+ generated_ids[:, seq_length] = next_token[:batch_size, 0]
130
+ input_pos = torch.tensor([seq_length], device=self.device)
131
+ for _ in range(1, num_tokens_to_generate):
132
+ if compiled_decode_one_token is not None:
133
+ next_token = compiled_decode_one_token(self, next_token.clone(), input_pos)
134
+ else:
135
+ next_token = self.decode_one_token(next_token.clone(), input_pos)
136
+ generated_ids[:, input_pos+1] = next_token.int()[:batch_size]
137
+ input_pos += 1
138
+ return generated_ids
139
+
140
+ def decode_one_token(self, cur_token, input_pos):
141
+ logits = self.forward(
142
+ cur_token,
143
+ input_pos=input_pos,
144
+ return_tensor=True,
145
+ )
146
+ new_token = torch.argmax(logits[:, -1], dim=-1)[:,None]
147
+ return new_token
148
+
149
+ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None, return_tensor=False) -> Tensor:
150
+ assert self.freqs_cis is not None, "Caches must be initialized first"
151
+ if input_pos is None:
152
+ input_pos = torch.arange(idx.shape[-1], device=idx.device, dtype=torch.int)
153
+ if self.window_size is None or self.is_training:
154
+ mask = self.causal_mask[None, None, input_pos]
155
+ else:
156
+ mask = self.causal_mask[None, None,:,input_pos]
157
+ freqs_cis = self.freqs_cis[input_pos][:idx.shape[-1]]
158
+ x = self.tok_embeddings(idx)
159
+ for i, layer in enumerate(self.layers):
160
+ if self.is_training or self.window_size is None :
161
+ layer_mask = mask
162
+ elif self.window_size is not None:
163
+ layer_mask = mask[:,:,1] if layer.attention.window_size is None else mask[:,:,0]
164
+ if self.use_gradient_checkpointing:
165
+ x = checkpoint(layer, x, input_pos, freqs_cis, layer_mask)
166
+ else:
167
+ x = layer(x, input_pos, freqs_cis, layer_mask)
168
+ x = self.norm(x)
169
+ logits = self.output(x)
170
+ if return_tensor:
171
+ return logits
172
+ else:
173
+ CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
174
+ return CausalLMOutput(logits=logits)
175
+
176
+ class DCPythiaBlock(nn.Module):
177
+ def __init__(self, config: DCPythiaConfig, lidx) -> None:
178
+ super().__init__()
179
+ self.lidx = lidx
180
+ self.attention = DCMHAttention(config, lidx)
181
+ self.feed_forward = FeedForward(config)
182
+ self.ffn_norm = nn.LayerNorm(config.dim, eps=config.norm_eps)
183
+ self.attention_norm = nn.LayerNorm(config.dim, eps=config.norm_eps)
184
+ self.use_parallel_residual = config.use_parallel_residual
185
+
186
+ def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
187
+ h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos, fast_infer=True)
188
+ if self.use_parallel_residual:
189
+ out = h + self.feed_forward(self.ffn_norm(x))
190
+ else:
191
+ out = h + self.feed_forward(self.ffn_norm(h))
192
+ return out
193
+
194
+ class DynamicWeightProjection(nn.Module):
195
+
196
+ def __init__(self, num_heads=32, num_groups=1, residual=True, query_input_dim=4096, dynamic_squeeze_ratio=16, dynamic_w_hidden_dim=128,dtype=torch.float16,use_sw=False):
197
+ super().__init__()
198
+ self.num_heads = num_heads
199
+ self.num_groups = num_groups
200
+ self.query_input_dim = query_input_dim
201
+ self.dynamic_squeeze_ratio = dynamic_squeeze_ratio
202
+ self.dynamic_w_hidden_dim = dynamic_w_hidden_dim
203
+ self.dw_hidden_activation = nn.GELU()
204
+ self.num_heads_per_group = self.num_heads // self.num_groups
205
+ self.dw_activation = nn.Tanh()
206
+ self.dw1_norm = RMSnormNoscale(dim=-1)
207
+ self.use_sw = use_sw
208
+ self.pre_proj = CrossHeadProjection('pre', num_heads=self.num_heads, use_sw=use_sw)
209
+ self.post_proj = CrossHeadProjection('post', num_heads=self.num_heads, use_sw=use_sw)
210
+
211
+ dynamic_hidden_dim = self.num_heads_per_group // self.dynamic_squeeze_ratio
212
+ self.dynamic_hidden_dim = dynamic_hidden_dim
213
+ self.dw1 = nn.parameter.Parameter(torch.zeros(self.query_input_dim, self.num_groups, 4, self.dynamic_w_hidden_dim, dtype=dtype)) #(4096, 1, 4, 128)
214
+ G, K, M = self.num_groups, self.dynamic_w_hidden_dim, self.num_heads_per_group
215
+ I = dynamic_hidden_dim * 2
216
+ self.qkw = nn.parameter.Parameter(torch.zeros([G, 4, K, I, M], dtype=dtype)) # (1, 4, 128, 4, 32)
217
+ self.dd = nn.parameter.Parameter(torch.zeros(self.query_input_dim, self.num_groups, self.num_heads_per_group * 4, dtype=dtype)) # (4096, 1, 128)
218
+
219
+ self.merge_weights()
220
+
221
+ def merge_weights(self):
222
+ self.dw_m = nn.parameter.Parameter(torch.cat([self.dw1.reshape(self.query_input_dim, -1), self.dd.squeeze(1)], dim=-1)).to(self.dw1.device) # E,(4*K + K) K=2*N*I
223
+ self.qkw_m = nn.parameter.Parameter(self.qkw.permute(0,1,2,3,4).reshape(4,self.dynamic_w_hidden_dim,-1)).to(self.dw1.device) #(4,K,I*M)
224
+ if self.use_sw:
225
+ self.sw = nn.parameter.Parameter(torch.stack([self.pre_proj.w, self.post_proj.w]).squeeze(1) + torch.eye(self.num_heads) ).to(self.dw1.device) # (2,N,N) sw + identity matrix
226
+ else:
227
+ self.sw = (torch.eye(self.num_heads).expand(2,self.num_heads,self.num_heads)).to(self.dw1.device) # identity matrix (2,N,N)
228
+
229
+ def forward(self,query_vec,KW:Optional[torch.Tensor]=None, gen_cache:Optional[bool]=True):
230
+ dw_hidden = torch.einsum('BTD,DGCK->BTGCK', query_vec, self.dw1) # C=4 [pre,post]*[query,key]
231
+ dw_hidden = self.dw_hidden_activation(dw_hidden) #BTGCK
232
+ w1, w2 = torch.split(torch.einsum('BTGCK,GCKIM->BTGCIM', dw_hidden, self.qkw), self.qkw.shape[-2]//2, dim=-2) #BTGC(2I)M -> [BTGCIM] * 2
233
+ w1 = self.dw1_norm(w1) # BTGCIM
234
+ pre_qw1, pre_kw1, post_qw1, post_kw1 = unbind(w1, 4, dim=3) # BTG4IM->[BTGIM]*4
235
+ pre_qw2, pre_kw2, post_qw2, post_kw2 = unbind(w2, 4, dim=3)
236
+ dd = torch.einsum('BTD,DGM->BTGM', query_vec, self.dd) # BTG(4M)
237
+ dd = self.dw_activation(dd)
238
+ pre_qdd, pre_kdd, post_qdd, post_kdd = torch.split(dd, dd.shape[-1] // 4, dim=-1) # BTG(4N)->[BTGN]*4
239
+ pre_dw_args = (pre_qw1, pre_qw2, pre_kw1, pre_kw2, pre_qdd, pre_kdd)
240
+ post_dw_args = (post_qw1, post_qw2, post_kw1, post_kw2, post_qdd, post_kdd)
241
+ if gen_cache: # generate KW cache
242
+ pre_kw = torch.einsum('BSGIM, BSGIN->BSMN', pre_kw1, pre_kw2) + torch.diag_embed(pre_kdd.squeeze(2)) # merge kw and kdd
243
+ post_kw = torch.einsum('BSGIM, BSGIN->BSMN', post_kw1, post_kw2) + torch.diag_embed(post_kdd.squeeze(2))
244
+ KW = torch.stack((pre_kw, post_kw), dim=-3) # BSMN,BSMN->BS2MN
245
+ return pre_dw_args, post_dw_args, KW
246
+
247
+
248
+ class RMSnormNoscale(nn.Module):
249
+
250
+ def __init__(self, epsilon=1e-6, dim=-1):
251
+ super().__init__()
252
+ self.dim = dim
253
+ self.epsilon = epsilon
254
+
255
+ def forward(self, inputs):
256
+ var = inputs.pow(2).mean(dim=self.dim, keepdim=True)
257
+ normed_inputs = inputs * torch.rsqrt(var + self.epsilon)
258
+ return normed_inputs
259
+
260
+
261
+ class RMSnorm(nn.Module):
262
+
263
+ def __init__(self, hid_dim=128, epsilon=1e-6, dim=-1):
264
+ super().__init__()
265
+ self.dim = dim
266
+ self.hid_dim = hid_dim
267
+ self.epsilon = epsilon
268
+ self.scale = nn.parameter.Parameter(data=torch.ones(self.hid_dim))
269
+
270
+ def forward(self, inputs):
271
+ var = inputs.pow(2).mean(dim=self.dim, keepdim=True)
272
+ normed_inputs = inputs * torch.rsqrt(var + self.epsilon)
273
+ normed_inputs = normed_inputs * self.scale
274
+ return normed_inputs
275
+
276
+
277
+ class CrossHeadProjection(nn.Module):
278
+
279
+ def __init__(self, mode, num_heads=16, num_groups=1, dtype=torch.float16, use_sw=False):
280
+ super().__init__()
281
+ self.mode = mode
282
+ self.use_sw = use_sw
283
+ self.num_heads = num_heads
284
+ self.num_groups = num_groups
285
+ self.num_heads_per_group = self.num_heads // self.num_groups
286
+ if self.use_sw:
287
+ self.w = nn.parameter.Parameter(data=torch.zeros(self.num_groups, self.num_heads_per_group, self.num_heads_per_group, dtype=dtype))
288
+ else:
289
+ self.register_buffer('w', torch.eye(self.num_heads_per_group, dtype=dtype).expand(self.num_groups, self.num_heads_per_group, self.num_heads_per_group))
290
+
291
+ def forward(self, inputs,
292
+ dws:Optional[Tuple[Tensor,Tensor, Tensor,Tensor, Tensor,Tensor]]=None,
293
+ query_vec=None, key_vec=None,
294
+ proj_w:Optional[Tensor]=None,
295
+ fast_infer=True):
296
+ if proj_w is not None:
297
+ ret = torch.einsum('BNTS,BSNM->BMTS', inputs, proj_w)
298
+ else:
299
+ assert dws is not None
300
+ qw1, qw2, kw1, kw2, qdd, kdd = dws
301
+ inputs = inputs.unsqueeze(1) #BNTS->BGNTS
302
+ # apply sw
303
+ ret = torch.einsum('BGMTS,GMN->BGNTS', inputs, self.w) if self.use_sw else inputs
304
+ if fast_infer:
305
+ inputs_label = 'BGMTS'
306
+ hidden_sym = 'I'; hidden_label = inputs_label.replace('M', 'I') # BGITS
307
+ # apply qw and kw
308
+ for sym, (w1, w2) in zip(['T', 'S'], [(qw1, qw2), (kw1, kw2)]):
309
+ dw_label = f'B{sym}G{hidden_sym}M' # w1: BTGIM, dw_label:BTGIM
310
+ dynamic_hidden_dim = w1.shape[dw_label.index(hidden_sym)]
311
+ eqn1 = f'{inputs_label},{dw_label}->{hidden_label}' # 'BGMTS,BTGMI->BGITS'
312
+ eqn2 = f'{hidden_label},{dw_label}->{inputs_label}' # 'BGITS,BTGMI->BGMTS'
313
+ for i in range(dynamic_hidden_dim):
314
+ hidden = torch.einsum(eqn1.replace(hidden_sym, ''), inputs, w1[..., i, :]) # BGMTS,BTG(I)M->BGTS
315
+ out = torch.einsum(eqn2.replace(hidden_sym, ''), hidden, w2[..., i, :]) # 'BG(I)TS,BTG(I)M->BGMTS'
316
+ ret = ret + out
317
+ # apply qdd and kdd
318
+ for sym, dd in zip(['T', 'S'], [qdd, kdd]):
319
+ dd_label = f'B{sym}GM'
320
+ dout = torch.einsum(f'{inputs_label},{dd_label}->{inputs_label}', inputs, dd) # BGMTS,B(T/S)GM->BGMTS
321
+ ret = ret + dout
322
+ else:
323
+ # apply qw and kw (BTGIN)
324
+ x_inter = torch.einsum('BGNTS, BTGIN->BGTSI', inputs, qw1)
325
+ qw_out = torch.einsum('BGTSI, BTGIN->BGNTS', x_inter, qw2)
326
+ ret = ret + qw_out
327
+ x_inter = torch.einsum('BGNTS, BSGIN->BGTSI', inputs, kw1)
328
+ kw_out = torch.einsum('BGTSI, BSGIN->BGNTS', x_inter, kw2)
329
+ ret = ret + kw_out
330
+
331
+ # apply qdd(BTGN) and kdd(BSGN)
332
+ ret = ret + torch.einsum('BGNTS, BTGN->BGNTS', inputs, qdd)
333
+ ret = ret + torch.einsum('BGNTS, BSGN->BGNTS', inputs, kdd)
334
+ ret = ret.squeeze(1) # BGNTS->BNTS
335
+ return ret
336
+
337
+
338
+ class DCMHAttention(nn.Module):
339
+ def __init__(self, config: DCPythiaConfig, lidx, use_sw=False):
340
+ super().__init__()
341
+ assert config.dim % config.n_head == 0
342
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
343
+ # key, query, value projections for all heads, but in a batch
344
+ self.lidx = lidx
345
+ self.wqkv = nn.Linear(config.dim, total_head_dim, bias=config.use_linear_bias)
346
+ self.wo = nn.Linear(config.dim, config.dim, bias=config.use_linear_bias)
347
+ self.kv_cache = None
348
+
349
+ self.n_head = config.n_head
350
+ self.head_dim = config.head_dim
351
+ self.n_local_heads = config.n_local_heads
352
+ self.is_training = config.is_training
353
+ self.dim = config.dim
354
+ self.use_dcmha = config.use_dcmha
355
+ self.scale_factor = 1 / math.sqrt(self.head_dim)
356
+ self.q_chunk_size = config.q_chunk_size
357
+ self.use_sw = use_sw
358
+ self.dyn_w_proj = DynamicWeightProjection(num_heads=self.n_head, query_input_dim=config.dim, dynamic_squeeze_ratio=self.n_head//2, dynamic_w_hidden_dim=self.n_head*4, use_sw=use_sw)
359
+ self.use_qk_norm = config.use_qk_norm
360
+ if self.use_qk_norm:
361
+ self.q_norm = RMSnorm(hid_dim=self.head_dim)
362
+ self.k_norm = RMSnorm(hid_dim=self.head_dim)
363
+
364
+ self.window_types = {
365
+ "LG":[256, None],
366
+ "LGLL":[256, None, 256, 256],
367
+ "LGL6":[256, None, 256, 256, 256, 256, 256, 256],
368
+ }
369
+
370
+ self.query_wise = config.query_wise
371
+ if config.window_type is None: # LG
372
+ self.window_size = None if self.lidx % 2 == 1 else config.window_size
373
+ else:
374
+ window_l = self.window_types[config.window_type]
375
+ self.window_size = window_l[self.lidx % len(window_l)]
376
+
377
+ self.rotary_ndims = int(self.head_dim * config.rotary_pct)
378
+
379
+ if not self.is_training:
380
+ self._register_load_state_dict_pre_hook(self.load_hook)
381
+
382
+ def load_hook(self, state_dict, prefix, *args):
383
+ if prefix + "wq.weight" in state_dict:
384
+ wq = state_dict.pop(prefix + "wq.weight")
385
+ wk = state_dict.pop(prefix + "wk.weight")
386
+ wv = state_dict.pop(prefix + "wv.weight")
387
+ state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
388
+ if prefix + "wq.bias" in state_dict:
389
+ wq_b = state_dict.pop(prefix + "wq.bias")
390
+ wk_b = state_dict.pop(prefix + "wk.bias")
391
+ wv_b = state_dict.pop(prefix + "wv.bias")
392
+ state_dict[prefix + "wqkv.bias"] = torch.cat([wq_b, wk_b, wv_b])
393
+
394
+ def _generate_fast(self, x, input_pos, q, k, v, k_mask):
395
+ B,T,D = x.shape
396
+ N,I = self.n_head, self.dyn_w_proj.dynamic_hidden_dim # 32, 2
397
+ dw_hidden, dd = (x @ self.dyn_w_proj.dw_m).split([2*2*N*(2*I), 2*2*N*1], -1) # BTD, D(4K+4N) -> BT(4K+4N) -> BT(4K), BT(4N)
398
+ dw_hidden = dw_hidden.view((B,T,4,-1,1)) # BT(4K) -> BT4K1
399
+ dw = (self.dyn_w_proj.dw_hidden_activation(dw_hidden) * self.dyn_w_proj.qkw_m).sum(-2) # gelu, BT4K1, 4K(IM)->BT4K(IM)->BT4(IM)
400
+ w1, w2 = dw.view((B,T,2,2,-1,N)).split(I,-2) # BT4(IM)->BT{pre/post}{q/k}IM->[BT22IM] * 2
401
+ w1 = self.dyn_w_proj.dw1_norm(w1) # BT22IN
402
+ qkdd = self.dyn_w_proj.dw_activation(dd.view((B,T,2,2,N))) # BT2{2}N1->BT2{2}N tanh
403
+ qkw = torch.einsum('BTKJIN,BTKJIM->BTKJNM', w1, w2) + torch.diag_embed(qkdd) # j=k=2, BT2{2}NM q/k, pre/post
404
+ if self.query_wise: # TODO: do not generate kw and kdd
405
+ qw, _ = qkw.unbind(3) # BS2NM
406
+ kw_new = None
407
+ qw = qw + self.dyn_w_proj.sw
408
+ else:
409
+ qw, kw_new = qkw.unbind(3) # BS{pre/post}{q/k}NM -> BS{pre/post}NM * 2
410
+ kw_new = kw_new + self.dyn_w_proj.sw # BS2NM + 2NM-> BS2NM
411
+ if self.kv_cache is not None:
412
+ k, v, kw_out = self.kv_cache.update(input_pos, k, v, kw_val=kw_new) #BNT2M
413
+ logits = q @ k.transpose(-2, -1) * self.scale_factor
414
+ if self.query_wise:
415
+ w = qw # B12NM
416
+ else:
417
+ w = qw + kw_out # B12NM,BS2NM -> BS2NM
418
+ wl, w = w.permute(0,2,3,4,1).unbind(1) # BS2NM->B2NMS->[BNMS]*2
419
+ logits = (logits * wl).sum(1).unsqueeze(2) # BN1S, BNMS -> BNMS-> BMS-> BM1S
420
+ min_value = torch.finfo(torch.float16).min
421
+ logits = torch.where(k_mask, logits, min_value)
422
+ probs = logits.softmax(-1)
423
+ probs = (probs * w).sum(1).unsqueeze(2)
424
+ y = probs @ v
425
+ return y
426
+
427
+ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None, fast_infer=True) -> Tensor:
428
+ bsz, seqlen, _ = x.shape
429
+
430
+ kv_size = self.n_local_heads * self.head_dim
431
+ q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
432
+
433
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim) # BSND
434
+ k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
435
+ v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
436
+
437
+ if self.use_qk_norm:
438
+ q, k = self.q_norm(q), self.k_norm(k)
439
+
440
+ if self.rotary_ndims == self.head_dim:
441
+ q = apply_rotary_emb(q, freqs_cis) #BTND
442
+ k = apply_rotary_emb(k, freqs_cis)
443
+ else:
444
+ q_rot = q[..., : self.rotary_ndims]
445
+ q_pass = q[..., self.rotary_ndims :]
446
+ k_rot = k[..., : self.rotary_ndims]
447
+ k_pass = k[..., self.rotary_ndims :]
448
+ q_rot = apply_rotary_emb(q_rot, freqs_cis, mode='half') #BTND
449
+ k_rot = apply_rotary_emb(k_rot, freqs_cis, mode='half')
450
+ q = torch.cat((q_rot, q_pass), dim=-1)
451
+ k = torch.cat((k_rot, k_pass), dim=-1)
452
+
453
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) # BNSD
454
+
455
+ if self.is_training:
456
+ N, D, I = self.n_head, self.head_dim, self.dyn_w_proj.dynamic_hidden_dim; # 6.7B
457
+ B,T,E = x.shape
458
+ if self.use_dcmha:
459
+ project_logits = True
460
+ project_probs = True
461
+ if project_probs:
462
+ dw_hidden, dd = (x @ self.dyn_w_proj.dw_m).split([2*2*N*(2*I), 2*2*N*1], -1)
463
+ dw_hidden = self.dyn_w_proj.dw_hidden_activation(dw_hidden)
464
+ dw_hidden = dw_hidden.view(dw_hidden.shape[:2]+(4,-1)) #B T (4 K) -> B T 4 K # reshape
465
+ dw = torch.einsum('B T C K, C K D -> B T C D', dw_hidden, self.dyn_w_proj.qkw_m) # BT4K,4K(MI)->BT4(MI)
466
+ shape = (B,T,2*2,-1,N)# if project_logits else (B,T,2,N,-1) # BT(pre/post)(q/k)IN
467
+ w1, w2 = dw.view(shape).split(I,-2)
468
+ w1 = self.dyn_w_proj.dw1_norm(w1) # BT22IN
469
+ if self.use_sw:
470
+ pre_sw, post_sw = self.dyn_w_proj.sw.unbind(0)
471
+ else:
472
+ pre_sw, post_sw = None, None
473
+ pre_qw1, pre_kw1, post_qw1, post_kw1 = w1.unbind(2) # BT(2{*2})IN->[BTIN]*4
474
+ pre_qw2, pre_kw2, post_qw2, post_kw2 = w2.unbind(2)
475
+ qkdd = F.tanh(dd).squeeze(-1).view(shape[:-2] + (N,)) # BT(2{*2})N1->BT(2{*2})N
476
+ pre_qdd, pre_kdd, post_qdd, post_kdd = qkdd.unbind(2) # BT(2{*2})N->[BTN]*4
477
+
478
+ y = torch.zeros(B, N, T, D).to(q.device, dtype=torch.float16)
479
+ for i in range(T // self.q_chunk_size):
480
+ start, stop = i * self.q_chunk_size, (i + 1) * self.q_chunk_size
481
+ kv_start = max(0, stop - self.q_chunk_size -self.window_size)
482
+ _q = q[:, :, start : stop, :]
483
+ _k, _v = k[:, :, kv_start : stop, :], v[:, :, kv_start : stop, :]
484
+ _atten_mask = mask[:, :, start : stop, kv_start : stop]
485
+ _pre_proj_dw_args = slice_dw(pre_sw, pre_qw1, pre_qw2, pre_kw1, pre_kw2, pre_qdd, pre_kdd, start, stop, kv_start) \
486
+ if project_logits else None
487
+ _post_proj_dw_args = slice_dw(post_sw, post_qw1, post_qw2, post_kw1, post_kw2, post_qdd, post_kdd, start,stop,kv_start) \
488
+ if project_probs else None
489
+ _o = _atten_context(_q, _k, _v, _atten_mask, _pre_proj_dw_args, _post_proj_dw_args)
490
+ y[:,:,start:stop] = _o
491
+ else:
492
+ y = torch.zeros(B, N, T, D).to(q.device, dtype=torch.float16)
493
+ for i in range(T // self.q_chunk_size):
494
+ start, stop = i * self.q_chunk_size, (i + 1) * self.q_chunk_size
495
+ kv_start = max(0, stop - self.q_chunk_size -self.window_size)
496
+ _q = q[:, :, start : stop, :]
497
+ _k, _v = k[:, :, kv_start : stop, :], v[:, :, kv_start : stop, :]
498
+ _atten_mask = mask[:, :, start : stop, kv_start : stop]
499
+ _pre_proj_dw_args, _post_proj_dw_args = None, None
500
+ _o = _atten_context(_q, _k, _v, _atten_mask, _pre_proj_dw_args, _post_proj_dw_args)
501
+ y[:,:,start:stop] = _o
502
+ else: # inference
503
+ if seqlen == 1: # one-token generation
504
+ k_mask = mask if self.window_size is None else mask[:,:,:,:self.kv_cache.seq_length]
505
+ if fast_infer:
506
+ y = self._generate_fast(x, input_pos, q, k, v, k_mask)
507
+ else:
508
+ assert not self.query_wise
509
+ # generate dw from hidden_state
510
+ pre_proj_dw_args, post_proj_dw_args, kw_new = self.dyn_w_proj(x, gen_cache=True)
511
+
512
+ # update kvkw cache
513
+ kw_new = kw_new + self.dyn_w_proj.sw # absorb residual or sw into kw cache
514
+ if self.kv_cache is not None:
515
+ k, v, kw_out = self.kv_cache.update(input_pos, k, v, kw_val=kw_new) # BNSD, BNSD, BS2NN
516
+
517
+ logits = q @ k.transpose(-2, -1) * self.scale_factor
518
+ # merge pre_w and apply it
519
+ pre_qw1, pre_qw2, pre_kw1, pre_kw2, pre_qdd, pre_kdd = pre_proj_dw_args
520
+ pre_qw = torch.einsum('BTGIN, BTGIM->BTNM',pre_qw1, pre_qw2) + torch.diag_embed(pre_qdd.squeeze(2))
521
+ pre_w = pre_qw + kw_out[:,:,0] # B1NM, BSNM -> BSNM
522
+ logits = self.dyn_w_proj.pre_proj(logits, proj_w=pre_w.squeeze(1))
523
+
524
+ logits = torch.where(k_mask, logits, torch.finfo(torch.float16).min)
525
+ probs = logits.softmax(-1)
526
+
527
+ # merge post_w and apply it
528
+ post_qw1, post_qw2, post_kw1, post_kw2, post_qdd, post_kdd = post_proj_dw_args
529
+ post_qw = torch.einsum('BTGIN, BTGIM->BTNM', post_qw1, post_qw2) + torch.diag_embed(post_qdd.squeeze(2))
530
+ post_w = post_qw + kw_out[:,:,1]
531
+ probs = self.dyn_w_proj.post_proj(probs, proj_w=post_w.squeeze(1))
532
+
533
+ y = probs @ v
534
+ else: # prefill
535
+ k_mask = mask[:,:,:,:k.shape[-2]]
536
+ pre_proj_dw_args, post_proj_dw_args,kw_new = self.dyn_w_proj(x, gen_cache=True)
537
+ kw_new = kw_new + self.dyn_w_proj.sw # absorb residual or sw into kw cache
538
+ if self.kv_cache is not None:
539
+ self.kv_cache.update(input_pos, k, v, kw_val=kw_new) # BNSD, BNSD, BS2NN
540
+ logits = q @ k.transpose(-2, -1) * self.scale_factor
541
+ logits = self.dyn_w_proj.pre_proj(logits, dws=pre_proj_dw_args, query_vec=x, key_vec=x, fast_infer=True) # XD BN1S
542
+ logits = torch.where(k_mask, logits, torch.finfo(torch.float16).min)
543
+ probs = logits.softmax(-1)
544
+ probs = self.dyn_w_proj.post_proj(probs, dws=post_proj_dw_args, query_vec=x, key_vec=x, fast_infer=True) # BN1S
545
+ y = probs @ v
546
+
547
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
548
+ y = self.wo(y)
549
+ return y
550
+
551
+
552
+ class FeedForward(nn.Module):
553
+ def __init__(self, config: DCPythiaConfig) -> None:
554
+ super().__init__()
555
+ self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=config.use_linear_bias)
556
+ self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=config.use_linear_bias)
557
+
558
+ def forward(self, x: Tensor) -> Tensor:
559
+ return self.w2(F.gelu(self.w1(x)))
560
+
561
+ def _atten_context(query, key, value, atten_mask, pre_proj_dw_args, post_proj_dw_args):
562
+ logits = query @ key.transpose(-2, -1)
563
+ if pre_proj_dw_args is not None: logits = _cross_head_proj(logits, *pre_proj_dw_args)
564
+ logits = torch.where(atten_mask, logits, torch.finfo(torch.float16).min)
565
+ probs = logits.softmax(-1)
566
+ if post_proj_dw_args is not None: probs = _cross_head_proj(probs, *post_proj_dw_args)
567
+ o = probs @ value # BNTS,BNSD->BNTD
568
+ return o
569
+
570
+ def _cross_head_proj(inputs, sw, qw1, qw2, kw1, kw2, qdd, kdd, loop_over_dynamic_hd=False):
571
+ out = inputs + torch.einsum('BNTS,NM->BMTS', inputs, sw) if sw is not None else inputs
572
+ for i in range(2): # qw1.shape[-2]):
573
+ qhidden = (inputs * qw1[..., i, :].transpose(-2, -1).unsqueeze(-1)).sum(1) # BNTS,(BTN->BNT->BNT1)->BNTS->BTS
574
+ qout = qhidden.unsqueeze(1) * qw2[..., i, :].transpose(-2, -1).unsqueeze(-1) # (BTS->B1TS),(BTN->BNT->BNT1)->BNTS
575
+ out = out + qout
576
+ khidden = (inputs * kw1[..., i, :].transpose(-2, -1).unsqueeze(-2)).sum(1) # BNTS,(BSN->BNS->BN1S)->BNTS->BTS
577
+ kout = khidden.unsqueeze(1) * kw2[..., i, :].transpose(-2, -1).unsqueeze(-2) # (BTS->B1TS),(BSN->BNS->BNS1)->BNTS
578
+ out = out + kout
579
+ qdout = inputs * qdd.transpose(-2, -1).unsqueeze(-1); out = out + qdout # BNTS,(BTN->BNT->BNT1)->BNTS
580
+ kdout = inputs * kdd.transpose(-2, -1).unsqueeze(-2); out = out + kdout # BNTS,(BSN->BNS->BN1S)->BNTS
581
+ return out
582
+
583
+ def find_multiple(n: int, k: int) -> int:
584
+ if n % k == 0:
585
+ return n
586
+ return n + k - (n % k)
587
+
588
+ def make_window_mask(t, window_size):
589
+ col_idx = torch.tile(torch.arange(t).unsqueeze(0), [t, 1])
590
+ row_idx = torch.tile(torch.arange(t).unsqueeze(1), [1, t])
591
+ bias_mask = (col_idx + window_size >= row_idx).tril().view(t, t)
592
+ return bias_mask
593
+
594
+ def slice_dw(sw, qw1, qw2, kw1, kw2, qdd, kdd, start, stop, kv_start):
595
+ return (sw,
596
+ qw1[:, start : stop] if qw1 is not None else None,
597
+ qw2[:, start : stop] if qw2 is not None else None,
598
+ kw1[:, kv_start : stop] if kw1 is not None else None,
599
+ kw2[:, kv_start : stop] if kw2 is not None else None,
600
+ qdd[:, start : stop] if qdd is not None else None,
601
+ kdd[:, kv_start : stop] if kdd is not None else None)
602
+
603
+ def precompute_freqs_cis(
604
+ seq_len: int, n_elem: int, base: int = 10000
605
+ ) -> Tensor:
606
+ freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
607
+ t = torch.arange(seq_len, device=freqs.device)
608
+ freqs = torch.outer(t, freqs)
609
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
610
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
611
+ return cache.to(dtype=torch.float16)
612
+
613
+ def unbind(ary, n, dim=0):
614
+ return [torch.squeeze(a, dim=dim) for a in torch.split(ary, ary.shape[dim] // n, dim=dim)]
615
+
616
+ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor, mode='half') -> Tensor:
617
+ if mode == 'half':
618
+ xshaped = x.float().reshape(*x.shape[:-1], 2,-1).transpose(-1,-2)
619
+ elif mode == 'alternative':
620
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
621
+ freqs_cis = freqs_cis.view(-1, xshaped.size(1), 1, xshaped.size(3), 2)
622
+ x_out2 = torch.stack(
623
+ [
624
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
625
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
626
+ ],
627
+ -1,
628
+ )
629
+ x_out2 = x_out2.flatten(3)
630
+ return x_out2.type_as(x)
pytorch_model-00001-of-00003.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:47e401e33d5c4435e46001b0c16fc008ca3e0cf0b545c3af7a4f5539d55a49ff
3
+ size 4931648137
pytorch_model-00002-of-00003.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3277a1c7ed043890757913c886701cbae9dc1b2d19789b1373633424d8fb0db7
3
+ size 4959405880
pytorch_model-00003-of-00003.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:024e2d2704825b3f5b825faaafc0726ecc2dc35a367fb1f6a3d65ada85df4ba0
3
+ size 4915805832
pytorch_model.bin.index.json ADDED
@@ -0,0 +1,779 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 14806597632
4
+ },
5
+ "weight_map": {
6
+ "layers.0.attention.dyn_w_proj.dd": "pytorch_model-00001-of-00003.bin",
7
+ "layers.0.attention.dyn_w_proj.dw1": "pytorch_model-00001-of-00003.bin",
8
+ "layers.0.attention.dyn_w_proj.dw_m": "pytorch_model-00001-of-00003.bin",
9
+ "layers.0.attention.dyn_w_proj.post_proj.w": "pytorch_model-00001-of-00003.bin",
10
+ "layers.0.attention.dyn_w_proj.pre_proj.w": "pytorch_model-00001-of-00003.bin",
11
+ "layers.0.attention.dyn_w_proj.qkw": "pytorch_model-00001-of-00003.bin",
12
+ "layers.0.attention.dyn_w_proj.qkw_m": "pytorch_model-00001-of-00003.bin",
13
+ "layers.0.attention.k_norm.scale": "pytorch_model-00001-of-00003.bin",
14
+ "layers.0.attention.kv_cache.k_cache": "pytorch_model-00001-of-00003.bin",
15
+ "layers.0.attention.kv_cache.kw_cache": "pytorch_model-00001-of-00003.bin",
16
+ "layers.0.attention.kv_cache.v_cache": "pytorch_model-00001-of-00003.bin",
17
+ "layers.0.attention.q_norm.scale": "pytorch_model-00001-of-00003.bin",
18
+ "layers.0.attention.wo.bias": "pytorch_model-00001-of-00003.bin",
19
+ "layers.0.attention.wo.weight": "pytorch_model-00001-of-00003.bin",
20
+ "layers.0.attention.wqkv.bias": "pytorch_model-00001-of-00003.bin",
21
+ "layers.0.attention.wqkv.weight": "pytorch_model-00001-of-00003.bin",
22
+ "layers.0.attention_norm.bias": "pytorch_model-00001-of-00003.bin",
23
+ "layers.0.attention_norm.weight": "pytorch_model-00001-of-00003.bin",
24
+ "layers.0.feed_forward.w1.bias": "pytorch_model-00001-of-00003.bin",
25
+ "layers.0.feed_forward.w1.weight": "pytorch_model-00001-of-00003.bin",
26
+ "layers.0.feed_forward.w2.bias": "pytorch_model-00001-of-00003.bin",
27
+ "layers.0.feed_forward.w2.weight": "pytorch_model-00001-of-00003.bin",
28
+ "layers.0.ffn_norm.bias": "pytorch_model-00001-of-00003.bin",
29
+ "layers.0.ffn_norm.weight": "pytorch_model-00001-of-00003.bin",
30
+ "layers.1.attention.dyn_w_proj.dd": "pytorch_model-00001-of-00003.bin",
31
+ "layers.1.attention.dyn_w_proj.dw1": "pytorch_model-00001-of-00003.bin",
32
+ "layers.1.attention.dyn_w_proj.dw_m": "pytorch_model-00001-of-00003.bin",
33
+ "layers.1.attention.dyn_w_proj.post_proj.w": "pytorch_model-00001-of-00003.bin",
34
+ "layers.1.attention.dyn_w_proj.pre_proj.w": "pytorch_model-00001-of-00003.bin",
35
+ "layers.1.attention.dyn_w_proj.qkw": "pytorch_model-00001-of-00003.bin",
36
+ "layers.1.attention.dyn_w_proj.qkw_m": "pytorch_model-00001-of-00003.bin",
37
+ "layers.1.attention.k_norm.scale": "pytorch_model-00001-of-00003.bin",
38
+ "layers.1.attention.kv_cache.k_cache": "pytorch_model-00001-of-00003.bin",
39
+ "layers.1.attention.kv_cache.kw_cache": "pytorch_model-00001-of-00003.bin",
40
+ "layers.1.attention.kv_cache.v_cache": "pytorch_model-00001-of-00003.bin",
41
+ "layers.1.attention.q_norm.scale": "pytorch_model-00001-of-00003.bin",
42
+ "layers.1.attention.wo.bias": "pytorch_model-00001-of-00003.bin",
43
+ "layers.1.attention.wo.weight": "pytorch_model-00001-of-00003.bin",
44
+ "layers.1.attention.wqkv.bias": "pytorch_model-00001-of-00003.bin",
45
+ "layers.1.attention.wqkv.weight": "pytorch_model-00001-of-00003.bin",
46
+ "layers.1.attention_norm.bias": "pytorch_model-00001-of-00003.bin",
47
+ "layers.1.attention_norm.weight": "pytorch_model-00001-of-00003.bin",
48
+ "layers.1.feed_forward.w1.bias": "pytorch_model-00001-of-00003.bin",
49
+ "layers.1.feed_forward.w1.weight": "pytorch_model-00001-of-00003.bin",
50
+ "layers.1.feed_forward.w2.bias": "pytorch_model-00001-of-00003.bin",
51
+ "layers.1.feed_forward.w2.weight": "pytorch_model-00001-of-00003.bin",
52
+ "layers.1.ffn_norm.bias": "pytorch_model-00001-of-00003.bin",
53
+ "layers.1.ffn_norm.weight": "pytorch_model-00001-of-00003.bin",
54
+ "layers.10.attention.dyn_w_proj.dd": "pytorch_model-00001-of-00003.bin",
55
+ "layers.10.attention.dyn_w_proj.dw1": "pytorch_model-00001-of-00003.bin",
56
+ "layers.10.attention.dyn_w_proj.dw_m": "pytorch_model-00001-of-00003.bin",
57
+ "layers.10.attention.dyn_w_proj.post_proj.w": "pytorch_model-00001-of-00003.bin",
58
+ "layers.10.attention.dyn_w_proj.pre_proj.w": "pytorch_model-00001-of-00003.bin",
59
+ "layers.10.attention.dyn_w_proj.qkw": "pytorch_model-00001-of-00003.bin",
60
+ "layers.10.attention.dyn_w_proj.qkw_m": "pytorch_model-00001-of-00003.bin",
61
+ "layers.10.attention.k_norm.scale": "pytorch_model-00001-of-00003.bin",
62
+ "layers.10.attention.kv_cache.k_cache": "pytorch_model-00001-of-00003.bin",
63
+ "layers.10.attention.kv_cache.kw_cache": "pytorch_model-00001-of-00003.bin",
64
+ "layers.10.attention.kv_cache.v_cache": "pytorch_model-00001-of-00003.bin",
65
+ "layers.10.attention.q_norm.scale": "pytorch_model-00001-of-00003.bin",
66
+ "layers.10.attention.wo.bias": "pytorch_model-00001-of-00003.bin",
67
+ "layers.10.attention.wo.weight": "pytorch_model-00001-of-00003.bin",
68
+ "layers.10.attention.wqkv.bias": "pytorch_model-00001-of-00003.bin",
69
+ "layers.10.attention.wqkv.weight": "pytorch_model-00001-of-00003.bin",
70
+ "layers.10.attention_norm.bias": "pytorch_model-00002-of-00003.bin",
71
+ "layers.10.attention_norm.weight": "pytorch_model-00002-of-00003.bin",
72
+ "layers.10.feed_forward.w1.bias": "pytorch_model-00002-of-00003.bin",
73
+ "layers.10.feed_forward.w1.weight": "pytorch_model-00002-of-00003.bin",
74
+ "layers.10.feed_forward.w2.bias": "pytorch_model-00002-of-00003.bin",
75
+ "layers.10.feed_forward.w2.weight": "pytorch_model-00002-of-00003.bin",
76
+ "layers.10.ffn_norm.bias": "pytorch_model-00002-of-00003.bin",
77
+ "layers.10.ffn_norm.weight": "pytorch_model-00002-of-00003.bin",
78
+ "layers.11.attention.dyn_w_proj.dd": "pytorch_model-00002-of-00003.bin",
79
+ "layers.11.attention.dyn_w_proj.dw1": "pytorch_model-00002-of-00003.bin",
80
+ "layers.11.attention.dyn_w_proj.dw_m": "pytorch_model-00002-of-00003.bin",
81
+ "layers.11.attention.dyn_w_proj.post_proj.w": "pytorch_model-00002-of-00003.bin",
82
+ "layers.11.attention.dyn_w_proj.pre_proj.w": "pytorch_model-00002-of-00003.bin",
83
+ "layers.11.attention.dyn_w_proj.qkw": "pytorch_model-00002-of-00003.bin",
84
+ "layers.11.attention.dyn_w_proj.qkw_m": "pytorch_model-00002-of-00003.bin",
85
+ "layers.11.attention.k_norm.scale": "pytorch_model-00002-of-00003.bin",
86
+ "layers.11.attention.kv_cache.k_cache": "pytorch_model-00002-of-00003.bin",
87
+ "layers.11.attention.kv_cache.kw_cache": "pytorch_model-00002-of-00003.bin",
88
+ "layers.11.attention.kv_cache.v_cache": "pytorch_model-00002-of-00003.bin",
89
+ "layers.11.attention.q_norm.scale": "pytorch_model-00002-of-00003.bin",
90
+ "layers.11.attention.wo.bias": "pytorch_model-00002-of-00003.bin",
91
+ "layers.11.attention.wo.weight": "pytorch_model-00002-of-00003.bin",
92
+ "layers.11.attention.wqkv.bias": "pytorch_model-00002-of-00003.bin",
93
+ "layers.11.attention.wqkv.weight": "pytorch_model-00002-of-00003.bin",
94
+ "layers.11.attention_norm.bias": "pytorch_model-00002-of-00003.bin",
95
+ "layers.11.attention_norm.weight": "pytorch_model-00002-of-00003.bin",
96
+ "layers.11.feed_forward.w1.bias": "pytorch_model-00002-of-00003.bin",
97
+ "layers.11.feed_forward.w1.weight": "pytorch_model-00002-of-00003.bin",
98
+ "layers.11.feed_forward.w2.bias": "pytorch_model-00002-of-00003.bin",
99
+ "layers.11.feed_forward.w2.weight": "pytorch_model-00002-of-00003.bin",
100
+ "layers.11.ffn_norm.bias": "pytorch_model-00002-of-00003.bin",
101
+ "layers.11.ffn_norm.weight": "pytorch_model-00002-of-00003.bin",
102
+ "layers.12.attention.dyn_w_proj.dd": "pytorch_model-00002-of-00003.bin",
103
+ "layers.12.attention.dyn_w_proj.dw1": "pytorch_model-00002-of-00003.bin",
104
+ "layers.12.attention.dyn_w_proj.dw_m": "pytorch_model-00002-of-00003.bin",
105
+ "layers.12.attention.dyn_w_proj.post_proj.w": "pytorch_model-00002-of-00003.bin",
106
+ "layers.12.attention.dyn_w_proj.pre_proj.w": "pytorch_model-00002-of-00003.bin",
107
+ "layers.12.attention.dyn_w_proj.qkw": "pytorch_model-00002-of-00003.bin",
108
+ "layers.12.attention.dyn_w_proj.qkw_m": "pytorch_model-00002-of-00003.bin",
109
+ "layers.12.attention.k_norm.scale": "pytorch_model-00002-of-00003.bin",
110
+ "layers.12.attention.kv_cache.k_cache": "pytorch_model-00002-of-00003.bin",
111
+ "layers.12.attention.kv_cache.kw_cache": "pytorch_model-00002-of-00003.bin",
112
+ "layers.12.attention.kv_cache.v_cache": "pytorch_model-00002-of-00003.bin",
113
+ "layers.12.attention.q_norm.scale": "pytorch_model-00002-of-00003.bin",
114
+ "layers.12.attention.wo.bias": "pytorch_model-00002-of-00003.bin",
115
+ "layers.12.attention.wo.weight": "pytorch_model-00002-of-00003.bin",
116
+ "layers.12.attention.wqkv.bias": "pytorch_model-00002-of-00003.bin",
117
+ "layers.12.attention.wqkv.weight": "pytorch_model-00002-of-00003.bin",
118
+ "layers.12.attention_norm.bias": "pytorch_model-00002-of-00003.bin",
119
+ "layers.12.attention_norm.weight": "pytorch_model-00002-of-00003.bin",
120
+ "layers.12.feed_forward.w1.bias": "pytorch_model-00002-of-00003.bin",
121
+ "layers.12.feed_forward.w1.weight": "pytorch_model-00002-of-00003.bin",
122
+ "layers.12.feed_forward.w2.bias": "pytorch_model-00002-of-00003.bin",
123
+ "layers.12.feed_forward.w2.weight": "pytorch_model-00002-of-00003.bin",
124
+ "layers.12.ffn_norm.bias": "pytorch_model-00002-of-00003.bin",
125
+ "layers.12.ffn_norm.weight": "pytorch_model-00002-of-00003.bin",
126
+ "layers.13.attention.dyn_w_proj.dd": "pytorch_model-00002-of-00003.bin",
127
+ "layers.13.attention.dyn_w_proj.dw1": "pytorch_model-00002-of-00003.bin",
128
+ "layers.13.attention.dyn_w_proj.dw_m": "pytorch_model-00002-of-00003.bin",
129
+ "layers.13.attention.dyn_w_proj.post_proj.w": "pytorch_model-00002-of-00003.bin",
130
+ "layers.13.attention.dyn_w_proj.pre_proj.w": "pytorch_model-00002-of-00003.bin",
131
+ "layers.13.attention.dyn_w_proj.qkw": "pytorch_model-00002-of-00003.bin",
132
+ "layers.13.attention.dyn_w_proj.qkw_m": "pytorch_model-00002-of-00003.bin",
133
+ "layers.13.attention.k_norm.scale": "pytorch_model-00002-of-00003.bin",
134
+ "layers.13.attention.kv_cache.k_cache": "pytorch_model-00002-of-00003.bin",
135
+ "layers.13.attention.kv_cache.kw_cache": "pytorch_model-00002-of-00003.bin",
136
+ "layers.13.attention.kv_cache.v_cache": "pytorch_model-00002-of-00003.bin",
137
+ "layers.13.attention.q_norm.scale": "pytorch_model-00002-of-00003.bin",
138
+ "layers.13.attention.wo.bias": "pytorch_model-00002-of-00003.bin",
139
+ "layers.13.attention.wo.weight": "pytorch_model-00002-of-00003.bin",
140
+ "layers.13.attention.wqkv.bias": "pytorch_model-00002-of-00003.bin",
141
+ "layers.13.attention.wqkv.weight": "pytorch_model-00002-of-00003.bin",
142
+ "layers.13.attention_norm.bias": "pytorch_model-00002-of-00003.bin",
143
+ "layers.13.attention_norm.weight": "pytorch_model-00002-of-00003.bin",
144
+ "layers.13.feed_forward.w1.bias": "pytorch_model-00002-of-00003.bin",
145
+ "layers.13.feed_forward.w1.weight": "pytorch_model-00002-of-00003.bin",
146
+ "layers.13.feed_forward.w2.bias": "pytorch_model-00002-of-00003.bin",
147
+ "layers.13.feed_forward.w2.weight": "pytorch_model-00002-of-00003.bin",
148
+ "layers.13.ffn_norm.bias": "pytorch_model-00002-of-00003.bin",
149
+ "layers.13.ffn_norm.weight": "pytorch_model-00002-of-00003.bin",
150
+ "layers.14.attention.dyn_w_proj.dd": "pytorch_model-00002-of-00003.bin",
151
+ "layers.14.attention.dyn_w_proj.dw1": "pytorch_model-00002-of-00003.bin",
152
+ "layers.14.attention.dyn_w_proj.dw_m": "pytorch_model-00002-of-00003.bin",
153
+ "layers.14.attention.dyn_w_proj.post_proj.w": "pytorch_model-00002-of-00003.bin",
154
+ "layers.14.attention.dyn_w_proj.pre_proj.w": "pytorch_model-00002-of-00003.bin",
155
+ "layers.14.attention.dyn_w_proj.qkw": "pytorch_model-00002-of-00003.bin",
156
+ "layers.14.attention.dyn_w_proj.qkw_m": "pytorch_model-00002-of-00003.bin",
157
+ "layers.14.attention.k_norm.scale": "pytorch_model-00002-of-00003.bin",
158
+ "layers.14.attention.kv_cache.k_cache": "pytorch_model-00002-of-00003.bin",
159
+ "layers.14.attention.kv_cache.kw_cache": "pytorch_model-00002-of-00003.bin",
160
+ "layers.14.attention.kv_cache.v_cache": "pytorch_model-00002-of-00003.bin",
161
+ "layers.14.attention.q_norm.scale": "pytorch_model-00002-of-00003.bin",
162
+ "layers.14.attention.wo.bias": "pytorch_model-00002-of-00003.bin",
163
+ "layers.14.attention.wo.weight": "pytorch_model-00002-of-00003.bin",
164
+ "layers.14.attention.wqkv.bias": "pytorch_model-00002-of-00003.bin",
165
+ "layers.14.attention.wqkv.weight": "pytorch_model-00002-of-00003.bin",
166
+ "layers.14.attention_norm.bias": "pytorch_model-00002-of-00003.bin",
167
+ "layers.14.attention_norm.weight": "pytorch_model-00002-of-00003.bin",
168
+ "layers.14.feed_forward.w1.bias": "pytorch_model-00002-of-00003.bin",
169
+ "layers.14.feed_forward.w1.weight": "pytorch_model-00002-of-00003.bin",
170
+ "layers.14.feed_forward.w2.bias": "pytorch_model-00002-of-00003.bin",
171
+ "layers.14.feed_forward.w2.weight": "pytorch_model-00002-of-00003.bin",
172
+ "layers.14.ffn_norm.bias": "pytorch_model-00002-of-00003.bin",
173
+ "layers.14.ffn_norm.weight": "pytorch_model-00002-of-00003.bin",
174
+ "layers.15.attention.dyn_w_proj.dd": "pytorch_model-00002-of-00003.bin",
175
+ "layers.15.attention.dyn_w_proj.dw1": "pytorch_model-00002-of-00003.bin",
176
+ "layers.15.attention.dyn_w_proj.dw_m": "pytorch_model-00002-of-00003.bin",
177
+ "layers.15.attention.dyn_w_proj.post_proj.w": "pytorch_model-00002-of-00003.bin",
178
+ "layers.15.attention.dyn_w_proj.pre_proj.w": "pytorch_model-00002-of-00003.bin",
179
+ "layers.15.attention.dyn_w_proj.qkw": "pytorch_model-00002-of-00003.bin",
180
+ "layers.15.attention.dyn_w_proj.qkw_m": "pytorch_model-00002-of-00003.bin",
181
+ "layers.15.attention.k_norm.scale": "pytorch_model-00002-of-00003.bin",
182
+ "layers.15.attention.kv_cache.k_cache": "pytorch_model-00002-of-00003.bin",
183
+ "layers.15.attention.kv_cache.kw_cache": "pytorch_model-00002-of-00003.bin",
184
+ "layers.15.attention.kv_cache.v_cache": "pytorch_model-00002-of-00003.bin",
185
+ "layers.15.attention.q_norm.scale": "pytorch_model-00002-of-00003.bin",
186
+ "layers.15.attention.wo.bias": "pytorch_model-00002-of-00003.bin",
187
+ "layers.15.attention.wo.weight": "pytorch_model-00002-of-00003.bin",
188
+ "layers.15.attention.wqkv.bias": "pytorch_model-00002-of-00003.bin",
189
+ "layers.15.attention.wqkv.weight": "pytorch_model-00002-of-00003.bin",
190
+ "layers.15.attention_norm.bias": "pytorch_model-00002-of-00003.bin",
191
+ "layers.15.attention_norm.weight": "pytorch_model-00002-of-00003.bin",
192
+ "layers.15.feed_forward.w1.bias": "pytorch_model-00002-of-00003.bin",
193
+ "layers.15.feed_forward.w1.weight": "pytorch_model-00002-of-00003.bin",
194
+ "layers.15.feed_forward.w2.bias": "pytorch_model-00002-of-00003.bin",
195
+ "layers.15.feed_forward.w2.weight": "pytorch_model-00002-of-00003.bin",
196
+ "layers.15.ffn_norm.bias": "pytorch_model-00002-of-00003.bin",
197
+ "layers.15.ffn_norm.weight": "pytorch_model-00002-of-00003.bin",
198
+ "layers.16.attention.dyn_w_proj.dd": "pytorch_model-00002-of-00003.bin",
199
+ "layers.16.attention.dyn_w_proj.dw1": "pytorch_model-00002-of-00003.bin",
200
+ "layers.16.attention.dyn_w_proj.dw_m": "pytorch_model-00002-of-00003.bin",
201
+ "layers.16.attention.dyn_w_proj.post_proj.w": "pytorch_model-00002-of-00003.bin",
202
+ "layers.16.attention.dyn_w_proj.pre_proj.w": "pytorch_model-00002-of-00003.bin",
203
+ "layers.16.attention.dyn_w_proj.qkw": "pytorch_model-00002-of-00003.bin",
204
+ "layers.16.attention.dyn_w_proj.qkw_m": "pytorch_model-00002-of-00003.bin",
205
+ "layers.16.attention.k_norm.scale": "pytorch_model-00002-of-00003.bin",
206
+ "layers.16.attention.kv_cache.k_cache": "pytorch_model-00002-of-00003.bin",
207
+ "layers.16.attention.kv_cache.kw_cache": "pytorch_model-00002-of-00003.bin",
208
+ "layers.16.attention.kv_cache.v_cache": "pytorch_model-00002-of-00003.bin",
209
+ "layers.16.attention.q_norm.scale": "pytorch_model-00002-of-00003.bin",
210
+ "layers.16.attention.wo.bias": "pytorch_model-00002-of-00003.bin",
211
+ "layers.16.attention.wo.weight": "pytorch_model-00002-of-00003.bin",
212
+ "layers.16.attention.wqkv.bias": "pytorch_model-00002-of-00003.bin",
213
+ "layers.16.attention.wqkv.weight": "pytorch_model-00002-of-00003.bin",
214
+ "layers.16.attention_norm.bias": "pytorch_model-00002-of-00003.bin",
215
+ "layers.16.attention_norm.weight": "pytorch_model-00002-of-00003.bin",
216
+ "layers.16.feed_forward.w1.bias": "pytorch_model-00002-of-00003.bin",
217
+ "layers.16.feed_forward.w1.weight": "pytorch_model-00002-of-00003.bin",
218
+ "layers.16.feed_forward.w2.bias": "pytorch_model-00002-of-00003.bin",
219
+ "layers.16.feed_forward.w2.weight": "pytorch_model-00002-of-00003.bin",
220
+ "layers.16.ffn_norm.bias": "pytorch_model-00002-of-00003.bin",
221
+ "layers.16.ffn_norm.weight": "pytorch_model-00002-of-00003.bin",
222
+ "layers.17.attention.dyn_w_proj.dd": "pytorch_model-00002-of-00003.bin",
223
+ "layers.17.attention.dyn_w_proj.dw1": "pytorch_model-00002-of-00003.bin",
224
+ "layers.17.attention.dyn_w_proj.dw_m": "pytorch_model-00002-of-00003.bin",
225
+ "layers.17.attention.dyn_w_proj.post_proj.w": "pytorch_model-00002-of-00003.bin",
226
+ "layers.17.attention.dyn_w_proj.pre_proj.w": "pytorch_model-00002-of-00003.bin",
227
+ "layers.17.attention.dyn_w_proj.qkw": "pytorch_model-00002-of-00003.bin",
228
+ "layers.17.attention.dyn_w_proj.qkw_m": "pytorch_model-00002-of-00003.bin",
229
+ "layers.17.attention.k_norm.scale": "pytorch_model-00002-of-00003.bin",
230
+ "layers.17.attention.kv_cache.k_cache": "pytorch_model-00002-of-00003.bin",
231
+ "layers.17.attention.kv_cache.kw_cache": "pytorch_model-00002-of-00003.bin",
232
+ "layers.17.attention.kv_cache.v_cache": "pytorch_model-00002-of-00003.bin",
233
+ "layers.17.attention.q_norm.scale": "pytorch_model-00002-of-00003.bin",
234
+ "layers.17.attention.wo.bias": "pytorch_model-00002-of-00003.bin",
235
+ "layers.17.attention.wo.weight": "pytorch_model-00002-of-00003.bin",
236
+ "layers.17.attention.wqkv.bias": "pytorch_model-00002-of-00003.bin",
237
+ "layers.17.attention.wqkv.weight": "pytorch_model-00002-of-00003.bin",
238
+ "layers.17.attention_norm.bias": "pytorch_model-00002-of-00003.bin",
239
+ "layers.17.attention_norm.weight": "pytorch_model-00002-of-00003.bin",
240
+ "layers.17.feed_forward.w1.bias": "pytorch_model-00002-of-00003.bin",
241
+ "layers.17.feed_forward.w1.weight": "pytorch_model-00002-of-00003.bin",
242
+ "layers.17.feed_forward.w2.bias": "pytorch_model-00002-of-00003.bin",
243
+ "layers.17.feed_forward.w2.weight": "pytorch_model-00002-of-00003.bin",
244
+ "layers.17.ffn_norm.bias": "pytorch_model-00002-of-00003.bin",
245
+ "layers.17.ffn_norm.weight": "pytorch_model-00002-of-00003.bin",
246
+ "layers.18.attention.dyn_w_proj.dd": "pytorch_model-00002-of-00003.bin",
247
+ "layers.18.attention.dyn_w_proj.dw1": "pytorch_model-00002-of-00003.bin",
248
+ "layers.18.attention.dyn_w_proj.dw_m": "pytorch_model-00002-of-00003.bin",
249
+ "layers.18.attention.dyn_w_proj.post_proj.w": "pytorch_model-00002-of-00003.bin",
250
+ "layers.18.attention.dyn_w_proj.pre_proj.w": "pytorch_model-00002-of-00003.bin",
251
+ "layers.18.attention.dyn_w_proj.qkw": "pytorch_model-00002-of-00003.bin",
252
+ "layers.18.attention.dyn_w_proj.qkw_m": "pytorch_model-00002-of-00003.bin",
253
+ "layers.18.attention.k_norm.scale": "pytorch_model-00002-of-00003.bin",
254
+ "layers.18.attention.kv_cache.k_cache": "pytorch_model-00002-of-00003.bin",
255
+ "layers.18.attention.kv_cache.kw_cache": "pytorch_model-00002-of-00003.bin",
256
+ "layers.18.attention.kv_cache.v_cache": "pytorch_model-00002-of-00003.bin",
257
+ "layers.18.attention.q_norm.scale": "pytorch_model-00002-of-00003.bin",
258
+ "layers.18.attention.wo.bias": "pytorch_model-00002-of-00003.bin",
259
+ "layers.18.attention.wo.weight": "pytorch_model-00002-of-00003.bin",
260
+ "layers.18.attention.wqkv.bias": "pytorch_model-00002-of-00003.bin",
261
+ "layers.18.attention.wqkv.weight": "pytorch_model-00002-of-00003.bin",
262
+ "layers.18.attention_norm.bias": "pytorch_model-00002-of-00003.bin",
263
+ "layers.18.attention_norm.weight": "pytorch_model-00002-of-00003.bin",
264
+ "layers.18.feed_forward.w1.bias": "pytorch_model-00002-of-00003.bin",
265
+ "layers.18.feed_forward.w1.weight": "pytorch_model-00002-of-00003.bin",
266
+ "layers.18.feed_forward.w2.bias": "pytorch_model-00002-of-00003.bin",
267
+ "layers.18.feed_forward.w2.weight": "pytorch_model-00002-of-00003.bin",
268
+ "layers.18.ffn_norm.bias": "pytorch_model-00002-of-00003.bin",
269
+ "layers.18.ffn_norm.weight": "pytorch_model-00002-of-00003.bin",
270
+ "layers.19.attention.dyn_w_proj.dd": "pytorch_model-00002-of-00003.bin",
271
+ "layers.19.attention.dyn_w_proj.dw1": "pytorch_model-00002-of-00003.bin",
272
+ "layers.19.attention.dyn_w_proj.dw_m": "pytorch_model-00002-of-00003.bin",
273
+ "layers.19.attention.dyn_w_proj.post_proj.w": "pytorch_model-00002-of-00003.bin",
274
+ "layers.19.attention.dyn_w_proj.pre_proj.w": "pytorch_model-00002-of-00003.bin",
275
+ "layers.19.attention.dyn_w_proj.qkw": "pytorch_model-00002-of-00003.bin",
276
+ "layers.19.attention.dyn_w_proj.qkw_m": "pytorch_model-00002-of-00003.bin",
277
+ "layers.19.attention.k_norm.scale": "pytorch_model-00002-of-00003.bin",
278
+ "layers.19.attention.kv_cache.k_cache": "pytorch_model-00002-of-00003.bin",
279
+ "layers.19.attention.kv_cache.kw_cache": "pytorch_model-00002-of-00003.bin",
280
+ "layers.19.attention.kv_cache.v_cache": "pytorch_model-00002-of-00003.bin",
281
+ "layers.19.attention.q_norm.scale": "pytorch_model-00002-of-00003.bin",
282
+ "layers.19.attention.wo.bias": "pytorch_model-00002-of-00003.bin",
283
+ "layers.19.attention.wo.weight": "pytorch_model-00002-of-00003.bin",
284
+ "layers.19.attention.wqkv.bias": "pytorch_model-00002-of-00003.bin",
285
+ "layers.19.attention.wqkv.weight": "pytorch_model-00002-of-00003.bin",
286
+ "layers.19.attention_norm.bias": "pytorch_model-00002-of-00003.bin",
287
+ "layers.19.attention_norm.weight": "pytorch_model-00002-of-00003.bin",
288
+ "layers.19.feed_forward.w1.bias": "pytorch_model-00002-of-00003.bin",
289
+ "layers.19.feed_forward.w1.weight": "pytorch_model-00002-of-00003.bin",
290
+ "layers.19.feed_forward.w2.bias": "pytorch_model-00002-of-00003.bin",
291
+ "layers.19.feed_forward.w2.weight": "pytorch_model-00002-of-00003.bin",
292
+ "layers.19.ffn_norm.bias": "pytorch_model-00002-of-00003.bin",
293
+ "layers.19.ffn_norm.weight": "pytorch_model-00002-of-00003.bin",
294
+ "layers.2.attention.dyn_w_proj.dd": "pytorch_model-00001-of-00003.bin",
295
+ "layers.2.attention.dyn_w_proj.dw1": "pytorch_model-00001-of-00003.bin",
296
+ "layers.2.attention.dyn_w_proj.dw_m": "pytorch_model-00001-of-00003.bin",
297
+ "layers.2.attention.dyn_w_proj.post_proj.w": "pytorch_model-00001-of-00003.bin",
298
+ "layers.2.attention.dyn_w_proj.pre_proj.w": "pytorch_model-00001-of-00003.bin",
299
+ "layers.2.attention.dyn_w_proj.qkw": "pytorch_model-00001-of-00003.bin",
300
+ "layers.2.attention.dyn_w_proj.qkw_m": "pytorch_model-00001-of-00003.bin",
301
+ "layers.2.attention.k_norm.scale": "pytorch_model-00001-of-00003.bin",
302
+ "layers.2.attention.kv_cache.k_cache": "pytorch_model-00001-of-00003.bin",
303
+ "layers.2.attention.kv_cache.kw_cache": "pytorch_model-00001-of-00003.bin",
304
+ "layers.2.attention.kv_cache.v_cache": "pytorch_model-00001-of-00003.bin",
305
+ "layers.2.attention.q_norm.scale": "pytorch_model-00001-of-00003.bin",
306
+ "layers.2.attention.wo.bias": "pytorch_model-00001-of-00003.bin",
307
+ "layers.2.attention.wo.weight": "pytorch_model-00001-of-00003.bin",
308
+ "layers.2.attention.wqkv.bias": "pytorch_model-00001-of-00003.bin",
309
+ "layers.2.attention.wqkv.weight": "pytorch_model-00001-of-00003.bin",
310
+ "layers.2.attention_norm.bias": "pytorch_model-00001-of-00003.bin",
311
+ "layers.2.attention_norm.weight": "pytorch_model-00001-of-00003.bin",
312
+ "layers.2.feed_forward.w1.bias": "pytorch_model-00001-of-00003.bin",
313
+ "layers.2.feed_forward.w1.weight": "pytorch_model-00001-of-00003.bin",
314
+ "layers.2.feed_forward.w2.bias": "pytorch_model-00001-of-00003.bin",
315
+ "layers.2.feed_forward.w2.weight": "pytorch_model-00001-of-00003.bin",
316
+ "layers.2.ffn_norm.bias": "pytorch_model-00001-of-00003.bin",
317
+ "layers.2.ffn_norm.weight": "pytorch_model-00001-of-00003.bin",
318
+ "layers.20.attention.dyn_w_proj.dd": "pytorch_model-00002-of-00003.bin",
319
+ "layers.20.attention.dyn_w_proj.dw1": "pytorch_model-00002-of-00003.bin",
320
+ "layers.20.attention.dyn_w_proj.dw_m": "pytorch_model-00002-of-00003.bin",
321
+ "layers.20.attention.dyn_w_proj.post_proj.w": "pytorch_model-00002-of-00003.bin",
322
+ "layers.20.attention.dyn_w_proj.pre_proj.w": "pytorch_model-00002-of-00003.bin",
323
+ "layers.20.attention.dyn_w_proj.qkw": "pytorch_model-00002-of-00003.bin",
324
+ "layers.20.attention.dyn_w_proj.qkw_m": "pytorch_model-00002-of-00003.bin",
325
+ "layers.20.attention.k_norm.scale": "pytorch_model-00002-of-00003.bin",
326
+ "layers.20.attention.kv_cache.k_cache": "pytorch_model-00002-of-00003.bin",
327
+ "layers.20.attention.kv_cache.kw_cache": "pytorch_model-00002-of-00003.bin",
328
+ "layers.20.attention.kv_cache.v_cache": "pytorch_model-00002-of-00003.bin",
329
+ "layers.20.attention.q_norm.scale": "pytorch_model-00002-of-00003.bin",
330
+ "layers.20.attention.wo.bias": "pytorch_model-00002-of-00003.bin",
331
+ "layers.20.attention.wo.weight": "pytorch_model-00002-of-00003.bin",
332
+ "layers.20.attention.wqkv.bias": "pytorch_model-00002-of-00003.bin",
333
+ "layers.20.attention.wqkv.weight": "pytorch_model-00002-of-00003.bin",
334
+ "layers.20.attention_norm.bias": "pytorch_model-00002-of-00003.bin",
335
+ "layers.20.attention_norm.weight": "pytorch_model-00002-of-00003.bin",
336
+ "layers.20.feed_forward.w1.bias": "pytorch_model-00002-of-00003.bin",
337
+ "layers.20.feed_forward.w1.weight": "pytorch_model-00002-of-00003.bin",
338
+ "layers.20.feed_forward.w2.bias": "pytorch_model-00002-of-00003.bin",
339
+ "layers.20.feed_forward.w2.weight": "pytorch_model-00002-of-00003.bin",
340
+ "layers.20.ffn_norm.bias": "pytorch_model-00002-of-00003.bin",
341
+ "layers.20.ffn_norm.weight": "pytorch_model-00002-of-00003.bin",
342
+ "layers.21.attention.dyn_w_proj.dd": "pytorch_model-00002-of-00003.bin",
343
+ "layers.21.attention.dyn_w_proj.dw1": "pytorch_model-00002-of-00003.bin",
344
+ "layers.21.attention.dyn_w_proj.dw_m": "pytorch_model-00002-of-00003.bin",
345
+ "layers.21.attention.dyn_w_proj.post_proj.w": "pytorch_model-00002-of-00003.bin",
346
+ "layers.21.attention.dyn_w_proj.pre_proj.w": "pytorch_model-00002-of-00003.bin",
347
+ "layers.21.attention.dyn_w_proj.qkw": "pytorch_model-00002-of-00003.bin",
348
+ "layers.21.attention.dyn_w_proj.qkw_m": "pytorch_model-00002-of-00003.bin",
349
+ "layers.21.attention.k_norm.scale": "pytorch_model-00002-of-00003.bin",
350
+ "layers.21.attention.kv_cache.k_cache": "pytorch_model-00002-of-00003.bin",
351
+ "layers.21.attention.kv_cache.kw_cache": "pytorch_model-00002-of-00003.bin",
352
+ "layers.21.attention.kv_cache.v_cache": "pytorch_model-00002-of-00003.bin",
353
+ "layers.21.attention.q_norm.scale": "pytorch_model-00002-of-00003.bin",
354
+ "layers.21.attention.wo.bias": "pytorch_model-00002-of-00003.bin",
355
+ "layers.21.attention.wo.weight": "pytorch_model-00002-of-00003.bin",
356
+ "layers.21.attention.wqkv.bias": "pytorch_model-00002-of-00003.bin",
357
+ "layers.21.attention.wqkv.weight": "pytorch_model-00002-of-00003.bin",
358
+ "layers.21.attention_norm.bias": "pytorch_model-00003-of-00003.bin",
359
+ "layers.21.attention_norm.weight": "pytorch_model-00003-of-00003.bin",
360
+ "layers.21.feed_forward.w1.bias": "pytorch_model-00002-of-00003.bin",
361
+ "layers.21.feed_forward.w1.weight": "pytorch_model-00002-of-00003.bin",
362
+ "layers.21.feed_forward.w2.bias": "pytorch_model-00003-of-00003.bin",
363
+ "layers.21.feed_forward.w2.weight": "pytorch_model-00003-of-00003.bin",
364
+ "layers.21.ffn_norm.bias": "pytorch_model-00003-of-00003.bin",
365
+ "layers.21.ffn_norm.weight": "pytorch_model-00003-of-00003.bin",
366
+ "layers.22.attention.dyn_w_proj.dd": "pytorch_model-00003-of-00003.bin",
367
+ "layers.22.attention.dyn_w_proj.dw1": "pytorch_model-00003-of-00003.bin",
368
+ "layers.22.attention.dyn_w_proj.dw_m": "pytorch_model-00003-of-00003.bin",
369
+ "layers.22.attention.dyn_w_proj.post_proj.w": "pytorch_model-00003-of-00003.bin",
370
+ "layers.22.attention.dyn_w_proj.pre_proj.w": "pytorch_model-00003-of-00003.bin",
371
+ "layers.22.attention.dyn_w_proj.qkw": "pytorch_model-00003-of-00003.bin",
372
+ "layers.22.attention.dyn_w_proj.qkw_m": "pytorch_model-00003-of-00003.bin",
373
+ "layers.22.attention.k_norm.scale": "pytorch_model-00003-of-00003.bin",
374
+ "layers.22.attention.kv_cache.k_cache": "pytorch_model-00003-of-00003.bin",
375
+ "layers.22.attention.kv_cache.kw_cache": "pytorch_model-00003-of-00003.bin",
376
+ "layers.22.attention.kv_cache.v_cache": "pytorch_model-00003-of-00003.bin",
377
+ "layers.22.attention.q_norm.scale": "pytorch_model-00003-of-00003.bin",
378
+ "layers.22.attention.wo.bias": "pytorch_model-00003-of-00003.bin",
379
+ "layers.22.attention.wo.weight": "pytorch_model-00003-of-00003.bin",
380
+ "layers.22.attention.wqkv.bias": "pytorch_model-00003-of-00003.bin",
381
+ "layers.22.attention.wqkv.weight": "pytorch_model-00003-of-00003.bin",
382
+ "layers.22.attention_norm.bias": "pytorch_model-00003-of-00003.bin",
383
+ "layers.22.attention_norm.weight": "pytorch_model-00003-of-00003.bin",
384
+ "layers.22.feed_forward.w1.bias": "pytorch_model-00003-of-00003.bin",
385
+ "layers.22.feed_forward.w1.weight": "pytorch_model-00003-of-00003.bin",
386
+ "layers.22.feed_forward.w2.bias": "pytorch_model-00003-of-00003.bin",
387
+ "layers.22.feed_forward.w2.weight": "pytorch_model-00003-of-00003.bin",
388
+ "layers.22.ffn_norm.bias": "pytorch_model-00003-of-00003.bin",
389
+ "layers.22.ffn_norm.weight": "pytorch_model-00003-of-00003.bin",
390
+ "layers.23.attention.dyn_w_proj.dd": "pytorch_model-00003-of-00003.bin",
391
+ "layers.23.attention.dyn_w_proj.dw1": "pytorch_model-00003-of-00003.bin",
392
+ "layers.23.attention.dyn_w_proj.dw_m": "pytorch_model-00003-of-00003.bin",
393
+ "layers.23.attention.dyn_w_proj.post_proj.w": "pytorch_model-00003-of-00003.bin",
394
+ "layers.23.attention.dyn_w_proj.pre_proj.w": "pytorch_model-00003-of-00003.bin",
395
+ "layers.23.attention.dyn_w_proj.qkw": "pytorch_model-00003-of-00003.bin",
396
+ "layers.23.attention.dyn_w_proj.qkw_m": "pytorch_model-00003-of-00003.bin",
397
+ "layers.23.attention.k_norm.scale": "pytorch_model-00003-of-00003.bin",
398
+ "layers.23.attention.kv_cache.k_cache": "pytorch_model-00003-of-00003.bin",
399
+ "layers.23.attention.kv_cache.kw_cache": "pytorch_model-00003-of-00003.bin",
400
+ "layers.23.attention.kv_cache.v_cache": "pytorch_model-00003-of-00003.bin",
401
+ "layers.23.attention.q_norm.scale": "pytorch_model-00003-of-00003.bin",
402
+ "layers.23.attention.wo.bias": "pytorch_model-00003-of-00003.bin",
403
+ "layers.23.attention.wo.weight": "pytorch_model-00003-of-00003.bin",
404
+ "layers.23.attention.wqkv.bias": "pytorch_model-00003-of-00003.bin",
405
+ "layers.23.attention.wqkv.weight": "pytorch_model-00003-of-00003.bin",
406
+ "layers.23.attention_norm.bias": "pytorch_model-00003-of-00003.bin",
407
+ "layers.23.attention_norm.weight": "pytorch_model-00003-of-00003.bin",
408
+ "layers.23.feed_forward.w1.bias": "pytorch_model-00003-of-00003.bin",
409
+ "layers.23.feed_forward.w1.weight": "pytorch_model-00003-of-00003.bin",
410
+ "layers.23.feed_forward.w2.bias": "pytorch_model-00003-of-00003.bin",
411
+ "layers.23.feed_forward.w2.weight": "pytorch_model-00003-of-00003.bin",
412
+ "layers.23.ffn_norm.bias": "pytorch_model-00003-of-00003.bin",
413
+ "layers.23.ffn_norm.weight": "pytorch_model-00003-of-00003.bin",
414
+ "layers.24.attention.dyn_w_proj.dd": "pytorch_model-00003-of-00003.bin",
415
+ "layers.24.attention.dyn_w_proj.dw1": "pytorch_model-00003-of-00003.bin",
416
+ "layers.24.attention.dyn_w_proj.dw_m": "pytorch_model-00003-of-00003.bin",
417
+ "layers.24.attention.dyn_w_proj.post_proj.w": "pytorch_model-00003-of-00003.bin",
418
+ "layers.24.attention.dyn_w_proj.pre_proj.w": "pytorch_model-00003-of-00003.bin",
419
+ "layers.24.attention.dyn_w_proj.qkw": "pytorch_model-00003-of-00003.bin",
420
+ "layers.24.attention.dyn_w_proj.qkw_m": "pytorch_model-00003-of-00003.bin",
421
+ "layers.24.attention.k_norm.scale": "pytorch_model-00003-of-00003.bin",
422
+ "layers.24.attention.kv_cache.k_cache": "pytorch_model-00003-of-00003.bin",
423
+ "layers.24.attention.kv_cache.kw_cache": "pytorch_model-00003-of-00003.bin",
424
+ "layers.24.attention.kv_cache.v_cache": "pytorch_model-00003-of-00003.bin",
425
+ "layers.24.attention.q_norm.scale": "pytorch_model-00003-of-00003.bin",
426
+ "layers.24.attention.wo.bias": "pytorch_model-00003-of-00003.bin",
427
+ "layers.24.attention.wo.weight": "pytorch_model-00003-of-00003.bin",
428
+ "layers.24.attention.wqkv.bias": "pytorch_model-00003-of-00003.bin",
429
+ "layers.24.attention.wqkv.weight": "pytorch_model-00003-of-00003.bin",
430
+ "layers.24.attention_norm.bias": "pytorch_model-00003-of-00003.bin",
431
+ "layers.24.attention_norm.weight": "pytorch_model-00003-of-00003.bin",
432
+ "layers.24.feed_forward.w1.bias": "pytorch_model-00003-of-00003.bin",
433
+ "layers.24.feed_forward.w1.weight": "pytorch_model-00003-of-00003.bin",
434
+ "layers.24.feed_forward.w2.bias": "pytorch_model-00003-of-00003.bin",
435
+ "layers.24.feed_forward.w2.weight": "pytorch_model-00003-of-00003.bin",
436
+ "layers.24.ffn_norm.bias": "pytorch_model-00003-of-00003.bin",
437
+ "layers.24.ffn_norm.weight": "pytorch_model-00003-of-00003.bin",
438
+ "layers.25.attention.dyn_w_proj.dd": "pytorch_model-00003-of-00003.bin",
439
+ "layers.25.attention.dyn_w_proj.dw1": "pytorch_model-00003-of-00003.bin",
440
+ "layers.25.attention.dyn_w_proj.dw_m": "pytorch_model-00003-of-00003.bin",
441
+ "layers.25.attention.dyn_w_proj.post_proj.w": "pytorch_model-00003-of-00003.bin",
442
+ "layers.25.attention.dyn_w_proj.pre_proj.w": "pytorch_model-00003-of-00003.bin",
443
+ "layers.25.attention.dyn_w_proj.qkw": "pytorch_model-00003-of-00003.bin",
444
+ "layers.25.attention.dyn_w_proj.qkw_m": "pytorch_model-00003-of-00003.bin",
445
+ "layers.25.attention.k_norm.scale": "pytorch_model-00003-of-00003.bin",
446
+ "layers.25.attention.kv_cache.k_cache": "pytorch_model-00003-of-00003.bin",
447
+ "layers.25.attention.kv_cache.kw_cache": "pytorch_model-00003-of-00003.bin",
448
+ "layers.25.attention.kv_cache.v_cache": "pytorch_model-00003-of-00003.bin",
449
+ "layers.25.attention.q_norm.scale": "pytorch_model-00003-of-00003.bin",
450
+ "layers.25.attention.wo.bias": "pytorch_model-00003-of-00003.bin",
451
+ "layers.25.attention.wo.weight": "pytorch_model-00003-of-00003.bin",
452
+ "layers.25.attention.wqkv.bias": "pytorch_model-00003-of-00003.bin",
453
+ "layers.25.attention.wqkv.weight": "pytorch_model-00003-of-00003.bin",
454
+ "layers.25.attention_norm.bias": "pytorch_model-00003-of-00003.bin",
455
+ "layers.25.attention_norm.weight": "pytorch_model-00003-of-00003.bin",
456
+ "layers.25.feed_forward.w1.bias": "pytorch_model-00003-of-00003.bin",
457
+ "layers.25.feed_forward.w1.weight": "pytorch_model-00003-of-00003.bin",
458
+ "layers.25.feed_forward.w2.bias": "pytorch_model-00003-of-00003.bin",
459
+ "layers.25.feed_forward.w2.weight": "pytorch_model-00003-of-00003.bin",
460
+ "layers.25.ffn_norm.bias": "pytorch_model-00003-of-00003.bin",
461
+ "layers.25.ffn_norm.weight": "pytorch_model-00003-of-00003.bin",
462
+ "layers.26.attention.dyn_w_proj.dd": "pytorch_model-00003-of-00003.bin",
463
+ "layers.26.attention.dyn_w_proj.dw1": "pytorch_model-00003-of-00003.bin",
464
+ "layers.26.attention.dyn_w_proj.dw_m": "pytorch_model-00003-of-00003.bin",
465
+ "layers.26.attention.dyn_w_proj.post_proj.w": "pytorch_model-00003-of-00003.bin",
466
+ "layers.26.attention.dyn_w_proj.pre_proj.w": "pytorch_model-00003-of-00003.bin",
467
+ "layers.26.attention.dyn_w_proj.qkw": "pytorch_model-00003-of-00003.bin",
468
+ "layers.26.attention.dyn_w_proj.qkw_m": "pytorch_model-00003-of-00003.bin",
469
+ "layers.26.attention.k_norm.scale": "pytorch_model-00003-of-00003.bin",
470
+ "layers.26.attention.kv_cache.k_cache": "pytorch_model-00003-of-00003.bin",
471
+ "layers.26.attention.kv_cache.kw_cache": "pytorch_model-00003-of-00003.bin",
472
+ "layers.26.attention.kv_cache.v_cache": "pytorch_model-00003-of-00003.bin",
473
+ "layers.26.attention.q_norm.scale": "pytorch_model-00003-of-00003.bin",
474
+ "layers.26.attention.wo.bias": "pytorch_model-00003-of-00003.bin",
475
+ "layers.26.attention.wo.weight": "pytorch_model-00003-of-00003.bin",
476
+ "layers.26.attention.wqkv.bias": "pytorch_model-00003-of-00003.bin",
477
+ "layers.26.attention.wqkv.weight": "pytorch_model-00003-of-00003.bin",
478
+ "layers.26.attention_norm.bias": "pytorch_model-00003-of-00003.bin",
479
+ "layers.26.attention_norm.weight": "pytorch_model-00003-of-00003.bin",
480
+ "layers.26.feed_forward.w1.bias": "pytorch_model-00003-of-00003.bin",
481
+ "layers.26.feed_forward.w1.weight": "pytorch_model-00003-of-00003.bin",
482
+ "layers.26.feed_forward.w2.bias": "pytorch_model-00003-of-00003.bin",
483
+ "layers.26.feed_forward.w2.weight": "pytorch_model-00003-of-00003.bin",
484
+ "layers.26.ffn_norm.bias": "pytorch_model-00003-of-00003.bin",
485
+ "layers.26.ffn_norm.weight": "pytorch_model-00003-of-00003.bin",
486
+ "layers.27.attention.dyn_w_proj.dd": "pytorch_model-00003-of-00003.bin",
487
+ "layers.27.attention.dyn_w_proj.dw1": "pytorch_model-00003-of-00003.bin",
488
+ "layers.27.attention.dyn_w_proj.dw_m": "pytorch_model-00003-of-00003.bin",
489
+ "layers.27.attention.dyn_w_proj.post_proj.w": "pytorch_model-00003-of-00003.bin",
490
+ "layers.27.attention.dyn_w_proj.pre_proj.w": "pytorch_model-00003-of-00003.bin",
491
+ "layers.27.attention.dyn_w_proj.qkw": "pytorch_model-00003-of-00003.bin",
492
+ "layers.27.attention.dyn_w_proj.qkw_m": "pytorch_model-00003-of-00003.bin",
493
+ "layers.27.attention.k_norm.scale": "pytorch_model-00003-of-00003.bin",
494
+ "layers.27.attention.kv_cache.k_cache": "pytorch_model-00003-of-00003.bin",
495
+ "layers.27.attention.kv_cache.kw_cache": "pytorch_model-00003-of-00003.bin",
496
+ "layers.27.attention.kv_cache.v_cache": "pytorch_model-00003-of-00003.bin",
497
+ "layers.27.attention.q_norm.scale": "pytorch_model-00003-of-00003.bin",
498
+ "layers.27.attention.wo.bias": "pytorch_model-00003-of-00003.bin",
499
+ "layers.27.attention.wo.weight": "pytorch_model-00003-of-00003.bin",
500
+ "layers.27.attention.wqkv.bias": "pytorch_model-00003-of-00003.bin",
501
+ "layers.27.attention.wqkv.weight": "pytorch_model-00003-of-00003.bin",
502
+ "layers.27.attention_norm.bias": "pytorch_model-00003-of-00003.bin",
503
+ "layers.27.attention_norm.weight": "pytorch_model-00003-of-00003.bin",
504
+ "layers.27.feed_forward.w1.bias": "pytorch_model-00003-of-00003.bin",
505
+ "layers.27.feed_forward.w1.weight": "pytorch_model-00003-of-00003.bin",
506
+ "layers.27.feed_forward.w2.bias": "pytorch_model-00003-of-00003.bin",
507
+ "layers.27.feed_forward.w2.weight": "pytorch_model-00003-of-00003.bin",
508
+ "layers.27.ffn_norm.bias": "pytorch_model-00003-of-00003.bin",
509
+ "layers.27.ffn_norm.weight": "pytorch_model-00003-of-00003.bin",
510
+ "layers.28.attention.dyn_w_proj.dd": "pytorch_model-00003-of-00003.bin",
511
+ "layers.28.attention.dyn_w_proj.dw1": "pytorch_model-00003-of-00003.bin",
512
+ "layers.28.attention.dyn_w_proj.dw_m": "pytorch_model-00003-of-00003.bin",
513
+ "layers.28.attention.dyn_w_proj.post_proj.w": "pytorch_model-00003-of-00003.bin",
514
+ "layers.28.attention.dyn_w_proj.pre_proj.w": "pytorch_model-00003-of-00003.bin",
515
+ "layers.28.attention.dyn_w_proj.qkw": "pytorch_model-00003-of-00003.bin",
516
+ "layers.28.attention.dyn_w_proj.qkw_m": "pytorch_model-00003-of-00003.bin",
517
+ "layers.28.attention.k_norm.scale": "pytorch_model-00003-of-00003.bin",
518
+ "layers.28.attention.kv_cache.k_cache": "pytorch_model-00003-of-00003.bin",
519
+ "layers.28.attention.kv_cache.kw_cache": "pytorch_model-00003-of-00003.bin",
520
+ "layers.28.attention.kv_cache.v_cache": "pytorch_model-00003-of-00003.bin",
521
+ "layers.28.attention.q_norm.scale": "pytorch_model-00003-of-00003.bin",
522
+ "layers.28.attention.wo.bias": "pytorch_model-00003-of-00003.bin",
523
+ "layers.28.attention.wo.weight": "pytorch_model-00003-of-00003.bin",
524
+ "layers.28.attention.wqkv.bias": "pytorch_model-00003-of-00003.bin",
525
+ "layers.28.attention.wqkv.weight": "pytorch_model-00003-of-00003.bin",
526
+ "layers.28.attention_norm.bias": "pytorch_model-00003-of-00003.bin",
527
+ "layers.28.attention_norm.weight": "pytorch_model-00003-of-00003.bin",
528
+ "layers.28.feed_forward.w1.bias": "pytorch_model-00003-of-00003.bin",
529
+ "layers.28.feed_forward.w1.weight": "pytorch_model-00003-of-00003.bin",
530
+ "layers.28.feed_forward.w2.bias": "pytorch_model-00003-of-00003.bin",
531
+ "layers.28.feed_forward.w2.weight": "pytorch_model-00003-of-00003.bin",
532
+ "layers.28.ffn_norm.bias": "pytorch_model-00003-of-00003.bin",
533
+ "layers.28.ffn_norm.weight": "pytorch_model-00003-of-00003.bin",
534
+ "layers.29.attention.dyn_w_proj.dd": "pytorch_model-00003-of-00003.bin",
535
+ "layers.29.attention.dyn_w_proj.dw1": "pytorch_model-00003-of-00003.bin",
536
+ "layers.29.attention.dyn_w_proj.dw_m": "pytorch_model-00003-of-00003.bin",
537
+ "layers.29.attention.dyn_w_proj.post_proj.w": "pytorch_model-00003-of-00003.bin",
538
+ "layers.29.attention.dyn_w_proj.pre_proj.w": "pytorch_model-00003-of-00003.bin",
539
+ "layers.29.attention.dyn_w_proj.qkw": "pytorch_model-00003-of-00003.bin",
540
+ "layers.29.attention.dyn_w_proj.qkw_m": "pytorch_model-00003-of-00003.bin",
541
+ "layers.29.attention.k_norm.scale": "pytorch_model-00003-of-00003.bin",
542
+ "layers.29.attention.kv_cache.k_cache": "pytorch_model-00003-of-00003.bin",
543
+ "layers.29.attention.kv_cache.kw_cache": "pytorch_model-00003-of-00003.bin",
544
+ "layers.29.attention.kv_cache.v_cache": "pytorch_model-00003-of-00003.bin",
545
+ "layers.29.attention.q_norm.scale": "pytorch_model-00003-of-00003.bin",
546
+ "layers.29.attention.wo.bias": "pytorch_model-00003-of-00003.bin",
547
+ "layers.29.attention.wo.weight": "pytorch_model-00003-of-00003.bin",
548
+ "layers.29.attention.wqkv.bias": "pytorch_model-00003-of-00003.bin",
549
+ "layers.29.attention.wqkv.weight": "pytorch_model-00003-of-00003.bin",
550
+ "layers.29.attention_norm.bias": "pytorch_model-00003-of-00003.bin",
551
+ "layers.29.attention_norm.weight": "pytorch_model-00003-of-00003.bin",
552
+ "layers.29.feed_forward.w1.bias": "pytorch_model-00003-of-00003.bin",
553
+ "layers.29.feed_forward.w1.weight": "pytorch_model-00003-of-00003.bin",
554
+ "layers.29.feed_forward.w2.bias": "pytorch_model-00003-of-00003.bin",
555
+ "layers.29.feed_forward.w2.weight": "pytorch_model-00003-of-00003.bin",
556
+ "layers.29.ffn_norm.bias": "pytorch_model-00003-of-00003.bin",
557
+ "layers.29.ffn_norm.weight": "pytorch_model-00003-of-00003.bin",
558
+ "layers.3.attention.dyn_w_proj.dd": "pytorch_model-00001-of-00003.bin",
559
+ "layers.3.attention.dyn_w_proj.dw1": "pytorch_model-00001-of-00003.bin",
560
+ "layers.3.attention.dyn_w_proj.dw_m": "pytorch_model-00001-of-00003.bin",
561
+ "layers.3.attention.dyn_w_proj.post_proj.w": "pytorch_model-00001-of-00003.bin",
562
+ "layers.3.attention.dyn_w_proj.pre_proj.w": "pytorch_model-00001-of-00003.bin",
563
+ "layers.3.attention.dyn_w_proj.qkw": "pytorch_model-00001-of-00003.bin",
564
+ "layers.3.attention.dyn_w_proj.qkw_m": "pytorch_model-00001-of-00003.bin",
565
+ "layers.3.attention.k_norm.scale": "pytorch_model-00001-of-00003.bin",
566
+ "layers.3.attention.kv_cache.k_cache": "pytorch_model-00001-of-00003.bin",
567
+ "layers.3.attention.kv_cache.kw_cache": "pytorch_model-00001-of-00003.bin",
568
+ "layers.3.attention.kv_cache.v_cache": "pytorch_model-00001-of-00003.bin",
569
+ "layers.3.attention.q_norm.scale": "pytorch_model-00001-of-00003.bin",
570
+ "layers.3.attention.wo.bias": "pytorch_model-00001-of-00003.bin",
571
+ "layers.3.attention.wo.weight": "pytorch_model-00001-of-00003.bin",
572
+ "layers.3.attention.wqkv.bias": "pytorch_model-00001-of-00003.bin",
573
+ "layers.3.attention.wqkv.weight": "pytorch_model-00001-of-00003.bin",
574
+ "layers.3.attention_norm.bias": "pytorch_model-00001-of-00003.bin",
575
+ "layers.3.attention_norm.weight": "pytorch_model-00001-of-00003.bin",
576
+ "layers.3.feed_forward.w1.bias": "pytorch_model-00001-of-00003.bin",
577
+ "layers.3.feed_forward.w1.weight": "pytorch_model-00001-of-00003.bin",
578
+ "layers.3.feed_forward.w2.bias": "pytorch_model-00001-of-00003.bin",
579
+ "layers.3.feed_forward.w2.weight": "pytorch_model-00001-of-00003.bin",
580
+ "layers.3.ffn_norm.bias": "pytorch_model-00001-of-00003.bin",
581
+ "layers.3.ffn_norm.weight": "pytorch_model-00001-of-00003.bin",
582
+ "layers.30.attention.dyn_w_proj.dd": "pytorch_model-00003-of-00003.bin",
583
+ "layers.30.attention.dyn_w_proj.dw1": "pytorch_model-00003-of-00003.bin",
584
+ "layers.30.attention.dyn_w_proj.dw_m": "pytorch_model-00003-of-00003.bin",
585
+ "layers.30.attention.dyn_w_proj.post_proj.w": "pytorch_model-00003-of-00003.bin",
586
+ "layers.30.attention.dyn_w_proj.pre_proj.w": "pytorch_model-00003-of-00003.bin",
587
+ "layers.30.attention.dyn_w_proj.qkw": "pytorch_model-00003-of-00003.bin",
588
+ "layers.30.attention.dyn_w_proj.qkw_m": "pytorch_model-00003-of-00003.bin",
589
+ "layers.30.attention.k_norm.scale": "pytorch_model-00003-of-00003.bin",
590
+ "layers.30.attention.kv_cache.k_cache": "pytorch_model-00003-of-00003.bin",
591
+ "layers.30.attention.kv_cache.kw_cache": "pytorch_model-00003-of-00003.bin",
592
+ "layers.30.attention.kv_cache.v_cache": "pytorch_model-00003-of-00003.bin",
593
+ "layers.30.attention.q_norm.scale": "pytorch_model-00003-of-00003.bin",
594
+ "layers.30.attention.wo.bias": "pytorch_model-00003-of-00003.bin",
595
+ "layers.30.attention.wo.weight": "pytorch_model-00003-of-00003.bin",
596
+ "layers.30.attention.wqkv.bias": "pytorch_model-00003-of-00003.bin",
597
+ "layers.30.attention.wqkv.weight": "pytorch_model-00003-of-00003.bin",
598
+ "layers.30.attention_norm.bias": "pytorch_model-00003-of-00003.bin",
599
+ "layers.30.attention_norm.weight": "pytorch_model-00003-of-00003.bin",
600
+ "layers.30.feed_forward.w1.bias": "pytorch_model-00003-of-00003.bin",
601
+ "layers.30.feed_forward.w1.weight": "pytorch_model-00003-of-00003.bin",
602
+ "layers.30.feed_forward.w2.bias": "pytorch_model-00003-of-00003.bin",
603
+ "layers.30.feed_forward.w2.weight": "pytorch_model-00003-of-00003.bin",
604
+ "layers.30.ffn_norm.bias": "pytorch_model-00003-of-00003.bin",
605
+ "layers.30.ffn_norm.weight": "pytorch_model-00003-of-00003.bin",
606
+ "layers.31.attention.dyn_w_proj.dd": "pytorch_model-00003-of-00003.bin",
607
+ "layers.31.attention.dyn_w_proj.dw1": "pytorch_model-00003-of-00003.bin",
608
+ "layers.31.attention.dyn_w_proj.dw_m": "pytorch_model-00003-of-00003.bin",
609
+ "layers.31.attention.dyn_w_proj.post_proj.w": "pytorch_model-00003-of-00003.bin",
610
+ "layers.31.attention.dyn_w_proj.pre_proj.w": "pytorch_model-00003-of-00003.bin",
611
+ "layers.31.attention.dyn_w_proj.qkw": "pytorch_model-00003-of-00003.bin",
612
+ "layers.31.attention.dyn_w_proj.qkw_m": "pytorch_model-00003-of-00003.bin",
613
+ "layers.31.attention.k_norm.scale": "pytorch_model-00003-of-00003.bin",
614
+ "layers.31.attention.kv_cache.k_cache": "pytorch_model-00003-of-00003.bin",
615
+ "layers.31.attention.kv_cache.kw_cache": "pytorch_model-00003-of-00003.bin",
616
+ "layers.31.attention.kv_cache.v_cache": "pytorch_model-00003-of-00003.bin",
617
+ "layers.31.attention.q_norm.scale": "pytorch_model-00003-of-00003.bin",
618
+ "layers.31.attention.wo.bias": "pytorch_model-00003-of-00003.bin",
619
+ "layers.31.attention.wo.weight": "pytorch_model-00003-of-00003.bin",
620
+ "layers.31.attention.wqkv.bias": "pytorch_model-00003-of-00003.bin",
621
+ "layers.31.attention.wqkv.weight": "pytorch_model-00003-of-00003.bin",
622
+ "layers.31.attention_norm.bias": "pytorch_model-00003-of-00003.bin",
623
+ "layers.31.attention_norm.weight": "pytorch_model-00003-of-00003.bin",
624
+ "layers.31.feed_forward.w1.bias": "pytorch_model-00003-of-00003.bin",
625
+ "layers.31.feed_forward.w1.weight": "pytorch_model-00003-of-00003.bin",
626
+ "layers.31.feed_forward.w2.bias": "pytorch_model-00003-of-00003.bin",
627
+ "layers.31.feed_forward.w2.weight": "pytorch_model-00003-of-00003.bin",
628
+ "layers.31.ffn_norm.bias": "pytorch_model-00003-of-00003.bin",
629
+ "layers.31.ffn_norm.weight": "pytorch_model-00003-of-00003.bin",
630
+ "layers.4.attention.dyn_w_proj.dd": "pytorch_model-00001-of-00003.bin",
631
+ "layers.4.attention.dyn_w_proj.dw1": "pytorch_model-00001-of-00003.bin",
632
+ "layers.4.attention.dyn_w_proj.dw_m": "pytorch_model-00001-of-00003.bin",
633
+ "layers.4.attention.dyn_w_proj.post_proj.w": "pytorch_model-00001-of-00003.bin",
634
+ "layers.4.attention.dyn_w_proj.pre_proj.w": "pytorch_model-00001-of-00003.bin",
635
+ "layers.4.attention.dyn_w_proj.qkw": "pytorch_model-00001-of-00003.bin",
636
+ "layers.4.attention.dyn_w_proj.qkw_m": "pytorch_model-00001-of-00003.bin",
637
+ "layers.4.attention.k_norm.scale": "pytorch_model-00001-of-00003.bin",
638
+ "layers.4.attention.kv_cache.k_cache": "pytorch_model-00001-of-00003.bin",
639
+ "layers.4.attention.kv_cache.kw_cache": "pytorch_model-00001-of-00003.bin",
640
+ "layers.4.attention.kv_cache.v_cache": "pytorch_model-00001-of-00003.bin",
641
+ "layers.4.attention.q_norm.scale": "pytorch_model-00001-of-00003.bin",
642
+ "layers.4.attention.wo.bias": "pytorch_model-00001-of-00003.bin",
643
+ "layers.4.attention.wo.weight": "pytorch_model-00001-of-00003.bin",
644
+ "layers.4.attention.wqkv.bias": "pytorch_model-00001-of-00003.bin",
645
+ "layers.4.attention.wqkv.weight": "pytorch_model-00001-of-00003.bin",
646
+ "layers.4.attention_norm.bias": "pytorch_model-00001-of-00003.bin",
647
+ "layers.4.attention_norm.weight": "pytorch_model-00001-of-00003.bin",
648
+ "layers.4.feed_forward.w1.bias": "pytorch_model-00001-of-00003.bin",
649
+ "layers.4.feed_forward.w1.weight": "pytorch_model-00001-of-00003.bin",
650
+ "layers.4.feed_forward.w2.bias": "pytorch_model-00001-of-00003.bin",
651
+ "layers.4.feed_forward.w2.weight": "pytorch_model-00001-of-00003.bin",
652
+ "layers.4.ffn_norm.bias": "pytorch_model-00001-of-00003.bin",
653
+ "layers.4.ffn_norm.weight": "pytorch_model-00001-of-00003.bin",
654
+ "layers.5.attention.dyn_w_proj.dd": "pytorch_model-00001-of-00003.bin",
655
+ "layers.5.attention.dyn_w_proj.dw1": "pytorch_model-00001-of-00003.bin",
656
+ "layers.5.attention.dyn_w_proj.dw_m": "pytorch_model-00001-of-00003.bin",
657
+ "layers.5.attention.dyn_w_proj.post_proj.w": "pytorch_model-00001-of-00003.bin",
658
+ "layers.5.attention.dyn_w_proj.pre_proj.w": "pytorch_model-00001-of-00003.bin",
659
+ "layers.5.attention.dyn_w_proj.qkw": "pytorch_model-00001-of-00003.bin",
660
+ "layers.5.attention.dyn_w_proj.qkw_m": "pytorch_model-00001-of-00003.bin",
661
+ "layers.5.attention.k_norm.scale": "pytorch_model-00001-of-00003.bin",
662
+ "layers.5.attention.kv_cache.k_cache": "pytorch_model-00001-of-00003.bin",
663
+ "layers.5.attention.kv_cache.kw_cache": "pytorch_model-00001-of-00003.bin",
664
+ "layers.5.attention.kv_cache.v_cache": "pytorch_model-00001-of-00003.bin",
665
+ "layers.5.attention.q_norm.scale": "pytorch_model-00001-of-00003.bin",
666
+ "layers.5.attention.wo.bias": "pytorch_model-00001-of-00003.bin",
667
+ "layers.5.attention.wo.weight": "pytorch_model-00001-of-00003.bin",
668
+ "layers.5.attention.wqkv.bias": "pytorch_model-00001-of-00003.bin",
669
+ "layers.5.attention.wqkv.weight": "pytorch_model-00001-of-00003.bin",
670
+ "layers.5.attention_norm.bias": "pytorch_model-00001-of-00003.bin",
671
+ "layers.5.attention_norm.weight": "pytorch_model-00001-of-00003.bin",
672
+ "layers.5.feed_forward.w1.bias": "pytorch_model-00001-of-00003.bin",
673
+ "layers.5.feed_forward.w1.weight": "pytorch_model-00001-of-00003.bin",
674
+ "layers.5.feed_forward.w2.bias": "pytorch_model-00001-of-00003.bin",
675
+ "layers.5.feed_forward.w2.weight": "pytorch_model-00001-of-00003.bin",
676
+ "layers.5.ffn_norm.bias": "pytorch_model-00001-of-00003.bin",
677
+ "layers.5.ffn_norm.weight": "pytorch_model-00001-of-00003.bin",
678
+ "layers.6.attention.dyn_w_proj.dd": "pytorch_model-00001-of-00003.bin",
679
+ "layers.6.attention.dyn_w_proj.dw1": "pytorch_model-00001-of-00003.bin",
680
+ "layers.6.attention.dyn_w_proj.dw_m": "pytorch_model-00001-of-00003.bin",
681
+ "layers.6.attention.dyn_w_proj.post_proj.w": "pytorch_model-00001-of-00003.bin",
682
+ "layers.6.attention.dyn_w_proj.pre_proj.w": "pytorch_model-00001-of-00003.bin",
683
+ "layers.6.attention.dyn_w_proj.qkw": "pytorch_model-00001-of-00003.bin",
684
+ "layers.6.attention.dyn_w_proj.qkw_m": "pytorch_model-00001-of-00003.bin",
685
+ "layers.6.attention.k_norm.scale": "pytorch_model-00001-of-00003.bin",
686
+ "layers.6.attention.kv_cache.k_cache": "pytorch_model-00001-of-00003.bin",
687
+ "layers.6.attention.kv_cache.kw_cache": "pytorch_model-00001-of-00003.bin",
688
+ "layers.6.attention.kv_cache.v_cache": "pytorch_model-00001-of-00003.bin",
689
+ "layers.6.attention.q_norm.scale": "pytorch_model-00001-of-00003.bin",
690
+ "layers.6.attention.wo.bias": "pytorch_model-00001-of-00003.bin",
691
+ "layers.6.attention.wo.weight": "pytorch_model-00001-of-00003.bin",
692
+ "layers.6.attention.wqkv.bias": "pytorch_model-00001-of-00003.bin",
693
+ "layers.6.attention.wqkv.weight": "pytorch_model-00001-of-00003.bin",
694
+ "layers.6.attention_norm.bias": "pytorch_model-00001-of-00003.bin",
695
+ "layers.6.attention_norm.weight": "pytorch_model-00001-of-00003.bin",
696
+ "layers.6.feed_forward.w1.bias": "pytorch_model-00001-of-00003.bin",
697
+ "layers.6.feed_forward.w1.weight": "pytorch_model-00001-of-00003.bin",
698
+ "layers.6.feed_forward.w2.bias": "pytorch_model-00001-of-00003.bin",
699
+ "layers.6.feed_forward.w2.weight": "pytorch_model-00001-of-00003.bin",
700
+ "layers.6.ffn_norm.bias": "pytorch_model-00001-of-00003.bin",
701
+ "layers.6.ffn_norm.weight": "pytorch_model-00001-of-00003.bin",
702
+ "layers.7.attention.dyn_w_proj.dd": "pytorch_model-00001-of-00003.bin",
703
+ "layers.7.attention.dyn_w_proj.dw1": "pytorch_model-00001-of-00003.bin",
704
+ "layers.7.attention.dyn_w_proj.dw_m": "pytorch_model-00001-of-00003.bin",
705
+ "layers.7.attention.dyn_w_proj.post_proj.w": "pytorch_model-00001-of-00003.bin",
706
+ "layers.7.attention.dyn_w_proj.pre_proj.w": "pytorch_model-00001-of-00003.bin",
707
+ "layers.7.attention.dyn_w_proj.qkw": "pytorch_model-00001-of-00003.bin",
708
+ "layers.7.attention.dyn_w_proj.qkw_m": "pytorch_model-00001-of-00003.bin",
709
+ "layers.7.attention.k_norm.scale": "pytorch_model-00001-of-00003.bin",
710
+ "layers.7.attention.kv_cache.k_cache": "pytorch_model-00001-of-00003.bin",
711
+ "layers.7.attention.kv_cache.kw_cache": "pytorch_model-00001-of-00003.bin",
712
+ "layers.7.attention.kv_cache.v_cache": "pytorch_model-00001-of-00003.bin",
713
+ "layers.7.attention.q_norm.scale": "pytorch_model-00001-of-00003.bin",
714
+ "layers.7.attention.wo.bias": "pytorch_model-00001-of-00003.bin",
715
+ "layers.7.attention.wo.weight": "pytorch_model-00001-of-00003.bin",
716
+ "layers.7.attention.wqkv.bias": "pytorch_model-00001-of-00003.bin",
717
+ "layers.7.attention.wqkv.weight": "pytorch_model-00001-of-00003.bin",
718
+ "layers.7.attention_norm.bias": "pytorch_model-00001-of-00003.bin",
719
+ "layers.7.attention_norm.weight": "pytorch_model-00001-of-00003.bin",
720
+ "layers.7.feed_forward.w1.bias": "pytorch_model-00001-of-00003.bin",
721
+ "layers.7.feed_forward.w1.weight": "pytorch_model-00001-of-00003.bin",
722
+ "layers.7.feed_forward.w2.bias": "pytorch_model-00001-of-00003.bin",
723
+ "layers.7.feed_forward.w2.weight": "pytorch_model-00001-of-00003.bin",
724
+ "layers.7.ffn_norm.bias": "pytorch_model-00001-of-00003.bin",
725
+ "layers.7.ffn_norm.weight": "pytorch_model-00001-of-00003.bin",
726
+ "layers.8.attention.dyn_w_proj.dd": "pytorch_model-00001-of-00003.bin",
727
+ "layers.8.attention.dyn_w_proj.dw1": "pytorch_model-00001-of-00003.bin",
728
+ "layers.8.attention.dyn_w_proj.dw_m": "pytorch_model-00001-of-00003.bin",
729
+ "layers.8.attention.dyn_w_proj.post_proj.w": "pytorch_model-00001-of-00003.bin",
730
+ "layers.8.attention.dyn_w_proj.pre_proj.w": "pytorch_model-00001-of-00003.bin",
731
+ "layers.8.attention.dyn_w_proj.qkw": "pytorch_model-00001-of-00003.bin",
732
+ "layers.8.attention.dyn_w_proj.qkw_m": "pytorch_model-00001-of-00003.bin",
733
+ "layers.8.attention.k_norm.scale": "pytorch_model-00001-of-00003.bin",
734
+ "layers.8.attention.kv_cache.k_cache": "pytorch_model-00001-of-00003.bin",
735
+ "layers.8.attention.kv_cache.kw_cache": "pytorch_model-00001-of-00003.bin",
736
+ "layers.8.attention.kv_cache.v_cache": "pytorch_model-00001-of-00003.bin",
737
+ "layers.8.attention.q_norm.scale": "pytorch_model-00001-of-00003.bin",
738
+ "layers.8.attention.wo.bias": "pytorch_model-00001-of-00003.bin",
739
+ "layers.8.attention.wo.weight": "pytorch_model-00001-of-00003.bin",
740
+ "layers.8.attention.wqkv.bias": "pytorch_model-00001-of-00003.bin",
741
+ "layers.8.attention.wqkv.weight": "pytorch_model-00001-of-00003.bin",
742
+ "layers.8.attention_norm.bias": "pytorch_model-00001-of-00003.bin",
743
+ "layers.8.attention_norm.weight": "pytorch_model-00001-of-00003.bin",
744
+ "layers.8.feed_forward.w1.bias": "pytorch_model-00001-of-00003.bin",
745
+ "layers.8.feed_forward.w1.weight": "pytorch_model-00001-of-00003.bin",
746
+ "layers.8.feed_forward.w2.bias": "pytorch_model-00001-of-00003.bin",
747
+ "layers.8.feed_forward.w2.weight": "pytorch_model-00001-of-00003.bin",
748
+ "layers.8.ffn_norm.bias": "pytorch_model-00001-of-00003.bin",
749
+ "layers.8.ffn_norm.weight": "pytorch_model-00001-of-00003.bin",
750
+ "layers.9.attention.dyn_w_proj.dd": "pytorch_model-00001-of-00003.bin",
751
+ "layers.9.attention.dyn_w_proj.dw1": "pytorch_model-00001-of-00003.bin",
752
+ "layers.9.attention.dyn_w_proj.dw_m": "pytorch_model-00001-of-00003.bin",
753
+ "layers.9.attention.dyn_w_proj.post_proj.w": "pytorch_model-00001-of-00003.bin",
754
+ "layers.9.attention.dyn_w_proj.pre_proj.w": "pytorch_model-00001-of-00003.bin",
755
+ "layers.9.attention.dyn_w_proj.qkw": "pytorch_model-00001-of-00003.bin",
756
+ "layers.9.attention.dyn_w_proj.qkw_m": "pytorch_model-00001-of-00003.bin",
757
+ "layers.9.attention.k_norm.scale": "pytorch_model-00001-of-00003.bin",
758
+ "layers.9.attention.kv_cache.k_cache": "pytorch_model-00001-of-00003.bin",
759
+ "layers.9.attention.kv_cache.kw_cache": "pytorch_model-00001-of-00003.bin",
760
+ "layers.9.attention.kv_cache.v_cache": "pytorch_model-00001-of-00003.bin",
761
+ "layers.9.attention.q_norm.scale": "pytorch_model-00001-of-00003.bin",
762
+ "layers.9.attention.wo.bias": "pytorch_model-00001-of-00003.bin",
763
+ "layers.9.attention.wo.weight": "pytorch_model-00001-of-00003.bin",
764
+ "layers.9.attention.wqkv.bias": "pytorch_model-00001-of-00003.bin",
765
+ "layers.9.attention.wqkv.weight": "pytorch_model-00001-of-00003.bin",
766
+ "layers.9.attention_norm.bias": "pytorch_model-00001-of-00003.bin",
767
+ "layers.9.attention_norm.weight": "pytorch_model-00001-of-00003.bin",
768
+ "layers.9.feed_forward.w1.bias": "pytorch_model-00001-of-00003.bin",
769
+ "layers.9.feed_forward.w1.weight": "pytorch_model-00001-of-00003.bin",
770
+ "layers.9.feed_forward.w2.bias": "pytorch_model-00001-of-00003.bin",
771
+ "layers.9.feed_forward.w2.weight": "pytorch_model-00001-of-00003.bin",
772
+ "layers.9.ffn_norm.bias": "pytorch_model-00001-of-00003.bin",
773
+ "layers.9.ffn_norm.weight": "pytorch_model-00001-of-00003.bin",
774
+ "norm.bias": "pytorch_model-00003-of-00003.bin",
775
+ "norm.weight": "pytorch_model-00003-of-00003.bin",
776
+ "output.weight": "pytorch_model-00003-of-00003.bin",
777
+ "tok_embeddings.weight": "pytorch_model-00001-of-00003.bin"
778
+ }
779
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": "<|endoftext|>",
4
+ "clean_up_tokenization_spaces": true,
5
+ "eos_token": "<|endoftext|>",
6
+ "model_max_length": 1000000000000000019884624838656,
7
+ "tokenizer_class": "GPTNeoXTokenizer",
8
+ "unk_token": "<|endoftext|>"
9
+ }