Alexandru Gherghescu commited on
Commit
bbb5d39
1 Parent(s): 8b5602c

Add original model weigts + conversion script

Browse files
README.md CHANGED
@@ -36,3 +36,9 @@ See `preprocessing.py` on how the data was preprocessed and tokenized.
36
  See `pre_training.py` on how the model was pre-trained.
37
 
38
  See `inference.py` for an example.
 
 
 
 
 
 
 
36
  See `pre_training.py` on how the model was pre-trained.
37
 
38
  See `inference.py` for an example.
39
+
40
+ ## Converted model
41
+
42
+ Inside `gpt1-converted-weights/` is the converted safetensors model from the
43
+ original weights, which can be used directly with the code inside this repo. The
44
+ conversion script and original weights can also be found there.
gpt1-converted-weights/config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "GPT1ForCausalLM"
4
+ ],
5
+ "attention_dropout": 0.1,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_gpt1.GPT1Config",
8
+ "AutoModelForCausalLM": "modeling_gpt1.GPT1ForCausalLM"
9
+ },
10
+ "embd_pdrop": 0.1,
11
+ "hidden_act": "gelu",
12
+ "hidden_size": 768,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 512,
17
+ "model_type": "gpt1",
18
+ "num_attention_heads": 12,
19
+ "num_hidden_layers": 12,
20
+ "resid_pdrop": 0.1,
21
+ "torch_dtype": "float32",
22
+ "transformers_version": "4.38.1",
23
+ "vocab_size": 40478
24
+ }
gpt1-converted-weights/configuration_gpt1.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ GPT1 model configuration """
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+
5
+
6
+ class GPT1Config(PretrainedConfig):
7
+ model_type = "gpt1"
8
+
9
+ def __init__(
10
+ self,
11
+ vocab_size=40478,
12
+ hidden_size=768,
13
+ intermediate_size=3072,
14
+ num_hidden_layers=12,
15
+ num_attention_heads=12,
16
+ resid_pdrop=0.1,
17
+ embd_pdrop=0.1,
18
+ attention_dropout=0.1,
19
+ hidden_act="gelu",
20
+ max_position_embeddings=512,
21
+ initializer_range=0.02,
22
+ layer_norm_eps=1e-5,
23
+ tie_word_embeddings=True,
24
+ **kwargs
25
+ ):
26
+ self.vocab_size = vocab_size
27
+ self.hidden_size = hidden_size
28
+ self.intermediate_size = intermediate_size
29
+ self.num_hidden_layers = num_hidden_layers
30
+ self.num_attention_heads = num_attention_heads
31
+ self.resid_pdrop = resid_pdrop
32
+ self.embd_pdrop = embd_pdrop
33
+ self.attention_dropout = attention_dropout
34
+ self.hidden_act = hidden_act
35
+ self.max_position_embeddings = max_position_embeddings
36
+ self.initializer_range = initializer_range
37
+ self.layer_norm_eps = layer_norm_eps
38
+
39
+ super().__init__(
40
+ tie_word_embeddings=tie_word_embeddings,
41
+ **kwargs,
42
+ )
gpt1-converted-weights/generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.38.1"
4
+ }
gpt1-converted-weights/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc19245dd9599204701492aecf9b89d5b130001085743adb249409040390ec02
3
+ size 466321576
gpt1-converted-weights/modeling_gpt1.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ PyTorch GPT1 model."""
2
+
3
+ import math
4
+
5
+ import torch
6
+ from torch import nn
7
+ from transformers import PreTrainedModel
8
+ from transformers.modeling_outputs import (
9
+ BaseModelOutput,
10
+ CausalLMOutput,
11
+ )
12
+ from transformers.activations import get_activation
13
+
14
+ from configuration_gpt1 import GPT1Config
15
+
16
+
17
+ class GPT1MLP(nn.Module):
18
+ def __init__(self, config: GPT1Config):
19
+ super().__init__()
20
+ self.activation_fn = get_activation(config.hidden_act)
21
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
22
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
23
+
24
+ def forward(self, hidden_state):
25
+ hidden_state = self.fc1(hidden_state)
26
+ hidden_state = self.activation_fn(hidden_state)
27
+ hidden_state = self.fc2(hidden_state)
28
+ return hidden_state
29
+
30
+
31
+ class GPT1Attention(nn.Module):
32
+ def __init__(self, config: GPT1Config):
33
+ """
34
+ Multi-head attention layer.
35
+ """
36
+ super().__init__()
37
+
38
+ assert config.hidden_size % config.num_attention_heads == 0
39
+ self.hidden_size = config.hidden_size
40
+ self.num_heads = config.num_attention_heads
41
+ self.head_dim = self.hidden_size // self.num_heads
42
+ self.attn_dropout = nn.Dropout(p=config.attention_dropout)
43
+
44
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
45
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
46
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
47
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size)
48
+
49
+ def forward(self, hidden_state, attn_mask):
50
+ bs, seq_len, _ = hidden_state.size() # (batch_size, seq_len, dim)
51
+
52
+ # linearly project the inputs
53
+ Q = self.q_proj(hidden_state) # (batch_size, seq_len, n_heads * head_dim)
54
+ K = self.k_proj(hidden_state)
55
+ V = self.v_proj(hidden_state)
56
+
57
+ # split into n_heads to compute attention
58
+ queries = Q.view(bs, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (batch_size, n_heads, seq_len, head_dim)
59
+ keys = K.view(bs, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
60
+ values = V.view(bs, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
61
+
62
+ # compute attention matmul
63
+ keys = keys.transpose(2, 3) # (batch_size, n_heads, head_dim, seq_len)
64
+ attn_scores = queries @ keys # (batch_size, n_heads, seq_len, seq_len)
65
+
66
+ # scale
67
+ attn_scores = attn_scores / math.sqrt(self.head_dim)
68
+
69
+ # mask
70
+ if attn_mask is not None:
71
+ attn_scores = attn_scores + attn_mask
72
+
73
+ # softmax (attention probabilities) + dropout
74
+ attn_probs = nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32).to(Q.dtype)
75
+ attn_probs = self.attn_dropout(attn_probs)
76
+
77
+ # matmul
78
+ attn_output = attn_probs @ values # (batch_size, n_heads, seq_len, head_dim)
79
+
80
+ attn_output = attn_output.transpose(1, 2).contiguous()
81
+ attn_output = attn_output.reshape(bs, seq_len, self.hidden_size) # (batch_size, seq_len, n_heads * head_dim)
82
+
83
+ # final linear
84
+ attn_output = self.o_proj(attn_output)
85
+ return attn_output
86
+
87
+
88
+ class GPT1DecoderLayer(nn.Module):
89
+ def __init__(self, config: GPT1Config):
90
+ super().__init__()
91
+ self.attention = GPT1Attention(config)
92
+ self.mlp = GPT1MLP(config)
93
+
94
+ self.attention_norm = nn.LayerNorm(normalized_shape=config.hidden_size,
95
+ eps=config.layer_norm_eps)
96
+ self.mlp_norm = nn.LayerNorm(normalized_shape=config.hidden_size,
97
+ eps=config.layer_norm_eps)
98
+
99
+ self.res_dropout = nn.Dropout(p=config.resid_pdrop)
100
+
101
+ def forward(self, hidden_state, attn_mask):
102
+ # attention
103
+ residual = hidden_state
104
+ hidden_state = self.attention(hidden_state, attn_mask)
105
+ hidden_state = self.res_dropout(hidden_state)
106
+ hidden_state = residual + hidden_state
107
+ hidden_state = self.attention_norm(hidden_state)
108
+
109
+ # feed forward fully connected
110
+ residual = hidden_state
111
+ hidden_state = self.mlp(hidden_state)
112
+ hidden_state = self.res_dropout(hidden_state)
113
+ hidden_state = residual + hidden_state
114
+ hidden_state = self.mlp_norm(hidden_state)
115
+
116
+ return hidden_state
117
+
118
+
119
+ class GPT1PreTrainedModel(PreTrainedModel):
120
+ config_class = GPT1Config
121
+ supports_gradient_checkpointing = False
122
+
123
+ def _init_weights(self, module):
124
+ std = self.config.initializer_range
125
+ if isinstance(module, nn.Linear):
126
+ module.weight.data.normal_(mean=0.0, std=std)
127
+ if module.bias is not None:
128
+ module.bias.data.zero_()
129
+ elif isinstance(module, nn.Embedding):
130
+ module.weight.data.normal_(mean=0.0, std=std)
131
+ if module.padding_idx is not None:
132
+ module.weight.data[module.padding_idx].zero_()
133
+
134
+
135
+ class GPT1Model(GPT1PreTrainedModel):
136
+
137
+ def __init__(self, config: GPT1Config):
138
+ super().__init__(config)
139
+
140
+ # embeddings
141
+ self.embs = nn.Embedding(config.vocab_size, config.hidden_size)
142
+ self.embs_dropout = nn.Dropout(p=config.embd_pdrop)
143
+
144
+ # positional encoding (learned)
145
+ self.pos_emb = nn.Embedding(config.max_position_embeddings,
146
+ config.hidden_size)
147
+
148
+ self.layers = nn.ModuleList(
149
+ [GPT1DecoderLayer(config) for _ in range(config.num_hidden_layers)]
150
+ )
151
+
152
+ self.post_init()
153
+
154
+ def get_input_embeddings(self):
155
+ return self.embs
156
+
157
+ def set_input_embeddings(self, value):
158
+ self.embs = value
159
+
160
+ def forward(self, input_ids, *args, **kwargs):
161
+ position_ids = torch.arange(input_ids.size(-1),
162
+ dtype=torch.long,
163
+ device=input_ids.device).unsqueeze_(0)
164
+
165
+ input_embeds = self.embs(input_ids) # (bs, seq_len, dim)
166
+ position_embeds = self.pos_emb(position_ids)
167
+ hidden_state = self.embs_dropout(input_embeds) + position_embeds
168
+
169
+ seq_len = input_ids.size(-1)
170
+ attn_mask = torch.full((seq_len, seq_len), fill_value=float('-inf'))
171
+ attn_mask = torch.triu(attn_mask, diagonal=1)
172
+
173
+ causal_mask = attn_mask.to(dtype=input_embeds.dtype,
174
+ device=input_embeds.device)
175
+
176
+ for layer in self.layers:
177
+ hidden_state = layer(hidden_state, attn_mask=causal_mask)
178
+
179
+ return BaseModelOutput(
180
+ last_hidden_state=hidden_state
181
+ )
182
+
183
+
184
+ class GPT1ForCausalLM(GPT1PreTrainedModel):
185
+ _tied_weights_keys = ["lm_head.weight"]
186
+
187
+ def __init__(self, config: GPT1Config):
188
+ super().__init__(config)
189
+ self.model = GPT1Model(config)
190
+ self.vocab_size = config.vocab_size
191
+
192
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
193
+
194
+ # initialize weigths and apply final processing
195
+ self.post_init()
196
+
197
+ def get_input_embeddings(self):
198
+ return self.model.embs
199
+
200
+ def set_input_embeddings(self, value):
201
+ self.model.embs = value
202
+
203
+ def get_output_embeddings(self):
204
+ return self.lm_head
205
+
206
+ def set_output_embeddings(self, new_embeddings):
207
+ self.lm_head = new_embeddings
208
+
209
+ def get_decoder(self):
210
+ return self.model
211
+
212
+ def set_decoder(self, decoder):
213
+ self.model = decoder
214
+
215
+ def forward(self, input_ids, labels=None, *args, **kwargs):
216
+ output = self.model(input_ids)
217
+
218
+ hidden_state = output[0]
219
+ logits = self.lm_head(hidden_state).float()
220
+
221
+ loss = None
222
+ if labels is not None:
223
+ shift_logits = logits[..., :-1, :].contiguous()
224
+ shift_labels = labels[..., 1:].contiguous()
225
+
226
+ loss_fn = torch.nn.CrossEntropyLoss()
227
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
228
+ shift_labels = shift_labels.view(-1)
229
+ loss = loss_fn(shift_logits, shift_labels)
230
+
231
+ return CausalLMOutput(
232
+ loss=loss,
233
+ logits=logits
234
+ )
235
+
236
+ def prepare_inputs_for_generation(self, input_ids, *args, **kwargs):
237
+ return { 'input_ids': input_ids }
original_gpt1_params/.ipynb_checkpoints/encoder_bpe_40000-checkpoint.json ADDED
The diff for this file is too large to render. See raw diff
 
original_gpt1_params/.ipynb_checkpoints/params_shapes-checkpoint.json ADDED
@@ -0,0 +1 @@
 
 
1
+ [[512, 768], [40478, 768], [1, 768, 2304], [2304], [1, 768, 768], [768], [768], [768], [1, 768, 3072], [3072], [1, 3072, 768], [768], [768], [768], [1, 768, 2304], [2304], [1, 768, 768], [768], [768], [768], [1, 768, 3072], [3072], [1, 3072, 768], [768], [768], [768], [1, 768, 2304], [2304], [1, 768, 768], [768], [768], [768], [1, 768, 3072], [3072], [1, 3072, 768], [768], [768], [768], [1, 768, 2304], [2304], [1, 768, 768], [768], [768], [768], [1, 768, 3072], [3072], [1, 3072, 768], [768], [768], [768], [1, 768, 2304], [2304], [1, 768, 768], [768], [768], [768], [1, 768, 3072], [3072], [1, 3072, 768], [768], [768], [768], [1, 768, 2304], [2304], [1, 768, 768], [768], [768], [768], [1, 768, 3072], [3072], [1, 3072, 768], [768], [768], [768], [1, 768, 2304], [2304], [1, 768, 768], [768], [768], [768], [1, 768, 3072], [3072], [1, 3072, 768], [768], [768], [768], [1, 768, 2304], [2304], [1, 768, 768], [768], [768], [768], [1, 768, 3072], [3072], [1, 3072, 768], [768], [768], [768], [1, 768, 2304], [2304], [1, 768, 768], [768], [768], [768], [1, 768, 3072], [3072], [1, 3072, 768], [768], [768], [768], [1, 768, 2304], [2304], [1, 768, 768], [768], [768], [768], [1, 768, 3072], [3072], [1, 3072, 768], [768], [768], [768], [1, 768, 2304], [2304], [1, 768, 768], [768], [768], [768], [1, 768, 3072], [3072], [1, 3072, 768], [768], [768], [768], [1, 768, 2304], [2304], [1, 768, 768], [768], [768], [768], [1, 768, 3072], [3072], [1, 3072, 768], [768], [768], [768]]
original_gpt1_params/.ipynb_checkpoints/vocab_40000-checkpoint.bpe ADDED
The diff for this file is too large to render. See raw diff
 
original_gpt1_params/encoder_bpe_40000.json ADDED
The diff for this file is too large to render. See raw diff
 
original_gpt1_params/params_0.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d9cd095b901dfbfbe0ce5e01d151dfe0b791e955d71149969ba65a6eab4480f
3
+ size 46614044
original_gpt1_params/params_1.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca074893c040fa69cbf2fc95c06feda45a4e1492d03b645e2076e89ccf7ddd9f
3
+ size 46614044
original_gpt1_params/params_2.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:966c25fbd632f0df18c4d4380ba57f23410f43311a96616f00b3d05ae6592f58
3
+ size 46614044
original_gpt1_params/params_3.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:40df0d328f5d3d1b2bec768855a5d2eeeaf2b2124758ef98116f76a02526fd92
3
+ size 46614044
original_gpt1_params/params_4.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:139f098dcd620ccf0200530e9ce9ff1c342714ff881a0c7258ac9faac4a06e6a
3
+ size 46614040
original_gpt1_params/params_5.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad27b5cb245db9a29657270ff637d3ff1c15fd9df3683324a2936674cef8c3c5
3
+ size 46614040
original_gpt1_params/params_6.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af5bb5c76ddfea50683e0b9895fe704ae689853ed8bb3f1b3fee4daff2f27d45
3
+ size 46614040
original_gpt1_params/params_7.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:27f55501d895ce1adb9b254aa762519a242edf2bcd2b43298b89538b5591566c
3
+ size 46614040
original_gpt1_params/params_8.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:17a2b695128ea0aae98a360351b92769b879bc0f2835862949b6405b0ce88569
3
+ size 46614040
original_gpt1_params/params_9.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1355fcd519db223f65db7fa7b79dcaf9b4c653915ffe4bd417d87f7903225c1
3
+ size 46614040
original_gpt1_params/params_shapes.json ADDED
@@ -0,0 +1 @@
 
 
1
+ [[512, 768], [40478, 768], [1, 768, 2304], [2304], [1, 768, 768], [768], [768], [768], [1, 768, 3072], [3072], [1, 3072, 768], [768], [768], [768], [1, 768, 2304], [2304], [1, 768, 768], [768], [768], [768], [1, 768, 3072], [3072], [1, 3072, 768], [768], [768], [768], [1, 768, 2304], [2304], [1, 768, 768], [768], [768], [768], [1, 768, 3072], [3072], [1, 3072, 768], [768], [768], [768], [1, 768, 2304], [2304], [1, 768, 768], [768], [768], [768], [1, 768, 3072], [3072], [1, 3072, 768], [768], [768], [768], [1, 768, 2304], [2304], [1, 768, 768], [768], [768], [768], [1, 768, 3072], [3072], [1, 3072, 768], [768], [768], [768], [1, 768, 2304], [2304], [1, 768, 768], [768], [768], [768], [1, 768, 3072], [3072], [1, 3072, 768], [768], [768], [768], [1, 768, 2304], [2304], [1, 768, 768], [768], [768], [768], [1, 768, 3072], [3072], [1, 3072, 768], [768], [768], [768], [1, 768, 2304], [2304], [1, 768, 768], [768], [768], [768], [1, 768, 3072], [3072], [1, 3072, 768], [768], [768], [768], [1, 768, 2304], [2304], [1, 768, 768], [768], [768], [768], [1, 768, 3072], [3072], [1, 3072, 768], [768], [768], [768], [1, 768, 2304], [2304], [1, 768, 768], [768], [768], [768], [1, 768, 3072], [3072], [1, 3072, 768], [768], [768], [768], [1, 768, 2304], [2304], [1, 768, 768], [768], [768], [768], [1, 768, 3072], [3072], [1, 3072, 768], [768], [768], [768], [1, 768, 2304], [2304], [1, 768, 768], [768], [768], [768], [1, 768, 3072], [3072], [1, 3072, 768], [768], [768], [768]]
original_gpt1_params/vocab_40000.bpe ADDED
The diff for this file is too large to render. See raw diff
 
tf_weights_to_hf.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import torch
4
+ import numpy as np
5
+
6
+ from modeling_gpt1 import GPT1ForCausalLM, GPT1Model
7
+ from configuration_gpt1 import GPT1Config
8
+
9
+
10
+ GPT1Config.register_for_auto_class()
11
+ GPT1Model.register_for_auto_class('AutoModel')
12
+ GPT1ForCausalLM.register_for_auto_class('AutoModelForCausalLM')
13
+
14
+ def lists_are_equal(list1, list2):
15
+ for i, j in zip(list1, list2):
16
+ if i != j:
17
+ return False
18
+ return True
19
+
20
+ # get the original weights from the GPT1 params.npy files
21
+ def get_weights_from_tf_model():
22
+
23
+ shapes = json.load(open('original_gpt1_params/params_shapes.json'))
24
+ offsets = np.cumsum([np.prod(shape) for shape in shapes])
25
+
26
+ init_params = [np.load('original_gpt1_params/params_{}.npy'.format(n)) for n in range(10)]
27
+ init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
28
+ init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
29
+
30
+ config = GPT1Config()
31
+ model = GPT1ForCausalLM(config)
32
+
33
+ # print(shapes[:15])
34
+ # print([k for k, v in model.named_parameters()][:10])
35
+
36
+ # embs layer
37
+ model.model.embs.weight.data = torch.from_numpy(init_params[1])
38
+
39
+ # pos enc layer
40
+ model.model.pos_emb.weight.data = torch.from_numpy(init_params[0])
41
+
42
+ layers = model.model.layers
43
+
44
+ for i in range(0, 12):
45
+
46
+ idx = 12 * i + 2
47
+
48
+ # attention q, k, v projections
49
+ init_params[idx] = np.squeeze(init_params[idx], axis=0)
50
+ q, k, v = torch.split(torch.tensor(init_params[idx]), 768, dim=-1)
51
+ layers[i].attention.q_proj.weight.data = q.detach().clone().transpose(-1, -2).contiguous()
52
+ layers[i].attention.k_proj.weight.data = k.detach().clone().transpose(-1, -2).contiguous()
53
+ layers[i].attention.v_proj.weight.data = v.detach().clone().transpose(-1, -2).contiguous()
54
+
55
+ # attention q, k, v biases
56
+ q_bias, k_bias, v_bias = torch.split(torch.tensor(init_params[idx + 1]), 768, dim=-1)
57
+ layers[i].attention.q_proj.bias.data = q_bias.detach().clone().contiguous()
58
+ layers[i].attention.k_proj.bias.data = k_bias.detach().clone().contiguous()
59
+ layers[i].attention.v_proj.bias.data = v_bias.detach().clone().contiguous()
60
+
61
+ # attention output proj + bias
62
+ init_params[idx + 2] = np.squeeze(init_params[idx + 2], axis=0)
63
+ layers[i].attention.o_proj.weight.data = torch.from_numpy(init_params[idx + 2]).transpose(-1, -2).contiguous()
64
+ layers[i].attention.o_proj.bias.data = torch.from_numpy(init_params[idx + 3])
65
+
66
+ # attention norm + bias
67
+ layers[i].attention_norm.weight.data = torch.from_numpy(init_params[idx + 4])
68
+ layers[i].attention_norm.bias.data = torch.from_numpy(init_params[idx + 5])
69
+
70
+ # mlp layer
71
+ init_params[idx + 6] = np.squeeze(init_params[idx + 6], axis=0)
72
+ layers[i].mlp.fc1.weight.data = torch.from_numpy(init_params[idx + 6]).transpose(-1, -2).contiguous()
73
+ layers[i].mlp.fc1.bias.data = torch.from_numpy(init_params[idx + 7])
74
+ init_params[idx + 8] = np.squeeze(init_params[idx + 8], axis=0)
75
+ layers[i].mlp.fc2.weight.data = torch.from_numpy(init_params[idx + 8]).transpose(-1, -2).contiguous()
76
+ layers[i].mlp.fc2.bias.data = torch.from_numpy(init_params[idx + 9])
77
+
78
+ # mlp norm + bias
79
+ layers[i].mlp_norm.weight.data = torch.from_numpy(init_params[idx + 10])
80
+ layers[i].mlp_norm.bias.data = torch.from_numpy(init_params[idx + 11])
81
+
82
+ model.save_pretrained('gpt1-converted-weights/')
83
+
84
+
85
+ get_weights_from_tf_model()