{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "T4" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU", "widgets": { "application/vnd.jupyter.widget-state+json": { "f60e3c3cbfa44ddfbb35a20c295a7071": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_c56ea1c6adc845238398f0267ecca203", "IPY_MODEL_d8c2363c5aec4b949c72fa119f4b9314", "IPY_MODEL_22cdf3c7f04d4c81a1de6d159a6e7c23" ], "layout": "IPY_MODEL_0742c322e39b4f7f893b274ce440feef" } }, "c56ea1c6adc845238398f0267ecca203": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_bb25a51f5eae4b049db57954d5d6c539", "placeholder": "​", "style": "IPY_MODEL_2a82e457fae047d6aa1cb45de6ccc65d", "value": "model.safetensors: 100%" } }, "d8c2363c5aec4b949c72fa119f4b9314": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_18776047ab7f4b658f651e5d18056d0d", "max": 267954768, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_c92a110c440e4f6489d43a42bcda95a7", "value": 267954768 } }, "22cdf3c7f04d4c81a1de6d159a6e7c23": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_cc2ab102716741f5937b327fd92367f2", "placeholder": "​", "style": "IPY_MODEL_c5a7ff2101c844bb812ccd02a897ba2f", "value": " 268M/268M [00:02<00:00, 316MB/s]" } }, "0742c322e39b4f7f893b274ce440feef": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "bb25a51f5eae4b049db57954d5d6c539": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "2a82e457fae047d6aa1cb45de6ccc65d": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "18776047ab7f4b658f651e5d18056d0d": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "c92a110c440e4f6489d43a42bcda95a7": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "cc2ab102716741f5937b327fd92367f2": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "c5a7ff2101c844bb812ccd02a897ba2f": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } } } } }, "cells": [ { "cell_type": "markdown", "source": [ "**https://medium.com/nerd-for-tech/fine-tuning-pretrained-bert-for-sentiment-classification-using-transformers-in-python-931ed142e37t**" ], "metadata": { "id": "DY8CxeND8kUT" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "v2yodoX72lbP", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "8c3ba77f-85d5-4f40-fff1-1e3696dcf256" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.41.0)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.14.0)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.23.0)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.25.2)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (24.0)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2023.12.25)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.31.0)\n", "Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.1)\n", "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.3)\n", "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.4)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.0->transformers) (2023.6.0)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.0->transformers) (4.11.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.7)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2024.2.2)\n" ] } ], "source": [ "!pip install transformers" ] }, { "cell_type": "code", "source": [ "import pandas as pd\n", "import numpy as np\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "import warnings\n", "warnings.filterwarnings('ignore')" ], "metadata": { "id": "3D8GpmYl2_eR" }, "execution_count": 1, "outputs": [] }, { "cell_type": "code", "source": [ "df = pd.read_csv('/content/complaint_data.csv')\n", "df = df.set_index(df.columns[0])\n", "df.head()\n" ], "metadata": { "id": "CMXrn4Q92_bf", "colab": { "base_uri": "https://localhost:8080/", "height": 237 }, "outputId": "599e00f8-5b6f-4ee9-987e-ae5f5d84d0a7" }, "execution_count": 2, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " complaint_what_happened Topic\n", "Unnamed: 0 \n", "1 good morning my name is xxxx xxxx and i apprec... 0\n", "2 i upgraded my xxxx xxxx card in xx xx 2018 and... 1\n", "10 chase card was reported on xx xx 2019 however... 3\n", "11 on xx xx 2018 while trying to book a xxxx xx... 3\n", "14 my grand son give me check for 1600 00 i de... 0" ], "text/html": [ "\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
complaint_what_happenedTopic
Unnamed: 0
1good morning my name is xxxx xxxx and i apprec...0
2i upgraded my xxxx xxxx card in xx xx 2018 and...1
10chase card was reported on xx xx 2019 however...3
11on xx xx 2018 while trying to book a xxxx xx...3
14my grand son give me check for 1600 00 i de...0
\n", "
\n", "
\n", "\n", "
\n", " \n", "\n", " \n", "\n", " \n", "
\n", "\n", "\n", "
\n", " \n", "\n", "\n", "\n", " \n", "
\n", "\n", "
\n", "
\n" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "dataframe", "variable_name": "df", "summary": "{\n \"name\": \"df\",\n \"rows\": 21072,\n \"fields\": [\n {\n \"column\": \"Unnamed: 0\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 23390,\n \"min\": 1,\n \"max\": 78312,\n \"num_unique_values\": 21072,\n \"samples\": [\n 30957,\n 58320,\n 61282\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"complaint_what_happened\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 20928,\n \"samples\": [\n \"this is in regards to a billing dispute that occurred on my chase freedom unlimited card ending in xxxx the billing dispute i will be referencing throughout is with the merchant xxxx xxxx xxxx in the amount of 700 00 posted on xx xx 2018 this charge was disputed because the merchant performed services that xxxx xxxx my dog i took my dog in for xxxx xxxx xxxx with this business after some back and forth the merchant assured me in no uncertain terms that this xxxx would have absolutely no complications for my dog in other words i was guaranteed by the doctor working at this place that my dog will not experience any adverse side effects from xxxx xxxx much less any xxxx xxxx conditions this turned out to be a complete and unequivocal lie after the xxxx my dog xxxx xxxx xxxx and experienced a xxxx xxxx xxxx xxxx xxxx my dog a xxxx pound puppy was xxxx xxxx and howling at xxxx in the morning after the xxxx i immediately rushed over to the nearest emergency room and took my dog in to get checked by a licensed veterinarian i was told that my dog was on the xxxx xxxx xxxx and he very possibly could have xxxx had i gone to the er any later the vet prescribed some temporary medication for my dog and told me to come back later when the hospital would be running with a full staff on hand i later took my dog in to a separate hospital where the vet was able to better diagnose and take care of my dog all in all this fraudulent and evil natured xxxx which nearly took my dog s xxxx costed me an additional 1500 00 in medical expenses this merchant does not deserve a single penny for this xxxx that not only xxxx my dog xxxx xxxx but also costed me an additional 1500 00 in otherwise preventable medical expenses \",\n \"on xx xx xxxx i made a remote deposit to my chase personal checking account with a check for 2800 00 written from my account at xxxx chase indicated that there was a hold on the deposit on xx xx xxxx i called xxxx to get the hold lifted on their end i then called chase and spoke to a representative who put me on hold called xxxx and then told me that the hold on the deposit had been lifted and the funds would be available for me to use immediately \\n\\non xx xx xxxx i logged in to my chase account online and saw that on xx xx xxxx the check had been returned and a 12 00 deposit item returned fee had been assessed i called chase and spoke first with an agent and then with the agent s supervisor who told me that the check had been returned because according to xxxx there was a hold on the check i then called xxxx and they told me that there was no hold as i had been told on xx xx xxxx \\n\\nneither xxxx nor chase will refund the 12 00 xxxx will not refund the fee because they are not the ones who charged it chase will not refund the fee because they claim to have been acting on information provided by xxxx that is to say chase claims that the check was returned on xx xx xxxx because xxxx indicated there was a hold on my xxxx account even though both xxxx and chase told me on xx xx xxxx that the hold had been lifted and xxxx confirmed on xx xx xxxx that there was no hold \",\n \"chase went on my credit reports without written authorization 3 times\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Topic\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 0,\n \"max\": 4,\n \"num_unique_values\": 5,\n \"samples\": [\n 1,\n 4,\n 3\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" } }, "metadata": {}, "execution_count": 2 } ] }, { "cell_type": "code", "source": [ "sentences = df['complaint_what_happened'].tolist()\n", "labels = df['Topic'].tolist()" ], "metadata": { "id": "vF0f0J3v2_WZ" }, "execution_count": 3, "outputs": [] }, { "cell_type": "code", "source": [ "test = pd.DataFrame({\n", " 'complaints': [\n", " \"I can not get from chase who services my mortgage, who owns it and who has original loan docs\",\n", " \"The bill amount of my credit card was debited twice. Please look into the matter and resolve at the earliest.\",\n", " \"I want to open a salary account at your downtown branch. Please provide me the procedure.\",\n", " \"Yesterday, I received a fraudulent email regarding renewal of my services.\",\n", " \"What is the procedure to know my CIBIL score?\",\n", " \"I need to know the number of bank branches and their locations in the city of Dubai\"\n", " ],\n", " 'labels':[2,1,0,4,0,0]\n", "})" ], "metadata": { "id": "0yB8C6g3409H" }, "execution_count": 4, "outputs": [] }, { "cell_type": "code", "source": [ "test" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 237 }, "id": "oIxaiiC772XR", "outputId": "022bfe26-ecf2-450e-da19-599d0d086882" }, "execution_count": 5, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " complaints labels\n", "0 I can not get from chase who services my mortg... 2\n", "1 The bill amount of my credit card was debited ... 1\n", "2 I want to open a salary account at your downto... 0\n", "3 Yesterday, I received a fraudulent email regar... 4\n", "4 What is the procedure to know my CIBIL score? 0\n", "5 I need to know the number of bank branches and... 0" ], "text/html": [ "\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
complaintslabels
0I can not get from chase who services my mortg...2
1The bill amount of my credit card was debited ...1
2I want to open a salary account at your downto...0
3Yesterday, I received a fraudulent email regar...4
4What is the procedure to know my CIBIL score?0
5I need to know the number of bank branches and...0
\n", "
\n", "
\n", "\n", "
\n", " \n", "\n", " \n", "\n", " \n", "
\n", "\n", "\n", "
\n", " \n", "\n", "\n", "\n", " \n", "
\n", "\n", "
\n", " \n", " \n", " \n", "
\n", "\n", "
\n", "
\n" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "dataframe", "variable_name": "test", "summary": "{\n \"name\": \"test\",\n \"rows\": 6,\n \"fields\": [\n {\n \"column\": \"complaints\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 6,\n \"samples\": [\n \"I can not get from chase who services my mortgage, who owns it and who has original loan docs\",\n \"The bill amount of my credit card was debited twice. Please look into the matter and resolve at the earliest.\",\n \"I need to know the number of bank branches and their locations in the city of Dubai\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"labels\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 0,\n \"max\": 4,\n \"num_unique_values\": 4,\n \"samples\": [\n 1,\n 4,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" } }, "metadata": {}, "execution_count": 5 } ] }, { "cell_type": "code", "source": [ "test_texts = test['complaints'].values.tolist()" ], "metadata": { "id": "O_FLBiXg5BSY" }, "execution_count": 6, "outputs": [] }, { "cell_type": "code", "source": [ "from sklearn.model_selection import train_test_split\n", "train_texts, val_texts, train_labels, val_labels = train_test_split(sentences, labels, test_size=.2,random_state=42,stratify=labels)" ], "metadata": { "id": "tw_68jZn3fZO" }, "execution_count": 7, "outputs": [] }, { "cell_type": "code", "source": [ "import torch\n", "from torch.utils.data import Dataset\n", "from transformers import DistilBertTokenizerFast,DistilBertForSequenceClassification\n", "from transformers import Trainer,TrainingArguments" ], "metadata": { "id": "4GWB7N313fWc" }, "execution_count": 8, "outputs": [] }, { "cell_type": "code", "source": [ "model_name = 'distilbert-base-uncased'" ], "metadata": { "id": "yeYlwFYE3fTj" }, "execution_count": 9, "outputs": [] }, { "cell_type": "code", "source": [ "tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased',num_labels=5)" ], "metadata": { "id": "A-Mtpqf43fQt" }, "execution_count": 10, "outputs": [] }, { "cell_type": "code", "source": [ "train_encodings = tokenizer(train_texts, truncation=True, padding=True,return_tensors = 'pt')\n", "val_encodings = tokenizer(val_texts, truncation=True, padding=True,return_tensors = 'pt')\n", "test_encodings = tokenizer(test_texts, truncation=True, padding=True,return_tensors = 'pt')" ], "metadata": { "id": "5MTfiPRR4UC2" }, "execution_count": 11, "outputs": [] }, { "cell_type": "code", "source": [ "class SentimentDataset(torch.utils.data.Dataset):\n", " def __init__(self, encodings, labels):\n", " self.encodings = encodings\n", " self.labels = labels\n", "\n", " def __getitem__(self, idx):\n", " item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}\n", " item['labels'] = torch.tensor(self.labels[idx])\n", " return item\n", "\n", " def __len__(self):\n", " return len(self.labels)\n", "## Test Dataset\n", "class SentimentTestDataset(torch.utils.data.Dataset):\n", " def __init__(self, encodings):\n", " self.encodings = encodings\n", "\n", " def __getitem__(self, idx):\n", " item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}\n", " return item\n", " def __len__(self):\n", " return len(self.encodings)" ], "metadata": { "id": "nyB35Xxs4fg_" }, "execution_count": 12, "outputs": [] }, { "cell_type": "code", "source": [ "train_dataset = SentimentDataset(train_encodings, train_labels)\n", "val_dataset = SentimentDataset(val_encodings, val_labels)\n", "test_dataset = SentimentTestDataset(test_encodings)" ], "metadata": { "id": "iIK86LpH4T_l" }, "execution_count": 13, "outputs": [] }, { "cell_type": "code", "source": [ "from sklearn.metrics import accuracy_score, f1_score\n", "def compute_metrics(p):\n", " pred, labels = p\n", " pred = np.argmax(pred, axis=1)\n", "\n", " accuracy = accuracy_score(y_true=labels, y_pred=pred)\n", " #recall = recall_score(y_true=labels, y_pred=pred)\n", " #precision = precision_score(y_true=labels, y_pred=pred)\n", " f1 = f1_score(labels, pred, average='weighted')\n", "\n", " return {\"accuracy\": accuracy,\"f1_score\":f1}" ], "metadata": { "id": "Cyhy5TIz5Mhr" }, "execution_count": 14, "outputs": [] }, { "cell_type": "code", "source": [ "!pip install accelerate>=0.21.0" ], "metadata": { "id": "BoYNJmFYipDY" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "!pip install transformers[torch]" ], "metadata": { "id": "GR4iBTJ3jI8S", "outputId": "37273942-10c8-4413-b51d-f87225fb72f4", "colab": { "base_uri": "https://localhost:8080/" } }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Requirement already satisfied: transformers[torch] in /usr/local/lib/python3.10/dist-packages (4.41.0)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (3.14.0)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (0.23.0)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (1.25.2)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (24.0)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (6.0.1)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (2023.12.25)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (2.31.0)\n", "Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (0.19.1)\n", "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (0.4.3)\n", "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (4.66.4)\n", "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (2.3.0+cu121)\n", "Requirement already satisfied: accelerate>=0.21.0 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (0.30.1)\n", "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.21.0->transformers[torch]) (5.9.5)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.0->transformers[torch]) (2023.6.0)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.0->transformers[torch]) (4.11.0)\n", "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->transformers[torch]) (1.12)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->transformers[torch]) (3.3)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->transformers[torch]) (3.1.4)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch->transformers[torch]) (12.1.105)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch->transformers[torch]) (12.1.105)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch->transformers[torch]) (12.1.105)\n", "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch->transformers[torch]) (8.9.2.26)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch->transformers[torch]) (12.1.3.1)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch->transformers[torch]) (11.0.2.54)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch->transformers[torch]) (10.3.2.106)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch->transformers[torch]) (11.4.5.107)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch->transformers[torch]) (12.1.0.106)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.20.5 in /usr/local/lib/python3.10/dist-packages (from torch->transformers[torch]) (2.20.5)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch->transformers[torch]) (12.1.105)\n", "Requirement already satisfied: triton==2.3.0 in /usr/local/lib/python3.10/dist-packages (from torch->transformers[torch]) (2.3.0)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch->transformers[torch]) (12.5.40)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[torch]) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[torch]) (3.7)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[torch]) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[torch]) (2024.2.2)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch->transformers[torch]) (2.1.5)\n", "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch->transformers[torch]) (1.3.0)\n" ] } ] }, { "cell_type": "code", "source": [ "training_args = TrainingArguments(\n", " output_dir='./res', # output directory\n", " evaluation_strategy=\"steps\",\n", " num_train_epochs=5, # total number of training epochs\n", " per_device_train_batch_size=32, # batch size per device during training\n", " per_device_eval_batch_size=64, # batch size for evaluation\n", " warmup_steps=500, # number of warmup steps for learning rate scheduler\n", " weight_decay=0.01, # strength of weight decay\n", " logging_dir='./logs4', # directory for storing logs\n", " #logging_steps=10,\n", " load_best_model_at_end=True,\n", ")" ], "metadata": { "id": "sRYXMvKA5McC" }, "execution_count": 15, "outputs": [] }, { "cell_type": "code", "source": [ "model = DistilBertForSequenceClassification.from_pretrained(\"distilbert-base-uncased\",num_labels=5)\n", "\n", "trainer = Trainer(\n", " model=model,# the instantiated 🤗 Transformers model to be trained\n", " args=training_args, # training arguments, defined above\n", " train_dataset=train_dataset,# training dataset\n", " eval_dataset=val_dataset , # evaluation dataset\n", " compute_metrics=compute_metrics,\n", ")\n", "\n", "trainer.train()" ], "metadata": { "id": "bVXIDe_Q5MZh", "colab": { "base_uri": "https://localhost:8080/", "height": 358, "referenced_widgets": [ "f60e3c3cbfa44ddfbb35a20c295a7071", "c56ea1c6adc845238398f0267ecca203", "d8c2363c5aec4b949c72fa119f4b9314", "22cdf3c7f04d4c81a1de6d159a6e7c23", "0742c322e39b4f7f893b274ce440feef", "bb25a51f5eae4b049db57954d5d6c539", "2a82e457fae047d6aa1cb45de6ccc65d", "18776047ab7f4b658f651e5d18056d0d", "c92a110c440e4f6489d43a42bcda95a7", "cc2ab102716741f5937b327fd92367f2", "c5a7ff2101c844bb812ccd02a897ba2f" ] }, "outputId": "55149643-5008-46bb-ca29-eeeced404f45" }, "execution_count": 16, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "model.safetensors: 0%| | 0.00/268M [00:00" ], "text/html": [ "\n", "
\n", " \n", " \n", " [2635/2635 1:10:13, Epoch 5/5]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining LossValidation LossAccuracyF1 Score
5000.7069000.2931790.8918150.891679
10000.2612000.2624960.9065240.906258
15000.1630000.3062550.9029660.902720
20000.0974000.2811100.9209960.920973
25000.0473000.3363580.9195730.919342

" ] }, "metadata": {} }, { "output_type": "execute_result", "data": { "text/plain": [ "TrainOutput(global_step=2635, training_loss=0.24418509228179758, metrics={'train_runtime': 4216.3449, 'train_samples_per_second': 19.99, 'train_steps_per_second': 0.625, 'total_flos': 1.11656120322816e+16, 'train_loss': 0.24418509228179758, 'epoch': 5.0})" ] }, "metadata": {}, "execution_count": 16 } ] }, { "cell_type": "code", "source": [ "trainer.evaluate()" ], "metadata": { "id": "Cei_CCL14T87", "colab": { "base_uri": "https://localhost:8080/", "height": 166 }, "outputId": "8e5e8393-5874-4daf-cad6-d8c1b2095e3f" }, "execution_count": 17, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", "

\n", " \n", " \n", " [66/66 01:08]\n", "
\n", " " ] }, "metadata": {} }, { "output_type": "execute_result", "data": { "text/plain": [ "{'eval_loss': 0.2624961733818054,\n", " 'eval_accuracy': 0.9065243179122182,\n", " 'eval_f1_score': 0.9062577319683326,\n", " 'eval_runtime': 69.9531,\n", " 'eval_samples_per_second': 60.255,\n", " 'eval_steps_per_second': 0.943,\n", " 'epoch': 5.0}" ] }, "metadata": {}, "execution_count": 17 } ] }, { "cell_type": "code", "source": [ "test_label = test['labels'].values.tolist()" ], "metadata": { "id": "LNrOiydM5cba" }, "execution_count": 18, "outputs": [] }, { "cell_type": "code", "source": [ "test_encodings = tokenizer(test_texts, truncation=True, padding=True,return_tensors = 'pt')\n", "test_dataset = SentimentDataset(test_encodings, test_label)\n", "preds = trainer.predict(test_dataset=test_dataset )" ], "metadata": { "id": "QFTsIEPS5s-T", "colab": { "base_uri": "https://localhost:8080/", "height": 17 }, "outputId": "b375656d-56cd-4dff-f03b-20b0c20c8c5f" }, "execution_count": 22, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [] }, "metadata": {} } ] }, { "cell_type": "code", "source": [ "probs = torch.from_numpy(preds[0]).softmax(1)\n", "\n", "predictions = probs.numpy()# convert tensors to numpy array" ], "metadata": { "id": "h3V-eUqj5cO-" }, "execution_count": 23, "outputs": [] }, { "cell_type": "code", "source": [ "newdf = pd.DataFrame(predictions,columns=['Account_Services','Others','Mortgage/Loan','Credit card or prepaid card','Theft/Dispute Reporting'])\n", "newdf.head()" ], "metadata": { "id": "EMHIgJHv52m8", "colab": { "base_uri": "https://localhost:8080/", "height": 206 }, "outputId": "7b8f7be6-5a36-4fd2-e0c8-fe55ac01f2bb" }, "execution_count": 24, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " Account_Services Others Mortgage/Loan Credit card or prepaid card \\\n", "0 0.000832 0.000475 0.996871 0.000974 \n", "1 0.000758 0.992179 0.000380 0.005469 \n", "2 0.984026 0.003422 0.003329 0.005595 \n", "3 0.029483 0.019977 0.013071 0.022676 \n", "4 0.224001 0.104025 0.205154 0.253695 \n", "\n", " Theft/Dispute Reporting \n", "0 0.000847 \n", "1 0.001214 \n", "2 0.003628 \n", "3 0.914794 \n", "4 0.213125 " ], "text/html": [ "\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Account_ServicesOthersMortgage/LoanCredit card or prepaid cardTheft/Dispute Reporting
00.0008320.0004750.9968710.0009740.000847
10.0007580.9921790.0003800.0054690.001214
20.9840260.0034220.0033290.0055950.003628
30.0294830.0199770.0130710.0226760.914794
40.2240010.1040250.2051540.2536950.213125
\n", "
\n", "
\n", "\n", "
\n", " \n", "\n", " \n", "\n", " \n", "
\n", "\n", "\n", "
\n", " \n", "\n", "\n", "\n", " \n", "
\n", "\n", "
\n", "
\n" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "dataframe", "variable_name": "newdf", "summary": "{\n \"name\": \"newdf\",\n \"rows\": 6,\n \"fields\": [\n {\n \"column\": \"Account_Services\",\n \"properties\": {\n \"dtype\": \"float32\",\n \"num_unique_values\": 6,\n \"samples\": [\n 0.0008318105828948319,\n 0.0007582061225548387,\n 0.9712517261505127\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Others\",\n \"properties\": {\n \"dtype\": \"float32\",\n \"num_unique_values\": 6,\n \"samples\": [\n 0.00047501950757578015,\n 0.9921789169311523,\n 0.0073503153398633\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Mortgage/Loan\",\n \"properties\": {\n \"dtype\": \"float32\",\n \"num_unique_values\": 6,\n \"samples\": [\n 0.9968714118003845,\n 0.0003799396799877286,\n 0.005817148834466934\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Credit card or prepaid card\",\n \"properties\": {\n \"dtype\": \"float32\",\n \"num_unique_values\": 6,\n \"samples\": [\n 0.0009743589325807989,\n 0.005469117779284716,\n 0.005451070610433817\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Theft/Dispute Reporting\",\n \"properties\": {\n \"dtype\": \"float32\",\n \"num_unique_values\": 6,\n \"samples\": [\n 0.0008474614005535841,\n 0.0012138704769313335,\n 0.010129823349416256\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" } }, "metadata": {}, "execution_count": 24 } ] }, { "cell_type": "code", "source": [ "def labels(x):\n", " if x == 0:\n", " return 'Account Services'\n", " elif x == 1:\n", " return 'Others'\n", " elif x == 2:\n", " return 'Mortgage/Loan'\n", " elif x == 3:\n", " return 'Credit card or prepaid card'\n", " else:\n", " return 'Theft/Dispute Reporting'\n", "\n", "results = np.argmax(predictions,axis=1)\n", "test['complaints'] = results\n", "test['complaints'] = test['complaints'].map(labels)\n", "test" ], "metadata": { "id": "ByDEzFit565T", "colab": { "base_uri": "https://localhost:8080/", "height": 237 }, "outputId": "987077a7-ec4e-48d1-9afc-6d84bfa4a6a1" }, "execution_count": 29, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " complaints labels\n", "0 Mortgage/Loan 2\n", "1 Others 1\n", "2 Account Services 0\n", "3 Theft/Dispute Reporting 4\n", "4 Credit card or prepaid card 0\n", "5 Account Services 0" ], "text/html": [ "\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
complaintslabels
0Mortgage/Loan2
1Others1
2Account Services0
3Theft/Dispute Reporting4
4Credit card or prepaid card0
5Account Services0
\n", "
\n", "
\n", "\n", "
\n", " \n", "\n", " \n", "\n", " \n", "
\n", "\n", "\n", "
\n", " \n", "\n", "\n", "\n", " \n", "
\n", "\n", "
\n", " \n", " \n", " \n", "
\n", "\n", "
\n", "
\n" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "dataframe", "variable_name": "test", "summary": "{\n \"name\": \"test\",\n \"rows\": 6,\n \"fields\": [\n {\n \"column\": \"complaints\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 5,\n \"samples\": [\n \"Others\",\n \"Credit card or prepaid card\",\n \"Account Services\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"labels\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 0,\n \"max\": 4,\n \"num_unique_values\": 4,\n \"samples\": [\n 1,\n 4,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" } }, "metadata": {}, "execution_count": 29 } ] }, { "cell_type": "code", "source": [ "import seaborn as sns\n", "sns.countplot(x='complaints',data=test)" ], "metadata": { "id": "cyj4flNd5_DI", "colab": { "base_uri": "https://localhost:8080/", "height": 467 }, "outputId": "1f332090-96ae-425e-81c4-d295f573f517" }, "execution_count": 26, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": {}, "execution_count": 26 }, { "output_type": "display_data", "data": { "text/plain": [ "
" ], "image/png": "\n" }, "metadata": {} } ] }, { "cell_type": "code", "source": [ "# prompt: i want to save this fine tuned model\n", "\n", "trainer.save_model('distilbert_model')\n" ], "metadata": { "id": "l3cL1GEtoCaD" }, "execution_count": 28, "outputs": [] } ] }