{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "17bffc12", "metadata": {}, "outputs": [], "source": [ "from transformers import AutoTokenizer\n", "from sentence_transformers import util\n", "import os\n", "import numpy as np\n", "import torch.nn.functional as F\n", "from transformers import T5EncoderModel\n", "import torch" ] }, { "cell_type": "code", "execution_count": 3, "id": "160d8ce6", "metadata": {}, "outputs": [], "source": [ "#Mean Pooling - Take attention mask into account for correct averaging\n", "def mean_pooling(model_output, attention_mask):\n", " token_embeddings = model_output[0] #First element of model_output contains all token embeddings\n", " input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()\n", " return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)\n" ] }, { "cell_type": "code", "execution_count": 16, "id": "2f67f426", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:Importing a function (__inference__9720) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.\n", "WARNING:absl:Importing a function (__inference__3354) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.\n", "WARNING:absl:Importing a function (__inference__6722) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.\n" ] } ], "source": [ "import tensorflow as tf\n", "import tensorflow_hub as hub\n", "import tensorflow_text as text \n", "\n", "model_size = \"base\"\n", "hub_url = f\"https://tfhub.dev/google/sentence-t5/st5-{model_size}/1\"\n", "encoder = hub.load(hub_url)\n", "\n", "v = encoder.signatures['serving_default'].variables" ] }, { "cell_type": "code", "execution_count": 17, "id": "5f4c8d94", "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "{'encoder__encoder_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_0__attention__key__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_0__attention__out__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_0__attention__query__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_0__attention__value__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_0__mlp__wi__kernel:0': TensorShape([768, 3072]),\n", " 'encoder__layers_0__mlp__wo__kernel:0': TensorShape([3072, 768]),\n", " 'encoder__layers_0__pre_attention_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_0__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_1__attention__key__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_1__attention__out__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_1__attention__query__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_1__attention__value__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_1__mlp__wi__kernel:0': TensorShape([768, 3072]),\n", " 'encoder__layers_1__mlp__wo__kernel:0': TensorShape([3072, 768]),\n", " 'encoder__layers_1__pre_attention_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_1__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_10__attention__key__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_10__attention__out__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_10__attention__query__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_10__attention__value__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_10__mlp__wi__kernel:0': TensorShape([768, 3072]),\n", " 'encoder__layers_10__mlp__wo__kernel:0': TensorShape([3072, 768]),\n", " 'encoder__layers_10__pre_attention_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_10__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_11__attention__key__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_11__attention__out__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_11__attention__query__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_11__attention__value__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_11__mlp__wi__kernel:0': TensorShape([768, 3072]),\n", " 'encoder__layers_11__mlp__wo__kernel:0': TensorShape([3072, 768]),\n", " 'encoder__layers_11__pre_attention_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_11__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_2__attention__key__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_2__attention__out__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_2__attention__query__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_2__attention__value__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_2__mlp__wi__kernel:0': TensorShape([768, 3072]),\n", " 'encoder__layers_2__mlp__wo__kernel:0': TensorShape([3072, 768]),\n", " 'encoder__layers_2__pre_attention_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_2__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_3__attention__key__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_3__attention__out__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_3__attention__query__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_3__attention__value__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_3__mlp__wi__kernel:0': TensorShape([768, 3072]),\n", " 'encoder__layers_3__mlp__wo__kernel:0': TensorShape([3072, 768]),\n", " 'encoder__layers_3__pre_attention_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_3__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_4__attention__key__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_4__attention__out__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_4__attention__query__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_4__attention__value__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_4__mlp__wi__kernel:0': TensorShape([768, 3072]),\n", " 'encoder__layers_4__mlp__wo__kernel:0': TensorShape([3072, 768]),\n", " 'encoder__layers_4__pre_attention_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_4__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_5__attention__key__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_5__attention__out__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_5__attention__query__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_5__attention__value__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_5__mlp__wi__kernel:0': TensorShape([768, 3072]),\n", " 'encoder__layers_5__mlp__wo__kernel:0': TensorShape([3072, 768]),\n", " 'encoder__layers_5__pre_attention_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_5__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_6__attention__key__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_6__attention__out__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_6__attention__query__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_6__attention__value__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_6__mlp__wi__kernel:0': TensorShape([768, 3072]),\n", " 'encoder__layers_6__mlp__wo__kernel:0': TensorShape([3072, 768]),\n", " 'encoder__layers_6__pre_attention_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_6__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_7__attention__key__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_7__attention__out__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_7__attention__query__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_7__attention__value__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_7__mlp__wi__kernel:0': TensorShape([768, 3072]),\n", " 'encoder__layers_7__mlp__wo__kernel:0': TensorShape([3072, 768]),\n", " 'encoder__layers_7__pre_attention_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_7__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_8__attention__key__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_8__attention__out__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_8__attention__query__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_8__attention__value__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_8__mlp__wi__kernel:0': TensorShape([768, 3072]),\n", " 'encoder__layers_8__mlp__wo__kernel:0': TensorShape([3072, 768]),\n", " 'encoder__layers_8__pre_attention_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_8__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_9__attention__key__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_9__attention__out__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_9__attention__query__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_9__attention__value__kernel:0': TensorShape([768, 768]),\n", " 'encoder__layers_9__mlp__wi__kernel:0': TensorShape([768, 3072]),\n", " 'encoder__layers_9__mlp__wo__kernel:0': TensorShape([3072, 768]),\n", " 'encoder__layers_9__pre_attention_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__layers_9__pre_mlp_layer_norm__scale:0': TensorShape([768]),\n", " 'encoder__relpos_bias__rel_embedding:0': TensorShape([12, 32]),\n", " 'projection_layer__kernel:0': TensorShape([768, 768]),\n", " 'token_embedder__embedding:0': TensorShape([32128, 768])}" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tf_name_weight = {var.name: var for var in v}\n", "tf_name_shape = {var.name: var.shape for var in v}\n", "tf_name_shape" ] }, { "cell_type": "code", "execution_count": 6, "id": "1d3c9865", "metadata": {}, "outputs": [], "source": [ "def convert_name(name):\n", " fct_map = {\n", " \"attention\": \"SelfAttention\",\n", " \"mlp\": \"DenseReluDense\",\n", " \"pre_attention_layer_norm\": \"layer_norm\",\n", " \"pre_mlp_layer_norm\": \"layer_norm\",\n", " }\n", " name_map = {\n", " 'key': 'k',\n", " 'out': 'o',\n", " 'query': 'q',\n", " 'value': 'v'\n", " }\n", " \n", " fixed_names = {\n", " \"token_embedder__embedding:0\": \"shared.weight\",\n", " \"encoder__encoder_norm__scale:0\": \"encoder.final_layer_norm.weight\",\n", " \"encoder__relpos_bias__rel_embedding:0\": \"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight\"\n", " }\n", " \n", " if name in fixed_names:\n", " return fixed_names[name]\n", " \n", " out = \"\"\n", " splits = name.split(\"__\")\n", " layer = splits[1].split(\"_\")[1]\n", " fct = fct_map.get(splits[2], splits[2])\n", " if 'layer_norm' in name:\n", " sublayer = \"1\" if \"pre_mlp_layer_norm\" in name else \"0\" #Not sure on the right setting here\n", " #sublayer = \"0\" if \"pre_mlp_layer_norm\" in name else \"1\" #Not sure on the right setting here\n", " out = f\"encoder.block.{layer}.layer.{sublayer}.{fct}.weight\"\n", " elif name.startswith(\"encoder__layers_\"):\n", " sublayer = \"0\" if fct == \"SelfAttention\" else \"1\"\n", " name = name_map.get(splits[3], splits[3])\n", " out = f\"encoder.block.{layer}.layer.{sublayer}.{fct}.{name}.weight\"\n", " \n", " return out" ] }, { "cell_type": "code", "execution_count": 7, "id": "1ca9590e", "metadata": {}, "outputs": [], "source": [ "def equal_shapes(shape1, shape2):\n", " if len(shape1) != len(shape2):\n", " return False\n", " \n", " for idx in range(len(shape1)):\n", " if shape1[idx] != shape2[idx]:\n", " return False\n", " \n", " return True" ] }, { "cell_type": "code", "execution_count": 8, "id": "6d223b07", "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of T5EncoderModel were not initialized from the model checkpoint at t5-11b and are newly initialized: ['encoder.embed_tokens.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "data": { "text/plain": [ "{'shared.weight': torch.Size([32128, 1024]),\n", " 'encoder.embed_tokens.weight': torch.Size([32128, 1024]),\n", " 'encoder.block.0.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.0.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.0.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.0.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight': torch.Size([32, 128]),\n", " 'encoder.block.0.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.0.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.0.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.0.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.1.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.1.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.1.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.1.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.1.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.1.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.1.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.1.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.2.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.2.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.2.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.2.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.2.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.2.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.2.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.2.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.3.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.3.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.3.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.3.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.3.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.3.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.3.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.3.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.4.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.4.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.4.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.4.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.4.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.4.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.4.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.4.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.5.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.5.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.5.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.5.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.5.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.5.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.5.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.5.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.6.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.6.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.6.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.6.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.6.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.6.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.6.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.6.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.7.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.7.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.7.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.7.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.7.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.7.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.7.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.7.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.8.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.8.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.8.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.8.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.8.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.8.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.8.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.8.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.9.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.9.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.9.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.9.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.9.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.9.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.9.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.9.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.10.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.10.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.10.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.10.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.10.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.10.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.10.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.10.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.11.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.11.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.11.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.11.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.11.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.11.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.11.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.11.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.12.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.12.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.12.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.12.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.12.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.12.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.12.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.12.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.13.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.13.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.13.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.13.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.13.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.13.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.13.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.13.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.14.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.14.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.14.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.14.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.14.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.14.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.14.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.14.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.15.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.15.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.15.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.15.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.15.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.15.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.15.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.15.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.16.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.16.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.16.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.16.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.16.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.16.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.16.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.16.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.17.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.17.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.17.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.17.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.17.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.17.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.17.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.17.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.18.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.18.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.18.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.18.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.18.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.18.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.18.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.18.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.19.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.19.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.19.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.19.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.19.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.19.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.19.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.19.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.20.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.20.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.20.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.20.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.20.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.20.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.20.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.20.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.21.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.21.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.21.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.21.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.21.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.21.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.21.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.21.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.22.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.22.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.22.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.22.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.22.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.22.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.22.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.22.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.23.layer.0.SelfAttention.q.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.23.layer.0.SelfAttention.k.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.23.layer.0.SelfAttention.v.weight': torch.Size([16384, 1024]),\n", " 'encoder.block.23.layer.0.SelfAttention.o.weight': torch.Size([1024, 16384]),\n", " 'encoder.block.23.layer.0.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.block.23.layer.1.DenseReluDense.wi.weight': torch.Size([65536, 1024]),\n", " 'encoder.block.23.layer.1.DenseReluDense.wo.weight': torch.Size([1024, 65536]),\n", " 'encoder.block.23.layer.1.layer_norm.weight': torch.Size([1024]),\n", " 'encoder.final_layer_norm.weight': torch.Size([1024])}" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer = AutoTokenizer.from_pretrained(f\"t5-{model_size}\")\n", "T5EncoderModel._keys_to_ignore_on_load_unexpected = [\"decoder.*\"]\n", "t5 = T5EncoderModel.from_pretrained(f\"t5-{model_size}\") \n", "pt_name_shape = {name: weight.shape for name, weight in t5.state_dict().items()}\n", "pt_name_shape" ] }, { "cell_type": "code", "execution_count": 9, "id": "ced52a5f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Remaining weights: {'encoder.embed_tokens.weight'}\n" ] } ], "source": [ "def need_transpose(name, transpose_names=['DenseReluDense', 'relative_attention_bias']):\n", " #HF function: https://github.com/huggingface/transformers/blob/c962c2adbff678ae6d2e98378bed5b8d1a9831d9/src/transformers/models/t5/modeling_t5.py#L161\n", " return name != \"shared.weight\"\n", "\n", "\n", "#Additional dense layer on top\n", "names_to_ignore = {\"projection_layer__kernel:0\"}\n", "\n", "#Check we used all names\n", "pt_all_names = set(t5.state_dict().keys())\n", "\n", "for var in v:\n", " name = var.name\n", " if name in names_to_ignore:\n", " continue\n", " \n", " pt_name = convert_name(name)\n", " if pt_name not in pt_all_names:\n", " print(\"Name not found:\", name, \"=>\", pt_name)\n", " else:\n", " pt_all_names.remove(pt_name)\n", " tf_shape = tf_name_shape[name].as_list()\n", " pt_shape = list(pt_name_shape[pt_name])\n", " \n", " if need_transpose(pt_name):\n", " pt_shape = list(reversed(pt_shape))\n", " \n", " if not equal_shapes(tf_shape, pt_shape):\n", " print(\"Different shape:\", name, tf_shape, pt_name, pt_shape )\n", " \n", "print(\"Remaining weights:\", pt_all_names)\n", "#All layers match" ] }, { "cell_type": "code", "execution_count": 10, "id": "1190984f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "encoder__encoder_norm__scale:0 ((1024,)) =transpose=> encoder.final_layer_norm.weight torch.Size([1024])\n", "encoder__layers_0__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.0.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_0__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.0.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_0__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.0.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_0__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.0.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_0__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.0.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_0__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.0.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_0__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.0.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_0__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.0.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_1__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.1.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_1__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.1.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_1__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.1.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_1__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.1.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_1__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.1.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_1__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.1.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_1__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.1.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_1__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.1.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_10__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.10.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_10__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.10.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_10__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.10.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_10__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.10.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_10__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.10.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_10__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.10.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_10__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.10.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_10__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.10.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_11__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.11.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_11__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.11.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_11__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.11.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_11__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.11.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_11__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.11.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_11__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.11.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_11__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.11.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_11__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.11.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_12__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.12.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_12__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.12.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_12__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.12.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_12__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.12.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_12__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.12.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_12__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.12.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_12__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.12.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_12__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.12.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_13__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.13.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_13__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.13.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_13__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.13.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_13__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.13.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_13__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.13.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_13__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.13.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_13__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.13.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_13__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.13.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_14__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.14.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_14__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.14.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_14__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.14.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_14__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.14.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_14__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.14.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_14__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.14.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_14__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.14.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_14__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.14.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_15__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.15.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "encoder__layers_15__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.15.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_15__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.15.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_15__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.15.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_15__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.15.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_15__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.15.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_15__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.15.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_15__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.15.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_16__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.16.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_16__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.16.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_16__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.16.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_16__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.16.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_16__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.16.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_16__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.16.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_16__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.16.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_16__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.16.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_17__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.17.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_17__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.17.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_17__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.17.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_17__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.17.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_17__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.17.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_17__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.17.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_17__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.17.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_17__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.17.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_18__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.18.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_18__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.18.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_18__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.18.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_18__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.18.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_18__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.18.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_18__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.18.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_18__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.18.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_18__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.18.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_19__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.19.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_19__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.19.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_19__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.19.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_19__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.19.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_19__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.19.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_19__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.19.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_19__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.19.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_19__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.19.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_2__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.2.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_2__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.2.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_2__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.2.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_2__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.2.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_2__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.2.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_2__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.2.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_2__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.2.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_2__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.2.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_20__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.20.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_20__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.20.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_20__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.20.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_20__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.20.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_20__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.20.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_20__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.20.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_20__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.20.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_20__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.20.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_21__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.21.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_21__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.21.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_21__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.21.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "encoder__layers_21__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.21.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_21__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.21.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_21__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.21.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_21__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.21.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_21__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.21.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_22__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.22.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_22__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.22.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_22__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.22.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_22__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.22.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_22__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.22.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_22__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.22.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_22__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.22.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_22__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.22.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_23__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.23.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_23__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.23.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_23__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.23.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_23__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.23.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_23__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.23.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_23__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.23.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_23__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.23.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_23__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.23.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_3__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.3.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_3__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.3.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_3__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.3.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_3__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.3.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_3__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.3.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_3__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.3.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_3__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.3.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_3__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.3.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_4__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.4.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_4__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.4.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_4__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.4.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_4__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.4.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_4__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.4.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_4__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.4.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_4__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.4.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_4__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.4.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_5__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.5.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_5__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.5.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_5__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.5.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_5__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.5.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_5__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.5.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_5__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.5.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_5__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.5.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_5__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.5.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_6__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.6.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_6__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.6.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_6__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.6.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_6__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.6.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_6__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.6.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_6__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.6.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_6__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.6.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_6__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.6.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_7__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.7.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_7__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.7.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_7__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.7.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_7__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.7.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_7__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.7.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "encoder__layers_7__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.7.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_7__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.7.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_7__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.7.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_8__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.8.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_8__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.8.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_8__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.8.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_8__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.8.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_8__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.8.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_8__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.8.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_8__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.8.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_8__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.8.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__layers_9__attention__key__kernel:0 ((1024, 16384)) =transpose=> encoder.block.9.layer.0.SelfAttention.k.weight torch.Size([16384, 1024])\n", "encoder__layers_9__attention__out__kernel:0 ((16384, 1024)) =transpose=> encoder.block.9.layer.0.SelfAttention.o.weight torch.Size([1024, 16384])\n", "encoder__layers_9__attention__query__kernel:0 ((1024, 16384)) =transpose=> encoder.block.9.layer.0.SelfAttention.q.weight torch.Size([16384, 1024])\n", "encoder__layers_9__attention__value__kernel:0 ((1024, 16384)) =transpose=> encoder.block.9.layer.0.SelfAttention.v.weight torch.Size([16384, 1024])\n", "encoder__layers_9__mlp__wi__kernel:0 ((1024, 65536)) =transpose=> encoder.block.9.layer.1.DenseReluDense.wi.weight torch.Size([65536, 1024])\n", "encoder__layers_9__mlp__wo__kernel:0 ((65536, 1024)) =transpose=> encoder.block.9.layer.1.DenseReluDense.wo.weight torch.Size([1024, 65536])\n", "encoder__layers_9__pre_attention_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.9.layer.0.layer_norm.weight torch.Size([1024])\n", "encoder__layers_9__pre_mlp_layer_norm__scale:0 ((1024,)) =transpose=> encoder.block.9.layer.1.layer_norm.weight torch.Size([1024])\n", "encoder__relpos_bias__rel_embedding:0 ((128, 32)) =transpose=> encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight torch.Size([32, 128])\n", "token_embedder__embedding:0 ((32128, 1024)) => shared.weight torch.Size([32128, 1024])\n", "Linear(in_features=1024, out_features=768, bias=False)\n", "Remaining weights: set()\n" ] } ], "source": [ "t5_state = t5.state_dict()\n", "state_all_names = set(t5_state.keys())\n", "\n", "\n", "for var in v:\n", " tf_name = var.name\n", " if tf_name in names_to_ignore:\n", " continue\n", " \n", " pt_name = convert_name(tf_name)\n", " weights = np.float32(var.numpy())\n", " \n", " state_all_names.remove(pt_name)\n", " \n", " tranpose_status = \"=>\"\n", " if need_transpose(pt_name, ['DenseReluDense', 'relative_attention_bias',]):\n", " tranpose_status = \"=transpose=>\"\n", " weights = weights.transpose()\n", " \n", " print(tf_name, f\"({var.shape})\", tranpose_status, pt_name, t5_state[pt_name].shape)\n", " \n", " original_shape = t5_state[pt_name].shape\n", " t5_state[pt_name] = torch.nn.Parameter(torch.tensor(weights))\n", " new_shape = t5_state[pt_name].shape\n", " \n", " if not equal_shapes(original_shape, new_shape):\n", " print(\"Different shape:\", tf_name, original_shape, pt_name, new_shape)\n", " break\n", "\n", "#Encoder Word embeddings\n", "t5_state['encoder.embed_tokens.weight'] = t5_state['shared.weight']\n", "state_all_names.remove('encoder.embed_tokens.weight')\n", " \n", "#Load back the weights\n", "t5.load_state_dict(t5_state) \n", "\n", "tf_linear_weight = tf_name_weight[\"projection_layer__kernel:0\"]\n", "linear = torch.nn.Linear(tf_linear_weight.shape[0], tf_linear_weight.shape[1], bias=False)\n", "original_shape = linear.weight.shape\n", "linear.weight = torch.nn.Parameter(torch.tensor(np.float32(tf_linear_weight.numpy()).transpose()))\n", "new_shape = linear.weight.shape\n", "if not equal_shapes(original_shape, new_shape):\n", " print(\"Different shape at linear layer\")\n", " \n", "print(linear)\n", "print(\"Remaining weights:\", state_all_names)\n", "assert len(state_all_names) == 0\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "d59d5a2c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([8, 768])\n" ] }, { "data": { "text/plain": [ "tensor([[1.0000, 0.9279, 0.6404, 0.5968, 0.5420, 0.5442, 0.6099, 0.6318],\n", " [0.9279, 1.0000, 0.6629, 0.6098, 0.5562, 0.5687, 0.6382, 0.6262],\n", " [0.6404, 0.6629, 1.0000, 0.8351, 0.7101, 0.6953, 0.6265, 0.6390],\n", " [0.5968, 0.6098, 0.8351, 1.0000, 0.6877, 0.6716, 0.5902, 0.6102],\n", " [0.5420, 0.5562, 0.7101, 0.6877, 1.0000, 0.8924, 0.5701, 0.5661],\n", " [0.5442, 0.5687, 0.6953, 0.6716, 0.8924, 1.0000, 0.5665, 0.5457],\n", " [0.6099, 0.6382, 0.6265, 0.5902, 0.5701, 0.5665, 1.0000, 0.7950],\n", " [0.6318, 0.6262, 0.6390, 0.6102, 0.5661, 0.5457, 0.7950, 1.0000]])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "english_sentences = [\"Berlin is the capital of Germany\", \"Berlin is a large city in Germany\",\n", " \"Tensorflow can be used for deep learning\", \"Pytorch, developed by Facebook AI, is a deep learning framework\",\n", " \"Is Scipy or numpy better?\", \"Which is faster: scipy or pandas?\",\n", " \"Cats can live for quite a long time\", \"Cats are humans best friend\"]\n", "\n", "encoded_input = tokenizer(english_sentences, return_tensors=\"pt\", padding=True)\n", "\n", "with torch.no_grad():\n", " model_output = t5(**encoded_input)\n", " \n", " # Perform pooling\n", " hf_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])\n", "\n", " # Apply linear layer\n", " hf_embeddings = linear(hf_embeddings)\n", " \n", " print(hf_embeddings.shape)\n", "\n", " # Normalize embeddings\n", " hf_embeddings = F.normalize(hf_embeddings, p=2, dim=1)\n", "\n", "# Cos\n", "hf_scores = util.dot_score(hf_embeddings, hf_embeddings).numpy()\n", "hf_scores" ] }, { "cell_type": "code", "execution_count": 12, "id": "677a8bab", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-02-01 20:00:27.115638: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)\n", "2022-02-01 20:00:29.328848: I tensorflow/compiler/xla/service/service.cc:171] XLA service 0x7fe9781cd6f0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:\n", "2022-02-01 20:00:29.328894: I tensorflow/compiler/xla/service/service.cc:179] StreamExecutor device (0): Host, Default Version\n", "2022-02-01 20:00:30.324558: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:210] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.\n", "2022-02-01 20:01:02.775112: I tensorflow/compiler/jit/xla_compilation_cache.cc:363] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(8, 768)\n" ] }, { "data": { "text/plain": [ "tensor([[1.0000, 0.9279, 0.6402, 0.5966, 0.5422, 0.5446, 0.6097, 0.6320],\n", " [0.9279, 1.0000, 0.6631, 0.6099, 0.5566, 0.5690, 0.6386, 0.6268],\n", " [0.6402, 0.6631, 1.0000, 0.8347, 0.7101, 0.6955, 0.6264, 0.6389],\n", " [0.5966, 0.6099, 0.8347, 1.0000, 0.6873, 0.6712, 0.5899, 0.6100],\n", " [0.5422, 0.5566, 0.7101, 0.6873, 1.0000, 0.8927, 0.5700, 0.5661],\n", " [0.5446, 0.5690, 0.6955, 0.6712, 0.8927, 1.0000, 0.5663, 0.5458],\n", " [0.6097, 0.6386, 0.6264, 0.5899, 0.5700, 0.5663, 1.0000, 0.7949],\n", " [0.6320, 0.6268, 0.6389, 0.6100, 0.5661, 0.5458, 0.7949, 1.0000]])" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Test the models - Original embeddings\n", "english_embeds = encoder(english_sentences)[0].numpy()\n", "print(english_embeds.shape)\n", "tf_scores = util.dot_score(english_embeds, english_embeds).numpy()\n", "print(tf_scores)\n", "print(\"Diff:\", np.sum(np.abs(tf_scores - hf_scores)))" ] }, { "cell_type": "code", "execution_count": 13, "id": "34b44ef7", "metadata": {}, "outputs": [ { "ename": "FileNotFoundError", "evalue": "[Errno 2] No such file or directory: 'models/sentence-t5-11b/2_Dense/config.json'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m/tmp/ipykernel_26913/2543044366.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 7\u001b[0m bias=False, activation_function=torch.nn.Identity())\n\u001b[1;32m 8\u001b[0m \u001b[0mdense\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlinear\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0mdense\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfolder\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'2_Dense'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/home/sbert/sentence-transformers/sentence_transformers/models/Dense.py\u001b[0m in \u001b[0;36msave\u001b[0;34m(self, output_path)\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0msave\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 48\u001b[0;31m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'config.json'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'w'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mfOut\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 49\u001b[0m \u001b[0mjson\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdump\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_config_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfOut\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'models/sentence-t5-11b/2_Dense/config.json'" ] } ], "source": [ "folder = f'models/sentence-t5-{model_size}'\n", "t5.save_pretrained(folder)\n", "tokenizer.save_pretrained(folder)\n", "\n", "import sentence_transformers\n", "dense = sentence_transformers.models.Dense(linear.in_features, linear.out_features, \n", " bias=False, activation_function=torch.nn.Identity())\n", "dense.linear = linear\n", "\n", "dense_path = os.path.join(folder, '2_Dense')\n", "os.makedirs(dense_path, exist_ok=True)\n", "dense.save(dense_path)\n" ] }, { "cell_type": "code", "execution_count": 15, "id": "f2d561c1", "metadata": {}, "outputs": [], "source": [ "\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 5 }