{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "1aafbf18-de38-4fcf-8245-e2e9a584971f", "metadata": { "tags": [] }, "outputs": [], "source": [ "# ! pip install pymilvus==2.3.4\n", "# ! pip install pyarrow==12.0.0\n", "# !pip install -U transformers" ] }, { "cell_type": "code", "execution_count": null, "id": "f1d8f101-f51b-4a50-b150-86e87c50c453", "metadata": { "tags": [] }, "outputs": [], "source": [ "from transformers import DistilBertTokenizerFast\n", "from tensorflow.keras.models import load_model, Model\n", "import numpy as np\n", "import tensorflow as tf\n", "from tqdm import tqdm\n", "from dotenv import load_dotenv\n", "import os\n", "import pandas as pd\n", "from pymilvus import connections, utility\n", "from pymilvus import Collection, DataType, FieldSchema, CollectionSchema\n", "import multiprocessing" ] }, { "cell_type": "code", "execution_count": null, "id": "4ad4e3ac-9685-4f12-8043-5fbcc373d3e1", "metadata": {}, "outputs": [], "source": [ "tf.config.list_physical_devices('GPU')" ] }, { "cell_type": "code", "execution_count": 2, "id": "da71d832-b8a7-452b-b736-538a3c069b54", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
indexcategoryshort_description
00SCIENCEA closer look at water-splitting's solar fuel ...
11SCIENCEAn irresistible scent makes locusts swarm, stu...
22SCIENCEArtificial intelligence warning: AI will know ...
33SCIENCEGlaciers Could Have Sculpted Mars Valleys: Study
44SCIENCEPerseid meteor shower 2020: What time and how ...
............
311171311171TECHRIM CEO Thorsten Heins' 'Significant' Plans Fo...
311172311172SPORTSMaria Sharapova Stunned By Victoria Azarenka I...
311173311173SPORTSGiants Over Patriots, Jets Over Colts Among M...
311174311174SPORTSAldon Smith Arrested: 49ers Linebacker Busted ...
311175311175SPORTSDwight Howard Rips Teammates After Magic Loss ...
\n", "

311176 rows × 3 columns

\n", "
" ], "text/plain": [ " index category short_description\n", "0 0 SCIENCE A closer look at water-splitting's solar fuel ...\n", "1 1 SCIENCE An irresistible scent makes locusts swarm, stu...\n", "2 2 SCIENCE Artificial intelligence warning: AI will know ...\n", "3 3 SCIENCE Glaciers Could Have Sculpted Mars Valleys: Study\n", "4 4 SCIENCE Perseid meteor shower 2020: What time and how ...\n", "... ... ... ...\n", "311171 311171 TECH RIM CEO Thorsten Heins' 'Significant' Plans Fo...\n", "311172 311172 SPORTS Maria Sharapova Stunned By Victoria Azarenka I...\n", "311173 311173 SPORTS Giants Over Patriots, Jets Over Colts Among M...\n", "311174 311174 SPORTS Aldon Smith Arrested: 49ers Linebacker Busted ...\n", "311175 311175 SPORTS Dwight Howard Rips Teammates After Magic Loss ...\n", "\n", "[311176 rows x 3 columns]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = pd.read_csv('labelled_newscatcher_dataset.csv', sep=\";\", usecols=['title', 'topic'])\n", "json_data=pd.read_json('News_Category_Dataset_v3.json', lines=True)\n", "data.drop_duplicates(subset=['title'], inplace=True)\n", "json_data.drop_duplicates(subset=['headline'], inplace=True)\n", "json_data = json_data[['headline', 'category']].copy()\n", "json_data.rename(columns={'headline': 'title'}, inplace=True)\n", "data.rename(columns={'topic': 'category'}, inplace=True)\n", "data = pd.concat([data, json_data], axis=0)\n", "data.drop_duplicates(subset=['title'], inplace=True)\n", "data.reset_index(drop=True, inplace=True)\n", "data.reset_index(inplace=True)\n", "data.rename(columns={'title': 'short_description'}, inplace=True)\n", "data" ] }, { "cell_type": "code", "execution_count": 3, "id": "df6d1d63-63e8-4571-8553-dc6662008848", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "False" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "any(data['short_description'].duplicated())" ] }, { "cell_type": "code", "execution_count": 4, "id": "5f46251b-156a-4a72-ab89-6abb6d810006", "metadata": { "tags": [] }, "outputs": [], "source": [ "data.to_csv('news_processed.csv', index=False)" ] }, { "cell_type": "code", "execution_count": 5, "id": "b15612af-6022-474c-ba8f-2a040ed6af52", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO: Created TensorFlow Lite XNNPACK delegate for CPU.\n" ] } ], "source": [ "model_checkpoint = \"distilbert-base-uncased\"\n", "tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint)\n", "interpreter = tf.lite.Interpreter(model_path=\"news_classification_hf_distilbert.tflite\")\n", "interpreter.allocate_tensors()\n", "input_details = interpreter.get_input_details()" ] }, { "cell_type": "code", "execution_count": 6, "id": "3affadbc-ace6-4b65-a20b-8bee676c837b", "metadata": { "tags": [] }, "outputs": [], "source": [ "class TextVectorizer:\n", " '''\n", " sentence transformers to extract sentence embeddings\n", " '''\n", " def vectorize(self, text, tokenizer): # need to have tokenizer as argument to prevent tokenizer error while using multiprocessing\n", " '''\n", " This code block of initializing tokenizer within the method is essential, else tokenizer will throw an error while using multiprocessing\n", " START\n", " '''\n", " model_checkpoint = \"distilbert-base-uncased\"\n", " tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint)\n", " '''\n", " END\n", " '''\n", " tokens = tokenizer(text, max_length=80, padding=\"max_length\", truncation=True, return_tensors=\"tf\")\n", " attention_mask, input_ids = tokens['attention_mask'], tokens['input_ids']\n", " interpreter.set_tensor(input_details[0][\"index\"], attention_mask)\n", " interpreter.set_tensor(input_details[1][\"index\"], input_ids)\n", " interpreter.invoke()\n", " tflite_embeds = interpreter.get_tensor(711)[0]\n", " return [*tflite_embeds]" ] }, { "cell_type": "code", "execution_count": 7, "id": "47a714f3-8948-470b-9caf-93ed2bbf4894", "metadata": { "tags": [] }, "outputs": [], "source": [ "vectorizer = TextVectorizer()" ] }, { "cell_type": "code", "execution_count": 8, "id": "8b1586e5-2923-4632-a3db-fd2364124d6f", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "320" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# getting max length of article descriptions to be used for VARCHAR while defining schema\n", "max_desc_len = max([len(s) for s in data['short_description']])\n", "max_desc_len" ] }, { "cell_type": "code", "execution_count": 9, "id": "debe0ef4-b877-495a-872e-47f720b758a9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "14" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# getting max length of article categories to be used for VARCHAR while defining schema\n", "max_cat_len = max([len(s) for s in data['category']])\n", "max_cat_len" ] }, { "cell_type": "code", "execution_count": 10, "id": "80489f00-e59f-46ab-a933-97145928176c", "metadata": { "tags": [] }, "outputs": [], "source": [ "# # Reading milvus URI & API token from secrets.env\n", "load_dotenv('secrets.env')\n", "uri = os.environ.get(\"URI\")\n", "token = os.environ.get(\"TOKEN\")" ] }, { "cell_type": "code", "execution_count": 11, "id": "0bf69f22-e113-43a5-be81-77224cafd856", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Connected to DB\n" ] } ], "source": [ "connections.connect(\"default\", uri=uri, token=token)\n", "print(f\"Connected to DB\")" ] }, { "cell_type": "code", "execution_count": 12, "id": "8da06a3b-2005-4c02-a168-dc84bcde7064", "metadata": { "tags": [] }, "outputs": [], "source": [ "collection_name = 'news_collection_full'\n", "check_collection = utility.has_collection(collection_name)" ] }, { "cell_type": "code", "execution_count": 13, "id": "33342612-1380-4d1a-a8e7-931476e07979", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Droped Existing collection\n" ] } ], "source": [ "if check_collection:\n", " drop_result = utility.drop_collection(collection_name)\n", " print(\"Droped Existing collection\")" ] }, { "cell_type": "code", "execution_count": 14, "id": "fc8ae048-d586-41e7-9678-75e1752c1693", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Creating the collection\n", "Schema: {'auto_id': False, 'description': 'collection of news articles', 'fields': [{'name': 'article_id', 'description': 'primary id', 'type': , 'is_primary': True, 'auto_id': False}, {'name': 'article_embed', 'description': '', 'type': , 'params': {'dim': 768}}, {'name': 'article_desc', 'description': 'short description of the article', 'type': , 'params': {'max_length': 370}}, {'name': 'article_category', 'description': 'category of the article', 'type': , 'params': {'max_length': 64}}]}\n", "Success!\n" ] } ], "source": [ "# Creating collection schema\n", "dim = 768 # embeddings dim\n", "article_id = FieldSchema(name=\"article_id\", dtype=DataType.INT64, is_primary=True, description=\"primary id\") # primary key\n", "article_embed_field = FieldSchema(name=\"article_embed\", dtype=DataType.FLOAT_VECTOR, dim=dim) # description embeddings\n", "article_desc = FieldSchema(name=\"article_desc\", dtype=DataType.VARCHAR, max_length=(max_desc_len + 50), # using max_desc_len to specify VARCHAR len \n", " is_primary=False, description=\"short description of the article\") # short description of article\n", "article_cat = FieldSchema(name=\"article_category\", dtype=DataType.VARCHAR, max_length=(max_cat_len + 50), # using max_desc_len to specify VARCHAR len \n", " is_primary=False, description=\"category of the article\") # category of article\n", "schema = CollectionSchema(fields=[article_id, article_embed_field, article_desc, article_cat], \n", " auto_id=False, description=\"collection of news articles\")\n", "print(f\"Creating the collection\")\n", "collection = Collection(name=collection_name, schema=schema)\n", "print(f\"Schema: {schema}\")\n", "print(\"Success!\")" ] }, { "cell_type": "code", "execution_count": 15, "id": "cca82380-98f6-4c44-aac6-86d4ae3484d0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000, 11000, 12000, 13000, 14000, 15000, 16000, 17000, 18000, 19000, 20000, 21000, 22000, 23000, 24000, 25000, 26000, 27000, 28000, 29000, 30000, 31000, 32000, 33000, 34000, 35000, 36000, 37000, 38000, 39000, 40000, 41000, 42000, 43000, 44000, 45000, 46000, 47000, 48000, 49000, 50000, 51000, 52000, 53000, 54000, 55000, 56000, 57000, 58000, 59000, 60000, 61000, 62000, 63000, 64000, 65000, 66000, 67000, 68000, 69000, 70000, 71000, 72000, 73000, 74000, 75000, 76000, 77000, 78000, 79000, 80000, 81000, 82000, 83000, 84000, 85000, 86000, 87000, 88000, 89000, 90000, 91000, 92000, 93000, 94000, 95000, 96000, 97000, 98000, 99000, 100000, 101000, 102000, 103000, 104000, 105000, 106000, 107000, 108000, 109000, 110000, 111000, 112000, 113000, 114000, 115000, 116000, 117000, 118000, 119000, 120000, 121000, 122000, 123000, 124000, 125000, 126000, 127000, 128000, 129000, 130000, 131000, 132000, 133000, 134000, 135000, 136000, 137000, 138000, 139000, 140000, 141000, 142000, 143000, 144000, 145000, 146000, 147000, 148000, 149000, 150000, 151000, 152000, 153000, 154000, 155000, 156000, 157000, 158000, 159000, 160000, 161000, 162000, 163000, 164000, 165000, 166000, 167000, 168000, 169000, 170000, 171000, 172000, 173000, 174000, 175000, 176000, 177000, 178000, 179000, 180000, 181000, 182000, 183000, 184000, 185000, 186000, 187000, 188000, 189000, 190000, 191000, 192000, 193000, 194000, 195000, 196000, 197000, 198000, 199000, 200000, 201000, 202000, 203000, 204000, 205000, 206000, 207000, 208000, 209000, 210000, 211000, 212000, 213000, 214000, 215000, 216000, 217000, 218000, 219000, 220000, 221000, 222000, 223000, 224000, 225000, 226000, 227000, 228000, 229000, 230000, 231000, 232000, 233000, 234000, 235000, 236000, 237000, 238000, 239000, 240000, 241000, 242000, 243000, 244000, 245000, 246000, 247000, 248000, 249000, 250000, 251000, 252000, 253000, 254000, 255000, 256000, 257000, 258000, 259000, 260000, 261000, 262000, 263000, 264000, 265000, 266000, 267000, 268000, 269000, 270000, 271000, 272000, 273000, 274000, 275000, 276000, 277000, 278000, 279000, 280000, 281000, 282000, 283000, 284000, 285000, 286000, 287000, 288000, 289000, 290000, 291000, 292000, 293000, 294000, 295000, 296000, 297000, 298000, 299000, 300000, 301000, 302000, 303000, 304000, 305000, 306000, 307000, 308000, 309000, 310000, 311000, 311176]\n" ] } ], "source": [ "cuts = [*range(0, len(data), 1000)]\n", "cuts.append(len(data))\n", "print(cuts)" ] }, { "cell_type": "code", "execution_count": 16, "id": "e28b2351-e333-44e8-bac4-96686abda113", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "8" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "multiprocessing.cpu_count()" ] }, { "cell_type": "code", "execution_count": null, "id": "066c67ac-01a6-4151-8e85-5869ddce1c0a", "metadata": {}, "outputs": [], "source": [ "article_id = []\n", "article_desc = []\n", "article_embed = []\n", "article_cat = []\n", "pool = multiprocessing.Pool(processes=multiprocessing.cpu_count()-2)\n", "try:\n", " for i in tqdm(range(len(cuts)-1)):\n", " df = data.iloc[cuts[i]: cuts[i+1]].copy()\n", " article_id = [*df['index']]\n", " article_desc = [*df['short_description']]\n", " article_cat = [*df['category']]\n", " results = []\n", " for doc in article_desc:\n", " f = pool.apply_async(vectorizer.vectorize, args=(doc, tokenizer)) # need to pass tokenizer as argument\n", " results.append(f) # appending result to results\n", " for f in results:\n", " emb = f.get(timeout=120)\n", " article_embed.append(emb)\n", " docs = [article_id, article_embed, article_desc, article_cat]\n", " ins_resp = collection.insert(docs)\n", " print(ins_resp)\n", " article_id = []\n", " article_desc = []\n", " article_embed = []\n", " article_cat = []\n", " if i == 0:\n", " index_params = {\"index_type\": \"AUTOINDEX\", \"metric_type\": \"L2\", \"params\": {}} \n", " collection.create_index(field_name='article_embed', index_params=index_params)\n", " collection = Collection(name=collection_name)\n", " collection.load()\n", " pool.close()\n", " pool.join()\n", "except:\n", " pool.close()\n", " pool.join()\n", " raise" ] }, { "cell_type": "code", "execution_count": null, "id": "d50177fa-fd0c-48ad-bc9b-a7bdc826a628", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "81e074e4-2cb2-44fc-8180-1f74ba79d5c6", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python (tf_gpu)", "language": "python", "name": "tf_gpu" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 5 }