{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os; os.chdir('..')" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ubuntu/SentenceStructureComparision/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import pandas as pd\n", "\n", "from datasets import Dataset, load_dataset" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
categorylabellabel_id
3982Citation context relevance assessment platformsReference12
24651Geology fieldworkScience2
28113Password management for individualsComputers_and_Electronics7
10999Real estate market statisticsReal Estate24
17096Running gear for womenBeauty_and_Fitness9
2374Sports Team Fan PrideSports26
9932Wine and food eventsFood_and_Drink15
2953College admissions for aspiring dancersJobs_and_Education21
25038Software development best practices forumsOnline Communities8
29703Quantum physics theoriesScience2
\n", "
" ], "text/plain": [ " category \\\n", "3982 Citation context relevance assessment platforms \n", "24651 Geology fieldwork \n", "28113 Password management for individuals \n", "10999 Real estate market statistics \n", "17096 Running gear for women \n", "2374 Sports Team Fan Pride \n", "9932 Wine and food events \n", "2953 College admissions for aspiring dancers \n", "25038 Software development best practices forums \n", "29703 Quantum physics theories \n", "\n", " label label_id \n", "3982 Reference 12 \n", "24651 Science 2 \n", "28113 Computers_and_Electronics 7 \n", "10999 Real Estate 24 \n", "17096 Beauty_and_Fitness 9 \n", "2374 Sports 26 \n", "9932 Food_and_Drink 15 \n", "2953 Jobs_and_Education 21 \n", "25038 Online Communities 8 \n", "29703 Science 2 " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df= pd.read_csv('data_categories/Final_Category_Data_With_Labels.csv')\n", "\n", "\n", "df.sample(10)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
categorylabel_id
0Internet usage monitoring25
1Food safety guidelines and regulations15
2Internet protocols and edge computing in finance25
3Online grocery shopping15
4Writing retreats for poets and novelists17
\n", "
" ], "text/plain": [ " category label_id\n", "0 Internet usage monitoring 25\n", "1 Food safety guidelines and regulations 15\n", "2 Internet protocols and edge computing in finance 25\n", "3 Online grocery shopping 15\n", "4 Writing retreats for poets and novelists 17" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_new= df[['category', 'label_id']]\n", "df_new.head()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "False 22474\n", "True 11138\n", "Name: count, dtype: int64" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_new.duplicated().value_counts() # 10837 duplicate values" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_139501/984288843.py:1: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", " df_new.rename(\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
textlabel
2925Kids' toy stores online13
31108Birdwatching apps for bird behavior5
6817Legal developments1
20037Citation context relevance assessment tools12
18928Orchid care guide20
33358Scientific publications and journals2
16499Service dog etiquette5
26484Social media trends analysis25
15543Troubleshooting computer issues7
15854large23
\n", "
" ], "text/plain": [ " text label\n", "2925 Kids' toy stores online 13\n", "31108 Birdwatching apps for bird behavior 5\n", "6817 Legal developments 1\n", "20037 Citation context relevance assessment tools 12\n", "18928 Orchid care guide 20\n", "33358 Scientific publications and journals 2\n", "16499 Service dog etiquette 5\n", "26484 Social media trends analysis 25\n", "15543 Troubleshooting computer issues 7\n", "15854 large 23" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_new.rename(\n", " columns={\n", " \"category\": \"text\", \n", " \"label_id\": \"label\"\n", "}, \n", " inplace=True\n", ")\n", "\n", "df_new.sample(10)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ubuntu/SentenceStructureComparision/venv/lib/python3.10/site-packages/pyarrow/pandas_compat.py:373: FutureWarning: is_sparse is deprecated and will be removed in a future version. Check `isinstance(dtype, pd.SparseDtype)` instead.\n", " if _pandas_api.is_sparse(col):\n" ] }, { "data": { "text/plain": [ "Dataset({\n", " features: ['text', 'label'],\n", " num_rows: 33612\n", "})" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset_df= Dataset.from_pandas(df_new)\n", "dataset_df" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['text', 'label'],\n", " num_rows: 26889\n", " })\n", " test: Dataset({\n", " features: ['text', 'label'],\n", " num_rows: 6723\n", " })\n", "})" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "new_data= dataset_df.train_test_split(test_size=0.2)\n", "new_data" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "from transformers import AutoTokenizer\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased\")" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def preprocess_function(examples):\n", " return tokenizer(examples[\"text\"], truncation=True)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Map: 48%|████▊ | 13000/26889 [00:00<00:00, 32226.42 examples/s]" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Map: 100%|██████████| 26889/26889 [00:00<00:00, 34388.34 examples/s]\n", "Map: 100%|██████████| 6723/6723 [00:00<00:00, 41978.69 examples/s]\n" ] } ], "source": [ "tokenized_df = new_data.map(preprocess_function, batched=True)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-10-13 10:29:49.212220: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", "2023-10-13 10:29:50.573292: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" ] } ], "source": [ "# from transformers import DataCollatorWithPadding\n", "\n", "# data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors=\"tf\")\n", "\n", "\n", "\n", "\n", "from transformers import DataCollatorWithPadding\n", "\n", "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "import evaluate\n", "\n", "accuracy = evaluate.load(\"accuracy\")" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "\n", "def compute_metrics(eval_pred):\n", " predictions, labels = eval_pred\n", " predictions = np.argmax(predictions, axis=1)\n", " return accuracy.compute(predictions=predictions, references=labels)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "import json\n" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'Hobbies_and_Leisure': 0,\n", " 'News': 1,\n", " 'Science': 2,\n", " 'Autos_and_Vehicles': 3,\n", " 'Health': 4,\n", " 'Pets_and_Animals': 5,\n", " 'Adult': 6,\n", " 'Computers_and_Electronics': 7,\n", " 'Online Communities': 8,\n", " 'Beauty_and_Fitness': 9,\n", " 'People_and_Society': 10,\n", " 'Business_and_Industrial': 11,\n", " 'Reference': 12,\n", " 'Shopping': 13,\n", " 'Travel_and_Transportation': 14,\n", " 'Food_and_Drink': 15,\n", " 'Law_and_Government': 16,\n", " 'Books_and_Literature': 17,\n", " 'Finance': 18,\n", " 'Games': 19,\n", " 'Home_and_Garden': 20,\n", " 'Jobs_and_Education': 21,\n", " 'Arts_and_Entertainment': 22,\n", " 'Sensitive Subjects': 23,\n", " 'Real Estate': 24,\n", " 'Internet_and_Telecom': 25,\n", " 'Sports': 26}" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "label2id= json.load(\n", " open('data/categories_refined.json', 'r')\n", ")\n", "label2id" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{0: 'Hobbies_and_Leisure',\n", " 1: 'News',\n", " 2: 'Science',\n", " 3: 'Autos_and_Vehicles',\n", " 4: 'Health',\n", " 5: 'Pets_and_Animals',\n", " 6: 'Adult',\n", " 7: 'Computers_and_Electronics',\n", " 8: 'Online Communities',\n", " 9: 'Beauty_and_Fitness',\n", " 10: 'People_and_Society',\n", " 11: 'Business_and_Industrial',\n", " 12: 'Reference',\n", " 13: 'Shopping',\n", " 14: 'Travel_and_Transportation',\n", " 15: 'Food_and_Drink',\n", " 16: 'Law_and_Government',\n", " 17: 'Books_and_Literature',\n", " 18: 'Finance',\n", " 19: 'Games',\n", " 20: 'Home_and_Garden',\n", " 21: 'Jobs_and_Education',\n", " 22: 'Arts_and_Entertainment',\n", " 23: 'Sensitive Subjects',\n", " 24: 'Real Estate',\n", " 25: 'Internet_and_Telecom',\n", " 26: 'Sports'}" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "id2label= {}\n", "for key in label2id.keys():\n", " id2label[label2id[key]] = key\n", " \n", "id2label" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer\n", "\n", "model = AutoModelForSequenceClassification.from_pretrained(\n", " \"finetuned_entity_categorical_classification/checkpoint-3346\", num_labels=27, id2label=id2label, label2id=label2id\n", ")" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [3362/3362 01:52, Epoch 2/2]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossAccuracy
10.1023000.0776520.975309
20.0834000.0862910.974714

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "TrainOutput(global_step=3362, training_loss=0.08880683540376008, metrics={'train_runtime': 113.5357, 'train_samples_per_second': 473.666, 'train_steps_per_second': 29.612, 'total_flos': 213673546900476.0, 'train_loss': 0.08880683540376008, 'epoch': 2.0})" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "training_args = TrainingArguments(\n", " output_dir=\"finetuned_entity_categorical_classification\",\n", " learning_rate=2e-5,\n", " per_device_train_batch_size=16,\n", " per_device_eval_batch_size=16,\n", " num_train_epochs=2,\n", " weight_decay=0.01,\n", " evaluation_strategy=\"epoch\",\n", " save_strategy=\"epoch\",\n", " load_best_model_at_end=True,\n", " # push_to_hub=True,\n", ")\n", "\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=tokenized_df[\"train\"],\n", " eval_dataset=tokenized_df[\"test\"],\n", " tokenizer=tokenizer,\n", " data_collator=data_collator,\n", " compute_metrics=compute_metrics,\n", ")\n", "\n", "trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }