nreimers commited on
Commit
86e4ec2
1 Parent(s): 82c075f
1_Pooling/config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
1
+ {
2
+ "word_embedding_dimension": 768,
3
+ "pooling_mode_cls_token": false,
4
+ "pooling_mode_mean_tokens": true,
5
+ "pooling_mode_max_tokens": false,
6
+ "pooling_mode_mean_sqrt_len_tokens": false
7
+ }
2_Dense/config.json ADDED
@@ -0,0 +1 @@
 
1
+ {"in_features": 768, "out_features": 768, "bias": false, "activation_function": "torch.nn.modules.linear.Identity"}
2_Dense/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9319f42e32d06c3e599b0d0d2aeb23bdeacfe71d019238d86d6413a778be8c1d
3
+ size 2360171
README.md ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ pipeline_tag: sentence-similarity
3
+ language: en
4
+ license: apache-2.0
5
+ tags:
6
+ - sentence-transformers
7
+ - feature-extraction
8
+ - sentence-similarity
9
+ - transformers
10
+ ---
11
+
12
+ # sentence-transformers/sentence-t5-base
13
+
14
+ This is a [sentence-transformers](https://www.SBERT.net) model: It maps sentences & paragraphs to a 768 dimensional dense vector space. The model works well for sentence similarity tasks, but doesn't perform that well for semantic search tasks.
15
+
16
+ This model was converted from the Tensorflow model [st5-base-1](https://tfhub.dev/google/sentence-t5/st5-base/1) to PyTorch. When using this model, have a look at the publication: [Sentence-T5: Scalable sentence encoders from pre-trained text-to-text models](https://arxiv.org/abs/2108.08877). The tfhub model and this PyTorch model can produce slightly different embeddings, however, when run on the same benchmarks, they produce identical results.
17
+
18
+ The model uses only the encoder from a T5-base model. The weights are stored in FP16.
19
+
20
+
21
+ ## Usage (Sentence-Transformers)
22
+
23
+ Using this model becomes easy when you have [sentence-transformers](https://www.SBERT.net) installed:
24
+
25
+ ```
26
+ pip install -U sentence-transformers
27
+ ```
28
+
29
+ Then you can use the model like this:
30
+
31
+ ```python
32
+ from sentence_transformers import SentenceTransformer
33
+ sentences = ["This is an example sentence", "Each sentence is converted"]
34
+
35
+ model = SentenceTransformer('sentence-transformers/sentence-t5-base')
36
+ embeddings = model.encode(sentences)
37
+ print(embeddings)
38
+ ```
39
+
40
+ The model requires sentence-transformers version 2.2.0 or newer.
41
+
42
+ ## Evaluation Results
43
+
44
+ For an automated evaluation of this model, see the *Sentence Embeddings Benchmark*: [https://seb.sbert.net](https://seb.sbert.net?model_name=sentence-transformers/sentence-t5-base)
45
+
46
+
47
+
48
+ ## Citing & Authors
49
+
50
+ If you find this model helpful, please cite the respective publication:
51
+ [Sentence-T5: Scalable sentence encoders from pre-trained text-to-text models](https://arxiv.org/abs/2108.08877)
config.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "models/sentence-t5-base",
3
+ "architectures": [
4
+ "T5EncoderModel"
5
+ ],
6
+ "d_ff": 3072,
7
+ "d_kv": 64,
8
+ "d_model": 768,
9
+ "decoder_start_token_id": 0,
10
+ "dropout_rate": 0.1,
11
+ "eos_token_id": 1,
12
+ "feed_forward_proj": "relu",
13
+ "initializer_factor": 1.0,
14
+ "is_encoder_decoder": true,
15
+ "layer_norm_epsilon": 1e-06,
16
+ "model_type": "t5",
17
+ "n_positions": 512,
18
+ "num_decoder_layers": 12,
19
+ "num_heads": 12,
20
+ "num_layers": 12,
21
+ "output_past": true,
22
+ "pad_token_id": 0,
23
+ "relative_attention_num_buckets": 32,
24
+ "task_specific_params": {
25
+ "summarization": {
26
+ "early_stopping": true,
27
+ "length_penalty": 2.0,
28
+ "max_length": 200,
29
+ "min_length": 30,
30
+ "no_repeat_ngram_size": 3,
31
+ "num_beams": 4,
32
+ "prefix": "summarize: "
33
+ },
34
+ "translation_en_to_de": {
35
+ "early_stopping": true,
36
+ "max_length": 300,
37
+ "num_beams": 4,
38
+ "prefix": "translate English to German: "
39
+ },
40
+ "translation_en_to_fr": {
41
+ "early_stopping": true,
42
+ "max_length": 300,
43
+ "num_beams": 4,
44
+ "prefix": "translate English to French: "
45
+ },
46
+ "translation_en_to_ro": {
47
+ "early_stopping": true,
48
+ "max_length": 300,
49
+ "num_beams": 4,
50
+ "prefix": "translate English to Romanian: "
51
+ }
52
+ },
53
+ "torch_dtype": "float16",
54
+ "transformers_version": "4.11.3",
55
+ "use_cache": true,
56
+ "vocab_size": 32128
57
+ }
config_sentence_transformers.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
1
+ {
2
+ "__version__": {
3
+ "sentence_transformers": "2.2.0",
4
+ "transformers": "4.7.0",
5
+ "pytorch": "1.9.0+cu102"
6
+ }
7
+ }
convert.ipynb ADDED
@@ -0,0 +1,981 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "id": "17bffc12",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from transformers import AutoTokenizer\n",
11
+ "from sentence_transformers import util\n",
12
+ "import os\n",
13
+ "import numpy as np\n",
14
+ "import torch.nn.functional as F\n",
15
+ "from transformers import T5EncoderModel\n",
16
+ "import torch"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 3,
22
+ "id": "160d8ce6",
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "#Mean Pooling - Take attention mask into account for correct averaging\n",
27
+ "def mean_pooling(model_output, attention_mask):\n",
28
+ " token_embeddings = model_output[0] #First element of model_output contains all token embeddings\n",
29
+ " input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()\n",
30
+ " return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)\n"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": 16,
36
+ "id": "2f67f426",
37
+ "metadata": {},
38
+ "outputs": [
39
+ {
40
+ "name": "stderr",
41
+ "output_type": "stream",
42
+ "text": [
43
+ "WARNING:absl:Importing a function (__inference_<lambda>_9720) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.\n",
44
+ "WARNING:absl:Importing a function (__inference_<lambda>_3354) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.\n",
45
+ "WARNING:absl:Importing a function (__inference_<lambda>_6722) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.\n"
46
+ ]
47
+ }
48
+ ],
49
+ "source": [
50
+ "import tensorflow as tf\n",
51
+ "import tensorflow_hub as hub\n",
52
+ "import tensorflow_text as text \n",
53
+ "\n",
54
+ "model_size = \"base\"\n",
55
+ "hub_url = f\"https://tfhub.dev/google/sentence-t5/st5-{model_size}/1\"\n",
56
+ "encoder = hub.load(hub_url)\n",
57
+ "\n",
58
+ "v = encoder.signatures['serving_default'].variables"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": 17,
64
+ "id": "5f4c8d94",
65
+ "metadata": {
66
+ "scrolled": true
67
+ },
68
+ "outputs": [
69
+ {
70
+ "data": {
71
+ "text/plain": [
72
+ "{'encoder__encoder_norm__scale:0': TensorShape([768]),\n",
73
+ " 'encoder__layers_0__attention__key__kernel:0': TensorShape([768, 768]),\n",
74
+ " 'encoder__layers_0__attention__out__kernel:0': TensorShape([768, 768]),\n",
75
+ " 'encoder__layers_0__attention__query__kernel:0': TensorShape([768, 768]),\n",
76
+ " 'encoder__layers_0__attention__value__kernel:0': TensorShape([768, 768]),\n",
77
+ " 'encoder__layers_0__mlp__wi__kernel:0': TensorShape([768, 3072]),\n",
78
+ " 'encoder__layers_0__mlp__wo__kernel:0': TensorShape([3072, 768]),\n",
79
+ " 'encoder__layers_0__pre_attention_layer_norm__scale:0': TensorShape([768]),\n",
80
+ " 'encoder__layers_0__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n",
81
+ " 'encoder__layers_1__attention__key__kernel:0': TensorShape([768, 768]),\n",
82
+ " 'encoder__layers_1__attention__out__kernel:0': TensorShape([768, 768]),\n",
83
+ " 'encoder__layers_1__attention__query__kernel:0': TensorShape([768, 768]),\n",
84
+ " 'encoder__layers_1__attention__value__kernel:0': TensorShape([768, 768]),\n",
85
+ " 'encoder__layers_1__mlp__wi__kernel:0': TensorShape([768, 3072]),\n",
86
+ " 'encoder__layers_1__mlp__wo__kernel:0': TensorShape([3072, 768]),\n",
87
+ " 'encoder__layers_1__pre_attention_layer_norm__scale:0': TensorShape([768]),\n",
88
+ " 'encoder__layers_1__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n",
89
+ " 'encoder__layers_10__attention__key__kernel:0': TensorShape([768, 768]),\n",
90
+ " 'encoder__layers_10__attention__out__kernel:0': TensorShape([768, 768]),\n",
91
+ " 'encoder__layers_10__attention__query__kernel:0': TensorShape([768, 768]),\n",
92
+ " 'encoder__layers_10__attention__value__kernel:0': TensorShape([768, 768]),\n",
93
+ " 'encoder__layers_10__mlp__wi__kernel:0': TensorShape([768, 3072]),\n",
94
+ " 'encoder__layers_10__mlp__wo__kernel:0': TensorShape([3072, 768]),\n",
95
+ " 'encoder__layers_10__pre_attention_layer_norm__scale:0': TensorShape([768]),\n",
96
+ " 'encoder__layers_10__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n",
97
+ " 'encoder__layers_11__attention__key__kernel:0': TensorShape([768, 768]),\n",
98
+ " 'encoder__layers_11__attention__out__kernel:0': TensorShape([768, 768]),\n",
99
+ " 'encoder__layers_11__attention__query__kernel:0': TensorShape([768, 768]),\n",
100
+ " 'encoder__layers_11__attention__value__kernel:0': TensorShape([768, 768]),\n",
101
+ " 'encoder__layers_11__mlp__wi__kernel:0': TensorShape([768, 3072]),\n",
102
+ " 'encoder__layers_11__mlp__wo__kernel:0': TensorShape([3072, 768]),\n",
103
+ " 'encoder__layers_11__pre_attention_layer_norm__scale:0': TensorShape([768]),\n",
104
+ " 'encoder__layers_11__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n",
105
+ " 'encoder__layers_2__attention__key__kernel:0': TensorShape([768, 768]),\n",
106
+ " 'encoder__layers_2__attention__out__kernel:0': TensorShape([768, 768]),\n",
107
+ " 'encoder__layers_2__attention__query__kernel:0': TensorShape([768, 768]),\n",
108
+ " 'encoder__layers_2__attention__value__kernel:0': TensorShape([768, 768]),\n",
109
+ " 'encoder__layers_2__mlp__wi__kernel:0': TensorShape([768, 3072]),\n",
110
+ " 'encoder__layers_2__mlp__wo__kernel:0': TensorShape([3072, 768]),\n",
111
+ " 'encoder__layers_2__pre_attention_layer_norm__scale:0': TensorShape([768]),\n",
112
+ " 'encoder__layers_2__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n",
113
+ " 'encoder__layers_3__attention__key__kernel:0': TensorShape([768, 768]),\n",
114
+ " 'encoder__layers_3__attention__out__kernel:0': TensorShape([768, 768]),\n",
115
+ " 'encoder__layers_3__attention__query__kernel:0': TensorShape([768, 768]),\n",
116
+ " 'encoder__layers_3__attention__value__kernel:0': TensorShape([768, 768]),\n",
117
+ " 'encoder__layers_3__mlp__wi__kernel:0': TensorShape([768, 3072]),\n",
118
+ " 'encoder__layers_3__mlp__wo__kernel:0': TensorShape([3072, 768]),\n",
119
+ " 'encoder__layers_3__pre_attention_layer_norm__scale:0': TensorShape([768]),\n",
120
+ " 'encoder__layers_3__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n",
121
+ " 'encoder__layers_4__attention__key__kernel:0': TensorShape([768, 768]),\n",
122
+ " 'encoder__layers_4__attention__out__kernel:0': TensorShape([768, 768]),\n",
123
+ " 'encoder__layers_4__attention__query__kernel:0': TensorShape([768, 768]),\n",
124
+ " 'encoder__layers_4__attention__value__kernel:0': TensorShape([768, 768]),\n",
125
+ " 'encoder__layers_4__mlp__wi__kernel:0': TensorShape([768, 3072]),\n",
126
+ " 'encoder__layers_4__mlp__wo__kernel:0': TensorShape([3072, 768]),\n",
127
+ " 'encoder__layers_4__pre_attention_layer_norm__scale:0': TensorShape([768]),\n",
128
+ " 'encoder__layers_4__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n",
129
+ " 'encoder__layers_5__attention__key__kernel:0': TensorShape([768, 768]),\n",
130
+ " 'encoder__layers_5__attention__out__kernel:0': TensorShape([768, 768]),\n",
131
+ " 'encoder__layers_5__attention__query__kernel:0': TensorShape([768, 768]),\n",
132
+ " 'encoder__layers_5__attention__value__kernel:0': TensorShape([768, 768]),\n",
133
+ " 'encoder__layers_5__mlp__wi__kernel:0': TensorShape([768, 3072]),\n",
134
+ " 'encoder__layers_5__mlp__wo__kernel:0': TensorShape([3072, 768]),\n",
135
+ " 'encoder__layers_5__pre_attention_layer_norm__scale:0': TensorShape([768]),\n",
136
+ " 'encoder__layers_5__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n",
137
+ " 'encoder__layers_6__attention__key__kernel:0': TensorShape([768, 768]),\n",
138
+ " 'encoder__layers_6__attention__out__kernel:0': TensorShape([768, 768]),\n",
139
+ " 'encoder__layers_6__attention__query__kernel:0': TensorShape([768, 768]),\n",
140
+ " 'encoder__layers_6__attention__value__kernel:0': TensorShape([768, 768]),\n",
141
+ " 'encoder__layers_6__mlp__wi__kernel:0': TensorShape([768, 3072]),\n",
142
+ " 'encoder__layers_6__mlp__wo__kernel:0': TensorShape([3072, 768]),\n",
143
+ " 'encoder__layers_6__pre_attention_layer_norm__scale:0': TensorShape([768]),\n",
144
+ " 'encoder__layers_6__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n",
145
+ " 'encoder__layers_7__attention__key__kernel:0': TensorShape([768, 768]),\n",
146
+ " 'encoder__layers_7__attention__out__kernel:0': TensorShape([768, 768]),\n",
147
+ " 'encoder__layers_7__attention__query__kernel:0': TensorShape([768, 768]),\n",
148
+ " 'encoder__layers_7__attention__value__kernel:0': TensorShape([768, 768]),\n",
149
+ " 'encoder__layers_7__mlp__wi__kernel:0': TensorShape([768, 3072]),\n",
150
+ " 'encoder__layers_7__mlp__wo__kernel:0': TensorShape([3072, 768]),\n",
151
+ " 'encoder__layers_7__pre_attention_layer_norm__scale:0': TensorShape([768]),\n",
152
+ " 'encoder__layers_7__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n",
153
+ " 'encoder__layers_8__attention__key__kernel:0': TensorShape([768, 768]),\n",
154
+ " 'encoder__layers_8__attention__out__kernel:0': TensorShape([768, 768]),\n",
155
+ " 'encoder__layers_8__attention__query__kernel:0': TensorShape([768, 768]),\n",
156
+ " 'encoder__layers_8__attention__value__kernel:0': TensorShape([768, 768]),\n",
157
+ " 'encoder__layers_8__mlp__wi__kernel:0': TensorShape([768, 3072]),\n",
158
+ " 'encoder__layers_8__mlp__wo__kernel:0': TensorShape([3072, 768]),\n",
159
+ " 'encoder__layers_8__pre_attention_layer_norm__scale:0': TensorShape([768]),\n",
160
+ " 'encoder__layers_8__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n",
161
+ " 'encoder__layers_9__attention__key__kernel:0': TensorShape([768, 768]),\n",
162
+ " 'encoder__layers_9__attention__out__kernel:0': TensorShape([768, 768]),\n",
163
+ " 'encoder__layers_9__attention__query__kernel:0': TensorShape([768, 768]),\n",
164
+ " 'encoder__layers_9__attention__value__kernel:0': TensorShape([768, 768]),\n",
165
+ " 'encoder__layers_9__mlp__wi__kernel:0': TensorShape([768, 3072]),\n",
166
+ " 'encoder__layers_9__mlp__wo__kernel:0': TensorShape([3072, 768]),\n",
167
+ " 'encoder__layers_9__pre_attention_layer_norm__scale:0': TensorShape([768]),\n",
168
+ " 'encoder__layers_9__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n",
169
+ " 'encoder__relpos_bias__rel_embedding:0': TensorShape([12, 32]),\n",
170
+ " 'projection_layer__kernel:0': TensorShape([768, 768]),\n",
171
+ " 'token_embedder__embedding:0': TensorShape([32128, 768])}"
172
+ ]
173
+ },
174
+ "execution_count": 17,
175
+ "metadata": {},
176
+ "output_type": "execute_result"
177
+ }
178
+ ],
179
+ "source": [
180
+ "tf_name_weight = {var.name: var for var in v}\n",
181
+ "tf_name_shape = {var.name: var.shape for var in v}\n",
182
+ "tf_name_shape"
183
+ ]
184
+ },
185
+ {
186
+ "cell_type": "code",
187
+ "execution_count": 6,
188
+ "id": "1d3c9865",
189
+ "metadata": {},
190
+ "outputs": [],
191
+ "source": [
192
+ "def convert_name(name):\n",
193
+ " fct_map = {\n",
194
+ " \"attention\": \"SelfAttention\",\n",
195
+ " \"mlp\": \"DenseReluDense\",\n",
196
+ " \"pre_attention_layer_norm\": \"layer_norm\",\n",
197
+ " \"pre_mlp_layer_norm\": \"layer_norm\",\n",
198
+ " }\n",
199
+ " name_map = {\n",
200
+ " 'key': 'k',\n",
201
+ " 'out': 'o',\n",
202
+ " 'query': 'q',\n",
203
+ " 'value': 'v'\n",
204
+ " }\n",
205
+ " \n",
206
+ " fixed_names = {\n",
207
+ " \"token_embedder__embedding:0\": \"shared.weight\",\n",
208
+ " \"encoder__encoder_norm__scale:0\": \"encoder.final_layer_norm.weight\",\n",
209
+ " \"encoder__relpos_bias__rel_embedding:0\": \"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight\"\n",
210
+ " }\n",
211
+ " \n",
212
+ " if name in fixed_names:\n",
213
+ " return fixed_names[name]\n",
214
+ " \n",
215
+ " out = \"\"\n",
216
+ " splits = name.split(\"__\")\n",
217
+ " layer = splits[1].split(\"_\")[1]\n",
218
+ " fct = fct_map.get(splits[2], splits[2])\n",
219
+ " if 'layer_norm' in name:\n",
220
+ " sublayer = \"1\" if \"pre_mlp_layer_norm\" in name else \"0\" #Not sure on the right setting here\n",
221
+ " #sublayer = \"0\" if \"pre_mlp_layer_norm\" in name else \"1\" #Not sure on the right setting here\n",
222
+ " out = f\"encoder.block.{layer}.layer.{sublayer}.{fct}.weight\"\n",
223
+ " elif name.startswith(\"encoder__layers_\"):\n",
224
+ " sublayer = \"0\" if fct == \"SelfAttention\" else \"1\"\n",
225
+ " name = name_map.get(splits[3], splits[3])\n",
226
+ " out = f\"encoder.block.{layer}.layer.{sublayer}.{fct}.{name}.weight\"\n",
227
+ " \n",
228
+ " return out"
229
+ ]
230
+ },
231
+ {
232
+ "cell_type": "code",
233
+ "execution_count": 7,
234
+ "id": "1ca9590e",
235
+ "metadata": {},
236
+ "outputs": [],
237
+ "source": [
238
+ "def equal_shapes(shape1, shape2):\n",
239
+ " if len(shape1) != len(shape2):\n",
240
+ " return False\n",
241
+ " \n",
242
+ " for idx in range(len(shape1)):\n",
243
+ " if shape1[idx] != shape2[idx]:\n",
244
+ " return False\n",
245
+ " \n",
246
+ " return True"
247
+ ]
248
+ },
249
+ {
250
+ "cell_type": "code",
251
+ "execution_count": 8,
252
+ "id": "6d223b07",
253
+ "metadata": {
254
+ "scrolled": true
255
+ },
256
+ "outputs": [
257
+ {
258
+ "name": "stderr",
259
+ "output_type": "stream",
260
+ "text": [
261
+ "Some weights of T5EncoderModel were not initialized from the model checkpoint at t5-11b and are newly initialized: ['encoder.embed_tokens.weight']\n",
262
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
263
+ ]
264
+ },
265
+ {
266
+ "data": {
267
+ "text/plain": [
268
+ "{'shared.weight': torch.Size([32128, 1024]),\n",
269
+ " 'encoder.embed_tokens.weight': torch.Size([32128, 1024]),\n",
270
+ " 'encoder.block.0.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
271
+ " 'encoder.block.0.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
272
+ " 'encoder.block.0.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
273
+ " 'encoder.block.0.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
274
+ " 'encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight': torch.Size([32, 128]),\n",
275
+ " 'encoder.block.0.layer.0.layer_norm.weight': torch.Size([1024]),\n",
276
+ " 'encoder.block.0.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
277
+ " 'encoder.block.0.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
278
+ " 'encoder.block.0.layer.1.layer_norm.weight': torch.Size([1024]),\n",
279
+ " 'encoder.block.1.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
280
+ " 'encoder.block.1.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
281
+ " 'encoder.block.1.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
282
+ " 'encoder.block.1.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
283
+ " 'encoder.block.1.layer.0.layer_norm.weight': torch.Size([1024]),\n",
284
+ " 'encoder.block.1.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
285
+ " 'encoder.block.1.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
286
+ " 'encoder.block.1.layer.1.layer_norm.weight': torch.Size([1024]),\n",
287
+ " 'encoder.block.2.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
288
+ " 'encoder.block.2.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
289
+ " 'encoder.block.2.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
290
+ " 'encoder.block.2.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
291
+ " 'encoder.block.2.layer.0.layer_norm.weight': torch.Size([1024]),\n",
292
+ " 'encoder.block.2.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
293
+ " 'encoder.block.2.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
294
+ " 'encoder.block.2.layer.1.layer_norm.weight': torch.Size([1024]),\n",
295
+ " 'encoder.block.3.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
296
+ " 'encoder.block.3.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
297
+ " 'encoder.block.3.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
298
+ " 'encoder.block.3.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
299
+ " 'encoder.block.3.layer.0.layer_norm.weight': torch.Size([1024]),\n",
300
+ " 'encoder.block.3.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
301
+ " 'encoder.block.3.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
302
+ " 'encoder.block.3.layer.1.layer_norm.weight': torch.Size([1024]),\n",
303
+ " 'encoder.block.4.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
304
+ " 'encoder.block.4.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
305
+ " 'encoder.block.4.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
306
+ " 'encoder.block.4.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
307
+ " 'encoder.block.4.layer.0.layer_norm.weight': torch.Size([1024]),\n",
308
+ " 'encoder.block.4.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
309
+ " 'encoder.block.4.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
310
+ " 'encoder.block.4.layer.1.layer_norm.weight': torch.Size([1024]),\n",
311
+ " 'encoder.block.5.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
312
+ " 'encoder.block.5.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
313
+ " 'encoder.block.5.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
314
+ " 'encoder.block.5.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
315
+ " 'encoder.block.5.layer.0.layer_norm.weight': torch.Size([1024]),\n",
316
+ " 'encoder.block.5.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
317
+ " 'encoder.block.5.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
318
+ " 'encoder.block.5.layer.1.layer_norm.weight': torch.Size([1024]),\n",
319
+ " 'encoder.block.6.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
320
+ " 'encoder.block.6.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
321
+ " 'encoder.block.6.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
322
+ " 'encoder.block.6.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
323
+ " 'encoder.block.6.layer.0.layer_norm.weight': torch.Size([1024]),\n",
324
+ " 'encoder.block.6.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
325
+ " 'encoder.block.6.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
326
+ " 'encoder.block.6.layer.1.layer_norm.weight': torch.Size([1024]),\n",
327
+ " 'encoder.block.7.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
328
+ " 'encoder.block.7.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
329
+ " 'encoder.block.7.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
330
+ " 'encoder.block.7.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
331
+ " 'encoder.block.7.layer.0.layer_norm.weight': torch.Size([1024]),\n",
332
+ " 'encoder.block.7.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
333
+ " 'encoder.block.7.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
334
+ " 'encoder.block.7.layer.1.layer_norm.weight': torch.Size([1024]),\n",
335
+ " 'encoder.block.8.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
336
+ " 'encoder.block.8.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
337
+ " 'encoder.block.8.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
338
+ " 'encoder.block.8.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
339
+ " 'encoder.block.8.layer.0.layer_norm.weight': torch.Size([1024]),\n",
340
+ " 'encoder.block.8.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
341
+ " 'encoder.block.8.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
342
+ " 'encoder.block.8.layer.1.layer_norm.weight': torch.Size([1024]),\n",
343
+ " 'encoder.block.9.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
344
+ " 'encoder.block.9.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
345
+ " 'encoder.block.9.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
346
+ " 'encoder.block.9.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
347
+ " 'encoder.block.9.layer.0.layer_norm.weight': torch.Size([1024]),\n",
348
+ " 'encoder.block.9.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
349
+ " 'encoder.block.9.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
350
+ " 'encoder.block.9.layer.1.layer_norm.weight': torch.Size([1024]),\n",
351
+ " 'encoder.block.10.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
352
+ " 'encoder.block.10.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
353
+ " 'encoder.block.10.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
354
+ " 'encoder.block.10.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
355
+ " 'encoder.block.10.layer.0.layer_norm.weight': torch.Size([1024]),\n",
356
+ " 'encoder.block.10.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
357
+ " 'encoder.block.10.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
358
+ " 'encoder.block.10.layer.1.layer_norm.weight': torch.Size([1024]),\n",
359
+ " 'encoder.block.11.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
360
+ " 'encoder.block.11.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
361
+ " 'encoder.block.11.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
362
+ " 'encoder.block.11.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
363
+ " 'encoder.block.11.layer.0.layer_norm.weight': torch.Size([1024]),\n",
364
+ " 'encoder.block.11.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
365
+ " 'encoder.block.11.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
366
+ " 'encoder.block.11.layer.1.layer_norm.weight': torch.Size([1024]),\n",
367
+ " 'encoder.block.12.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
368
+ " 'encoder.block.12.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
369
+ " 'encoder.block.12.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
370
+ " 'encoder.block.12.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
371
+ " 'encoder.block.12.layer.0.layer_norm.weight': torch.Size([1024]),\n",
372
+ " 'encoder.block.12.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
373
+ " 'encoder.block.12.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
374
+ " 'encoder.block.12.layer.1.layer_norm.weight': torch.Size([1024]),\n",
375
+ " 'encoder.block.13.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
376
+ " 'encoder.block.13.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
377
+ " 'encoder.block.13.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
378
+ " 'encoder.block.13.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
379
+ " 'encoder.block.13.layer.0.layer_norm.weight': torch.Size([1024]),\n",
380
+ " 'encoder.block.13.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
381
+ " 'encoder.block.13.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
382
+ " 'encoder.block.13.layer.1.layer_norm.weight': torch.Size([1024]),\n",
383
+ " 'encoder.block.14.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
384
+ " 'encoder.block.14.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
385
+ " 'encoder.block.14.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
386
+ " 'encoder.block.14.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
387
+ " 'encoder.block.14.layer.0.layer_norm.weight': torch.Size([1024]),\n",
388
+ " 'encoder.block.14.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
389
+ " 'encoder.block.14.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
390
+ " 'encoder.block.14.layer.1.layer_norm.weight': torch.Size([1024]),\n",
391
+ " 'encoder.block.15.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
392
+ " 'encoder.block.15.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
393
+ " 'encoder.block.15.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
394
+ " 'encoder.block.15.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
395
+ " 'encoder.block.15.layer.0.layer_norm.weight': torch.Size([1024]),\n",
396
+ " 'encoder.block.15.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
397
+ " 'encoder.block.15.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
398
+ " 'encoder.block.15.layer.1.layer_norm.weight': torch.Size([1024]),\n",
399
+ " 'encoder.block.16.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
400
+ " 'encoder.block.16.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
401
+ " 'encoder.block.16.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
402
+ " 'encoder.block.16.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
403
+ " 'encoder.block.16.layer.0.layer_norm.weight': torch.Size([1024]),\n",
404
+ " 'encoder.block.16.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
405
+ " 'encoder.block.16.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
406
+ " 'encoder.block.16.layer.1.layer_norm.weight': torch.Size([1024]),\n",
407
+ " 'encoder.block.17.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
408
+ " 'encoder.block.17.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
409
+ " 'encoder.block.17.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
410
+ " 'encoder.block.17.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
411
+ " 'encoder.block.17.layer.0.layer_norm.weight': torch.Size([1024]),\n",
412
+ " 'encoder.block.17.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
413
+ " 'encoder.block.17.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
414
+ " 'encoder.block.17.layer.1.layer_norm.weight': torch.Size([1024]),\n",
415
+ " 'encoder.block.18.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
416
+ " 'encoder.block.18.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
417
+ " 'encoder.block.18.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
418
+ " 'encoder.block.18.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
419
+ " 'encoder.block.18.layer.0.layer_norm.weight': torch.Size([1024]),\n",
420
+ " 'encoder.block.18.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
421
+ " 'encoder.block.18.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
422
+ " 'encoder.block.18.layer.1.layer_norm.weight': torch.Size([1024]),\n",
423
+ " 'encoder.block.19.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
424
+ " 'encoder.block.19.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
425
+ " 'encoder.block.19.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
426
+ " 'encoder.block.19.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
427
+ " 'encoder.block.19.layer.0.layer_norm.weight': torch.Size([1024]),\n",
428
+ " 'encoder.block.19.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
429
+ " 'encoder.block.19.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
430
+ " 'encoder.block.19.layer.1.layer_norm.weight': torch.Size([1024]),\n",
431
+ " 'encoder.block.20.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
432
+ " 'encoder.block.20.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
433
+ " 'encoder.block.20.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
434
+ " 'encoder.block.20.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
435
+ " 'encoder.block.20.layer.0.layer_norm.weight': torch.Size([1024]),\n",
436
+ " 'encoder.block.20.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
437
+ " 'encoder.block.20.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
438
+ " 'encoder.block.20.layer.1.layer_norm.weight': torch.Size([1024]),\n",
439
+ " 'encoder.block.21.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
440
+ " 'encoder.block.21.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
441
+ " 'encoder.block.21.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
442
+ " 'encoder.block.21.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
443
+ " 'encoder.block.21.layer.0.layer_norm.weight': torch.Size([1024]),\n",
444
+ " 'encoder.block.21.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
445
+ " 'encoder.block.21.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
446
+ " 'encoder.block.21.layer.1.layer_norm.weight': torch.Size([1024]),\n",
447
+ " 'encoder.block.22.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
448
+ " 'encoder.block.22.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
449
+ " 'encoder.block.22.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
450
+ " 'encoder.block.22.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
451
+ " 'encoder.block.22.layer.0.layer_norm.weight': torch.Size([1024]),\n",
452
+ " 'encoder.block.22.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
453
+ " 'encoder.block.22.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
454
+ " 'encoder.block.22.layer.1.layer_norm.weight': torch.Size([1024]),\n",
455
+ " 'encoder.block.23.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n",
456
+ " 'encoder.block.23.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n",
457
+ " 'encoder.block.23.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n",
458
+ " 'encoder.block.23.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n",
459
+ " 'encoder.block.23.layer.0.layer_norm.weight': torch.Size([1024]),\n",
460
+ " 'encoder.block.23.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n",
461
+ " 'encoder.block.23.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n",
462
+ " 'encoder.block.23.layer.1.layer_norm.weight': torch.Size([1024]),\n",
463
+ " 'encoder.final_layer_norm.weight': torch.Size([1024])}"
464
+ ]
465
+ },
466
+ "execution_count": 8,
467
+ "metadata": {},
468
+ "output_type": "execute_result"
469
+ }
470
+ ],
471
+ "source": [
472
+ "tokenizer = AutoTokenizer.from_pretrained(f\"t5-{model_size}\")\n",
473
+ "T5EncoderModel._keys_to_ignore_on_load_unexpected = [\"decoder.*\"]\n",
474
+ "t5 = T5EncoderModel.from_pretrained(f\"t5-{model_size}\") \n",
475
+ "pt_name_shape = {name: weight.shape for name, weight in t5.state_dict().items()}\n",
476
+ "pt_name_shape"
477
+ ]
478
+ },
479
+ {
480
+ "cell_type": "code",
481
+ "execution_count": 9,
482
+ "id": "ced52a5f",
483
+ "metadata": {},
484
+ "outputs": [
485
+ {
486
+ "name": "stdout",
487
+ "output_type": "stream",
488
+ "text": [
489
+ "Remaining weights: {'encoder.embed_tokens.weight'}\n"
490
+ ]
491
+ }
492
+ ],
493
+ "source": [
494
+ "def need_transpose(name, transpose_names=['DenseReluDense', 'relative_attention_bias']):\n",
495
+ " #HF function: https://github.com/huggingface/transformers/blob/c962c2adbff678ae6d2e98378bed5b8d1a9831d9/src/transformers/models/t5/modeling_t5.py#L161\n",
496
+ " return name != \"shared.weight\"\n",
497
+ "\n",
498
+ "\n",
499
+ "#Additional dense layer on top\n",
500
+ "names_to_ignore = {\"projection_layer__kernel:0\"}\n",
501
+ "\n",
502
+ "#Check we used all names\n",
503
+ "pt_all_names = set(t5.state_dict().keys())\n",
504
+ "\n",
505
+ "for var in v:\n",
506
+ " name = var.name\n",
507
+ " if name in names_to_ignore:\n",
508
+ " continue\n",
509
+ " \n",
510
+ " pt_name = convert_name(name)\n",
511
+ " if pt_name not in pt_all_names:\n",
512
+ " print(\"Name not found:\", name, \"=>\", pt_name)\n",
513
+ " else:\n",
514
+ " pt_all_names.remove(pt_name)\n",
515
+ " tf_shape = tf_name_shape[name].as_list()\n",
516
+ " pt_shape = list(pt_name_shape[pt_name])\n",
517
+ " \n",
518
+ " if need_transpose(pt_name):\n",
519
+ " pt_shape = list(reversed(pt_shape))\n",
520
+ " \n",
521
+ " if not equal_shapes(tf_shape, pt_shape):\n",
522
+ " print(\"Different shape:\", name, tf_shape, pt_name, pt_shape )\n",
523
+ " \n",
524
+ "print(\"Remaining weights:\", pt_all_names)\n",
525
+ "#All layers match"
526
+ ]
527
+ },
528
+ {
529
+ "cell_type": "code",
530
+ "execution_count": 10,
531
+ "id": "1190984f",
532
+ "metadata": {},
533
+ "outputs": [
534
+ {
535
+ "name": "stdout",
536
+ "output_type": "stream",
537
+ "text": [
538
+ "encoder__encoder_norm__scale:0 ((1024,)) =transpose=> encoder.final_layer_norm.weight torch.Size([1024])\n",
539
+ "encoder__layers_0__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.0.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
540
+ "encoder__layers_0__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.0.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
541
+ "encoder__layers_0__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.0.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
542
+ "encoder__layers_0__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.0.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
543
+ "encoder__layers_0__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.0.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
544
+ "encoder__layers_0__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.0.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
545
+ "encoder__layers_0__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.0.layer.0.layer_norm.weight torch.Size([1024])\n",
546
+ "encoder__layers_0__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.0.layer.1.layer_norm.weight torch.Size([1024])\n",
547
+ "encoder__layers_1__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.1.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
548
+ "encoder__layers_1__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.1.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
549
+ "encoder__layers_1__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.1.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
550
+ "encoder__layers_1__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.1.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
551
+ "encoder__layers_1__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.1.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
552
+ "encoder__layers_1__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.1.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
553
+ "encoder__layers_1__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.1.layer.0.layer_norm.weight torch.Size([1024])\n",
554
+ "encoder__layers_1__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.1.layer.1.layer_norm.weight torch.Size([1024])\n",
555
+ "encoder__layers_10__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.10.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
556
+ "encoder__layers_10__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.10.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
557
+ "encoder__layers_10__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.10.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
558
+ "encoder__layers_10__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.10.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
559
+ "encoder__layers_10__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.10.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
560
+ "encoder__layers_10__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.10.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
561
+ "encoder__layers_10__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.10.layer.0.layer_norm.weight torch.Size([1024])\n",
562
+ "encoder__layers_10__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.10.layer.1.layer_norm.weight torch.Size([1024])\n",
563
+ "encoder__layers_11__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.11.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
564
+ "encoder__layers_11__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.11.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
565
+ "encoder__layers_11__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.11.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
566
+ "encoder__layers_11__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.11.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
567
+ "encoder__layers_11__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.11.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
568
+ "encoder__layers_11__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.11.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
569
+ "encoder__layers_11__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.11.layer.0.layer_norm.weight torch.Size([1024])\n",
570
+ "encoder__layers_11__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.11.layer.1.layer_norm.weight torch.Size([1024])\n",
571
+ "encoder__layers_12__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.12.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
572
+ "encoder__layers_12__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.12.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
573
+ "encoder__layers_12__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.12.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
574
+ "encoder__layers_12__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.12.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
575
+ "encoder__layers_12__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.12.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
576
+ "encoder__layers_12__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.12.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
577
+ "encoder__layers_12__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.12.layer.0.layer_norm.weight torch.Size([1024])\n",
578
+ "encoder__layers_12__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.12.layer.1.layer_norm.weight torch.Size([1024])\n",
579
+ "encoder__layers_13__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.13.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
580
+ "encoder__layers_13__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.13.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
581
+ "encoder__layers_13__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.13.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
582
+ "encoder__layers_13__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.13.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
583
+ "encoder__layers_13__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.13.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
584
+ "encoder__layers_13__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.13.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
585
+ "encoder__layers_13__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.13.layer.0.layer_norm.weight torch.Size([1024])\n",
586
+ "encoder__layers_13__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.13.layer.1.layer_norm.weight torch.Size([1024])\n",
587
+ "encoder__layers_14__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.14.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
588
+ "encoder__layers_14__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.14.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
589
+ "encoder__layers_14__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.14.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
590
+ "encoder__layers_14__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.14.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
591
+ "encoder__layers_14__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.14.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
592
+ "encoder__layers_14__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.14.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
593
+ "encoder__layers_14__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.14.layer.0.layer_norm.weight torch.Size([1024])\n",
594
+ "encoder__layers_14__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.14.layer.1.layer_norm.weight torch.Size([1024])\n",
595
+ "encoder__layers_15__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.15.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n"
596
+ ]
597
+ },
598
+ {
599
+ "name": "stdout",
600
+ "output_type": "stream",
601
+ "text": [
602
+ "encoder__layers_15__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.15.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
603
+ "encoder__layers_15__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.15.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
604
+ "encoder__layers_15__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.15.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
605
+ "encoder__layers_15__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.15.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
606
+ "encoder__layers_15__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.15.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
607
+ "encoder__layers_15__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.15.layer.0.layer_norm.weight torch.Size([1024])\n",
608
+ "encoder__layers_15__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.15.layer.1.layer_norm.weight torch.Size([1024])\n",
609
+ "encoder__layers_16__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.16.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
610
+ "encoder__layers_16__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.16.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
611
+ "encoder__layers_16__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.16.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
612
+ "encoder__layers_16__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.16.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
613
+ "encoder__layers_16__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.16.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
614
+ "encoder__layers_16__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.16.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
615
+ "encoder__layers_16__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.16.layer.0.layer_norm.weight torch.Size([1024])\n",
616
+ "encoder__layers_16__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.16.layer.1.layer_norm.weight torch.Size([1024])\n",
617
+ "encoder__layers_17__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.17.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
618
+ "encoder__layers_17__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.17.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
619
+ "encoder__layers_17__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.17.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
620
+ "encoder__layers_17__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.17.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
621
+ "encoder__layers_17__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.17.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
622
+ "encoder__layers_17__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.17.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
623
+ "encoder__layers_17__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.17.layer.0.layer_norm.weight torch.Size([1024])\n",
624
+ "encoder__layers_17__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.17.layer.1.layer_norm.weight torch.Size([1024])\n",
625
+ "encoder__layers_18__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.18.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
626
+ "encoder__layers_18__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.18.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
627
+ "encoder__layers_18__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.18.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
628
+ "encoder__layers_18__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.18.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
629
+ "encoder__layers_18__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.18.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
630
+ "encoder__layers_18__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.18.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
631
+ "encoder__layers_18__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.18.layer.0.layer_norm.weight torch.Size([1024])\n",
632
+ "encoder__layers_18__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.18.layer.1.layer_norm.weight torch.Size([1024])\n",
633
+ "encoder__layers_19__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.19.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
634
+ "encoder__layers_19__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.19.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
635
+ "encoder__layers_19__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.19.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
636
+ "encoder__layers_19__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.19.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
637
+ "encoder__layers_19__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.19.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
638
+ "encoder__layers_19__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.19.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
639
+ "encoder__layers_19__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.19.layer.0.layer_norm.weight torch.Size([1024])\n",
640
+ "encoder__layers_19__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.19.layer.1.layer_norm.weight torch.Size([1024])\n",
641
+ "encoder__layers_2__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.2.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
642
+ "encoder__layers_2__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.2.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
643
+ "encoder__layers_2__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.2.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
644
+ "encoder__layers_2__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.2.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
645
+ "encoder__layers_2__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.2.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
646
+ "encoder__layers_2__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.2.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
647
+ "encoder__layers_2__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.2.layer.0.layer_norm.weight torch.Size([1024])\n",
648
+ "encoder__layers_2__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.2.layer.1.layer_norm.weight torch.Size([1024])\n",
649
+ "encoder__layers_20__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.20.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
650
+ "encoder__layers_20__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.20.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
651
+ "encoder__layers_20__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.20.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
652
+ "encoder__layers_20__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.20.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
653
+ "encoder__layers_20__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.20.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
654
+ "encoder__layers_20__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.20.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
655
+ "encoder__layers_20__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.20.layer.0.layer_norm.weight torch.Size([1024])\n",
656
+ "encoder__layers_20__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.20.layer.1.layer_norm.weight torch.Size([1024])\n",
657
+ "encoder__layers_21__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.21.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
658
+ "encoder__layers_21__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.21.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
659
+ "encoder__layers_21__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.21.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n"
660
+ ]
661
+ },
662
+ {
663
+ "name": "stdout",
664
+ "output_type": "stream",
665
+ "text": [
666
+ "encoder__layers_21__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.21.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
667
+ "encoder__layers_21__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.21.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
668
+ "encoder__layers_21__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.21.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
669
+ "encoder__layers_21__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.21.layer.0.layer_norm.weight torch.Size([1024])\n",
670
+ "encoder__layers_21__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.21.layer.1.layer_norm.weight torch.Size([1024])\n",
671
+ "encoder__layers_22__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.22.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
672
+ "encoder__layers_22__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.22.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
673
+ "encoder__layers_22__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.22.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
674
+ "encoder__layers_22__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.22.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
675
+ "encoder__layers_22__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.22.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
676
+ "encoder__layers_22__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.22.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
677
+ "encoder__layers_22__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.22.layer.0.layer_norm.weight torch.Size([1024])\n",
678
+ "encoder__layers_22__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.22.layer.1.layer_norm.weight torch.Size([1024])\n",
679
+ "encoder__layers_23__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.23.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
680
+ "encoder__layers_23__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.23.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
681
+ "encoder__layers_23__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.23.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
682
+ "encoder__layers_23__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.23.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
683
+ "encoder__layers_23__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.23.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
684
+ "encoder__layers_23__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.23.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
685
+ "encoder__layers_23__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.23.layer.0.layer_norm.weight torch.Size([1024])\n",
686
+ "encoder__layers_23__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.23.layer.1.layer_norm.weight torch.Size([1024])\n",
687
+ "encoder__layers_3__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.3.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
688
+ "encoder__layers_3__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.3.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
689
+ "encoder__layers_3__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.3.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
690
+ "encoder__layers_3__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.3.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
691
+ "encoder__layers_3__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.3.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
692
+ "encoder__layers_3__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.3.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
693
+ "encoder__layers_3__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.3.layer.0.layer_norm.weight torch.Size([1024])\n",
694
+ "encoder__layers_3__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.3.layer.1.layer_norm.weight torch.Size([1024])\n",
695
+ "encoder__layers_4__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.4.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
696
+ "encoder__layers_4__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.4.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
697
+ "encoder__layers_4__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.4.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
698
+ "encoder__layers_4__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.4.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
699
+ "encoder__layers_4__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.4.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
700
+ "encoder__layers_4__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.4.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
701
+ "encoder__layers_4__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.4.layer.0.layer_norm.weight torch.Size([1024])\n",
702
+ "encoder__layers_4__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.4.layer.1.layer_norm.weight torch.Size([1024])\n",
703
+ "encoder__layers_5__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.5.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
704
+ "encoder__layers_5__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.5.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
705
+ "encoder__layers_5__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.5.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
706
+ "encoder__layers_5__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.5.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
707
+ "encoder__layers_5__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.5.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
708
+ "encoder__layers_5__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.5.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
709
+ "encoder__layers_5__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.5.layer.0.layer_norm.weight torch.Size([1024])\n",
710
+ "encoder__layers_5__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.5.layer.1.layer_norm.weight torch.Size([1024])\n",
711
+ "encoder__layers_6__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.6.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
712
+ "encoder__layers_6__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.6.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
713
+ "encoder__layers_6__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.6.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
714
+ "encoder__layers_6__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.6.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
715
+ "encoder__layers_6__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.6.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
716
+ "encoder__layers_6__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.6.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
717
+ "encoder__layers_6__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.6.layer.0.layer_norm.weight torch.Size([1024])\n",
718
+ "encoder__layers_6__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.6.layer.1.layer_norm.weight torch.Size([1024])\n",
719
+ "encoder__layers_7__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.7.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
720
+ "encoder__layers_7__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.7.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
721
+ "encoder__layers_7__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.7.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
722
+ "encoder__layers_7__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.7.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
723
+ "encoder__layers_7__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.7.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n"
724
+ ]
725
+ },
726
+ {
727
+ "name": "stdout",
728
+ "output_type": "stream",
729
+ "text": [
730
+ "encoder__layers_7__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.7.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
731
+ "encoder__layers_7__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.7.layer.0.layer_norm.weight torch.Size([1024])\n",
732
+ "encoder__layers_7__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.7.layer.1.layer_norm.weight torch.Size([1024])\n",
733
+ "encoder__layers_8__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.8.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
734
+ "encoder__layers_8__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.8.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
735
+ "encoder__layers_8__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.8.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
736
+ "encoder__layers_8__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.8.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
737
+ "encoder__layers_8__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.8.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
738
+ "encoder__layers_8__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.8.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
739
+ "encoder__layers_8__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.8.layer.0.layer_norm.weight torch.Size([1024])\n",
740
+ "encoder__layers_8__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.8.layer.1.layer_norm.weight torch.Size([1024])\n",
741
+ "encoder__layers_9__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.9.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n",
742
+ "encoder__layers_9__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.9.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n",
743
+ "encoder__layers_9__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.9.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n",
744
+ "encoder__layers_9__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.9.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n",
745
+ "encoder__layers_9__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.9.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n",
746
+ "encoder__layers_9__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.9.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n",
747
+ "encoder__layers_9__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.9.layer.0.layer_norm.weight torch.Size([1024])\n",
748
+ "encoder__layers_9__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.9.layer.1.layer_norm.weight torch.Size([1024])\n",
749
+ "encoder__relpos_bias__rel_embedding:0 ((128, 32)) =transpose=> encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight torch.Size([32, 128])\n",
750
+ "token_embedder__embedding:0 ((32128, 1024)) => shared.weight torch.Size([32128, 1024])\n",
751
+ "Linear(in_features=1024, out_features=768, bias=False)\n",
752
+ "Remaining weights: set()\n"
753
+ ]
754
+ }
755
+ ],
756
+ "source": [
757
+ "t5_state = t5.state_dict()\n",
758
+ "state_all_names = set(t5_state.keys())\n",
759
+ "\n",
760
+ "\n",
761
+ "for var in v:\n",
762
+ " tf_name = var.name\n",
763
+ " if tf_name in names_to_ignore:\n",
764
+ " continue\n",
765
+ " \n",
766
+ " pt_name = convert_name(tf_name)\n",
767
+ " weights = np.float32(var.numpy())\n",
768
+ " \n",
769
+ " state_all_names.remove(pt_name)\n",
770
+ " \n",
771
+ " tranpose_status = \"=>\"\n",
772
+ " if need_transpose(pt_name, ['DenseReluDense', 'relative_attention_bias',]):\n",
773
+ " tranpose_status = \"=transpose=>\"\n",
774
+ " weights = weights.transpose()\n",
775
+ " \n",
776
+ " print(tf_name, f\"({var.shape})\", tranpose_status, pt_name, t5_state[pt_name].shape)\n",
777
+ " \n",
778
+ " original_shape = t5_state[pt_name].shape\n",
779
+ " t5_state[pt_name] = torch.nn.Parameter(torch.tensor(weights))\n",
780
+ " new_shape = t5_state[pt_name].shape\n",
781
+ " \n",
782
+ " if not equal_shapes(original_shape, new_shape):\n",
783
+ " print(\"Different shape:\", tf_name, original_shape, pt_name, new_shape)\n",
784
+ " break\n",
785
+ "\n",
786
+ "#Encoder Word embeddings\n",
787
+ "t5_state['encoder.embed_tokens.weight'] = t5_state['shared.weight']\n",
788
+ "state_all_names.remove('encoder.embed_tokens.weight')\n",
789
+ " \n",
790
+ "#Load back the weights\n",
791
+ "t5.load_state_dict(t5_state) \n",
792
+ "\n",
793
+ "tf_linear_weight = tf_name_weight[\"projection_layer__kernel:0\"]\n",
794
+ "linear = torch.nn.Linear(tf_linear_weight.shape[0], tf_linear_weight.shape[1], bias=False)\n",
795
+ "original_shape = linear.weight.shape\n",
796
+ "linear.weight = torch.nn.Parameter(torch.tensor(np.float32(tf_linear_weight.numpy()).transpose()))\n",
797
+ "new_shape = linear.weight.shape\n",
798
+ "if not equal_shapes(original_shape, new_shape):\n",
799
+ " print(\"Different shape at linear layer\")\n",
800
+ " \n",
801
+ "print(linear)\n",
802
+ "print(\"Remaining weights:\", state_all_names)\n",
803
+ "assert len(state_all_names) == 0\n"
804
+ ]
805
+ },
806
+ {
807
+ "cell_type": "code",
808
+ "execution_count": 11,
809
+ "id": "d59d5a2c",
810
+ "metadata": {},
811
+ "outputs": [
812
+ {
813
+ "name": "stdout",
814
+ "output_type": "stream",
815
+ "text": [
816
+ "torch.Size([8, 768])\n"
817
+ ]
818
+ },
819
+ {
820
+ "data": {
821
+ "text/plain": [
822
+ "tensor([[1.0000, 0.9279, 0.6404, 0.5968, 0.5420, 0.5442, 0.6099, 0.6318],\n",
823
+ " [0.9279, 1.0000, 0.6629, 0.6098, 0.5562, 0.5687, 0.6382, 0.6262],\n",
824
+ " [0.6404, 0.6629, 1.0000, 0.8351, 0.7101, 0.6953, 0.6265, 0.6390],\n",
825
+ " [0.5968, 0.6098, 0.8351, 1.0000, 0.6877, 0.6716, 0.5902, 0.6102],\n",
826
+ " [0.5420, 0.5562, 0.7101, 0.6877, 1.0000, 0.8924, 0.5701, 0.5661],\n",
827
+ " [0.5442, 0.5687, 0.6953, 0.6716, 0.8924, 1.0000, 0.5665, 0.5457],\n",
828
+ " [0.6099, 0.6382, 0.6265, 0.5902, 0.5701, 0.5665, 1.0000, 0.7950],\n",
829
+ " [0.6318, 0.6262, 0.6390, 0.6102, 0.5661, 0.5457, 0.7950, 1.0000]])"
830
+ ]
831
+ },
832
+ "execution_count": 11,
833
+ "metadata": {},
834
+ "output_type": "execute_result"
835
+ }
836
+ ],
837
+ "source": [
838
+ "english_sentences = [\"Berlin is the capital of Germany\", \"Berlin is a large city in Germany\",\n",
839
+ " \"Tensorflow can be used for deep learning\", \"Pytorch, developed by Facebook AI, is a deep learning framework\",\n",
840
+ " \"Is Scipy or numpy better?\", \"Which is faster: scipy or pandas?\",\n",
841
+ " \"Cats can live for quite a long time\", \"Cats are humans best friend\"]\n",
842
+ "\n",
843
+ "encoded_input = tokenizer(english_sentences, return_tensors=\"pt\", padding=True)\n",
844
+ "\n",
845
+ "with torch.no_grad():\n",
846
+ " model_output = t5(**encoded_input)\n",
847
+ " \n",
848
+ " # Perform pooling\n",
849
+ " hf_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])\n",
850
+ "\n",
851
+ " # Apply linear layer\n",
852
+ " hf_embeddings = linear(hf_embeddings)\n",
853
+ " \n",
854
+ " print(hf_embeddings.shape)\n",
855
+ "\n",
856
+ " # Normalize embeddings\n",
857
+ " hf_embeddings = F.normalize(hf_embeddings, p=2, dim=1)\n",
858
+ "\n",
859
+ "# Cos\n",
860
+ "hf_scores = util.dot_score(hf_embeddings, hf_embeddings).numpy()\n",
861
+ "hf_scores"
862
+ ]
863
+ },
864
+ {
865
+ "cell_type": "code",
866
+ "execution_count": 12,
867
+ "id": "677a8bab",
868
+ "metadata": {},
869
+ "outputs": [
870
+ {
871
+ "name": "stderr",
872
+ "output_type": "stream",
873
+ "text": [
874
+ "2022-02-01 20:00:27.115638: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)\n",
875
+ "2022-02-01 20:00:29.328848: I tensorflow/compiler/xla/service/service.cc:171] XLA service 0x7fe9781cd6f0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:\n",
876
+ "2022-02-01 20:00:29.328894: I tensorflow/compiler/xla/service/service.cc:179] StreamExecutor device (0): Host, Default Version\n",
877
+ "2022-02-01 20:00:30.324558: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:210] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.\n",
878
+ "2022-02-01 20:01:02.775112: I tensorflow/compiler/jit/xla_compilation_cache.cc:363] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.\n"
879
+ ]
880
+ },
881
+ {
882
+ "name": "stdout",
883
+ "output_type": "stream",
884
+ "text": [
885
+ "(8, 768)\n"
886
+ ]
887
+ },
888
+ {
889
+ "data": {
890
+ "text/plain": [
891
+ "tensor([[1.0000, 0.9279, 0.6402, 0.5966, 0.5422, 0.5446, 0.6097, 0.6320],\n",
892
+ " [0.9279, 1.0000, 0.6631, 0.6099, 0.5566, 0.5690, 0.6386, 0.6268],\n",
893
+ " [0.6402, 0.6631, 1.0000, 0.8347, 0.7101, 0.6955, 0.6264, 0.6389],\n",
894
+ " [0.5966, 0.6099, 0.8347, 1.0000, 0.6873, 0.6712, 0.5899, 0.6100],\n",
895
+ " [0.5422, 0.5566, 0.7101, 0.6873, 1.0000, 0.8927, 0.5700, 0.5661],\n",
896
+ " [0.5446, 0.5690, 0.6955, 0.6712, 0.8927, 1.0000, 0.5663, 0.5458],\n",
897
+ " [0.6097, 0.6386, 0.6264, 0.5899, 0.5700, 0.5663, 1.0000, 0.7949],\n",
898
+ " [0.6320, 0.6268, 0.6389, 0.6100, 0.5661, 0.5458, 0.7949, 1.0000]])"
899
+ ]
900
+ },
901
+ "execution_count": 12,
902
+ "metadata": {},
903
+ "output_type": "execute_result"
904
+ }
905
+ ],
906
+ "source": [
907
+ "# Test the models - Original embeddings\n",
908
+ "english_embeds = encoder(english_sentences)[0].numpy()\n",
909
+ "print(english_embeds.shape)\n",
910
+ "tf_scores = util.dot_score(english_embeds, english_embeds).numpy()\n",
911
+ "print(tf_scores)\n",
912
+ "print(\"Diff:\", np.sum(np.abs(tf_scores - hf_scores)))"
913
+ ]
914
+ },
915
+ {
916
+ "cell_type": "code",
917
+ "execution_count": 13,
918
+ "id": "34b44ef7",
919
+ "metadata": {},
920
+ "outputs": [
921
+ {
922
+ "ename": "FileNotFoundError",
923
+ "evalue": "[Errno 2] No such file or directory: 'models/sentence-t5-11b/2_Dense/config.json'",
924
+ "output_type": "error",
925
+ "traceback": [
926
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
927
+ "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
928
+ "\u001b[0;32m/tmp/ipykernel_26913/2543044366.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 7\u001b[0m bias=False, activation_function=torch.nn.Identity())\n\u001b[1;32m 8\u001b[0m \u001b[0mdense\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlinear\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0mdense\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfolder\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'2_Dense'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
929
+ "\u001b[0;32m/home/sbert/sentence-transformers/sentence_transformers/models/Dense.py\u001b[0m in \u001b[0;36msave\u001b[0;34m(self, output_path)\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0msave\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 48\u001b[0;31m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'config.json'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'w'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mfOut\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 49\u001b[0m \u001b[0mjson\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdump\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_config_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfOut\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
930
+ "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'models/sentence-t5-11b/2_Dense/config.json'"
931
+ ]
932
+ }
933
+ ],
934
+ "source": [
935
+ "folder = f'models/sentence-t5-{model_size}'\n",
936
+ "t5.save_pretrained(folder)\n",
937
+ "tokenizer.save_pretrained(folder)\n",
938
+ "\n",
939
+ "import sentence_transformers\n",
940
+ "dense = sentence_transformers.models.Dense(linear.in_features, linear.out_features, \n",
941
+ " bias=False, activation_function=torch.nn.Identity())\n",
942
+ "dense.linear = linear\n",
943
+ "\n",
944
+ "dense_path = os.path.join(folder, '2_Dense')\n",
945
+ "os.makedirs(dense_path, exist_ok=True)\n",
946
+ "dense.save(dense_path)\n"
947
+ ]
948
+ },
949
+ {
950
+ "cell_type": "code",
951
+ "execution_count": 15,
952
+ "id": "f2d561c1",
953
+ "metadata": {},
954
+ "outputs": [],
955
+ "source": [
956
+ "\n"
957
+ ]
958
+ }
959
+ ],
960
+ "metadata": {
961
+ "kernelspec": {
962
+ "display_name": "Python 3 (ipykernel)",
963
+ "language": "python",
964
+ "name": "python3"
965
+ },
966
+ "language_info": {
967
+ "codemirror_mode": {
968
+ "name": "ipython",
969
+ "version": 3
970
+ },
971
+ "file_extension": ".py",
972
+ "mimetype": "text/x-python",
973
+ "name": "python",
974
+ "nbconvert_exporter": "python",
975
+ "pygments_lexer": "ipython3",
976
+ "version": "3.8.8"
977
+ }
978
+ },
979
+ "nbformat": 4,
980
+ "nbformat_minor": 5
981
+ }
convert_to_fp16.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from transformers import T5EncoderModel
3
+
4
+ in_path = sys.argv[1]
5
+ out_path = sys.argv[2]
6
+
7
+ model = T5EncoderModel.from_pretrained(in_path)
8
+ model.half()
9
+ model.save_pretrained(out_path)
modules.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "sentence_transformers.models.Transformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_Pooling",
12
+ "type": "sentence_transformers.models.Pooling"
13
+ },
14
+ {
15
+ "idx": 2,
16
+ "name": "2",
17
+ "path": "2_Dense",
18
+ "type": "sentence_transformers.models.Dense"
19
+ },
20
+ {
21
+ "idx": 3,
22
+ "name": "3",
23
+ "path": "3_Normalize",
24
+ "type": "sentence_transformers.models.Normalize"
25
+ }
26
+ ]
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b91bd3ded13728f29297a9f2ee2a809acd211f52271a857488e491c4c345208
3
+ size 219303530
sentence_bert_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
1
+ {
2
+ "max_seq_length": 256,
3
+ "do_lower_case": false
4
+ }
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
1
+ {"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>", "additional_special_tokens": ["<extra_id_0>", "<extra_id_1>", "<extra_id_2>", "<extra_id_3>", "<extra_id_4>", "<extra_id_5>", "<extra_id_6>", "<extra_id_7>", "<extra_id_8>", "<extra_id_9>", "<extra_id_10>", "<extra_id_11>", "<extra_id_12>", "<extra_id_13>", "<extra_id_14>", "<extra_id_15>", "<extra_id_16>", "<extra_id_17>", "<extra_id_18>", "<extra_id_19>", "<extra_id_20>", "<extra_id_21>", "<extra_id_22>", "<extra_id_23>", "<extra_id_24>", "<extra_id_25>", "<extra_id_26>", "<extra_id_27>", "<extra_id_28>", "<extra_id_29>", "<extra_id_30>", "<extra_id_31>", "<extra_id_32>", "<extra_id_33>", "<extra_id_34>", "<extra_id_35>", "<extra_id_36>", "<extra_id_37>", "<extra_id_38>", "<extra_id_39>", "<extra_id_40>", "<extra_id_41>", "<extra_id_42>", "<extra_id_43>", "<extra_id_44>", "<extra_id_45>", "<extra_id_46>", "<extra_id_47>", "<extra_id_48>", "<extra_id_49>", "<extra_id_50>", "<extra_id_51>", "<extra_id_52>", "<extra_id_53>", "<extra_id_54>", "<extra_id_55>", "<extra_id_56>", "<extra_id_57>", "<extra_id_58>", "<extra_id_59>", "<extra_id_60>", "<extra_id_61>", "<extra_id_62>", "<extra_id_63>", "<extra_id_64>", "<extra_id_65>", "<extra_id_66>", "<extra_id_67>", "<extra_id_68>", "<extra_id_69>", "<extra_id_70>", "<extra_id_71>", "<extra_id_72>", "<extra_id_73>", "<extra_id_74>", "<extra_id_75>", "<extra_id_76>", "<extra_id_77>", "<extra_id_78>", "<extra_id_79>", "<extra_id_80>", "<extra_id_81>", "<extra_id_82>", "<extra_id_83>", "<extra_id_84>", "<extra_id_85>", "<extra_id_86>", "<extra_id_87>", "<extra_id_88>", "<extra_id_89>", "<extra_id_90>", "<extra_id_91>", "<extra_id_92>", "<extra_id_93>", "<extra_id_94>", "<extra_id_95>", "<extra_id_96>", "<extra_id_97>", "<extra_id_98>", "<extra_id_99>"]}
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d60acb128cf7b7f2536e8f38a5b18a05535c9e14c7a355904270e15b0945ea86
3
+ size 791656
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
1
+ {"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>", "extra_ids": 100, "additional_special_tokens": ["<extra_id_0>", "<extra_id_1>", "<extra_id_2>", "<extra_id_3>", "<extra_id_4>", "<extra_id_5>", "<extra_id_6>", "<extra_id_7>", "<extra_id_8>", "<extra_id_9>", "<extra_id_10>", "<extra_id_11>", "<extra_id_12>", "<extra_id_13>", "<extra_id_14>", "<extra_id_15>", "<extra_id_16>", "<extra_id_17>", "<extra_id_18>", "<extra_id_19>", "<extra_id_20>", "<extra_id_21>", "<extra_id_22>", "<extra_id_23>", "<extra_id_24>", "<extra_id_25>", "<extra_id_26>", "<extra_id_27>", "<extra_id_28>", "<extra_id_29>", "<extra_id_30>", "<extra_id_31>", "<extra_id_32>", "<extra_id_33>", "<extra_id_34>", "<extra_id_35>", "<extra_id_36>", "<extra_id_37>", "<extra_id_38>", "<extra_id_39>", "<extra_id_40>", "<extra_id_41>", "<extra_id_42>", "<extra_id_43>", "<extra_id_44>", "<extra_id_45>", "<extra_id_46>", "<extra_id_47>", "<extra_id_48>", "<extra_id_49>", "<extra_id_50>", "<extra_id_51>", "<extra_id_52>", "<extra_id_53>", "<extra_id_54>", "<extra_id_55>", "<extra_id_56>", "<extra_id_57>", "<extra_id_58>", "<extra_id_59>", "<extra_id_60>", "<extra_id_61>", "<extra_id_62>", "<extra_id_63>", "<extra_id_64>", "<extra_id_65>", "<extra_id_66>", "<extra_id_67>", "<extra_id_68>", "<extra_id_69>", "<extra_id_70>", "<extra_id_71>", "<extra_id_72>", "<extra_id_73>", "<extra_id_74>", "<extra_id_75>", "<extra_id_76>", "<extra_id_77>", "<extra_id_78>", "<extra_id_79>", "<extra_id_80>", "<extra_id_81>", "<extra_id_82>", "<extra_id_83>", "<extra_id_84>", "<extra_id_85>", "<extra_id_86>", "<extra_id_87>", "<extra_id_88>", "<extra_id_89>", "<extra_id_90>", "<extra_id_91>", "<extra_id_92>", "<extra_id_93>", "<extra_id_94>", "<extra_id_95>", "<extra_id_96>", "<extra_id_97>", "<extra_id_98>", "<extra_id_99>"], "model_max_length": 512, "special_tokens_map_file": null, "name_or_path": "t5-base", "tokenizer_class": "T5Tokenizer"}