nreimers commited on
Commit
967b085
1 Parent(s): 2113e9d

add convert files

Browse files
Files changed (2) hide show
  1. convert-gtr.ipynb +1027 -0
  2. convert_to_fp16.py +9 -0
convert-gtr.ipynb ADDED
@@ -0,0 +1,1027 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "17bffc12",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from transformers import AutoTokenizer\n",
11
+ "from sentence_transformers import util\n",
12
+ "import os\n",
13
+ "import numpy as np\n",
14
+ "import torch.nn.functional as F\n",
15
+ "from transformers import T5EncoderModel\n",
16
+ "import sentence_transformers"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 2,
22
+ "id": "160d8ce6",
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "\n",
27
+ "#Mean Pooling - Take attention mask into account for correct averaging\n",
28
+ "def mean_pooling(model_output, attention_mask):\n",
29
+ " token_embeddings = model_output[0] #First element of model_output contains all token embeddings\n",
30
+ " input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()\n",
31
+ " return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)\n",
32
+ "\n",
33
+ "\n",
34
+ " "
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": 3,
40
+ "id": "2f67f426",
41
+ "metadata": {},
42
+ "outputs": [
43
+ {
44
+ "name": "stderr",
45
+ "output_type": "stream",
46
+ "text": [
47
+ "INFO:absl:Using /tmp/tfhub_modules to cache modules.\n",
48
+ "2022-02-01 20:04:53.747606: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudnn.so.8'; dlerror: libcudnn.so.8: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n",
49
+ "2022-02-01 20:04:53.747647: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1835] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.\n",
50
+ "Skipping registering GPU devices...\n",
51
+ "2022-02-01 20:04:53.747987: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F FMA\n",
52
+ "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
53
+ "WARNING:absl:Importing a function (__inference_closure_12264) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.\n",
54
+ "WARNING:absl:Importing a function (__inference_closure_8418) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.\n",
55
+ "WARNING:absl:Importing a function (__inference_closure_4202) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.\n"
56
+ ]
57
+ }
58
+ ],
59
+ "source": [
60
+ "import tensorflow as tf\n",
61
+ "import tensorflow_hub as hub\n",
62
+ "import tensorflow_text as text \n",
63
+ "\n",
64
+ "model_size_tf, model_size_hf = \"base\", \"base\"\n",
65
+ "hub_url = f\"https://tfhub.dev/google/gtr/gtr-{model_size_tf}/1\"\n",
66
+ "encoder = hub.load(hub_url)\n",
67
+ "\n",
68
+ "v = encoder.signatures['serving_default'].variables"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": 4,
74
+ "id": "5f4c8d94",
75
+ "metadata": {
76
+ "scrolled": true
77
+ },
78
+ "outputs": [
79
+ {
80
+ "data": {
81
+ "text/plain": [
82
+ "{'encoder__encoder_norm__scale:0': TensorShape([768]),\n",
83
+ " 'encoder__layers_0__attention__key__kernel:0': TensorShape([768, 768]),\n",
84
+ " 'encoder__layers_0__attention__out__kernel:0': TensorShape([768, 768]),\n",
85
+ " 'encoder__layers_0__attention__query__kernel:0': TensorShape([768, 768]),\n",
86
+ " 'encoder__layers_0__attention__value__kernel:0': TensorShape([768, 768]),\n",
87
+ " 'encoder__layers_0__mlp__wi__kernel:0': TensorShape([768, 3072]),\n",
88
+ " 'encoder__layers_0__mlp__wo__kernel:0': TensorShape([3072, 768]),\n",
89
+ " 'encoder__layers_0__pre_attention_layer_norm__scale:0': TensorShape([768]),\n",
90
+ " 'encoder__layers_0__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n",
91
+ " 'encoder__layers_1__attention__key__kernel:0': TensorShape([768, 768]),\n",
92
+ " 'encoder__layers_1__attention__out__kernel:0': TensorShape([768, 768]),\n",
93
+ " 'encoder__layers_1__attention__query__kernel:0': TensorShape([768, 768]),\n",
94
+ " 'encoder__layers_1__attention__value__kernel:0': TensorShape([768, 768]),\n",
95
+ " 'encoder__layers_1__mlp__wi__kernel:0': TensorShape([768, 3072]),\n",
96
+ " 'encoder__layers_1__mlp__wo__kernel:0': TensorShape([3072, 768]),\n",
97
+ " 'encoder__layers_1__pre_attention_layer_norm__scale:0': TensorShape([768]),\n",
98
+ " 'encoder__layers_1__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n",
99
+ " 'encoder__layers_10__attention__key__kernel:0': TensorShape([768, 768]),\n",
100
+ " 'encoder__layers_10__attention__out__kernel:0': TensorShape([768, 768]),\n",
101
+ " 'encoder__layers_10__attention__query__kernel:0': TensorShape([768, 768]),\n",
102
+ " 'encoder__layers_10__attention__value__kernel:0': TensorShape([768, 768]),\n",
103
+ " 'encoder__layers_10__mlp__wi__kernel:0': TensorShape([768, 3072]),\n",
104
+ " 'encoder__layers_10__mlp__wo__kernel:0': TensorShape([3072, 768]),\n",
105
+ " 'encoder__layers_10__pre_attention_layer_norm__scale:0': TensorShape([768]),\n",
106
+ " 'encoder__layers_10__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n",
107
+ " 'encoder__layers_11__attention__key__kernel:0': TensorShape([768, 768]),\n",
108
+ " 'encoder__layers_11__attention__out__kernel:0': TensorShape([768, 768]),\n",
109
+ " 'encoder__layers_11__attention__query__kernel:0': TensorShape([768, 768]),\n",
110
+ " 'encoder__layers_11__attention__value__kernel:0': TensorShape([768, 768]),\n",
111
+ " 'encoder__layers_11__mlp__wi__kernel:0': TensorShape([768, 3072]),\n",
112
+ " 'encoder__layers_11__mlp__wo__kernel:0': TensorShape([3072, 768]),\n",
113
+ " 'encoder__layers_11__pre_attention_layer_norm__scale:0': TensorShape([768]),\n",
114
+ " 'encoder__layers_11__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n",
115
+ " 'encoder__layers_2__attention__key__kernel:0': TensorShape([768, 768]),\n",
116
+ " 'encoder__layers_2__attention__out__kernel:0': TensorShape([768, 768]),\n",
117
+ " 'encoder__layers_2__attention__query__kernel:0': TensorShape([768, 768]),\n",
118
+ " 'encoder__layers_2__attention__value__kernel:0': TensorShape([768, 768]),\n",
119
+ " 'encoder__layers_2__mlp__wi__kernel:0': TensorShape([768, 3072]),\n",
120
+ " 'encoder__layers_2__mlp__wo__kernel:0': TensorShape([3072, 768]),\n",
121
+ " 'encoder__layers_2__pre_attention_layer_norm__scale:0': TensorShape([768]),\n",
122
+ " 'encoder__layers_2__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n",
123
+ " 'encoder__layers_3__attention__key__kernel:0': TensorShape([768, 768]),\n",
124
+ " 'encoder__layers_3__attention__out__kernel:0': TensorShape([768, 768]),\n",
125
+ " 'encoder__layers_3__attention__query__kernel:0': TensorShape([768, 768]),\n",
126
+ " 'encoder__layers_3__attention__value__kernel:0': TensorShape([768, 768]),\n",
127
+ " 'encoder__layers_3__mlp__wi__kernel:0': TensorShape([768, 3072]),\n",
128
+ " 'encoder__layers_3__mlp__wo__kernel:0': TensorShape([3072, 768]),\n",
129
+ " 'encoder__layers_3__pre_attention_layer_norm__scale:0': TensorShape([768]),\n",
130
+ " 'encoder__layers_3__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n",
131
+ " 'encoder__layers_4__attention__key__kernel:0': TensorShape([768, 768]),\n",
132
+ " 'encoder__layers_4__attention__out__kernel:0': TensorShape([768, 768]),\n",
133
+ " 'encoder__layers_4__attention__query__kernel:0': TensorShape([768, 768]),\n",
134
+ " 'encoder__layers_4__attention__value__kernel:0': TensorShape([768, 768]),\n",
135
+ " 'encoder__layers_4__mlp__wi__kernel:0': TensorShape([768, 3072]),\n",
136
+ " 'encoder__layers_4__mlp__wo__kernel:0': TensorShape([3072, 768]),\n",
137
+ " 'encoder__layers_4__pre_attention_layer_norm__scale:0': TensorShape([768]),\n",
138
+ " 'encoder__layers_4__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n",
139
+ " 'encoder__layers_5__attention__key__kernel:0': TensorShape([768, 768]),\n",
140
+ " 'encoder__layers_5__attention__out__kernel:0': TensorShape([768, 768]),\n",
141
+ " 'encoder__layers_5__attention__query__kernel:0': TensorShape([768, 768]),\n",
142
+ " 'encoder__layers_5__attention__value__kernel:0': TensorShape([768, 768]),\n",
143
+ " 'encoder__layers_5__mlp__wi__kernel:0': TensorShape([768, 3072]),\n",
144
+ " 'encoder__layers_5__mlp__wo__kernel:0': TensorShape([3072, 768]),\n",
145
+ " 'encoder__layers_5__pre_attention_layer_norm__scale:0': TensorShape([768]),\n",
146
+ " 'encoder__layers_5__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n",
147
+ " 'encoder__layers_6__attention__key__kernel:0': TensorShape([768, 768]),\n",
148
+ " 'encoder__layers_6__attention__out__kernel:0': TensorShape([768, 768]),\n",
149
+ " 'encoder__layers_6__attention__query__kernel:0': TensorShape([768, 768]),\n",
150
+ " 'encoder__layers_6__attention__value__kernel:0': TensorShape([768, 768]),\n",
151
+ " 'encoder__layers_6__mlp__wi__kernel:0': TensorShape([768, 3072]),\n",
152
+ " 'encoder__layers_6__mlp__wo__kernel:0': TensorShape([3072, 768]),\n",
153
+ " 'encoder__layers_6__pre_attention_layer_norm__scale:0': TensorShape([768]),\n",
154
+ " 'encoder__layers_6__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n",
155
+ " 'encoder__layers_7__attention__key__kernel:0': TensorShape([768, 768]),\n",
156
+ " 'encoder__layers_7__attention__out__kernel:0': TensorShape([768, 768]),\n",
157
+ " 'encoder__layers_7__attention__query__kernel:0': TensorShape([768, 768]),\n",
158
+ " 'encoder__layers_7__attention__value__kernel:0': TensorShape([768, 768]),\n",
159
+ " 'encoder__layers_7__mlp__wi__kernel:0': TensorShape([768, 3072]),\n",
160
+ " 'encoder__layers_7__mlp__wo__kernel:0': TensorShape([3072, 768]),\n",
161
+ " 'encoder__layers_7__pre_attention_layer_norm__scale:0': TensorShape([768]),\n",
162
+ " 'encoder__layers_7__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n",
163
+ " 'encoder__layers_8__attention__key__kernel:0': TensorShape([768, 768]),\n",
164
+ " 'encoder__layers_8__attention__out__kernel:0': TensorShape([768, 768]),\n",
165
+ " 'encoder__layers_8__attention__query__kernel:0': TensorShape([768, 768]),\n",
166
+ " 'encoder__layers_8__attention__value__kernel:0': TensorShape([768, 768]),\n",
167
+ " 'encoder__layers_8__mlp__wi__kernel:0': TensorShape([768, 3072]),\n",
168
+ " 'encoder__layers_8__mlp__wo__kernel:0': TensorShape([3072, 768]),\n",
169
+ " 'encoder__layers_8__pre_attention_layer_norm__scale:0': TensorShape([768]),\n",
170
+ " 'encoder__layers_8__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n",
171
+ " 'encoder__layers_9__attention__key__kernel:0': TensorShape([768, 768]),\n",
172
+ " 'encoder__layers_9__attention__out__kernel:0': TensorShape([768, 768]),\n",
173
+ " 'encoder__layers_9__attention__query__kernel:0': TensorShape([768, 768]),\n",
174
+ " 'encoder__layers_9__attention__value__kernel:0': TensorShape([768, 768]),\n",
175
+ " 'encoder__layers_9__mlp__wi__kernel:0': TensorShape([768, 3072]),\n",
176
+ " 'encoder__layers_9__mlp__wo__kernel:0': TensorShape([3072, 768]),\n",
177
+ " 'encoder__layers_9__pre_attention_layer_norm__scale:0': TensorShape([768]),\n",
178
+ " 'encoder__layers_9__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n",
179
+ " 'encoder__relpos_bias__rel_embedding:0': TensorShape([12, 32]),\n",
180
+ " 'projection_layer__kernel:0': TensorShape([768, 768]),\n",
181
+ " 'token_embedder__embedding:0': TensorShape([32128, 768])}"
182
+ ]
183
+ },
184
+ "execution_count": 4,
185
+ "metadata": {},
186
+ "output_type": "execute_result"
187
+ }
188
+ ],
189
+ "source": [
190
+ "tf_name_weight = {var.name: var for var in v}\n",
191
+ "tf_name_shape = {var.name: var.shape for var in v}\n",
192
+ "tf_name_shape"
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "code",
197
+ "execution_count": 7,
198
+ "id": "6d223b07",
199
+ "metadata": {
200
+ "scrolled": true
201
+ },
202
+ "outputs": [
203
+ {
204
+ "data": {
205
+ "application/vnd.jupyter.widget-view+json": {
206
+ "model_id": "e4b637f4a8f847fcaa7b09fb729227b0",
207
+ "version_major": 2,
208
+ "version_minor": 0
209
+ },
210
+ "text/plain": [
211
+ "HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=45229452544.0), HTML(value='')))"
212
+ ]
213
+ },
214
+ "metadata": {},
215
+ "output_type": "display_data"
216
+ },
217
+ {
218
+ "name": "stdout",
219
+ "output_type": "stream",
220
+ "text": [
221
+ "\n"
222
+ ]
223
+ },
224
+ {
225
+ "name": "stderr",
226
+ "output_type": "stream",
227
+ "text": [
228
+ "Some weights of the model checkpoint at t5-11b were not used when initializing T5EncoderModel: ['decoder.block.13.layer.1.EncDecAttention.o.weight', 'decoder.block.14.layer.2.DenseReluDense.wo.weight', 'decoder.block.4.layer.0.layer_norm.weight', 'decoder.block.6.layer.1.EncDecAttention.v.weight', 'decoder.block.15.layer.0.SelfAttention.v.weight', 'decoder.block.3.layer.1.layer_norm.weight', 'decoder.block.11.layer.2.DenseReluDense.wi.weight', 'decoder.block.11.layer.2.DenseReluDense.wo.weight', 'decoder.block.3.layer.0.SelfAttention.o.weight', 'decoder.block.12.layer.2.DenseReluDense.wo.weight', 'decoder.block.8.layer.1.EncDecAttention.k.weight', 'decoder.block.18.layer.1.layer_norm.weight', 'decoder.block.9.layer.2.DenseReluDense.wi.weight', 'decoder.block.15.layer.0.SelfAttention.q.weight', 'decoder.block.7.layer.0.SelfAttention.k.weight', 'decoder.block.14.layer.0.SelfAttention.v.weight', 'decoder.block.2.layer.0.SelfAttention.o.weight', 'decoder.block.14.layer.0.SelfAttention.q.weight', 'decoder.block.7.layer.0.SelfAttention.o.weight', 'decoder.block.9.layer.1.EncDecAttention.v.weight', 'decoder.block.21.layer.1.EncDecAttention.v.weight', 'decoder.block.5.layer.0.SelfAttention.q.weight', 'decoder.block.19.layer.2.layer_norm.weight', 'decoder.block.14.layer.1.EncDecAttention.k.weight', 'decoder.block.12.layer.1.EncDecAttention.o.weight', 'decoder.block.18.layer.2.DenseReluDense.wi.weight', 'decoder.block.3.layer.0.SelfAttention.q.weight', 'decoder.block.9.layer.2.DenseReluDense.wo.weight', 'decoder.block.14.layer.0.SelfAttention.k.weight', 'decoder.block.4.layer.0.SelfAttention.o.weight', 'decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight', 'decoder.block.23.layer.1.EncDecAttention.o.weight', 'decoder.block.6.layer.0.SelfAttention.v.weight', 'decoder.block.14.layer.1.EncDecAttention.v.weight', 'decoder.block.8.layer.0.SelfAttention.o.weight', 'decoder.block.0.layer.1.EncDecAttention.v.weight', 'decoder.block.3.layer.0.layer_norm.weight', 'decoder.block.10.layer.2.layer_norm.weight', 'decoder.block.4.layer.2.DenseReluDense.wo.weight', 'decoder.block.21.layer.0.layer_norm.weight', 'decoder.block.12.layer.2.layer_norm.weight', 'decoder.block.2.layer.0.SelfAttention.k.weight', 'decoder.block.0.layer.2.DenseReluDense.wi.weight', 'decoder.block.1.layer.2.layer_norm.weight', 'decoder.block.5.layer.0.layer_norm.weight', 'decoder.block.1.layer.1.EncDecAttention.o.weight', 'decoder.block.17.layer.0.SelfAttention.v.weight', 'decoder.block.11.layer.2.layer_norm.weight', 'decoder.block.3.layer.1.EncDecAttention.k.weight', 'decoder.block.15.layer.1.EncDecAttention.v.weight', 'decoder.block.12.layer.0.SelfAttention.q.weight', 'decoder.block.23.layer.0.SelfAttention.o.weight', 'decoder.block.16.layer.0.SelfAttention.q.weight', 'decoder.block.21.layer.0.SelfAttention.o.weight', 'decoder.block.19.layer.0.SelfAttention.q.weight', 'decoder.block.19.layer.0.SelfAttention.v.weight', 'decoder.block.7.layer.0.SelfAttention.q.weight', 'decoder.block.13.layer.0.SelfAttention.o.weight', 'decoder.block.20.layer.2.DenseReluDense.wi.weight', 'decoder.block.12.layer.2.DenseReluDense.wi.weight', 'decoder.block.5.layer.1.EncDecAttention.o.weight', 'decoder.block.17.layer.2.layer_norm.weight', 'decoder.block.21.layer.1.EncDecAttention.k.weight', 'decoder.block.16.layer.1.EncDecAttention.q.weight', 'decoder.block.8.layer.1.layer_norm.weight', 'decoder.block.5.layer.2.DenseReluDense.wi.weight', 'decoder.block.9.layer.1.EncDecAttention.o.weight', 'decoder.block.22.layer.1.EncDecAttention.k.weight', 'decoder.block.0.layer.1.EncDecAttention.k.weight', 'decoder.block.4.layer.1.EncDecAttention.k.weight', 'decoder.block.22.layer.2.DenseReluDense.wi.weight', 'decoder.block.23.layer.2.layer_norm.weight', 'decoder.block.2.layer.2.DenseReluDense.wo.weight', 'decoder.block.23.layer.0.SelfAttention.q.weight', 'decoder.block.20.layer.1.EncDecAttention.o.weight', 'decoder.block.3.layer.1.EncDecAttention.v.weight', 'decoder.block.1.layer.0.SelfAttention.o.weight', 'decoder.block.17.layer.1.EncDecAttention.v.weight', 'decoder.block.9.layer.1.EncDecAttention.k.weight', 'decoder.block.12.layer.0.SelfAttention.k.weight', 'decoder.block.3.layer.0.SelfAttention.k.weight', 'decoder.block.0.layer.0.SelfAttention.q.weight', 'decoder.block.16.layer.2.DenseReluDense.wi.weight', 'decoder.block.2.layer.0.layer_norm.weight', 'decoder.block.17.layer.1.EncDecAttention.q.weight', 'decoder.block.13.layer.0.SelfAttention.k.weight', 'decoder.block.21.layer.2.layer_norm.weight', 'decoder.block.4.layer.1.layer_norm.weight', 'decoder.block.13.layer.0.SelfAttention.v.weight', 'decoder.block.7.layer.2.DenseReluDense.wi.weight', 'decoder.block.13.layer.2.DenseReluDense.wi.weight', 'decoder.block.16.layer.1.EncDecAttention.v.weight', 'decoder.block.15.layer.1.layer_norm.weight', 'decoder.block.0.layer.1.EncDecAttention.o.weight', 'decoder.block.19.layer.1.EncDecAttention.v.weight', 'decoder.block.3.layer.1.EncDecAttention.q.weight', 'decoder.block.6.layer.2.DenseReluDense.wo.weight', 'decoder.block.20.layer.0.SelfAttention.v.weight', 'decoder.block.4.layer.0.SelfAttention.v.weight', 'decoder.block.7.layer.1.layer_norm.weight', 'decoder.block.11.layer.0.SelfAttention.o.weight', 'decoder.block.19.layer.1.EncDecAttention.o.weight', 'decoder.block.23.layer.2.DenseReluDense.wo.weight', 'decoder.block.5.layer.0.SelfAttention.o.weight', 'decoder.block.18.layer.1.EncDecAttention.v.weight', 'decoder.block.5.layer.1.EncDecAttention.v.weight', 'decoder.block.1.layer.2.DenseReluDense.wo.weight', 'decoder.block.16.layer.1.layer_norm.weight', 'decoder.block.12.layer.1.EncDecAttention.v.weight', 'decoder.block.17.layer.1.EncDecAttention.o.weight', 'decoder.block.6.layer.0.SelfAttention.k.weight', 'decoder.block.11.layer.0.SelfAttention.k.weight', 'decoder.block.4.layer.2.layer_norm.weight', 'decoder.block.8.layer.1.EncDecAttention.q.weight', 'decoder.block.16.layer.0.SelfAttention.v.weight', 'decoder.block.0.layer.2.layer_norm.weight', 'decoder.block.15.layer.1.EncDecAttention.k.weight', 'decoder.block.19.layer.1.EncDecAttention.k.weight', 'decoder.block.18.layer.0.SelfAttention.q.weight', 'decoder.block.6.layer.1.EncDecAttention.q.weight', 'decoder.block.2.layer.1.EncDecAttention.q.weight', 'decoder.block.17.layer.2.DenseReluDense.wi.weight', 'decoder.block.5.layer.2.layer_norm.weight', 'decoder.block.13.layer.2.layer_norm.weight', 'decoder.block.2.layer.2.layer_norm.weight', 'decoder.block.16.layer.1.EncDecAttention.k.weight', 'decoder.block.18.layer.1.EncDecAttention.q.weight', 'decoder.block.12.layer.1.layer_norm.weight', 'decoder.block.10.layer.1.EncDecAttention.o.weight', 'decoder.block.9.layer.0.SelfAttention.k.weight', 'decoder.block.0.layer.2.DenseReluDense.wo.weight', 'decoder.block.20.layer.1.EncDecAttention.v.weight', 'decoder.block.20.layer.0.SelfAttention.q.weight', 'decoder.block.22.layer.2.DenseReluDense.wo.weight', 'decoder.block.14.layer.1.layer_norm.weight', 'decoder.block.4.layer.1.EncDecAttention.v.weight', 'decoder.block.22.layer.0.SelfAttention.v.weight', 'decoder.block.15.layer.2.layer_norm.weight', 'decoder.block.23.layer.2.DenseReluDense.wi.weight', 'decoder.block.23.layer.1.EncDecAttention.v.weight', 'decoder.block.8.layer.2.DenseReluDense.wo.weight', 'decoder.block.7.layer.0.SelfAttention.v.weight', 'decoder.block.4.layer.1.EncDecAttention.o.weight', 'decoder.block.6.layer.1.EncDecAttention.k.weight', 'decoder.block.3.layer.2.layer_norm.weight', 'decoder.block.7.layer.1.EncDecAttention.o.weight', 'decoder.block.6.layer.0.SelfAttention.q.weight', 'decoder.block.0.layer.0.SelfAttention.k.weight', 'decoder.block.22.layer.0.SelfAttention.q.weight', 'decoder.block.18.layer.2.DenseReluDense.wo.weight', 'decoder.block.10.layer.0.SelfAttention.k.weight', 'decoder.block.4.layer.0.SelfAttention.q.weight', 'decoder.block.20.layer.2.DenseReluDense.wo.weight', 'decoder.block.11.layer.1.EncDecAttention.o.weight', 'decoder.block.3.layer.2.DenseReluDense.wi.weight', 'decoder.block.10.layer.0.SelfAttention.q.weight', 'decoder.block.17.layer.1.layer_norm.weight', 'decoder.block.20.layer.1.layer_norm.weight', 'decoder.block.18.layer.0.layer_norm.weight', 'decoder.block.21.layer.0.SelfAttention.v.weight', 'decoder.block.20.layer.0.SelfAttention.o.weight', 'decoder.block.22.layer.1.EncDecAttention.o.weight', 'decoder.block.21.layer.1.EncDecAttention.q.weight', 'decoder.block.16.layer.2.layer_norm.weight', 'decoder.block.2.layer.1.EncDecAttention.k.weight', 'decoder.block.10.layer.1.EncDecAttention.v.weight', 'decoder.block.10.layer.1.layer_norm.weight', 'decoder.block.3.layer.2.DenseReluDense.wo.weight', 'decoder.block.14.layer.1.EncDecAttention.o.weight', 'decoder.block.16.layer.1.EncDecAttention.o.weight', 'decoder.block.17.layer.1.EncDecAttention.k.weight', 'decoder.block.15.layer.0.SelfAttention.k.weight', 'decoder.block.11.layer.0.layer_norm.weight', 'decoder.block.23.layer.1.EncDecAttention.q.weight', 'decoder.block.13.layer.1.EncDecAttention.q.weight', 'decoder.block.0.layer.1.EncDecAttention.q.weight', 'decoder.block.13.layer.0.layer_norm.weight', 'decoder.block.9.layer.0.SelfAttention.o.weight', 'decoder.block.19.layer.2.DenseReluDense.wo.weight', 'decoder.block.4.layer.1.EncDecAttention.q.weight', 'decoder.block.16.layer.0.layer_norm.weight', 'decoder.block.8.layer.2.DenseReluDense.wi.weight', 'decoder.block.17.layer.2.DenseReluDense.wo.weight', 'decoder.block.7.layer.1.EncDecAttention.k.weight', 'decoder.block.14.layer.2.DenseReluDense.wi.weight', 'decoder.block.6.layer.1.layer_norm.weight', 'decoder.block.17.layer.0.SelfAttention.o.weight', 'decoder.block.19.layer.1.layer_norm.weight', 'decoder.block.13.layer.1.EncDecAttention.k.weight', 'decoder.block.0.layer.0.SelfAttention.o.weight', 'decoder.block.20.layer.1.EncDecAttention.q.weight', 'decoder.block.23.layer.0.SelfAttention.v.weight', 'decoder.block.3.layer.0.SelfAttention.v.weight', 'decoder.block.18.layer.1.EncDecAttention.k.weight', 'decoder.block.1.layer.2.DenseReluDense.wi.weight', 'decoder.block.1.layer.1.EncDecAttention.q.weight', 'decoder.block.6.layer.1.EncDecAttention.o.weight', 'decoder.block.22.layer.0.layer_norm.weight', 'decoder.block.8.layer.0.SelfAttention.v.weight', 'decoder.block.12.layer.1.EncDecAttention.q.weight', 'decoder.block.1.layer.0.SelfAttention.v.weight', 'decoder.block.15.layer.1.EncDecAttention.o.weight', 'decoder.block.6.layer.0.SelfAttention.o.weight', 'decoder.block.15.layer.1.EncDecAttention.q.weight', 'decoder.block.19.layer.2.DenseReluDense.wi.weight', 'decoder.block.12.layer.0.SelfAttention.o.weight', 'decoder.block.14.layer.2.layer_norm.weight', 'decoder.block.22.layer.0.SelfAttention.k.weight', 'decoder.block.13.layer.0.SelfAttention.q.weight', 'decoder.block.11.layer.1.EncDecAttention.v.weight', 'decoder.block.22.layer.1.layer_norm.weight', 'decoder.block.21.layer.0.SelfAttention.q.weight', 'decoder.block.6.layer.2.DenseReluDense.wi.weight', 'decoder.block.12.layer.0.SelfAttention.v.weight', 'decoder.block.21.layer.0.SelfAttention.k.weight', 'decoder.block.19.layer.0.SelfAttention.k.weight', 'decoder.block.10.layer.0.layer_norm.weight', 'decoder.block.2.layer.2.DenseReluDense.wi.weight', 'decoder.block.17.layer.0.SelfAttention.k.weight', 'decoder.block.23.layer.0.layer_norm.weight', 'decoder.block.4.layer.2.DenseReluDense.wi.weight', 'decoder.block.5.layer.1.EncDecAttention.k.weight', 'decoder.block.19.layer.0.SelfAttention.o.weight', 'decoder.block.5.layer.0.SelfAttention.k.weight', 'decoder.block.10.layer.2.DenseReluDense.wo.weight', 'decoder.block.2.layer.0.SelfAttention.q.weight', 'decoder.block.22.layer.1.EncDecAttention.v.weight', 'decoder.block.23.layer.1.layer_norm.weight', 'decoder.block.5.layer.1.layer_norm.weight', 'decoder.block.3.layer.1.EncDecAttention.o.weight', 'decoder.block.14.layer.0.SelfAttention.o.weight', 'decoder.block.17.layer.0.layer_norm.weight', 'decoder.final_layer_norm.weight', 'decoder.block.10.layer.2.DenseReluDense.wi.weight', 'decoder.block.12.layer.0.layer_norm.weight', 'decoder.block.23.layer.1.EncDecAttention.k.weight', 'decoder.block.21.layer.1.EncDecAttention.o.weight', 'decoder.block.13.layer.1.layer_norm.weight', 'decoder.block.1.layer.1.EncDecAttention.k.weight', 'decoder.block.18.layer.0.SelfAttention.v.weight', 'decoder.block.1.layer.0.layer_norm.weight', 'decoder.block.5.layer.0.SelfAttention.v.weight', 'decoder.block.13.layer.2.DenseReluDense.wo.weight', 'decoder.block.8.layer.1.EncDecAttention.v.weight', 'decoder.block.0.layer.0.SelfAttention.v.weight', 'decoder.block.23.layer.0.SelfAttention.k.weight', 'decoder.block.0.layer.0.layer_norm.weight', 'decoder.block.15.layer.0.layer_norm.weight', 'decoder.block.7.layer.2.layer_norm.weight', 'decoder.block.8.layer.0.SelfAttention.k.weight', 'decoder.block.15.layer.2.DenseReluDense.wo.weight', 'decoder.block.8.layer.1.EncDecAttention.o.weight', 'decoder.block.22.layer.0.SelfAttention.o.weight', 'decoder.block.17.layer.0.SelfAttention.q.weight', 'decoder.block.9.layer.0.SelfAttention.v.weight', 'decoder.block.9.layer.1.EncDecAttention.q.weight', 'decoder.block.7.layer.0.layer_norm.weight', 'decoder.block.14.layer.0.layer_norm.weight', 'decoder.block.9.layer.0.SelfAttention.q.weight', 'decoder.block.16.layer.0.SelfAttention.o.weight', 'decoder.block.2.layer.1.EncDecAttention.o.weight', 'decoder.block.20.layer.1.EncDecAttention.k.weight', 'decoder.block.18.layer.2.layer_norm.weight', 'decoder.block.13.layer.1.EncDecAttention.v.weight', 'decoder.block.7.layer.2.DenseReluDense.wo.weight', 'decoder.block.21.layer.2.DenseReluDense.wo.weight', 'decoder.block.15.layer.2.DenseReluDense.wi.weight', 'decoder.block.10.layer.0.SelfAttention.v.weight', 'decoder.block.2.layer.0.SelfAttention.v.weight', 'decoder.block.11.layer.1.EncDecAttention.k.weight', 'decoder.block.22.layer.2.layer_norm.weight', 'decoder.block.2.layer.1.layer_norm.weight', 'decoder.block.8.layer.2.layer_norm.weight', 'decoder.block.8.layer.0.SelfAttention.q.weight', 'decoder.block.12.layer.1.EncDecAttention.k.weight', 'decoder.block.11.layer.0.SelfAttention.v.weight', 'decoder.block.22.layer.1.EncDecAttention.q.weight', 'decoder.block.5.layer.1.EncDecAttention.q.weight', 'decoder.block.11.layer.0.SelfAttention.q.weight', 'decoder.block.1.layer.0.SelfAttention.k.weight', 'decoder.block.20.layer.0.SelfAttention.k.weight', 'decoder.block.6.layer.0.layer_norm.weight', 'decoder.block.6.layer.2.layer_norm.weight', 'decoder.block.21.layer.1.layer_norm.weight', 'decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight', 'decoder.block.20.layer.2.layer_norm.weight', 'decoder.block.19.layer.1.EncDecAttention.q.weight', 'decoder.block.10.layer.1.EncDecAttention.k.weight', 'decoder.block.20.layer.0.layer_norm.weight', 'decoder.block.18.layer.0.SelfAttention.k.weight', 'decoder.block.21.layer.2.DenseReluDense.wi.weight', 'decoder.block.11.layer.1.EncDecAttention.q.weight', 'decoder.block.15.layer.0.SelfAttention.o.weight', 'decoder.block.5.layer.2.DenseReluDense.wo.weight', 'decoder.block.10.layer.1.EncDecAttention.q.weight', 'decoder.block.9.layer.2.layer_norm.weight', 'decoder.block.7.layer.1.EncDecAttention.v.weight', 'decoder.block.9.layer.0.layer_norm.weight', 'decoder.block.16.layer.2.DenseReluDense.wo.weight', 'decoder.block.9.layer.1.layer_norm.weight', 'decoder.block.7.layer.1.EncDecAttention.q.weight', 'decoder.block.1.layer.0.SelfAttention.q.weight', 'decoder.block.18.layer.0.SelfAttention.o.weight', 'decoder.block.1.layer.1.EncDecAttention.v.weight', 'decoder.block.14.layer.1.EncDecAttention.q.weight', 'decoder.block.10.layer.0.SelfAttention.o.weight', 'decoder.block.16.layer.0.SelfAttention.k.weight', 'decoder.block.18.layer.1.EncDecAttention.o.weight', 'decoder.block.11.layer.1.layer_norm.weight', 'decoder.block.2.layer.1.EncDecAttention.v.weight', 'decoder.block.4.layer.0.SelfAttention.k.weight', 'decoder.block.19.layer.0.layer_norm.weight', 'decoder.block.1.layer.1.layer_norm.weight', 'decoder.block.8.layer.0.layer_norm.weight', 'decoder.block.0.layer.1.layer_norm.weight']\n",
229
+ "- This IS expected if you are initializing T5EncoderModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
230
+ "- This IS NOT expected if you are initializing T5EncoderModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
231
+ ]
232
+ },
233
+ {
234
+ "name": "stderr",
235
+ "output_type": "stream",
236
+ "text": [
237
+ "Some weights of T5EncoderModel were not initialized from the model checkpoint at t5-11b and are newly initialized: ['encoder.embed_tokens.weight']\n",
238
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
239
+ ]
240
+ },
241
+ {
242
+ "data": {
243
+ "text/plain": [
244
+ "{'shared.weight': torch.Size([32128, 1024]),\n",
245
+ " 'encoder.embed_tokens.weight': torch.Size([32128, 1024]),\n",
246
+ " 'encoder.block.0.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
247
+ " 'encoder.block.0.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
248
+ " 'encoder.block.0.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
249
+ " 'encoder.block.0.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
250
+ " 'encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight': torch.Size([32, 128]),\n",
251
+ " 'encoder.block.0.layer.0.layer_norm.weight': torch.Size([1024]),\n",
252
+ " 'encoder.block.0.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
253
+ " 'encoder.block.0.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
254
+ " 'encoder.block.0.layer.1.layer_norm.weight': torch.Size([1024]),\n",
255
+ " 'encoder.block.1.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
256
+ " 'encoder.block.1.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
257
+ " 'encoder.block.1.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
258
+ " 'encoder.block.1.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
259
+ " 'encoder.block.1.layer.0.layer_norm.weight': torch.Size([1024]),\n",
260
+ " 'encoder.block.1.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
261
+ " 'encoder.block.1.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
262
+ " 'encoder.block.1.layer.1.layer_norm.weight': torch.Size([1024]),\n",
263
+ " 'encoder.block.2.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
264
+ " 'encoder.block.2.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
265
+ " 'encoder.block.2.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
266
+ " 'encoder.block.2.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
267
+ " 'encoder.block.2.layer.0.layer_norm.weight': torch.Size([1024]),\n",
268
+ " 'encoder.block.2.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
269
+ " 'encoder.block.2.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
270
+ " 'encoder.block.2.layer.1.layer_norm.weight': torch.Size([1024]),\n",
271
+ " 'encoder.block.3.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
272
+ " 'encoder.block.3.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
273
+ " 'encoder.block.3.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
274
+ " 'encoder.block.3.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
275
+ " 'encoder.block.3.layer.0.layer_norm.weight': torch.Size([1024]),\n",
276
+ " 'encoder.block.3.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
277
+ " 'encoder.block.3.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
278
+ " 'encoder.block.3.layer.1.layer_norm.weight': torch.Size([1024]),\n",
279
+ " 'encoder.block.4.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
280
+ " 'encoder.block.4.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
281
+ " 'encoder.block.4.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
282
+ " 'encoder.block.4.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
283
+ " 'encoder.block.4.layer.0.layer_norm.weight': torch.Size([1024]),\n",
284
+ " 'encoder.block.4.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
285
+ " 'encoder.block.4.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
286
+ " 'encoder.block.4.layer.1.layer_norm.weight': torch.Size([1024]),\n",
287
+ " 'encoder.block.5.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
288
+ " 'encoder.block.5.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
289
+ " 'encoder.block.5.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
290
+ " 'encoder.block.5.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
291
+ " 'encoder.block.5.layer.0.layer_norm.weight': torch.Size([1024]),\n",
292
+ " 'encoder.block.5.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
293
+ " 'encoder.block.5.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
294
+ " 'encoder.block.5.layer.1.layer_norm.weight': torch.Size([1024]),\n",
295
+ " 'encoder.block.6.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
296
+ " 'encoder.block.6.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
297
+ " 'encoder.block.6.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
298
+ " 'encoder.block.6.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
299
+ " 'encoder.block.6.layer.0.layer_norm.weight': torch.Size([1024]),\n",
300
+ " 'encoder.block.6.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
301
+ " 'encoder.block.6.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
302
+ " 'encoder.block.6.layer.1.layer_norm.weight': torch.Size([1024]),\n",
303
+ " 'encoder.block.7.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
304
+ " 'encoder.block.7.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
305
+ " 'encoder.block.7.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
306
+ " 'encoder.block.7.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
307
+ " 'encoder.block.7.layer.0.layer_norm.weight': torch.Size([1024]),\n",
308
+ " 'encoder.block.7.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
309
+ " 'encoder.block.7.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
310
+ " 'encoder.block.7.layer.1.layer_norm.weight': torch.Size([1024]),\n",
311
+ " 'encoder.block.8.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
312
+ " 'encoder.block.8.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
313
+ " 'encoder.block.8.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
314
+ " 'encoder.block.8.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
315
+ " 'encoder.block.8.layer.0.layer_norm.weight': torch.Size([1024]),\n",
316
+ " 'encoder.block.8.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
317
+ " 'encoder.block.8.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
318
+ " 'encoder.block.8.layer.1.layer_norm.weight': torch.Size([1024]),\n",
319
+ " 'encoder.block.9.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
320
+ " 'encoder.block.9.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
321
+ " 'encoder.block.9.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
322
+ " 'encoder.block.9.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
323
+ " 'encoder.block.9.layer.0.layer_norm.weight': torch.Size([1024]),\n",
324
+ " 'encoder.block.9.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
325
+ " 'encoder.block.9.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
326
+ " 'encoder.block.9.layer.1.layer_norm.weight': torch.Size([1024]),\n",
327
+ " 'encoder.block.10.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
328
+ " 'encoder.block.10.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
329
+ " 'encoder.block.10.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
330
+ " 'encoder.block.10.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
331
+ " 'encoder.block.10.layer.0.layer_norm.weight': torch.Size([1024]),\n",
332
+ " 'encoder.block.10.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
333
+ " 'encoder.block.10.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
334
+ " 'encoder.block.10.layer.1.layer_norm.weight': torch.Size([1024]),\n",
335
+ " 'encoder.block.11.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
336
+ " 'encoder.block.11.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
337
+ " 'encoder.block.11.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
338
+ " 'encoder.block.11.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
339
+ " 'encoder.block.11.layer.0.layer_norm.weight': torch.Size([1024]),\n",
340
+ " 'encoder.block.11.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
341
+ " 'encoder.block.11.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
342
+ " 'encoder.block.11.layer.1.layer_norm.weight': torch.Size([1024]),\n",
343
+ " 'encoder.block.12.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
344
+ " 'encoder.block.12.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
345
+ " 'encoder.block.12.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
346
+ " 'encoder.block.12.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
347
+ " 'encoder.block.12.layer.0.layer_norm.weight': torch.Size([1024]),\n",
348
+ " 'encoder.block.12.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
349
+ " 'encoder.block.12.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
350
+ " 'encoder.block.12.layer.1.layer_norm.weight': torch.Size([1024]),\n",
351
+ " 'encoder.block.13.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
352
+ " 'encoder.block.13.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
353
+ " 'encoder.block.13.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
354
+ " 'encoder.block.13.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
355
+ " 'encoder.block.13.layer.0.layer_norm.weight': torch.Size([1024]),\n",
356
+ " 'encoder.block.13.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
357
+ " 'encoder.block.13.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
358
+ " 'encoder.block.13.layer.1.layer_norm.weight': torch.Size([1024]),\n",
359
+ " 'encoder.block.14.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
360
+ " 'encoder.block.14.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
361
+ " 'encoder.block.14.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
362
+ " 'encoder.block.14.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
363
+ " 'encoder.block.14.layer.0.layer_norm.weight': torch.Size([1024]),\n",
364
+ " 'encoder.block.14.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
365
+ " 'encoder.block.14.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
366
+ " 'encoder.block.14.layer.1.layer_norm.weight': torch.Size([1024]),\n",
367
+ " 'encoder.block.15.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
368
+ " 'encoder.block.15.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
369
+ " 'encoder.block.15.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
370
+ " 'encoder.block.15.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
371
+ " 'encoder.block.15.layer.0.layer_norm.weight': torch.Size([1024]),\n",
372
+ " 'encoder.block.15.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
373
+ " 'encoder.block.15.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
374
+ " 'encoder.block.15.layer.1.layer_norm.weight': torch.Size([1024]),\n",
375
+ " 'encoder.block.16.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
376
+ " 'encoder.block.16.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
377
+ " 'encoder.block.16.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
378
+ " 'encoder.block.16.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
379
+ " 'encoder.block.16.layer.0.layer_norm.weight': torch.Size([1024]),\n",
380
+ " 'encoder.block.16.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
381
+ " 'encoder.block.16.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
382
+ " 'encoder.block.16.layer.1.layer_norm.weight': torch.Size([1024]),\n",
383
+ " 'encoder.block.17.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
384
+ " 'encoder.block.17.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
385
+ " 'encoder.block.17.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
386
+ " 'encoder.block.17.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
387
+ " 'encoder.block.17.layer.0.layer_norm.weight': torch.Size([1024]),\n",
388
+ " 'encoder.block.17.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
389
+ " 'encoder.block.17.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
390
+ " 'encoder.block.17.layer.1.layer_norm.weight': torch.Size([1024]),\n",
391
+ " 'encoder.block.18.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
392
+ " 'encoder.block.18.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
393
+ " 'encoder.block.18.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
394
+ " 'encoder.block.18.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
395
+ " 'encoder.block.18.layer.0.layer_norm.weight': torch.Size([1024]),\n",
396
+ " 'encoder.block.18.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
397
+ " 'encoder.block.18.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
398
+ " 'encoder.block.18.layer.1.layer_norm.weight': torch.Size([1024]),\n",
399
+ " 'encoder.block.19.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
400
+ " 'encoder.block.19.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
401
+ " 'encoder.block.19.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
402
+ " 'encoder.block.19.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
403
+ " 'encoder.block.19.layer.0.layer_norm.weight': torch.Size([1024]),\n",
404
+ " 'encoder.block.19.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
405
+ " 'encoder.block.19.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
406
+ " 'encoder.block.19.layer.1.layer_norm.weight': torch.Size([1024]),\n",
407
+ " 'encoder.block.20.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
408
+ " 'encoder.block.20.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
409
+ " 'encoder.block.20.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
410
+ " 'encoder.block.20.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
411
+ " 'encoder.block.20.layer.0.layer_norm.weight': torch.Size([1024]),\n",
412
+ " 'encoder.block.20.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
413
+ " 'encoder.block.20.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
414
+ " 'encoder.block.20.layer.1.layer_norm.weight': torch.Size([1024]),\n",
415
+ " 'encoder.block.21.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
416
+ " 'encoder.block.21.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
417
+ " 'encoder.block.21.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
418
+ " 'encoder.block.21.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
419
+ " 'encoder.block.21.layer.0.layer_norm.weight': torch.Size([1024]),\n",
420
+ " 'encoder.block.21.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
421
+ " 'encoder.block.21.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
422
+ " 'encoder.block.21.layer.1.layer_norm.weight': torch.Size([1024]),\n",
423
+ " 'encoder.block.22.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
424
+ " 'encoder.block.22.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
425
+ " 'encoder.block.22.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
426
+ " 'encoder.block.22.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
427
+ " 'encoder.block.22.layer.0.layer_norm.weight': torch.Size([1024]),\n",
428
+ " 'encoder.block.22.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
429
+ " 'encoder.block.22.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
430
+ " 'encoder.block.22.layer.1.layer_norm.weight': torch.Size([1024]),\n",
431
+ " 'encoder.block.23.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
432
+ " 'encoder.block.23.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
433
+ " 'encoder.block.23.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
434
+ " 'encoder.block.23.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
435
+ " 'encoder.block.23.layer.0.layer_norm.weight': torch.Size([1024]),\n",
436
+ " 'encoder.block.23.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
437
+ " 'encoder.block.23.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
438
+ " 'encoder.block.23.layer.1.layer_norm.weight': torch.Size([1024]),\n",
439
+ " 'encoder.final_layer_norm.weight': torch.Size([1024])}"
440
+ ]
441
+ },
442
+ "execution_count": 7,
443
+ "metadata": {},
444
+ "output_type": "execute_result"
445
+ }
446
+ ],
447
+ "source": [
448
+ "tokenizer = AutoTokenizer.from_pretrained(f\"t5-{model_size_hf}\")\n",
449
+ "t5 = T5EncoderModel.from_pretrained(f\"t5-{model_size_hf}\") \n",
450
+ "pt_name_shape = {name: weight.shape for name, weight in t5.state_dict().items()}\n",
451
+ "pt_name_shape"
452
+ ]
453
+ },
454
+ {
455
+ "cell_type": "code",
456
+ "execution_count": 8,
457
+ "id": "1d3c9865",
458
+ "metadata": {},
459
+ "outputs": [],
460
+ "source": [
461
+ "def convert_name(name):\n",
462
+ " fct_map = {\n",
463
+ " \"attention\": \"SelfAttention\",\n",
464
+ " \"mlp\": \"DenseReluDense\",\n",
465
+ " \"pre_attention_layer_norm\": \"layer_norm\",\n",
466
+ " \"pre_mlp_layer_norm\": \"layer_norm\",\n",
467
+ " }\n",
468
+ " name_map = {\n",
469
+ " 'key': 'k',\n",
470
+ " 'out': 'o',\n",
471
+ " 'query': 'q',\n",
472
+ " 'value': 'v'\n",
473
+ " }\n",
474
+ " \n",
475
+ " fixed_names = {\n",
476
+ " \"token_embedder__embedding:0\": \"shared.weight\",\n",
477
+ " \"encoder__encoder_norm__scale:0\": \"encoder.final_layer_norm.weight\",\n",
478
+ " \"encoder__relpos_bias__rel_embedding:0\": \"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight\"\n",
479
+ " }\n",
480
+ " \n",
481
+ " if name in fixed_names:\n",
482
+ " return fixed_names[name]\n",
483
+ " \n",
484
+ " out = \"\"\n",
485
+ " splits = name.split(\"__\")\n",
486
+ " layer = splits[1].split(\"_\")[1]\n",
487
+ " fct = fct_map.get(splits[2], splits[2])\n",
488
+ " if 'layer_norm' in name:\n",
489
+ " sublayer = \"1\" if \"pre_mlp_layer_norm\" in name else \"0\" #Not sure on the right setting here\n",
490
+ " #sublayer = \"0\" if \"pre_mlp_layer_norm\" in name else \"1\" #Not sure on the right setting here\n",
491
+ " out = f\"encoder.block.{layer}.layer.{sublayer}.{fct}.weight\"\n",
492
+ " elif name.startswith(\"encoder__layers_\"):\n",
493
+ " sublayer = \"0\" if fct == \"SelfAttention\" else \"1\"\n",
494
+ " name = name_map.get(splits[3], splits[3])\n",
495
+ " out = f\"encoder.block.{layer}.layer.{sublayer}.{fct}.{name}.weight\"\n",
496
+ " \n",
497
+ " return out"
498
+ ]
499
+ },
500
+ {
501
+ "cell_type": "code",
502
+ "execution_count": 9,
503
+ "id": "1ca9590e",
504
+ "metadata": {},
505
+ "outputs": [],
506
+ "source": [
507
+ "def equal_shapes(shape1, shape2):\n",
508
+ " if len(shape1) != len(shape2):\n",
509
+ " return False\n",
510
+ " \n",
511
+ " for idx in range(len(shape1)):\n",
512
+ " if shape1[idx] != shape2[idx]:\n",
513
+ " return False\n",
514
+ " \n",
515
+ " return True"
516
+ ]
517
+ },
518
+ {
519
+ "cell_type": "code",
520
+ "execution_count": 10,
521
+ "id": "ced52a5f",
522
+ "metadata": {},
523
+ "outputs": [
524
+ {
525
+ "name": "stdout",
526
+ "output_type": "stream",
527
+ "text": [
528
+ "Remaining weights: {'encoder.embed_tokens.weight'}\n"
529
+ ]
530
+ }
531
+ ],
532
+ "source": [
533
+ "def need_transpose(name):\n",
534
+ " #HF function: https://github.com/huggingface/transformers/blob/c962c2adbff678ae6d2e98378bed5b8d1a9831d9/src/transformers/models/t5/modeling_t5.py#L161\n",
535
+ " return name != \"shared.weight\"\n",
536
+ "\n",
537
+ " \n",
538
+ "\n",
539
+ "names_to_ignore = {\"projection_layer__kernel:0\"}\n",
540
+ "#Additional dense layer on top\n",
541
+ "\n",
542
+ "#Check we used all names\n",
543
+ "pt_all_names = set(t5.state_dict().keys())\n",
544
+ "\n",
545
+ "for var in v:\n",
546
+ " name = var.name\n",
547
+ " if name in names_to_ignore:\n",
548
+ " continue\n",
549
+ " \n",
550
+ " pt_name = convert_name(name)\n",
551
+ " if pt_name not in pt_all_names:\n",
552
+ " print(\"Name not found:\", name, \"=>\", pt_name)\n",
553
+ " else:\n",
554
+ " pt_all_names.remove(pt_name)\n",
555
+ " tf_shape = tf_name_shape[name].as_list()\n",
556
+ " pt_shape = list(pt_name_shape[pt_name])\n",
557
+ " \n",
558
+ " if need_transpose(pt_name):\n",
559
+ " pt_shape = list(reversed(pt_shape))\n",
560
+ " \n",
561
+ " if not equal_shapes(tf_shape, pt_shape):\n",
562
+ " print(\"Different shape:\", name, tf_shape, pt_name, pt_shape )\n",
563
+ " \n",
564
+ "print(\"Remaining weights:\", pt_all_names)\n",
565
+ "#All layers match"
566
+ ]
567
+ },
568
+ {
569
+ "cell_type": "code",
570
+ "execution_count": 11,
571
+ "id": "1190984f",
572
+ "metadata": {},
573
+ "outputs": [
574
+ {
575
+ "name": "stderr",
576
+ "output_type": "stream",
577
+ "text": [
578
+ "Some weights of T5EncoderModel were not initialized from the model checkpoint at t5-11b and are newly initialized: ['encoder.embed_tokens.weight']\n",
579
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
580
+ ]
581
+ },
582
+ {
583
+ "name": "stdout",
584
+ "output_type": "stream",
585
+ "text": [
586
+ "encoder__encoder_norm__scale:0 ((1024,)) =transpose=> encoder.final_layer_norm.weight torch.Size([1024])\n",
587
+ "encoder__layers_0__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.0.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
588
+ "encoder__layers_0__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.0.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
589
+ "encoder__layers_0__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.0.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
590
+ "encoder__layers_0__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.0.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
591
+ "encoder__layers_0__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.0.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
592
+ "encoder__layers_0__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.0.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
593
+ "encoder__layers_0__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.0.layer.0.layer_norm.weight torch.Size([1024])\n",
594
+ "encoder__layers_0__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.0.layer.1.layer_norm.weight torch.Size([1024])\n",
595
+ "encoder__layers_1__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.1.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
596
+ "encoder__layers_1__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.1.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
597
+ "encoder__layers_1__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.1.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
598
+ "encoder__layers_1__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.1.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
599
+ "encoder__layers_1__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.1.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
600
+ "encoder__layers_1__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.1.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
601
+ "encoder__layers_1__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.1.layer.0.layer_norm.weight torch.Size([1024])\n",
602
+ "encoder__layers_1__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.1.layer.1.layer_norm.weight torch.Size([1024])\n",
603
+ "encoder__layers_10__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.10.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
604
+ "encoder__layers_10__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.10.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
605
+ "encoder__layers_10__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.10.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
606
+ "encoder__layers_10__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.10.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
607
+ "encoder__layers_10__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.10.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
608
+ "encoder__layers_10__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.10.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
609
+ "encoder__layers_10__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.10.layer.0.layer_norm.weight torch.Size([1024])\n",
610
+ "encoder__layers_10__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.10.layer.1.layer_norm.weight torch.Size([1024])\n",
611
+ "encoder__layers_11__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.11.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
612
+ "encoder__layers_11__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.11.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
613
+ "encoder__layers_11__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.11.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
614
+ "encoder__layers_11__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.11.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
615
+ "encoder__layers_11__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.11.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
616
+ "encoder__layers_11__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.11.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
617
+ "encoder__layers_11__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.11.layer.0.layer_norm.weight torch.Size([1024])\n",
618
+ "encoder__layers_11__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.11.layer.1.layer_norm.weight torch.Size([1024])\n",
619
+ "encoder__layers_12__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.12.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
620
+ "encoder__layers_12__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.12.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
621
+ "encoder__layers_12__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.12.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
622
+ "encoder__layers_12__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.12.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
623
+ "encoder__layers_12__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.12.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
624
+ "encoder__layers_12__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.12.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
625
+ "encoder__layers_12__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.12.layer.0.layer_norm.weight torch.Size([1024])\n",
626
+ "encoder__layers_12__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.12.layer.1.layer_norm.weight torch.Size([1024])\n",
627
+ "encoder__layers_13__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.13.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
628
+ "encoder__layers_13__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.13.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
629
+ "encoder__layers_13__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.13.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
630
+ "encoder__layers_13__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.13.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
631
+ "encoder__layers_13__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.13.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
632
+ "encoder__layers_13__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.13.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
633
+ "encoder__layers_13__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.13.layer.0.layer_norm.weight torch.Size([1024])\n",
634
+ "encoder__layers_13__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.13.layer.1.layer_norm.weight torch.Size([1024])\n",
635
+ "encoder__layers_14__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.14.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
636
+ "encoder__layers_14__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.14.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
637
+ "encoder__layers_14__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.14.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
638
+ "encoder__layers_14__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.14.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
639
+ "encoder__layers_14__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.14.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
640
+ "encoder__layers_14__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.14.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
641
+ "encoder__layers_14__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.14.layer.0.layer_norm.weight torch.Size([1024])\n",
642
+ "encoder__layers_14__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.14.layer.1.layer_norm.weight torch.Size([1024])\n",
643
+ "encoder__layers_15__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.15.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n"
644
+ ]
645
+ },
646
+ {
647
+ "name": "stdout",
648
+ "output_type": "stream",
649
+ "text": [
650
+ "encoder__layers_15__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.15.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
651
+ "encoder__layers_15__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.15.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
652
+ "encoder__layers_15__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.15.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
653
+ "encoder__layers_15__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.15.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
654
+ "encoder__layers_15__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.15.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
655
+ "encoder__layers_15__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.15.layer.0.layer_norm.weight torch.Size([1024])\n",
656
+ "encoder__layers_15__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.15.layer.1.layer_norm.weight torch.Size([1024])\n",
657
+ "encoder__layers_16__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.16.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
658
+ "encoder__layers_16__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.16.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
659
+ "encoder__layers_16__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.16.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
660
+ "encoder__layers_16__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.16.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
661
+ "encoder__layers_16__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.16.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
662
+ "encoder__layers_16__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.16.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
663
+ "encoder__layers_16__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.16.layer.0.layer_norm.weight torch.Size([1024])\n",
664
+ "encoder__layers_16__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.16.layer.1.layer_norm.weight torch.Size([1024])\n",
665
+ "encoder__layers_17__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.17.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
666
+ "encoder__layers_17__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.17.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
667
+ "encoder__layers_17__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.17.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
668
+ "encoder__layers_17__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.17.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
669
+ "encoder__layers_17__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.17.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
670
+ "encoder__layers_17__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.17.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
671
+ "encoder__layers_17__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.17.layer.0.layer_norm.weight torch.Size([1024])\n",
672
+ "encoder__layers_17__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.17.layer.1.layer_norm.weight torch.Size([1024])\n",
673
+ "encoder__layers_18__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.18.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
674
+ "encoder__layers_18__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.18.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
675
+ "encoder__layers_18__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.18.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
676
+ "encoder__layers_18__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.18.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
677
+ "encoder__layers_18__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.18.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
678
+ "encoder__layers_18__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.18.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
679
+ "encoder__layers_18__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.18.layer.0.layer_norm.weight torch.Size([1024])\n",
680
+ "encoder__layers_18__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.18.layer.1.layer_norm.weight torch.Size([1024])\n",
681
+ "encoder__layers_19__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.19.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
682
+ "encoder__layers_19__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.19.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
683
+ "encoder__layers_19__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.19.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
684
+ "encoder__layers_19__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.19.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
685
+ "encoder__layers_19__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.19.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
686
+ "encoder__layers_19__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.19.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
687
+ "encoder__layers_19__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.19.layer.0.layer_norm.weight torch.Size([1024])\n",
688
+ "encoder__layers_19__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.19.layer.1.layer_norm.weight torch.Size([1024])\n",
689
+ "encoder__layers_2__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.2.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
690
+ "encoder__layers_2__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.2.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
691
+ "encoder__layers_2__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.2.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
692
+ "encoder__layers_2__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.2.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
693
+ "encoder__layers_2__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.2.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
694
+ "encoder__layers_2__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.2.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
695
+ "encoder__layers_2__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.2.layer.0.layer_norm.weight torch.Size([1024])\n",
696
+ "encoder__layers_2__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.2.layer.1.layer_norm.weight torch.Size([1024])\n",
697
+ "encoder__layers_20__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.20.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
698
+ "encoder__layers_20__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.20.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
699
+ "encoder__layers_20__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.20.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
700
+ "encoder__layers_20__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.20.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
701
+ "encoder__layers_20__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.20.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
702
+ "encoder__layers_20__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.20.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
703
+ "encoder__layers_20__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.20.layer.0.layer_norm.weight torch.Size([1024])\n",
704
+ "encoder__layers_20__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.20.layer.1.layer_norm.weight torch.Size([1024])\n",
705
+ "encoder__layers_21__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.21.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
706
+ "encoder__layers_21__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.21.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n"
707
+ ]
708
+ },
709
+ {
710
+ "name": "stdout",
711
+ "output_type": "stream",
712
+ "text": [
713
+ "encoder__layers_21__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.21.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
714
+ "encoder__layers_21__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.21.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
715
+ "encoder__layers_21__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.21.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
716
+ "encoder__layers_21__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.21.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
717
+ "encoder__layers_21__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.21.layer.0.layer_norm.weight torch.Size([1024])\n",
718
+ "encoder__layers_21__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.21.layer.1.layer_norm.weight torch.Size([1024])\n",
719
+ "encoder__layers_22__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.22.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
720
+ "encoder__layers_22__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.22.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
721
+ "encoder__layers_22__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.22.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
722
+ "encoder__layers_22__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.22.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
723
+ "encoder__layers_22__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.22.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
724
+ "encoder__layers_22__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.22.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
725
+ "encoder__layers_22__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.22.layer.0.layer_norm.weight torch.Size([1024])\n",
726
+ "encoder__layers_22__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.22.layer.1.layer_norm.weight torch.Size([1024])\n",
727
+ "encoder__layers_23__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.23.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
728
+ "encoder__layers_23__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.23.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
729
+ "encoder__layers_23__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.23.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
730
+ "encoder__layers_23__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.23.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
731
+ "encoder__layers_23__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.23.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
732
+ "encoder__layers_23__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.23.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
733
+ "encoder__layers_23__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.23.layer.0.layer_norm.weight torch.Size([1024])\n",
734
+ "encoder__layers_23__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.23.layer.1.layer_norm.weight torch.Size([1024])\n",
735
+ "encoder__layers_3__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.3.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
736
+ "encoder__layers_3__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.3.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
737
+ "encoder__layers_3__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.3.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
738
+ "encoder__layers_3__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.3.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
739
+ "encoder__layers_3__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.3.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
740
+ "encoder__layers_3__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.3.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
741
+ "encoder__layers_3__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.3.layer.0.layer_norm.weight torch.Size([1024])\n",
742
+ "encoder__layers_3__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.3.layer.1.layer_norm.weight torch.Size([1024])\n",
743
+ "encoder__layers_4__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.4.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
744
+ "encoder__layers_4__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.4.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
745
+ "encoder__layers_4__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.4.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
746
+ "encoder__layers_4__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.4.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
747
+ "encoder__layers_4__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.4.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
748
+ "encoder__layers_4__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.4.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
749
+ "encoder__layers_4__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.4.layer.0.layer_norm.weight torch.Size([1024])\n",
750
+ "encoder__layers_4__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.4.layer.1.layer_norm.weight torch.Size([1024])\n",
751
+ "encoder__layers_5__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.5.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
752
+ "encoder__layers_5__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.5.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
753
+ "encoder__layers_5__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.5.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
754
+ "encoder__layers_5__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.5.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
755
+ "encoder__layers_5__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.5.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
756
+ "encoder__layers_5__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.5.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
757
+ "encoder__layers_5__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.5.layer.0.layer_norm.weight torch.Size([1024])\n",
758
+ "encoder__layers_5__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.5.layer.1.layer_norm.weight torch.Size([1024])\n",
759
+ "encoder__layers_6__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.6.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
760
+ "encoder__layers_6__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.6.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
761
+ "encoder__layers_6__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.6.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
762
+ "encoder__layers_6__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.6.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
763
+ "encoder__layers_6__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.6.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
764
+ "encoder__layers_6__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.6.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
765
+ "encoder__layers_6__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.6.layer.0.layer_norm.weight torch.Size([1024])\n",
766
+ "encoder__layers_6__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.6.layer.1.layer_norm.weight torch.Size([1024])\n",
767
+ "encoder__layers_7__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.7.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
768
+ "encoder__layers_7__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.7.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
769
+ "encoder__layers_7__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.7.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
770
+ "encoder__layers_7__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.7.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n"
771
+ ]
772
+ },
773
+ {
774
+ "name": "stdout",
775
+ "output_type": "stream",
776
+ "text": [
777
+ "encoder__layers_7__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.7.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
778
+ "encoder__layers_7__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.7.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
779
+ "encoder__layers_7__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.7.layer.0.layer_norm.weight torch.Size([1024])\n",
780
+ "encoder__layers_7__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.7.layer.1.layer_norm.weight torch.Size([1024])\n",
781
+ "encoder__layers_8__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.8.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
782
+ "encoder__layers_8__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.8.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
783
+ "encoder__layers_8__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.8.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
784
+ "encoder__layers_8__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.8.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
785
+ "encoder__layers_8__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.8.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
786
+ "encoder__layers_8__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.8.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
787
+ "encoder__layers_8__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.8.layer.0.layer_norm.weight torch.Size([1024])\n",
788
+ "encoder__layers_8__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.8.layer.1.layer_norm.weight torch.Size([1024])\n",
789
+ "encoder__layers_9__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.9.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
790
+ "encoder__layers_9__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.9.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
791
+ "encoder__layers_9__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.9.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
792
+ "encoder__layers_9__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.9.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
793
+ "encoder__layers_9__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.9.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
794
+ "encoder__layers_9__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.9.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
795
+ "encoder__layers_9__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.9.layer.0.layer_norm.weight torch.Size([1024])\n",
796
+ "encoder__layers_9__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.9.layer.1.layer_norm.weight torch.Size([1024])\n",
797
+ "encoder__relpos_bias__rel_embedding:0 ((128, 32)) =transpose=> encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight torch.Size([32, 128])\n",
798
+ "token_embedder__embedding:0 ((32128, 1024)) => shared.weight torch.Size([32128, 1024])\n",
799
+ "Linear(in_features=1024, out_features=768, bias=False)\n",
800
+ "Remaining weights: set()\n"
801
+ ]
802
+ }
803
+ ],
804
+ "source": [
805
+ "import torch\n",
806
+ "tokenizer = AutoTokenizer.from_pretrained(f\"t5-{model_size_hf}\")\n",
807
+ "T5EncoderModel._keys_to_ignore_on_load_unexpected = [\"decoder.*\"]\n",
808
+ "t5 = T5EncoderModel.from_pretrained(f\"t5-{model_size_hf}\")\n",
809
+ "t5_state = t5.state_dict()\n",
810
+ "\n",
811
+ "state_all_names = set(t5_state.keys())\n",
812
+ "\n",
813
+ "for var in v:\n",
814
+ " tf_name = var.name\n",
815
+ " if tf_name in names_to_ignore:\n",
816
+ " continue\n",
817
+ " \n",
818
+ " pt_name = convert_name(tf_name)\n",
819
+ " weights = np.float32(var.numpy())\n",
820
+ " \n",
821
+ " state_all_names.remove(pt_name)\n",
822
+ " \n",
823
+ " tranpose_status = \"=>\"\n",
824
+ " if need_transpose(pt_name):\n",
825
+ " tranpose_status = \"=transpose=>\"\n",
826
+ " weights = weights.transpose()\n",
827
+ " \n",
828
+ " print(tf_name, f\"({var.shape})\", tranpose_status, pt_name, t5_state[pt_name].shape)\n",
829
+ " \n",
830
+ " original_shape = t5_state[pt_name].shape\n",
831
+ " t5_state[pt_name] = torch.nn.Parameter(torch.tensor(weights))\n",
832
+ " new_shape = t5_state[pt_name].shape\n",
833
+ " \n",
834
+ " if not equal_shapes(original_shape, new_shape):\n",
835
+ " print(\"Different shape:\", tf_name, original_shape, pt_name, new_shape)\n",
836
+ " break\n",
837
+ "\n",
838
+ "#Encoder Word embeddings\n",
839
+ "t5_state['encoder.embed_tokens.weight'] = t5_state['shared.weight']\n",
840
+ "state_all_names.remove('encoder.embed_tokens.weight')\n",
841
+ " \n",
842
+ "#Load back the weights\n",
843
+ "t5.load_state_dict(t5_state) \n",
844
+ "\n",
845
+ "tf_linear_weight = tf_name_weight[\"projection_layer__kernel:0\"]\n",
846
+ "linear = torch.nn.Linear(tf_linear_weight.shape[0], tf_linear_weight.shape[1], bias=False)\n",
847
+ "original_shape = linear.weight.shape\n",
848
+ "linear.weight = torch.nn.Parameter(torch.tensor(np.float32(tf_linear_weight.numpy()).transpose()))\n",
849
+ "new_shape = linear.weight.shape\n",
850
+ "if not equal_shapes(original_shape, new_shape):\n",
851
+ " print(\"Different shape at linear layer\")\n",
852
+ " \n",
853
+ "print(linear)\n",
854
+ "print(\"Remaining weights:\", state_all_names)\n",
855
+ "assert len(state_all_names) == 0\n"
856
+ ]
857
+ },
858
+ {
859
+ "cell_type": "code",
860
+ "execution_count": 12,
861
+ "id": "d59d5a2c",
862
+ "metadata": {},
863
+ "outputs": [
864
+ {
865
+ "name": "stdout",
866
+ "output_type": "stream",
867
+ "text": [
868
+ "torch.Size([8, 768])\n"
869
+ ]
870
+ },
871
+ {
872
+ "data": {
873
+ "text/plain": [
874
+ "tensor([[1.0000, 0.8303, 0.2995, 0.3906, 0.2986, 0.3062, 0.3430, 0.3734],\n",
875
+ " [0.8303, 1.0000, 0.3455, 0.4187, 0.3043, 0.3464, 0.4388, 0.3959],\n",
876
+ " [0.2995, 0.3455, 1.0000, 0.6648, 0.4726, 0.4597, 0.3798, 0.3454],\n",
877
+ " [0.3906, 0.4187, 0.6648, 1.0000, 0.5167, 0.5195, 0.3746, 0.4006],\n",
878
+ " [0.2986, 0.3043, 0.4726, 0.5167, 1.0000, 0.7602, 0.3923, 0.3550],\n",
879
+ " [0.3062, 0.3464, 0.4597, 0.5195, 0.7602, 1.0000, 0.4338, 0.3432],\n",
880
+ " [0.3430, 0.4388, 0.3798, 0.3746, 0.3923, 0.4338, 1.0000, 0.6090],\n",
881
+ " [0.3734, 0.3959, 0.3454, 0.4006, 0.3550, 0.3432, 0.6090, 1.0000]])"
882
+ ]
883
+ },
884
+ "execution_count": 12,
885
+ "metadata": {},
886
+ "output_type": "execute_result"
887
+ }
888
+ ],
889
+ "source": [
890
+ "english_sentences = [\"Berlin is the capital of Germany\", \"Berlin is a large city in Germany\",\n",
891
+ " \"Tensorflow can be used for deep learning\", \"Pytorch, developed by Facebook AI, is a deep learning framework\",\n",
892
+ " \"Is Scipy or numpy better?\", \"Which is faster: scipy or pandas?\",\n",
893
+ " \"Cats can live for quite a long time\", \"Cats are humans best friend\"]\n",
894
+ "\n",
895
+ "encoded_input = tokenizer(english_sentences, return_tensors=\"pt\", padding=True)\n",
896
+ "\n",
897
+ "with torch.no_grad():\n",
898
+ " model_output = t5(**encoded_input)\n",
899
+ " \n",
900
+ " # Perform pooling\n",
901
+ " hf_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])\n",
902
+ "\n",
903
+ " # Apply linear layer\n",
904
+ " hf_embeddings = linear(hf_embeddings)\n",
905
+ " \n",
906
+ " print(hf_embeddings.shape)\n",
907
+ "\n",
908
+ " # Normalize embeddings\n",
909
+ " hf_embeddings = F.normalize(hf_embeddings, p=2, dim=1)\n",
910
+ "\n",
911
+ "# Cos\n",
912
+ "util.dot_score(hf_embeddings, hf_embeddings)"
913
+ ]
914
+ },
915
+ {
916
+ "cell_type": "code",
917
+ "execution_count": 13,
918
+ "id": "677a8bab",
919
+ "metadata": {},
920
+ "outputs": [
921
+ {
922
+ "name": "stderr",
923
+ "output_type": "stream",
924
+ "text": [
925
+ "2022-01-31 23:13:39.702310: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)\n",
926
+ "2022-01-31 23:13:41.448337: I tensorflow/compiler/xla/service/service.cc:171] XLA service 0x7f41641cf460 initialized for platform Host (this does not guarantee that XLA will be used). Devices:\n",
927
+ "2022-01-31 23:13:41.448385: I tensorflow/compiler/xla/service/service.cc:179] StreamExecutor device (0): Host, Default Version\n",
928
+ "2022-01-31 23:13:44.375222: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:210] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.\n",
929
+ "2022-01-31 23:14:17.816928: I tensorflow/compiler/jit/xla_compilation_cache.cc:363] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.\n",
930
+ "2022-01-31 23:14:17.866550: W tensorflow/core/framework/cpu_allocator_impl.cc:80] Allocation of 3089104896 exceeds 10% of free system memory.\n"
931
+ ]
932
+ },
933
+ {
934
+ "name": "stdout",
935
+ "output_type": "stream",
936
+ "text": [
937
+ "(8, 768)\n"
938
+ ]
939
+ },
940
+ {
941
+ "data": {
942
+ "text/plain": [
943
+ "tensor([[1.0000, 0.8303, 0.2996, 0.3908, 0.2984, 0.3062, 0.3428, 0.3735],\n",
944
+ " [0.8303, 1.0000, 0.3453, 0.4187, 0.3044, 0.3462, 0.4387, 0.3961],\n",
945
+ " [0.2996, 0.3453, 1.0000, 0.6643, 0.4724, 0.4596, 0.3803, 0.3454],\n",
946
+ " [0.3908, 0.4187, 0.6643, 1.0000, 0.5169, 0.5196, 0.3744, 0.4003],\n",
947
+ " [0.2984, 0.3044, 0.4724, 0.5169, 1.0000, 0.7603, 0.3920, 0.3550],\n",
948
+ " [0.3062, 0.3462, 0.4596, 0.5196, 0.7603, 1.0000, 0.4333, 0.3427],\n",
949
+ " [0.3428, 0.4387, 0.3803, 0.3744, 0.3920, 0.4333, 1.0000, 0.6087],\n",
950
+ " [0.3735, 0.3961, 0.3454, 0.4003, 0.3550, 0.3427, 0.6087, 1.0000]])"
951
+ ]
952
+ },
953
+ "execution_count": 13,
954
+ "metadata": {},
955
+ "output_type": "execute_result"
956
+ }
957
+ ],
958
+ "source": [
959
+ "# Test the models - Original embeddings\n",
960
+ "english_embeds = encoder(english_sentences)[0].numpy()\n",
961
+ "print(english_embeds.shape)\n",
962
+ "util.dot_score(english_embeds, english_embeds)"
963
+ ]
964
+ },
965
+ {
966
+ "cell_type": "code",
967
+ "execution_count": 14,
968
+ "id": "34b44ef7",
969
+ "metadata": {},
970
+ "outputs": [],
971
+ "source": [
972
+ "folder = f'models/gtr-t5-{model_size_hf}'\n",
973
+ "t5.save_pretrained(folder)\n",
974
+ "tokenizer.save_pretrained(folder)\n",
975
+ "os.makedirs(os.path.join(folder, '2_Dense'), exist_ok=True)\n",
976
+ "\n",
977
+ "\n",
978
+ "dense = sentence_transformers.models.Dense(linear.in_features, linear.out_features, \n",
979
+ " bias=False, activation_function=torch.nn.Identity())\n",
980
+ "dense.linear = linear\n",
981
+ "dense.save(os.path.join(folder, '2_Dense'))\n"
982
+ ]
983
+ },
984
+ {
985
+ "cell_type": "markdown",
986
+ "id": "8f6e006b",
987
+ "metadata": {},
988
+ "source": [
989
+ "# FP16 experiment"
990
+ ]
991
+ },
992
+ {
993
+ "cell_type": "code",
994
+ "execution_count": null,
995
+ "id": "38b1b35e",
996
+ "metadata": {},
997
+ "outputs": [],
998
+ "source": [
999
+ "#FP16 experiment\n",
1000
+ "#t5 = T5EncoderModel.from_pretrained('models/gtr-t5-base')\n",
1001
+ "#t5.half()\n",
1002
+ "#t5.save_pretrained('models/gtr-t5-base-fp16')"
1003
+ ]
1004
+ }
1005
+ ],
1006
+ "metadata": {
1007
+ "kernelspec": {
1008
+ "display_name": "Python 3 (ipykernel)",
1009
+ "language": "python",
1010
+ "name": "python3"
1011
+ },
1012
+ "language_info": {
1013
+ "codemirror_mode": {
1014
+ "name": "ipython",
1015
+ "version": 3
1016
+ },
1017
+ "file_extension": ".py",
1018
+ "mimetype": "text/x-python",
1019
+ "name": "python",
1020
+ "nbconvert_exporter": "python",
1021
+ "pygments_lexer": "ipython3",
1022
+ "version": "3.8.8"
1023
+ }
1024
+ },
1025
+ "nbformat": 4,
1026
+ "nbformat_minor": 5
1027
+ }
convert_to_fp16.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from transformers import T5EncoderModel
3
+
4
+ in_path = sys.argv[1]
5
+ out_path = sys.argv[2]
6
+
7
+ model = T5EncoderModel.from_pretrained(in_path)
8
+ model.half()
9
+ model.save_pretrained(out_path)