{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "17bffc12", "metadata": {}, "outputs": [], "source": [ "from transformers import AutoTokenizer\n", "from sentence_transformers import util\n", "import os\n", "import numpy as np\n", "import torch.nn.functional as F\n", "from transformers import T5EncoderModel\n", "import sentence_transformers" ] }, { "cell_type": "code", "execution_count": 2, "id": "160d8ce6", "metadata": {}, "outputs": [], "source": [ "\n", "#Mean Pooling - Take attention mask into account for correct averaging\n", "def mean_pooling(model_output, attention_mask):\n", " token_embeddings = model_output[0] #First element of model_output contains all token embeddings\n", " input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()\n", " return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)\n", "\n", "\n", " " ] }, { "cell_type": "code", "execution_count": 3, "id": "2f67f426", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl:Using /tmp/tfhub_modules to cache modules.\n", "2022-02-01 20:04:53.747606: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudnn.so.8'; dlerror: libcudnn.so.8: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n", "2022-02-01 20:04:53.747647: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1835] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.\n", "Skipping registering GPU devices...\n", "2022-02-01 20:04:53.747987: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F FMA\n", "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", "WARNING:absl:Importing a function (__inference_closure_12264) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.\n", "WARNING:absl:Importing a function (__inference_closure_8418) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.\n", "WARNING:absl:Importing a function (__inference_closure_4202) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.\n" ] } ], "source": [ "import tensorflow as tf\n", "import tensorflow_hub as hub\n", "import tensorflow_text as text \n", "\n", "model_size_tf, model_size_hf = \"base\", \"base\"\n", "hub_url = f\"https://tfhub.dev/google/gtr/gtr-{model_size_tf}/1\"\n", "encoder = hub.load(hub_url)\n", "\n", "v = encoder.signatures['serving_default'].variables" ] }, { "cell_type": "code", "execution_count": 4, "id": "5f4c8d94", "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "{'encoder__encoder_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_0__attention__key__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_0__attention__out__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_0__attention__query__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_0__attention__value__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_0__mlp__wi__kernel:0': TensorShape([768, 3072]),\n", " 'encoder__layers_0__mlp__wo__kernel:0': TensorShape([3072, 768]),\n", " 'encoder__layers_0__pre_attention_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_0__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_1__attention__key__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_1__attention__out__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_1__attention__query__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_1__attention__value__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_1__mlp__wi__kernel:0': TensorShape([768, 3072]),\n", " 'encoder__layers_1__mlp__wo__kernel:0': TensorShape([3072, 768]),\n", " 'encoder__layers_1__pre_attention_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_1__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_10__attention__key__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_10__attention__out__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_10__attention__query__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_10__attention__value__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_10__mlp__wi__kernel:0': TensorShape([768, 3072]),\n", " 'encoder__layers_10__mlp__wo__kernel:0': TensorShape([3072, 768]),\n", " 'encoder__layers_10__pre_attention_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_10__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_11__attention__key__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_11__attention__out__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_11__attention__query__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_11__attention__value__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_11__mlp__wi__kernel:0': TensorShape([768, 3072]),\n", " 'encoder__layers_11__mlp__wo__kernel:0': TensorShape([3072, 768]),\n", " 'encoder__layers_11__pre_attention_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_11__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_2__attention__key__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_2__attention__out__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_2__attention__query__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_2__attention__value__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_2__mlp__wi__kernel:0': TensorShape([768, 3072]),\n", " 'encoder__layers_2__mlp__wo__kernel:0': TensorShape([3072, 768]),\n", " 'encoder__layers_2__pre_attention_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_2__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_3__attention__key__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_3__attention__out__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_3__attention__query__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_3__attention__value__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_3__mlp__wi__kernel:0': TensorShape([768, 3072]),\n", " 'encoder__layers_3__mlp__wo__kernel:0': TensorShape([3072, 768]),\n", " 'encoder__layers_3__pre_attention_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_3__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_4__attention__key__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_4__attention__out__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_4__attention__query__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_4__attention__value__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_4__mlp__wi__kernel:0': TensorShape([768, 3072]),\n", " 'encoder__layers_4__mlp__wo__kernel:0': TensorShape([3072, 768]),\n", " 'encoder__layers_4__pre_attention_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_4__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_5__attention__key__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_5__attention__out__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_5__attention__query__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_5__attention__value__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_5__mlp__wi__kernel:0': TensorShape([768, 3072]),\n", " 'encoder__layers_5__mlp__wo__kernel:0': TensorShape([3072, 768]),\n", " 'encoder__layers_5__pre_attention_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_5__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_6__attention__key__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_6__attention__out__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_6__attention__query__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_6__attention__value__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_6__mlp__wi__kernel:0': TensorShape([768, 3072]),\n", " 'encoder__layers_6__mlp__wo__kernel:0': TensorShape([3072, 768]),\n", " 'encoder__layers_6__pre_attention_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_6__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_7__attention__key__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_7__attention__out__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_7__attention__query__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_7__attention__value__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_7__mlp__wi__kernel:0': TensorShape([768, 3072]),\n", " 'encoder__layers_7__mlp__wo__kernel:0': TensorShape([3072, 768]),\n", " 'encoder__layers_7__pre_attention_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_7__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_8__attention__key__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_8__attention__out__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_8__attention__query__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_8__attention__value__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_8__mlp__wi__kernel:0': TensorShape([768, 3072]),\n", " 'encoder__layers_8__mlp__wo__kernel:0': TensorShape([3072, 768]),\n", " 'encoder__layers_8__pre_attention_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_8__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_9__attention__key__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_9__attention__out__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_9__attention__query__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_9__attention__value__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_9__mlp__wi__kernel:0': TensorShape([768, 3072]),\n", " 'encoder__layers_9__mlp__wo__kernel:0': TensorShape([3072, 768]),\n", " 'encoder__layers_9__pre_attention_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_9__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__relpos_bias__rel_embedding:0': TensorShape([12, 32]),\n", " 'projection_layer__kernel:0': TensorShape([768, 768]),\n", " 'token_embedder__embedding:0': TensorShape([32128, 768])}" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tf_name_weight = {var.name: var for var in v}\n", "tf_name_shape = {var.name: var.shape for var in v}\n", "tf_name_shape" ] }, { "cell_type": "code", "execution_count": 7, "id": "6d223b07", "metadata": { "scrolled": true }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e4b637f4a8f847fcaa7b09fb729227b0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=45229452544.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at t5-11b were not used when initializing T5EncoderModel: ['decoder.block.13.layer.1.EncDecAttention.o.weight', 'decoder.block.14.layer.2.DenseReluDense.wo.weight', 'decoder.block.4.layer.0.layer_norm.weight', 'decoder.block.6.layer.1.EncDecAttention.v.weight', 'decoder.block.15.layer.0.SelfAttention.v.weight', 'decoder.block.3.layer.1.layer_norm.weight', 'decoder.block.11.layer.2.DenseReluDense.wi.weight', 'decoder.block.11.layer.2.DenseReluDense.wo.weight', 'decoder.block.3.layer.0.SelfAttention.o.weight', 'decoder.block.12.layer.2.DenseReluDense.wo.weight', 'decoder.block.8.layer.1.EncDecAttention.k.weight', 'decoder.block.18.layer.1.layer_norm.weight', 'decoder.block.9.layer.2.DenseReluDense.wi.weight', 'decoder.block.15.layer.0.SelfAttention.q.weight', 'decoder.block.7.layer.0.SelfAttention.k.weight', 'decoder.block.14.layer.0.SelfAttention.v.weight', 'decoder.block.2.layer.0.SelfAttention.o.weight', 'decoder.block.14.layer.0.SelfAttention.q.weight', 'decoder.block.7.layer.0.SelfAttention.o.weight', 'decoder.block.9.layer.1.EncDecAttention.v.weight', 'decoder.block.21.layer.1.EncDecAttention.v.weight', 'decoder.block.5.layer.0.SelfAttention.q.weight', 'decoder.block.19.layer.2.layer_norm.weight', 'decoder.block.14.layer.1.EncDecAttention.k.weight', 'decoder.block.12.layer.1.EncDecAttention.o.weight', 'decoder.block.18.layer.2.DenseReluDense.wi.weight', 'decoder.block.3.layer.0.SelfAttention.q.weight', 'decoder.block.9.layer.2.DenseReluDense.wo.weight', 'decoder.block.14.layer.0.SelfAttention.k.weight', 'decoder.block.4.layer.0.SelfAttention.o.weight', 'decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight', 'decoder.block.23.layer.1.EncDecAttention.o.weight', 'decoder.block.6.layer.0.SelfAttention.v.weight', 'decoder.block.14.layer.1.EncDecAttention.v.weight', 'decoder.block.8.layer.0.SelfAttention.o.weight', 'decoder.block.0.layer.1.EncDecAttention.v.weight', 'decoder.block.3.layer.0.layer_norm.weight', 'decoder.block.10.layer.2.layer_norm.weight', 'decoder.block.4.layer.2.DenseReluDense.wo.weight', 'decoder.block.21.layer.0.layer_norm.weight', 'decoder.block.12.layer.2.layer_norm.weight', 'decoder.block.2.layer.0.SelfAttention.k.weight', 'decoder.block.0.layer.2.DenseReluDense.wi.weight', 'decoder.block.1.layer.2.layer_norm.weight', 'decoder.block.5.layer.0.layer_norm.weight', 'decoder.block.1.layer.1.EncDecAttention.o.weight', 'decoder.block.17.layer.0.SelfAttention.v.weight', 'decoder.block.11.layer.2.layer_norm.weight', 'decoder.block.3.layer.1.EncDecAttention.k.weight', 'decoder.block.15.layer.1.EncDecAttention.v.weight', 'decoder.block.12.layer.0.SelfAttention.q.weight', 'decoder.block.23.layer.0.SelfAttention.o.weight', 'decoder.block.16.layer.0.SelfAttention.q.weight', 'decoder.block.21.layer.0.SelfAttention.o.weight', 'decoder.block.19.layer.0.SelfAttention.q.weight', 'decoder.block.19.layer.0.SelfAttention.v.weight', 'decoder.block.7.layer.0.SelfAttention.q.weight', 'decoder.block.13.layer.0.SelfAttention.o.weight', 'decoder.block.20.layer.2.DenseReluDense.wi.weight', 'decoder.block.12.layer.2.DenseReluDense.wi.weight', 'decoder.block.5.layer.1.EncDecAttention.o.weight', 'decoder.block.17.layer.2.layer_norm.weight', 'decoder.block.21.layer.1.EncDecAttention.k.weight', 'decoder.block.16.layer.1.EncDecAttention.q.weight', 'decoder.block.8.layer.1.layer_norm.weight', 'decoder.block.5.layer.2.DenseReluDense.wi.weight', 'decoder.block.9.layer.1.EncDecAttention.o.weight', 'decoder.block.22.layer.1.EncDecAttention.k.weight', 'decoder.block.0.layer.1.EncDecAttention.k.weight', 'decoder.block.4.layer.1.EncDecAttention.k.weight', 'decoder.block.22.layer.2.DenseReluDense.wi.weight', 'decoder.block.23.layer.2.layer_norm.weight', 'decoder.block.2.layer.2.DenseReluDense.wo.weight', 'decoder.block.23.layer.0.SelfAttention.q.weight', 'decoder.block.20.layer.1.EncDecAttention.o.weight', 'decoder.block.3.layer.1.EncDecAttention.v.weight', 'decoder.block.1.layer.0.SelfAttention.o.weight', 'decoder.block.17.layer.1.EncDecAttention.v.weight', 'decoder.block.9.layer.1.EncDecAttention.k.weight', 'decoder.block.12.layer.0.SelfAttention.k.weight', 'decoder.block.3.layer.0.SelfAttention.k.weight', 'decoder.block.0.layer.0.SelfAttention.q.weight', 'decoder.block.16.layer.2.DenseReluDense.wi.weight', 'decoder.block.2.layer.0.layer_norm.weight', 'decoder.block.17.layer.1.EncDecAttention.q.weight', 'decoder.block.13.layer.0.SelfAttention.k.weight', 'decoder.block.21.layer.2.layer_norm.weight', 'decoder.block.4.layer.1.layer_norm.weight', 'decoder.block.13.layer.0.SelfAttention.v.weight', 'decoder.block.7.layer.2.DenseReluDense.wi.weight', 'decoder.block.13.layer.2.DenseReluDense.wi.weight', 'decoder.block.16.layer.1.EncDecAttention.v.weight', 'decoder.block.15.layer.1.layer_norm.weight', 'decoder.block.0.layer.1.EncDecAttention.o.weight', 'decoder.block.19.layer.1.EncDecAttention.v.weight', 'decoder.block.3.layer.1.EncDecAttention.q.weight', 'decoder.block.6.layer.2.DenseReluDense.wo.weight', 'decoder.block.20.layer.0.SelfAttention.v.weight', 'decoder.block.4.layer.0.SelfAttention.v.weight', 'decoder.block.7.layer.1.layer_norm.weight', 'decoder.block.11.layer.0.SelfAttention.o.weight', 'decoder.block.19.layer.1.EncDecAttention.o.weight', 'decoder.block.23.layer.2.DenseReluDense.wo.weight', 'decoder.block.5.layer.0.SelfAttention.o.weight', 'decoder.block.18.layer.1.EncDecAttention.v.weight', 'decoder.block.5.layer.1.EncDecAttention.v.weight', 'decoder.block.1.layer.2.DenseReluDense.wo.weight', 'decoder.block.16.layer.1.layer_norm.weight', 'decoder.block.12.layer.1.EncDecAttention.v.weight', 'decoder.block.17.layer.1.EncDecAttention.o.weight', 'decoder.block.6.layer.0.SelfAttention.k.weight', 'decoder.block.11.layer.0.SelfAttention.k.weight', 'decoder.block.4.layer.2.layer_norm.weight', 'decoder.block.8.layer.1.EncDecAttention.q.weight', 'decoder.block.16.layer.0.SelfAttention.v.weight', 'decoder.block.0.layer.2.layer_norm.weight', 'decoder.block.15.layer.1.EncDecAttention.k.weight', 'decoder.block.19.layer.1.EncDecAttention.k.weight', 'decoder.block.18.layer.0.SelfAttention.q.weight', 'decoder.block.6.layer.1.EncDecAttention.q.weight', 'decoder.block.2.layer.1.EncDecAttention.q.weight', 'decoder.block.17.layer.2.DenseReluDense.wi.weight', 'decoder.block.5.layer.2.layer_norm.weight', 'decoder.block.13.layer.2.layer_norm.weight', 'decoder.block.2.layer.2.layer_norm.weight', 'decoder.block.16.layer.1.EncDecAttention.k.weight', 'decoder.block.18.layer.1.EncDecAttention.q.weight', 'decoder.block.12.layer.1.layer_norm.weight', 'decoder.block.10.layer.1.EncDecAttention.o.weight', 'decoder.block.9.layer.0.SelfAttention.k.weight', 'decoder.block.0.layer.2.DenseReluDense.wo.weight', 'decoder.block.20.layer.1.EncDecAttention.v.weight', 'decoder.block.20.layer.0.SelfAttention.q.weight', 'decoder.block.22.layer.2.DenseReluDense.wo.weight', 'decoder.block.14.layer.1.layer_norm.weight', 'decoder.block.4.layer.1.EncDecAttention.v.weight', 'decoder.block.22.layer.0.SelfAttention.v.weight', 'decoder.block.15.layer.2.layer_norm.weight', 'decoder.block.23.layer.2.DenseReluDense.wi.weight', 'decoder.block.23.layer.1.EncDecAttention.v.weight', 'decoder.block.8.layer.2.DenseReluDense.wo.weight', 'decoder.block.7.layer.0.SelfAttention.v.weight', 'decoder.block.4.layer.1.EncDecAttention.o.weight', 'decoder.block.6.layer.1.EncDecAttention.k.weight', 'decoder.block.3.layer.2.layer_norm.weight', 'decoder.block.7.layer.1.EncDecAttention.o.weight', 'decoder.block.6.layer.0.SelfAttention.q.weight', 'decoder.block.0.layer.0.SelfAttention.k.weight', 'decoder.block.22.layer.0.SelfAttention.q.weight', 'decoder.block.18.layer.2.DenseReluDense.wo.weight', 'decoder.block.10.layer.0.SelfAttention.k.weight', 'decoder.block.4.layer.0.SelfAttention.q.weight', 'decoder.block.20.layer.2.DenseReluDense.wo.weight', 'decoder.block.11.layer.1.EncDecAttention.o.weight', 'decoder.block.3.layer.2.DenseReluDense.wi.weight', 'decoder.block.10.layer.0.SelfAttention.q.weight', 'decoder.block.17.layer.1.layer_norm.weight', 'decoder.block.20.layer.1.layer_norm.weight', 'decoder.block.18.layer.0.layer_norm.weight', 'decoder.block.21.layer.0.SelfAttention.v.weight', 'decoder.block.20.layer.0.SelfAttention.o.weight', 'decoder.block.22.layer.1.EncDecAttention.o.weight', 'decoder.block.21.layer.1.EncDecAttention.q.weight', 'decoder.block.16.layer.2.layer_norm.weight', 'decoder.block.2.layer.1.EncDecAttention.k.weight', 'decoder.block.10.layer.1.EncDecAttention.v.weight', 'decoder.block.10.layer.1.layer_norm.weight', 'decoder.block.3.layer.2.DenseReluDense.wo.weight', 'decoder.block.14.layer.1.EncDecAttention.o.weight', 'decoder.block.16.layer.1.EncDecAttention.o.weight', 'decoder.block.17.layer.1.EncDecAttention.k.weight', 'decoder.block.15.layer.0.SelfAttention.k.weight', 'decoder.block.11.layer.0.layer_norm.weight', 'decoder.block.23.layer.1.EncDecAttention.q.weight', 'decoder.block.13.layer.1.EncDecAttention.q.weight', 'decoder.block.0.layer.1.EncDecAttention.q.weight', 'decoder.block.13.layer.0.layer_norm.weight', 'decoder.block.9.layer.0.SelfAttention.o.weight', 'decoder.block.19.layer.2.DenseReluDense.wo.weight', 'decoder.block.4.layer.1.EncDecAttention.q.weight', 'decoder.block.16.layer.0.layer_norm.weight', 'decoder.block.8.layer.2.DenseReluDense.wi.weight', 'decoder.block.17.layer.2.DenseReluDense.wo.weight', 'decoder.block.7.layer.1.EncDecAttention.k.weight', 'decoder.block.14.layer.2.DenseReluDense.wi.weight', 'decoder.block.6.layer.1.layer_norm.weight', 'decoder.block.17.layer.0.SelfAttention.o.weight', 'decoder.block.19.layer.1.layer_norm.weight', 'decoder.block.13.layer.1.EncDecAttention.k.weight', 'decoder.block.0.layer.0.SelfAttention.o.weight', 'decoder.block.20.layer.1.EncDecAttention.q.weight', 'decoder.block.23.layer.0.SelfAttention.v.weight', 'decoder.block.3.layer.0.SelfAttention.v.weight', 'decoder.block.18.layer.1.EncDecAttention.k.weight', 'decoder.block.1.layer.2.DenseReluDense.wi.weight', 'decoder.block.1.layer.1.EncDecAttention.q.weight', 'decoder.block.6.layer.1.EncDecAttention.o.weight', 'decoder.block.22.layer.0.layer_norm.weight', 'decoder.block.8.layer.0.SelfAttention.v.weight', 'decoder.block.12.layer.1.EncDecAttention.q.weight', 'decoder.block.1.layer.0.SelfAttention.v.weight', 'decoder.block.15.layer.1.EncDecAttention.o.weight', 'decoder.block.6.layer.0.SelfAttention.o.weight', 'decoder.block.15.layer.1.EncDecAttention.q.weight', 'decoder.block.19.layer.2.DenseReluDense.wi.weight', 'decoder.block.12.layer.0.SelfAttention.o.weight', 'decoder.block.14.layer.2.layer_norm.weight', 'decoder.block.22.layer.0.SelfAttention.k.weight', 'decoder.block.13.layer.0.SelfAttention.q.weight', 'decoder.block.11.layer.1.EncDecAttention.v.weight', 'decoder.block.22.layer.1.layer_norm.weight', 'decoder.block.21.layer.0.SelfAttention.q.weight', 'decoder.block.6.layer.2.DenseReluDense.wi.weight', 'decoder.block.12.layer.0.SelfAttention.v.weight', 'decoder.block.21.layer.0.SelfAttention.k.weight', 'decoder.block.19.layer.0.SelfAttention.k.weight', 'decoder.block.10.layer.0.layer_norm.weight', 'decoder.block.2.layer.2.DenseReluDense.wi.weight', 'decoder.block.17.layer.0.SelfAttention.k.weight', 'decoder.block.23.layer.0.layer_norm.weight', 'decoder.block.4.layer.2.DenseReluDense.wi.weight', 'decoder.block.5.layer.1.EncDecAttention.k.weight', 'decoder.block.19.layer.0.SelfAttention.o.weight', 'decoder.block.5.layer.0.SelfAttention.k.weight', 'decoder.block.10.layer.2.DenseReluDense.wo.weight', 'decoder.block.2.layer.0.SelfAttention.q.weight', 'decoder.block.22.layer.1.EncDecAttention.v.weight', 'decoder.block.23.layer.1.layer_norm.weight', 'decoder.block.5.layer.1.layer_norm.weight', 'decoder.block.3.layer.1.EncDecAttention.o.weight', 'decoder.block.14.layer.0.SelfAttention.o.weight', 'decoder.block.17.layer.0.layer_norm.weight', 'decoder.final_layer_norm.weight', 'decoder.block.10.layer.2.DenseReluDense.wi.weight', 'decoder.block.12.layer.0.layer_norm.weight', 'decoder.block.23.layer.1.EncDecAttention.k.weight', 'decoder.block.21.layer.1.EncDecAttention.o.weight', 'decoder.block.13.layer.1.layer_norm.weight', 'decoder.block.1.layer.1.EncDecAttention.k.weight', 'decoder.block.18.layer.0.SelfAttention.v.weight', 'decoder.block.1.layer.0.layer_norm.weight', 'decoder.block.5.layer.0.SelfAttention.v.weight', 'decoder.block.13.layer.2.DenseReluDense.wo.weight', 'decoder.block.8.layer.1.EncDecAttention.v.weight', 'decoder.block.0.layer.0.SelfAttention.v.weight', 'decoder.block.23.layer.0.SelfAttention.k.weight', 'decoder.block.0.layer.0.layer_norm.weight', 'decoder.block.15.layer.0.layer_norm.weight', 'decoder.block.7.layer.2.layer_norm.weight', 'decoder.block.8.layer.0.SelfAttention.k.weight', 'decoder.block.15.layer.2.DenseReluDense.wo.weight', 'decoder.block.8.layer.1.EncDecAttention.o.weight', 'decoder.block.22.layer.0.SelfAttention.o.weight', 'decoder.block.17.layer.0.SelfAttention.q.weight', 'decoder.block.9.layer.0.SelfAttention.v.weight', 'decoder.block.9.layer.1.EncDecAttention.q.weight', 'decoder.block.7.layer.0.layer_norm.weight', 'decoder.block.14.layer.0.layer_norm.weight', 'decoder.block.9.layer.0.SelfAttention.q.weight', 'decoder.block.16.layer.0.SelfAttention.o.weight', 'decoder.block.2.layer.1.EncDecAttention.o.weight', 'decoder.block.20.layer.1.EncDecAttention.k.weight', 'decoder.block.18.layer.2.layer_norm.weight', 'decoder.block.13.layer.1.EncDecAttention.v.weight', 'decoder.block.7.layer.2.DenseReluDense.wo.weight', 'decoder.block.21.layer.2.DenseReluDense.wo.weight', 'decoder.block.15.layer.2.DenseReluDense.wi.weight', 'decoder.block.10.layer.0.SelfAttention.v.weight', 'decoder.block.2.layer.0.SelfAttention.v.weight', 'decoder.block.11.layer.1.EncDecAttention.k.weight', 'decoder.block.22.layer.2.layer_norm.weight', 'decoder.block.2.layer.1.layer_norm.weight', 'decoder.block.8.layer.2.layer_norm.weight', 'decoder.block.8.layer.0.SelfAttention.q.weight', 'decoder.block.12.layer.1.EncDecAttention.k.weight', 'decoder.block.11.layer.0.SelfAttention.v.weight', 'decoder.block.22.layer.1.EncDecAttention.q.weight', 'decoder.block.5.layer.1.EncDecAttention.q.weight', 'decoder.block.11.layer.0.SelfAttention.q.weight', 'decoder.block.1.layer.0.SelfAttention.k.weight', 'decoder.block.20.layer.0.SelfAttention.k.weight', 'decoder.block.6.layer.0.layer_norm.weight', 'decoder.block.6.layer.2.layer_norm.weight', 'decoder.block.21.layer.1.layer_norm.weight', 'decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight', 'decoder.block.20.layer.2.layer_norm.weight', 'decoder.block.19.layer.1.EncDecAttention.q.weight', 'decoder.block.10.layer.1.EncDecAttention.k.weight', 'decoder.block.20.layer.0.layer_norm.weight', 'decoder.block.18.layer.0.SelfAttention.k.weight', 'decoder.block.21.layer.2.DenseReluDense.wi.weight', 'decoder.block.11.layer.1.EncDecAttention.q.weight', 'decoder.block.15.layer.0.SelfAttention.o.weight', 'decoder.block.5.layer.2.DenseReluDense.wo.weight', 'decoder.block.10.layer.1.EncDecAttention.q.weight', 'decoder.block.9.layer.2.layer_norm.weight', 'decoder.block.7.layer.1.EncDecAttention.v.weight', 'decoder.block.9.layer.0.layer_norm.weight', 'decoder.block.16.layer.2.DenseReluDense.wo.weight', 'decoder.block.9.layer.1.layer_norm.weight', 'decoder.block.7.layer.1.EncDecAttention.q.weight', 'decoder.block.1.layer.0.SelfAttention.q.weight', 'decoder.block.18.layer.0.SelfAttention.o.weight', 'decoder.block.1.layer.1.EncDecAttention.v.weight', 'decoder.block.14.layer.1.EncDecAttention.q.weight', 'decoder.block.10.layer.0.SelfAttention.o.weight', 'decoder.block.16.layer.0.SelfAttention.k.weight', 'decoder.block.18.layer.1.EncDecAttention.o.weight', 'decoder.block.11.layer.1.layer_norm.weight', 'decoder.block.2.layer.1.EncDecAttention.v.weight', 'decoder.block.4.layer.0.SelfAttention.k.weight', 'decoder.block.19.layer.0.layer_norm.weight', 'decoder.block.1.layer.1.layer_norm.weight', 'decoder.block.8.layer.0.layer_norm.weight', 'decoder.block.0.layer.1.layer_norm.weight']\n", "- This IS expected if you are initializing T5EncoderModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing T5EncoderModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Some weights of T5EncoderModel were not initialized from the model checkpoint at t5-11b and are newly initialized: ['encoder.embed_tokens.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "data": { "text/plain": [ "{'shared.weight': torch.Size([32128, 1024]),\n", " 'encoder.embed_tokens.weight': torch.Size([32128, 1024]),\n", " 'encoder.block.0.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.0.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.0.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.0.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight': torch.Size([32, 128]),\n", " 'encoder.block.0.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.0.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.0.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.0.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.1.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.1.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.1.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.1.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.1.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.1.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.1.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.1.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.2.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.2.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.2.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.2.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.2.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.2.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.2.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.2.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.3.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.3.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.3.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.3.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.3.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.3.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.3.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.3.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.4.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.4.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.4.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.4.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.4.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.4.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.4.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.4.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.5.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.5.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.5.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.5.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.5.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.5.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.5.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.5.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.6.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.6.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.6.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.6.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.6.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.6.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.6.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.6.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.7.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.7.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.7.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.7.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.7.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.7.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.7.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.7.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.8.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.8.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.8.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.8.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.8.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.8.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.8.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.8.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.9.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.9.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.9.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.9.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.9.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.9.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.9.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.9.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.10.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.10.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.10.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.10.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.10.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.10.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.10.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.10.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.11.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.11.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.11.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.11.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.11.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.11.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.11.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.11.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.12.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.12.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.12.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.12.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.12.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.12.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.12.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.12.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.13.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.13.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.13.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.13.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.13.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.13.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.13.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.13.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.14.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.14.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.14.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.14.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.14.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.14.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.14.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.14.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.15.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.15.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.15.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.15.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.15.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.15.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.15.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.15.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.16.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.16.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.16.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.16.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.16.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.16.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.16.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.16.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.17.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.17.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.17.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.17.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.17.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.17.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.17.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.17.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.18.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.18.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.18.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.18.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.18.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.18.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.18.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.18.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.19.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.19.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.19.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.19.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.19.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.19.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.19.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.19.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.20.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.20.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.20.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.20.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.20.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.20.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.20.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.20.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.21.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.21.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.21.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.21.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.21.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.21.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.21.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.21.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.22.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.22.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.22.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.22.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.22.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.22.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.22.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.22.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.23.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.23.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.23.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.23.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.23.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.23.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.23.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.23.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.final_layer_norm.weight': torch.Size([1024])}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer = AutoTokenizer.from_pretrained(f\"t5-{model_size_hf}\")\n", "t5 = T5EncoderModel.from_pretrained(f\"t5-{model_size_hf}\") \n", "pt_name_shape = {name: weight.shape for name, weight in t5.state_dict().items()}\n", "pt_name_shape" ] }, { "cell_type": "code", "execution_count": 8, "id": "1d3c9865", "metadata": {}, "outputs": [], "source": [ "def convert_name(name):\n", " fct_map = {\n", " \"attention\": \"SelfAttention\",\n", " \"mlp\": \"DenseReluDense\",\n", " \"pre_attention_layer_norm\": \"layer_norm\",\n", " \"pre_mlp_layer_norm\": \"layer_norm\",\n", " }\n", " name_map = {\n", " 'key': 'k',\n", " 'out': 'o',\n", " 'query': 'q',\n", " 'value': 'v'\n", " }\n", " \n", " fixed_names = {\n", " \"token_embedder__embedding:0\": \"shared.weight\",\n", " \"encoder__encoder_norm__scale:0\": \"encoder.final_layer_norm.weight\",\n", " \"encoder__relpos_bias__rel_embedding:0\": \"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight\"\n", " }\n", " \n", " if name in fixed_names:\n", " return fixed_names[name]\n", " \n", " out = \"\"\n", " splits = name.split(\"__\")\n", " layer = splits[1].split(\"_\")[1]\n", " fct = fct_map.get(splits[2], splits[2])\n", " if 'layer_norm' in name:\n", " sublayer = \"1\" if \"pre_mlp_layer_norm\" in name else \"0\" #Not sure on the right setting here\n", " #sublayer = \"0\" if \"pre_mlp_layer_norm\" in name else \"1\" #Not sure on the right setting here\n", " out = f\"encoder.block.{layer}.layer.{sublayer}.{fct}.weight\"\n", " elif name.startswith(\"encoder__layers_\"):\n", " sublayer = \"0\" if fct == \"SelfAttention\" else \"1\"\n", " name = name_map.get(splits[3], splits[3])\n", " out = f\"encoder.block.{layer}.layer.{sublayer}.{fct}.{name}.weight\"\n", " \n", " return out" ] }, { "cell_type": "code", "execution_count": 9, "id": "1ca9590e", "metadata": {}, "outputs": [], "source": [ "def equal_shapes(shape1, shape2):\n", " if len(shape1) != len(shape2):\n", " return False\n", " \n", " for idx in range(len(shape1)):\n", " if shape1[idx] != shape2[idx]:\n", " return False\n", " \n", " return True" ] }, { "cell_type": "code", "execution_count": 10, "id": "ced52a5f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Remaining weights: {'encoder.embed_tokens.weight'}\n" ] } ], "source": [ "def need_transpose(name):\n", " #HF function: https://github.com/huggingface/transformers/blob/c962c2adbff678ae6d2e98378bed5b8d1a9831d9/src/transformers/models/t5/modeling_t5.py#L161\n", " return name != \"shared.weight\"\n", "\n", " \n", "\n", "names_to_ignore = {\"projection_layer__kernel:0\"}\n", "#Additional dense layer on top\n", "\n", "#Check we used all names\n", "pt_all_names = set(t5.state_dict().keys())\n", "\n", "for var in v:\n", " name = var.name\n", " if name in names_to_ignore:\n", " continue\n", " \n", " pt_name = convert_name(name)\n", " if pt_name not in pt_all_names:\n", " print(\"Name not found:\", name, \"=>\", pt_name)\n", " else:\n", " pt_all_names.remove(pt_name)\n", " tf_shape = tf_name_shape[name].as_list()\n", " pt_shape = list(pt_name_shape[pt_name])\n", " \n", " if need_transpose(pt_name):\n", " pt_shape = list(reversed(pt_shape))\n", " \n", " if not equal_shapes(tf_shape, pt_shape):\n", " print(\"Different shape:\", name, tf_shape, pt_name, pt_shape )\n", " \n", "print(\"Remaining weights:\", pt_all_names)\n", "#All layers match" ] }, { "cell_type": "code", "execution_count": 11, "id": "1190984f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of T5EncoderModel were not initialized from the model checkpoint at t5-11b and are newly initialized: ['encoder.embed_tokens.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "encoder__encoder_norm__scale:0 ((1024,)) =transpose=> encoder.final_layer_norm.weight torch.Size([1024])\n", "encoder__layers_0__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.0.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_0__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.0.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_0__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.0.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_0__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.0.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_0__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.0.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_0__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.0.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_0__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.0.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_0__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.0.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_1__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.1.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_1__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.1.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_1__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.1.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_1__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.1.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_1__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.1.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_1__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.1.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_1__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.1.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_1__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.1.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_10__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.10.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_10__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.10.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_10__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.10.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_10__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.10.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_10__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.10.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_10__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.10.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_10__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.10.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_10__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.10.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_11__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.11.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_11__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.11.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_11__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.11.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_11__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.11.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_11__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.11.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_11__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.11.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_11__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.11.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_11__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.11.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_12__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.12.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_12__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.12.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_12__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.12.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_12__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.12.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_12__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.12.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_12__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.12.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_12__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.12.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_12__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.12.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_13__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.13.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_13__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.13.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_13__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.13.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_13__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.13.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_13__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.13.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_13__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.13.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_13__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.13.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_13__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.13.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_14__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.14.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_14__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.14.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_14__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.14.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_14__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.14.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_14__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.14.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_14__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.14.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_14__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.14.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_14__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.14.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_15__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.15.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "encoder__layers_15__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.15.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_15__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.15.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_15__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.15.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_15__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.15.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_15__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.15.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_15__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.15.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_15__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.15.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_16__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.16.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_16__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.16.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_16__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.16.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_16__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.16.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_16__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.16.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_16__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.16.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_16__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.16.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_16__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.16.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_17__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.17.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_17__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.17.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_17__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.17.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_17__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.17.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_17__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.17.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_17__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.17.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_17__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.17.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_17__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.17.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_18__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.18.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_18__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.18.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_18__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.18.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_18__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.18.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_18__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.18.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_18__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.18.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_18__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.18.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_18__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.18.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_19__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.19.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_19__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.19.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_19__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.19.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_19__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.19.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_19__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.19.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_19__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.19.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_19__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.19.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_19__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.19.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_2__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.2.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_2__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.2.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_2__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.2.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_2__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.2.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_2__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.2.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_2__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.2.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_2__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.2.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_2__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.2.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_20__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.20.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_20__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.20.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_20__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.20.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_20__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.20.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_20__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.20.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_20__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.20.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_20__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.20.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_20__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.20.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_21__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.21.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_21__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.21.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "encoder__layers_21__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.21.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_21__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.21.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_21__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.21.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_21__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.21.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_21__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.21.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_21__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.21.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_22__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.22.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_22__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.22.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_22__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.22.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_22__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.22.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_22__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.22.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_22__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.22.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_22__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.22.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_22__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.22.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_23__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.23.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_23__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.23.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_23__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.23.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_23__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.23.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_23__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.23.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_23__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.23.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_23__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.23.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_23__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.23.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_3__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.3.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_3__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.3.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_3__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.3.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_3__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.3.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_3__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.3.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_3__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.3.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_3__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.3.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_3__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.3.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_4__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.4.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_4__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.4.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_4__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.4.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_4__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.4.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_4__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.4.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_4__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.4.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_4__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.4.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_4__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.4.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_5__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.5.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_5__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.5.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_5__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.5.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_5__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.5.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_5__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.5.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_5__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.5.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_5__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.5.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_5__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.5.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_6__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.6.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_6__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.6.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_6__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.6.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_6__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.6.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_6__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.6.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_6__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.6.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_6__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.6.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_6__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.6.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_7__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.7.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_7__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.7.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_7__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.7.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_7__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.7.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "encoder__layers_7__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.7.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_7__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.7.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_7__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.7.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_7__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.7.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_8__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.8.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_8__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.8.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_8__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.8.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_8__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.8.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_8__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.8.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_8__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.8.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_8__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.8.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_8__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.8.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_9__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.9.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_9__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.9.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_9__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.9.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_9__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.9.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_9__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.9.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_9__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.9.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_9__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.9.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_9__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.9.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__relpos_bias__rel_embedding:0 ((128, 32)) =transpose=> encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight torch.Size([32, 128])\n", "token_embedder__embedding:0 ((32128, 1024)) => shared.weight torch.Size([32128, 1024])\n", "Linear(in_features=1024, out_features=768, bias=False)\n", "Remaining weights: set()\n" ] } ], "source": [ "import torch\n", "tokenizer = AutoTokenizer.from_pretrained(f\"t5-{model_size_hf}\")\n", "T5EncoderModel._keys_to_ignore_on_load_unexpected = [\"decoder.*\"]\n", "t5 = T5EncoderModel.from_pretrained(f\"t5-{model_size_hf}\")\n", "t5_state = t5.state_dict()\n", "\n", "state_all_names = set(t5_state.keys())\n", "\n", "for var in v:\n", " tf_name = var.name\n", " if tf_name in names_to_ignore:\n", " continue\n", " \n", " pt_name = convert_name(tf_name)\n", " weights = np.float32(var.numpy())\n", " \n", " state_all_names.remove(pt_name)\n", " \n", " tranpose_status = \"=>\"\n", " if need_transpose(pt_name):\n", " tranpose_status = \"=transpose=>\"\n", " weights = weights.transpose()\n", " \n", " print(tf_name, f\"({var.shape})\", tranpose_status, pt_name, t5_state[pt_name].shape)\n", " \n", " original_shape = t5_state[pt_name].shape\n", " t5_state[pt_name] = torch.nn.Parameter(torch.tensor(weights))\n", " new_shape = t5_state[pt_name].shape\n", " \n", " if not equal_shapes(original_shape, new_shape):\n", " print(\"Different shape:\", tf_name, original_shape, pt_name, new_shape)\n", " break\n", "\n", "#Encoder Word embeddings\n", "t5_state['encoder.embed_tokens.weight'] = t5_state['shared.weight']\n", "state_all_names.remove('encoder.embed_tokens.weight')\n", " \n", "#Load back the weights\n", "t5.load_state_dict(t5_state) \n", "\n", "tf_linear_weight = tf_name_weight[\"projection_layer__kernel:0\"]\n", "linear = torch.nn.Linear(tf_linear_weight.shape[0], tf_linear_weight.shape[1], bias=False)\n", "original_shape = linear.weight.shape\n", "linear.weight = torch.nn.Parameter(torch.tensor(np.float32(tf_linear_weight.numpy()).transpose()))\n", "new_shape = linear.weight.shape\n", "if not equal_shapes(original_shape, new_shape):\n", " print(\"Different shape at linear layer\")\n", " \n", "print(linear)\n", "print(\"Remaining weights:\", state_all_names)\n", "assert len(state_all_names) == 0\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "d59d5a2c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([8, 768])\n" ] }, { "data": { "text/plain": [ "tensor([[1.0000, 0.8303, 0.2995, 0.3906, 0.2986, 0.3062, 0.3430, 0.3734],\n", " [0.8303, 1.0000, 0.3455, 0.4187, 0.3043, 0.3464, 0.4388, 0.3959],\n", " [0.2995, 0.3455, 1.0000, 0.6648, 0.4726, 0.4597, 0.3798, 0.3454],\n", " [0.3906, 0.4187, 0.6648, 1.0000, 0.5167, 0.5195, 0.3746, 0.4006],\n", " [0.2986, 0.3043, 0.4726, 0.5167, 1.0000, 0.7602, 0.3923, 0.3550],\n", " [0.3062, 0.3464, 0.4597, 0.5195, 0.7602, 1.0000, 0.4338, 0.3432],\n", " [0.3430, 0.4388, 0.3798, 0.3746, 0.3923, 0.4338, 1.0000, 0.6090],\n", " [0.3734, 0.3959, 0.3454, 0.4006, 0.3550, 0.3432, 0.6090, 1.0000]])" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "english_sentences = [\"Berlin is the capital of Germany\", \"Berlin is a large city in Germany\",\n", " \"Tensorflow can be used for deep learning\", \"Pytorch, developed by Facebook AI, is a deep learning framework\",\n", " \"Is Scipy or numpy better?\", \"Which is faster: scipy or pandas?\",\n", " \"Cats can live for quite a long time\", \"Cats are humans best friend\"]\n", "\n", "encoded_input = tokenizer(english_sentences, return_tensors=\"pt\", padding=True)\n", "\n", "with torch.no_grad():\n", " model_output = t5(**encoded_input)\n", " \n", " # Perform pooling\n", " hf_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])\n", "\n", " # Apply linear layer\n", " hf_embeddings = linear(hf_embeddings)\n", " \n", " print(hf_embeddings.shape)\n", "\n", " # Normalize embeddings\n", " hf_embeddings = F.normalize(hf_embeddings, p=2, dim=1)\n", "\n", "# Cos\n", "util.dot_score(hf_embeddings, hf_embeddings)" ] }, { "cell_type": "code", "execution_count": 13, "id": "677a8bab", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-01-31 23:13:39.702310: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)\n", "2022-01-31 23:13:41.448337: I tensorflow/compiler/xla/service/service.cc:171] XLA service 0x7f41641cf460 initialized for platform Host (this does not guarantee that XLA will be used). Devices:\n", "2022-01-31 23:13:41.448385: I tensorflow/compiler/xla/service/service.cc:179] StreamExecutor device (0): Host, Default Version\n", "2022-01-31 23:13:44.375222: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:210] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.\n", "2022-01-31 23:14:17.816928: I tensorflow/compiler/jit/xla_compilation_cache.cc:363] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.\n", "2022-01-31 23:14:17.866550: W tensorflow/core/framework/cpu_allocator_impl.cc:80] Allocation of 3089104896 exceeds 10% of free system memory.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(8, 768)\n" ] }, { "data": { "text/plain": [ "tensor([[1.0000, 0.8303, 0.2996, 0.3908, 0.2984, 0.3062, 0.3428, 0.3735],\n", " [0.8303, 1.0000, 0.3453, 0.4187, 0.3044, 0.3462, 0.4387, 0.3961],\n", " [0.2996, 0.3453, 1.0000, 0.6643, 0.4724, 0.4596, 0.3803, 0.3454],\n", " [0.3908, 0.4187, 0.6643, 1.0000, 0.5169, 0.5196, 0.3744, 0.4003],\n", " [0.2984, 0.3044, 0.4724, 0.5169, 1.0000, 0.7603, 0.3920, 0.3550],\n", " [0.3062, 0.3462, 0.4596, 0.5196, 0.7603, 1.0000, 0.4333, 0.3427],\n", " [0.3428, 0.4387, 0.3803, 0.3744, 0.3920, 0.4333, 1.0000, 0.6087],\n", " [0.3735, 0.3961, 0.3454, 0.4003, 0.3550, 0.3427, 0.6087, 1.0000]])" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Test the models - Original embeddings\n", "english_embeds = encoder(english_sentences)[0].numpy()\n", "print(english_embeds.shape)\n", "util.dot_score(english_embeds, english_embeds)" ] }, { "cell_type": "code", "execution_count": 14, "id": "34b44ef7", "metadata": {}, "outputs": [], "source": [ "folder = f'models/gtr-t5-{model_size_hf}'\n", "t5.save_pretrained(folder)\n", "tokenizer.save_pretrained(folder)\n", "os.makedirs(os.path.join(folder, '2_Dense'), exist_ok=True)\n", "\n", "\n", "dense = sentence_transformers.models.Dense(linear.in_features, linear.out_features, \n", " bias=False, activation_function=torch.nn.Identity())\n", "dense.linear = linear\n", "dense.save(os.path.join(folder, '2_Dense'))\n" ] }, { "cell_type": "markdown", "id": "8f6e006b", "metadata": {}, "source": [ "# FP16 experiment" ] }, { "cell_type": "code", "execution_count": null, "id": "38b1b35e", "metadata": {}, "outputs": [], "source": [ "#FP16 experiment\n", "#t5 = T5EncoderModel.from_pretrained('models/gtr-t5-base')\n", "#t5.half()\n", "#t5.save_pretrained('models/gtr-t5-base-fp16')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 5 }