{"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"GI4Sz98ItJW7"},"outputs":[],"source":["# TPU\n","# !pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.10-cp37-cp37m-linux_x86_64.whl"]},{"cell_type":"code","execution_count":9,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":3348,"status":"ok","timestamp":1637680969592,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"},"user_tz":-60},"id":"97-OsdFhlD20","outputId":"c47a98a7-f016-4a4f-827b-edc9229c5eca"},"outputs":[{"name":"stdout","output_type":"stream","text":["Requirement already satisfied: transformers in /usr/local/lib/python3.7/dist-packages (4.12.5)\n","Requirement already satisfied: sentence_transformers in /usr/local/lib/python3.7/dist-packages (2.1.0)\n","Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (6.0)\n","Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)\n","Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2019.12.20)\n","Requirement already satisfied: tokenizers<0.11,>=0.10.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.10.3)\n","Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (21.3)\n","Requirement already satisfied: huggingface-hub<1.0,>=0.1.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.1.2)\n","Requirement already satisfied: sacremoses in /usr/local/lib/python3.7/dist-packages (from transformers) (0.0.46)\n","Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.19.5)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.4.0)\n","Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.8.2)\n","Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.62.3)\n","Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0,>=0.1.0->transformers) (3.10.0.2)\n","Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers) (3.0.6)\n","Requirement already satisfied: nltk in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (3.2.5)\n","Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (1.0.1)\n","Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (1.4.1)\n","Requirement already satisfied: sentencepiece in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (0.1.96)\n","Requirement already satisfied: torchvision in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (0.11.1+cu111)\n","Requirement already satisfied: torch>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (1.10.0+cu111)\n","Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.6.0)\n","Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from nltk->sentence_transformers) (1.15.0)\n","Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2021.10.8)\n","Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (1.24.3)\n","Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)\n","Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (7.1.2)\n","Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (1.1.0)\n","Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->sentence_transformers) (3.0.0)\n","Requirement already satisfied: pillow!=8.3.0,>=5.3.0 in /usr/local/lib/python3.7/dist-packages (from torchvision->sentence_transformers) (7.1.2)\n"]}],"source":["!pip install transformers sentence_transformers"]},{"cell_type":"code","execution_count":10,"metadata":{"executionInfo":{"elapsed":3,"status":"ok","timestamp":1637680970023,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"},"user_tz":-60},"id":"3-jkyQkdkdPQ"},"outputs":[],"source":["from transformers import AutoTokenizer, AutoModel\n","import torch\n","import pickle\n","from sentence_transformers import util\n","from datetime import datetime"]},{"cell_type":"code","execution_count":3,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":26589,"status":"ok","timestamp":1637654036646,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"},"user_tz":-60},"id":"kA2h5mH8m-n8","outputId":"88fcd97f-276c-4f70-de60-d1c5c9810443"},"outputs":[{"name":"stdout","output_type":"stream","text":["Mounted at /content/drive\n"]}],"source":["from google.colab import drive\n","drive.mount('/content/drive')\n","#drive.mount('/content/drive', force_remount=True)"]},{"cell_type":"markdown","metadata":{"id":"b8SkQGWuB1z7"},"source":["# Load pretrained \n","\n","- multilingual sentence transformers from checkpoint\n","- tokenizer from checkpoint"]},{"cell_type":"code","execution_count":3,"metadata":{"executionInfo":{"elapsed":6237,"status":"ok","timestamp":1637655426545,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"},"user_tz":-60},"id":"1R83LLVAk98K"},"outputs":[],"source":["multilingual_checkpoint = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'\n","tokenizer = AutoTokenizer.from_pretrained(multilingual_checkpoint)\n","model = AutoModel.from_pretrained(multilingual_checkpoint)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"wcdik3tQpkyi"},"outputs":[],"source":["# GPU\n","device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n","model.to(device)\n","print(device)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"-YzAkemLsrC9"},"outputs":[],"source":["# TPU\n","# unfortunately incompatible wheel package for pytorch-xla 1.10 version\n","#import torch_xla.core.xla_model as xm\n","#device = xm.xla_device()\n","#print(device)\n","#pip list | grep torch"]},{"cell_type":"code","execution_count":23,"metadata":{"executionInfo":{"elapsed":362,"status":"ok","timestamp":1637682096594,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"},"user_tz":-60},"id":"dfeEQJOglxdw"},"outputs":[],"source":["#Mean Pooling - Take attention mask into account for correct averaging\n","def mean_pooling(model_output, attention_mask):\n"," token_embeddings = model_output[0] #First element of model_output contains all token embeddings\n"," input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()\n"," sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)\n"," sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)\n"," return sum_embeddings / sum_mask\n","\n","def calculateEmbeddings(sentences,tokenizer,model,device=\"cpu\"):\n"," tokenized_sentences = tokenizer(sentences, padding=True, truncation=True, max_length=128, return_tensors='pt')\n"," tokenized_sentences.to(device)\n"," with torch.no_grad():\n"," model_output = model(**tokenized_sentences)\n"," sentence_embeddings = mean_pooling(model_output, tokenized_sentences['attention_mask'])\n"," del tokenized_sentences\n"," torch.cuda.empty_cache()\n"," return sentence_embeddings\n","\n","def findTopKMostSimilar(query_embedding, embeddings, k):\n"," cosine_scores = util.pytorch_cos_sim(query_embedding, embeddings)\n"," cosine_scores_list = cosine_scores.squeeze().tolist()\n"," pairs = []\n"," for idx,score in enumerate(cosine_scores_list):\n"," pairs.append({'index': idx, 'score': score})\n"," pairs = sorted(pairs, key=lambda x: x['score'], reverse=True)\n"," return pairs[0:k]\n","\n","def saveToDisc(embeddings, output_filename):\n"," with open(output_filename, \"ab\") as f:\n"," pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL)"]},{"cell_type":"markdown","metadata":{"id":"MddjkKfMCH81"},"source":["# Create sentence embeddings\n","\n","\n","* Load sentences from raw text file\n","* Precalculate in batches of 1000, to avoid running out of memory\n","* Save to disc/files incrementally, to be able to reuse later (in total 5 files of 100.000 embedding each)\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"yfOsCAVImIAl"},"outputs":[],"source":["batch_size = 1000\n","\n","raw_text_file = '/content/drive/MyDrive/huggingface/shortened_abstracts_hu_2021_09_01.txt'\n","datetime_formatted = datetime.now().strftime('%Y-%m-%d_%H:%M:%S')\n","output_embeddings_file_batched = f'/content/drive/MyDrive/huggingface/embeddings_{batch_size}_batches_at_{datetime_formatted}.pkl'\n","output_embeddings_file = f'/content/drive/MyDrive/huggingface/embeddings_at_{datetime_formatted}.pkl'\n","\n","print(datetime.now())\n","concated_sentence_embeddings = None\n","all_sentences = []\n","line = 'init'\n","total_read = 0\n","total_read_limit = 500000\n","skip_index = 400000\n","with open(raw_text_file) as f:\n"," while line and total_read < total_read_limit:\n"," count = 0\n"," sentence_batch = []\n"," while line and count < batch_size:\n"," line = f.readline()\n"," sentence_batch.append(line)\n"," count += 1\n"," \n"," all_sentences.extend(sentence_batch)\n"," \n"," if total_read >= skip_index:\n"," sentence_embeddings = calculateEmbeddings(sentence_batch,tokenizer,model,device)\n"," if concated_sentence_embeddings == None:\n"," concated_sentence_embeddings = sentence_embeddings\n"," else:\n"," concated_sentence_embeddings = torch.cat([concated_sentence_embeddings, sentence_embeddings], dim=0)\n"," print(concated_sentence_embeddings.size())\n"," saveToDisc(sentence_embeddings,output_embeddings_file_batched)\n"," total_read += count\n","print(datetime.now())"]},{"cell_type":"markdown","metadata":{"id":"1rGQc9GRCuNy"},"source":["# Test: Query embeddings"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"FT7CwpM0Bwhi"},"outputs":[],"source":["query_embedding = calculateEmbeddings(['Melyik a legnépesebb város a világon?'],tokenizer,model,device)\n","top_pairs = findTopKMostSimilar(query_embedding, concated_sentence_embeddings, 5)\n","\n","for pair in top_pairs:\n"," i = pair['index']\n"," score = pair['score']\n"," print(\"{} \\t\\t Score: {:.4f}\".format(all_sentences[skip_index+i], score))"]},{"cell_type":"markdown","metadata":{"id":"6Hdu_5FiDYJr"},"source":["# Test: Load pre-calculated embeddings\n","\n","* Load embedding from files and stitch them together\n","* Save into one file\n"]},{"cell_type":"code","execution_count":20,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":1722,"status":"ok","timestamp":1637682006152,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"},"user_tz":-60},"id":"gkWt0Uj_Ddsp","outputId":"1921456e-1fd6-4218-9ebb-cbe503f402b1"},"outputs":[{"name":"stdout","output_type":"stream","text":["Processing file 0\n","Read file using 100 number of read+unpickle operations\n","torch.Size([100000, 384])\n","Processing file 1\n","Read file using 100 number of read+unpickle operations\n","torch.Size([100000, 384])\n","Processing file 2\n","Read file using 100 number of read+unpickle operations\n","torch.Size([100000, 384])\n","Processing file 3\n","Read file using 100 number of read+unpickle operations\n","torch.Size([100000, 384])\n","Processing file 4\n","Read file using 67 number of read+unpickle operations\n","torch.Size([66529, 384])\n","torch.Size([466529, 384])\n"]}],"source":["def concatTensors(new_tensor, acc_tensor='None'):\n"," if acc_tensor == None:\n"," acc_tensor = new_tensor\n"," else:\n"," acc_tensor = torch.cat([acc_tensor, new_tensor], dim=0)\n"," return acc_tensor\n","\n","def loadFromDisc(batch_size, number_of_batches, filename):\n"," concated_sentence_embeddings = None\n"," count = 0\n"," batches = 0\n"," with open(filename, \"rb\") as f:\n"," loaded_embeddings = torch.empty([batch_size])\n"," while count < number_of_batches and loaded_embeddings.size()[0]==batch_size:\n"," loaded_embeddings = pickle.load(f)\n"," count += 1\n"," concated_sentence_embeddings = concatTensors(loaded_embeddings,concated_sentence_embeddings)\n"," print(f'Read file using {count} number of read+unpickle operations')\n"," print(concated_sentence_embeddings.size())\n"," return concated_sentence_embeddings\n","\n","\n","output_embeddings_file = 'data/preprocessed/DBpedia_shortened_abstracts_hu_embeddings.pkl'\n","\n","embeddings_files = [\n"," '/content/drive/MyDrive/huggingface/embeddings_1000_batches_at_2021-11-23_08:17:17.pkl',\n"," '/content/drive/MyDrive/huggingface/embeddings_1000_batches_at_2021-11-23_08:28:46.pkl',\n"," '/content/drive/MyDrive/huggingface/embeddings_1000_batches_at_2021-11-23_08:40:54.pkl',\n"," '/content/drive/MyDrive/huggingface/embeddings_1000_batches_at_2021-11-23_08:56:26.pkl',\n"," '/content/drive/MyDrive/huggingface/embeddings_1000_batches_at_2021-11-23_09:31:47.pkl'\n","]\n","\n","all_embeddings = None\n","for idx,emb_file in enumerate(embeddings_files):\n"," print(f'Processing file {idx}')\n"," file_embeddings = loadFromDisc(1000, 100, emb_file)\n"," all_embeddings = concatTensors(file_embeddings,all_embeddings)\n","\n","print(all_embeddings.size())"]},{"cell_type":"code","execution_count":28,"metadata":{"executionInfo":{"elapsed":384,"status":"ok","timestamp":1637683739951,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"},"user_tz":-60},"id":"M_8RHpNnIU7o"},"outputs":[],"source":["all_embeddings_output_file = '/content/drive/MyDrive/huggingface/shortened_abstracts_hu_2021_09_01_embedded.pt'\n","#saveToDisc(all_embeddings, all_embeddings_output_file)\n","torch.save(all_embeddings,all_embeddings_output_file)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"LYCwyDpMjsXg"},"outputs":[],"source":[]}],"metadata":{"accelerator":"GPU","colab":{"authorship_tag":"ABX9TyN3TvKBRyS+wRVSLWNFgC+f","collapsed_sections":[],"mount_file_id":"1e_NcpgIuSh8rfI_Xf16ltcybK8TbgJWB","name":"QA_retrieval_huggingface_couser_2021_Nov.ipynb","provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0}