{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "OrxTxENAIO_g" }, "source": [ "# Spam Detection Project README\n", "\n", "Spam Detection project, for users to classify messages as spam or not.\n", "\n", "## Table of Contents\n", "\n", "- [Data Collection and Preprocessing](#data-collection-and-preprocessing)\n", " - [Mount Google Drive](#mount-google-drive)\n", " - [Install Required Libraries](#install-required-libraries)\n", " - [Load and Prepare Data](#load-and-prepare-data)\n", " - [Prepare Data Labels](#prepare-data-labels)\n", " - [Split Data](#split-data)\n", "- [Model Building and Training](#model-building-and-training)\n", " - [Initialize Tokenizer](#initialize-tokenizer)\n", " - [Tokenize Data](#tokenize-data)\n", " - [Create TensorFlow Datasets](#create-tensorflow-datasets)\n", " - [Define Training Arguments](#define-training-arguments)\n", " - [Initialize and Train Model](#initialize-and-train-model)\n", "- [Model Evaluation and Inference](#model-evaluation-and-inference)\n", " - [Evaluate Model](#evaluate-model)\n", " - [Generate Predictions](#generate-predictions)\n", " - [Save Trained Model](#save-trained-model)\n", "- [Interactive Gradio Interface](#interactive-gradio-interface)\n", " - [Inference on Sample Text](#inference-on-sample-text)\n", " - [Create Gradio Interface](#create-gradio-interface)" ] }, { "cell_type": "markdown", "metadata": { "id": "c_5oaNFSI1oB" }, "source": [ "## Data Collection and Preprocessing" ] }, { "cell_type": "markdown", "metadata": { "id": "UAerPRTUCyny" }, "source": [ "### Mount Google Drive\n", "Mounting Google Drive in Google Colab to access files and data stored in Google Drive within the notebook." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "dEGvOVWFsQcf" }, "outputs": [], "source": [ "# from google.colab import drive\n", "# drive.mount('/content/drive')" ] }, { "cell_type": "markdown", "metadata": { "id": "QFpdIglgDBZO" }, "source": [ "### Install Required Libraries\n", "Installing the necessary libraries datasets, transformers, and gradio using the pip package manager." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "gzn-RVdH1LGD", "outputId": "cbca1211-e2ae-4175-b8bb-062419955eff" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting datasets\n", " Downloading datasets-2.14.4-py3-none-any.whl (519 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m519.3/519.3 kB\u001b[0m \u001b[31m6.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting transformers\n", " Downloading transformers-4.31.0-py3-none-any.whl (7.4 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.4/7.4 MB\u001b[0m \u001b[31m20.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting gradio\n", " Downloading gradio-3.40.1-py3-none-any.whl (20.0 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m20.0/20.0 MB\u001b[0m \u001b[31m44.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.23.5)\n", "Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (9.0.0)\n", "Collecting dill<0.3.8,>=0.3.0 (from datasets)\n", " Downloading dill-0.3.7-py3-none-any.whl (115 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.3/115.3 kB\u001b[0m \u001b[31m13.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (1.5.3)\n", "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.31.0)\n", "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.0)\n", "Collecting xxhash (from datasets)\n", " Downloading xxhash-3.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m23.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting multiprocess (from datasets)\n", " Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m16.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (2023.6.0)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.8.5)\n", "Collecting huggingface-hub<1.0.0,>=0.14.0 (from datasets)\n", " Downloading huggingface_hub-0.16.4-py3-none-any.whl (268 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m268.8/268.8 kB\u001b[0m \u001b[31m31.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (23.1)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.1)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.12.2)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2023.6.3)\n", "Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)\n", " Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m86.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting safetensors>=0.3.1 (from transformers)\n", " Downloading safetensors-0.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m80.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting aiofiles<24.0,>=22.0 (from gradio)\n", " Downloading aiofiles-23.2.1-py3-none-any.whl (15 kB)\n", "Requirement already satisfied: altair<6.0,>=4.2.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (4.2.2)\n", "Collecting fastapi (from gradio)\n", " Downloading fastapi-0.101.0-py3-none-any.whl (65 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m65.7/65.7 kB\u001b[0m \u001b[31m8.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting ffmpy (from gradio)\n", " Downloading ffmpy-0.3.1.tar.gz (5.5 kB)\n", " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", "Collecting gradio-client>=0.4.0 (from gradio)\n", " Downloading gradio_client-0.4.0-py3-none-any.whl (297 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m297.4/297.4 kB\u001b[0m \u001b[31m29.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting httpx (from gradio)\n", " Downloading httpx-0.24.1-py3-none-any.whl (75 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m75.4/75.4 kB\u001b[0m \u001b[31m10.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: importlib-resources<7.0,>=1.3 in /usr/local/lib/python3.10/dist-packages (from gradio) (6.0.1)\n", "Requirement already satisfied: jinja2<4.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (3.1.2)\n", "Requirement already satisfied: markdown-it-py[linkify]>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (3.0.0)\n", "Requirement already satisfied: markupsafe~=2.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (2.1.3)\n", "Requirement already satisfied: matplotlib~=3.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (3.7.1)\n", "Collecting mdit-py-plugins<=0.3.3 (from gradio)\n", " Downloading mdit_py_plugins-0.3.3-py3-none-any.whl (50 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m50.5/50.5 kB\u001b[0m \u001b[31m4.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting orjson~=3.0 (from gradio)\n", " Downloading orjson-3.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (140 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m140.3/140.3 kB\u001b[0m \u001b[31m10.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: pillow<11.0,>=8.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (9.4.0)\n", "Requirement already satisfied: pydantic!=1.8,!=1.8.1,!=2.0.0,!=2.0.1,<3.0.0,>=1.7.4 in /usr/local/lib/python3.10/dist-packages (from gradio) (2.1.1)\n", "Collecting pydub (from gradio)\n", " Downloading pydub-0.25.1-py2.py3-none-any.whl (32 kB)\n", "Collecting python-multipart (from gradio)\n", " Downloading python_multipart-0.0.6-py3-none-any.whl (45 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m45.7/45.7 kB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting semantic-version~=2.0 (from gradio)\n", " Downloading semantic_version-2.10.0-py2.py3-none-any.whl (15 kB)\n", "Requirement already satisfied: typing-extensions~=4.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (4.7.1)\n", "Collecting uvicorn>=0.14.0 (from gradio)\n", " Downloading uvicorn-0.23.2-py3-none-any.whl (59 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m59.5/59.5 kB\u001b[0m \u001b[31m7.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting websockets<12.0,>=10.0 (from gradio)\n", " Downloading websockets-11.0.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (129 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m129.9/129.9 kB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.1.0)\n", "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (3.2.0)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.4)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.2)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.2)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.0)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", "Requirement already satisfied: entrypoints in /usr/local/lib/python3.10/dist-packages (from altair<6.0,>=4.2.0->gradio) (0.4)\n", "Requirement already satisfied: jsonschema>=3.0 in /usr/local/lib/python3.10/dist-packages (from altair<6.0,>=4.2.0->gradio) (4.19.0)\n", "Requirement already satisfied: toolz in /usr/local/lib/python3.10/dist-packages (from altair<6.0,>=4.2.0->gradio) (0.12.0)\n", "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py[linkify]>=2.0.0->gradio) (0.1.2)\n", "Requirement already satisfied: linkify-it-py<3,>=1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py[linkify]>=2.0.0->gradio) (2.0.2)\n", "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (1.1.0)\n", "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (0.11.0)\n", "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (4.42.0)\n", "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (1.4.4)\n", "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (3.1.1)\n", "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (2.8.2)\n", "INFO: pip is looking at multiple versions of mdit-py-plugins to determine which version is compatible with other requirements. This could take a while.\n", "Collecting mdit-py-plugins<=0.3.3 (from gradio)\n", " Downloading mdit_py_plugins-0.3.2-py3-none-any.whl (50 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m50.4/50.4 kB\u001b[0m \u001b[31m6.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Downloading mdit_py_plugins-0.3.1-py3-none-any.whl (46 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m46.5/46.5 kB\u001b[0m \u001b[31m5.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Downloading mdit_py_plugins-0.3.0-py3-none-any.whl (43 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.7/43.7 kB\u001b[0m \u001b[31m5.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Downloading mdit_py_plugins-0.2.8-py3-none-any.whl (41 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m41.0/41.0 kB\u001b[0m \u001b[31m5.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Downloading mdit_py_plugins-0.2.7-py3-none-any.whl (41 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m41.0/41.0 kB\u001b[0m \u001b[31m4.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Downloading mdit_py_plugins-0.2.6-py3-none-any.whl (39 kB)\n", " Downloading mdit_py_plugins-0.2.5-py3-none-any.whl (39 kB)\n", "INFO: pip is looking at multiple versions of mdit-py-plugins to determine which version is compatible with other requirements. This could take a while.\n", " Downloading mdit_py_plugins-0.2.4-py3-none-any.whl (39 kB)\n", " Downloading mdit_py_plugins-0.2.3-py3-none-any.whl (39 kB)\n", " Downloading mdit_py_plugins-0.2.2-py3-none-any.whl (39 kB)\n", " Downloading mdit_py_plugins-0.2.1-py3-none-any.whl (38 kB)\n", " Downloading mdit_py_plugins-0.2.0-py3-none-any.whl (38 kB)\n", "INFO: This is taking longer than usual. You might need to provide the dependency resolver with stricter constraints to reduce runtime. See https://pip.pypa.io/warnings/backtracking for guidance. If you want to abort this run, press Ctrl + C.\n", " Downloading mdit_py_plugins-0.1.0-py3-none-any.whl (37 kB)\n", "Collecting markdown-it-py[linkify]>=2.0.0 (from gradio)\n", " Downloading markdown_it_py-3.0.0-py3-none-any.whl (87 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m87.5/87.5 kB\u001b[0m \u001b[31m9.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Downloading markdown_it_py-2.2.0-py3-none-any.whl (84 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m84.5/84.5 kB\u001b[0m \u001b[31m10.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2023.3)\n", "Requirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from pydantic!=1.8,!=1.8.1,!=2.0.0,!=2.0.1,<3.0.0,>=1.7.4->gradio) (0.5.0)\n", "Requirement already satisfied: pydantic-core==2.4.0 in /usr/local/lib/python3.10/dist-packages (from pydantic!=1.8,!=1.8.1,!=2.0.0,!=2.0.1,<3.0.0,>=1.7.4->gradio) (2.4.0)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (3.4)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2.0.4)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2023.7.22)\n", "Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.10/dist-packages (from uvicorn>=0.14.0->gradio) (8.1.6)\n", "Collecting h11>=0.8 (from uvicorn>=0.14.0->gradio)\n", " Downloading h11-0.14.0-py3-none-any.whl (58 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.3/58.3 kB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting starlette<0.28.0,>=0.27.0 (from fastapi->gradio)\n", " Downloading starlette-0.27.0-py3-none-any.whl (66 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m67.0/67.0 kB\u001b[0m \u001b[31m8.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting httpcore<0.18.0,>=0.15.0 (from httpx->gradio)\n", " Downloading httpcore-0.17.3-py3-none-any.whl (74 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m74.5/74.5 kB\u001b[0m \u001b[31m9.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from httpx->gradio) (1.3.0)\n", "Requirement already satisfied: anyio<5.0,>=3.0 in /usr/local/lib/python3.10/dist-packages (from httpcore<0.18.0,>=0.15.0->httpx->gradio) (3.7.1)\n", "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (2023.7.1)\n", "Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (0.30.2)\n", "Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (0.9.2)\n", "Requirement already satisfied: uc-micro-py in /usr/local/lib/python3.10/dist-packages (from linkify-it-py<3,>=1->markdown-it-py[linkify]>=2.0.0->gradio) (1.0.2)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib~=3.0->gradio) (1.16.0)\n", "Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5.0,>=3.0->httpcore<0.18.0,>=0.15.0->httpx->gradio) (1.1.2)\n", "Building wheels for collected packages: ffmpy\n", " Building wheel for ffmpy (setup.py) ... \u001b[?25l\u001b[?25hdone\n", " Created wheel for ffmpy: filename=ffmpy-0.3.1-py3-none-any.whl size=5579 sha256=4e5de191491a9554f02f5a03cdee8c73aaf0bb717e661245e1d758f3fa5ff879\n", " Stored in directory: /root/.cache/pip/wheels/01/a6/d1/1c0828c304a4283b2c1639a09ad86f83d7c487ef34c6b4a1bf\n", "Successfully built ffmpy\n", "Installing collected packages: tokenizers, safetensors, pydub, ffmpy, xxhash, websockets, semantic-version, python-multipart, orjson, markdown-it-py, h11, dill, aiofiles, uvicorn, starlette, multiprocess, mdit-py-plugins, huggingface-hub, httpcore, transformers, httpx, fastapi, gradio-client, datasets, gradio\n", " Attempting uninstall: markdown-it-py\n", " Found existing installation: markdown-it-py 3.0.0\n", " Uninstalling markdown-it-py-3.0.0:\n", " Successfully uninstalled markdown-it-py-3.0.0\n", " Attempting uninstall: mdit-py-plugins\n", " Found existing installation: mdit-py-plugins 0.4.0\n", " Uninstalling mdit-py-plugins-0.4.0:\n", " Successfully uninstalled mdit-py-plugins-0.4.0\n", "Successfully installed aiofiles-23.2.1 datasets-2.14.4 dill-0.3.7 fastapi-0.101.0 ffmpy-0.3.1 gradio-3.40.1 gradio-client-0.4.0 h11-0.14.0 httpcore-0.17.3 httpx-0.24.1 huggingface-hub-0.16.4 markdown-it-py-2.2.0 mdit-py-plugins-0.3.3 multiprocess-0.70.15 orjson-3.9.4 pydub-0.25.1 python-multipart-0.0.6 safetensors-0.3.2 semantic-version-2.10.0 starlette-0.27.0 tokenizers-0.13.3 transformers-4.31.0 uvicorn-0.23.2 websockets-11.0.3 xxhash-3.3.0\n" ] } ], "source": [ "! pip install datasets transformers gradio" ] }, { "cell_type": "markdown", "metadata": { "id": "YK0DqnKcDLMZ" }, "source": [ "### Load and Prepare Data\n", "Reading and preparing the dataset from different sources, including Kaggle, Hugging Face dataset, and CSV files. It concatenates all dataframes to create a single DataFrame containing spam detection data.\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 293, "referenced_widgets": [ "4347b72765dc4b648ceb62e9070f658d", "499914c9e16149b2b049cb6c9ef95fcf", "4adbcfbf0849497ea56cabdad3a47e19", "46c3430f419c4bf2bf6d3a15b7f82046", "cfaf0babee804a388ec01e398711c3e8", "2772b03ac7af4298941a9144900acc3f", "66b1ab4837b845fcacfadba09e6a7b8a", "100f110a8c8f411e90adc249c1403ead", "3f2121d2b30d4b7d83c61d791c93c936", "de2f81de90314089b241ff6616ccad4e", "5483bfd3598c4b6f9d7a0e05a99522f4", "052295558c1643dda2b00a8a01e283ce", "8a64ce09ab3c4c19a5d5282d9ffde73f", "856c752433dd4f93921f32df10baf7d9", "9e9091d714b046269fc13b05af67a6fa", "6168e0319a7f47bd90da780e5b10801d", "4d364545afba43e0b106bf0c76f15bd6", "1b1b7a9d0f7c461f8d44c0e85161098e", "833c4250ddea4a81a5931bbd31d417e1", "f0db3dbb8ec94419892f44f3e6cdcaf4", "eae03701ea8349258d5a2e651dfcbf42", "7689961e22e44571a1f5360b64a43092", "65029556a2554d35a9c196f04e89e1f8", "6b6f95c9a95e4a268eb3234fd611d836", "02bbb5d393fb4f6b87be088211d096fd", "f1dc32d3a3d54a65b64cd40cfb510e20", "0e095329876749e9a90cbf1438ae3596", "589be21250504e2f9645f4967133b3fa", "1c6043cf26cf49cd9de1706239f8547e", "a5cdc392c2174fa9a012019ad86e717c", "eae9a4fe7de74d5c97a55dffe54578b8", "7786e1311b154c2a8925a771b872c2a5", "d00c58f649014a0f8d27170b686f1ea7", "39ff21b018a34a0b9a797d79ff2956fe", "3baad2602ede4b74a9d4b83fcc3316b6", "9550fbee7af64de8b9bedd9ff9eb1a1a", "bfb619747a0a488cbe7d55ec84607755", "d79f321ea05d4e719efbb7473a4acb1b", "9a98efe8b5a44e2381c8a9007adb3c90", "bdbf111fff98423fa1b60ea36a0a1274", "68ac65610cc745f68d7bb328edad1441", "9f88cabe75f3463ca557a1488aa219d6", "be29933248e24710b00754c2f5309763", "e554cd5a77b246f18766c28602351c2d", "6b4bf1cf451d461dabdb15731186d737", "3258a62300274529aeaa0fb61f6d5940", "4b34fe251583421f85f23b349e04664a", "eeb827f05b2745e6900ca7fa41e56561", "92bf7344afb5411bac6c303db4a352e3", "0f5104a0f942465f95af4d0fc654a01c", "819dc872bd294ab8952dd0a7ad64a106", "93b4dd790db141e0b4090b9e4c7eb421", "b264d3a47e9e45b397a14c476433a6b3", "17aa533ae25346c899a4cd7967a71475", "6a9de837d3ea40ffbf390d71274344f0", "cc8c9effe14647d1a479038c3072857e", "cc05125d9b1140f8b949ea8a3fb4df4f", "718b268839ac4c5da9f521e3d66f0afa", "b1aba650e0a04a928ff322a535bd8627", "19ec82fda7e64a9792d7d7e3cc472101", "9c7c32bd1e054f9789ebf68c5f5e7bde", "2a03298fe65e4a07a5ff546eecc24e6f", "f1275b866aab4dd19d40f332806398a5", "2469b3d5c2d8475d8a17f76af1e7524c", "65b778aaae134d3c974cdbc60d0423b2", "3b407ad22c3843caa6c5ca86e89bcd69", "46d88872c06c437eaf97e81ad795d303", "44c3f55af53e45779e17ca4c4ac416e1", "392ba8b8297b48919302cb790eb34b92", "59cca3d017a24bd8865d0d1b0c3400a5", "2072ce56714a40afac08dcd605d53d88", "836b30eb51864b1d9ea60a223bcd4c5f", "7ff2369ac8ae43e18d49865152f4487d", "b8dde770ed76454aa9b28c201cc89000", "908f1ad70f9e4114a7b5f4b0c5089205", "5a269a7bb77b4e3c907e4e50ed39eaf9", "c32b2ba0bb5a4130bdf6196ad59dffe6" ] }, "id": "awPXefiYqQsF", "outputId": "ec530e3f-c076-4901-f01e-3445a0fa9549" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4347b72765dc4b648ceb62e9070f658d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading readme: 0%| | 0.00/581 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "052295558c1643dda2b00a8a01e283ce", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading data files: 0%| | 0/2 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "65029556a2554d35a9c196f04e89e1f8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading data: 0%| | 0.00/1.92M [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "39ff21b018a34a0b9a797d79ff2956fe", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading data: 0%| | 0.00/663k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6b4bf1cf451d461dabdb15731186d737", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Extracting data files: 0%| | 0/2 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "cc8c9effe14647d1a479038c3072857e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Generating train split: 0%| | 0/8175 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "46d88872c06c437eaf97e81ad795d303", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Generating test split: 0%| | 0/2725 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "not_spam 15625\n", "spam 11747\n", "Name: label, dtype: int64\n" ] } ], "source": [ "import pandas as pd\n", "from datasets import load_dataset\n", "\n", "drive_data_path = \"/content/\"\n", "\n", "def read_data():\n", " # Read Kaggle data\n", " df_kaggle = pd.read_csv(drive_data_path + 'kaggle.txt', sep='\\t', names=[\"label\", \"message\"])\n", " df_kaggle.loc[df_kaggle['label'] == 'ham', 'label'] = 'not_spam'\n", "\n", " # Load data from Hugging Face dataset\n", " data_ = load_dataset(\"Deysi/spam-detection-dataset\")\n", " texts_train = [item['text'] for item in data_[\"train\"]]\n", " labels_train = [item['label'] for item in data_[\"train\"]]\n", " df_hugging_face_train = pd.DataFrame({'label': labels_train, 'message': texts_train})\n", " texts_test = [item['text'] for item in data_[\"test\"]]\n", " labels_test = [item['label'] for item in data_[\"test\"]]\n", " df_hugging_face_test = pd.DataFrame({'label': labels_test, 'message': texts_test})\n", "\n", " # Concatenate Hugging Face dataset train and test data\n", " df_hugging_face = pd.concat([df_hugging_face_train, df_hugging_face_test], ignore_index=True)\n", "\n", " # Read CSV file data\n", " df_csv_train = pd.read_csv(drive_data_path + \"train.csv\")\n", " df_csv_train = df_csv_train[['label', 'text']]\n", " df_csv_train = df_csv_train.rename(columns={'text': 'message'})\n", " df_csv_test = pd.read_csv(drive_data_path + \"test.csv\")\n", " df_csv_test = df_csv_test[['label', 'text']]\n", " df_csv_test = df_csv_test.rename(columns={'text': 'message'})\n", "\n", " # Concatenate all dataframes\n", " df = pd.concat([df_kaggle, df_hugging_face, df_csv_train, df_csv_test], ignore_index=True)\n", "\n", " return df\n", "\n", "df = read_data()\n", "\n", "# Display value counts of labels\n", "print(df['label'].value_counts())" ] }, { "cell_type": "markdown", "metadata": { "id": "SRMTXrclDVNX" }, "source": [ "### Prepare Data Labels\n", "Preparing the target labels for classification by converting them to binary values (0 or 1) using one-hot encoding." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "EsvEwG98qq2b" }, "outputs": [], "source": [ "X,y=list(df['message']),list(df['label'])" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "aZpBn3MtssZL" }, "outputs": [], "source": [ "y=list(pd.get_dummies(y,drop_first=True)['spam'])" ] }, { "cell_type": "markdown", "metadata": { "id": "PRQHUraUDfQN" }, "source": [ "### Split Data\n", "Splitting the data into training and testing sets using train_test_split from scikit-learn." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "dLFDWda0rIKw", "outputId": "fd45b292-deb9-493d-fbb8-2b8e9c27eef4" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "X,y shape: 27372 27372\n" ] } ], "source": [ "from sklearn.model_selection import train_test_split\n", "print(\"X,y shape: \",len(X),len(y))\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.20, random_state = 0)" ] }, { "cell_type": "markdown", "metadata": { "id": "ZvAONDqeL-BT" }, "source": [ "## Model Building and Training" ] }, { "cell_type": "markdown", "metadata": { "id": "TRS4Vm32Do1N" }, "source": [ "### Initialize Tokenizer\n", "Initializing the DistilBERT tokenizer from the Hugging Face library." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 145, "referenced_widgets": [ "ecf70d4a4dfc4d5e9781659bdb588f5a", "3267e1033edf49d09a7217cedfa075d8", "5736306b44f545f38fc864ff2b1d81ed", "cee6708b0d5b49659bc55a7300155264", "37f41deeb065420c98ecf5b190117afd", "dbdf6c76f5dd4de5964f93feefbdfa6f", "131d3787f2964986b79896403d499f90", "9cd6f61892284fd9abf32cecda1a0c35", "3fc7ed1926cf49fdb55eed3689c48a3d", "e28205180ccf4f3c96aae517a4dd2c3c", "5272b52221c04feb90d24fd19cfebeb6", "e3a121298df44bd2ad9bb75726bf25f7", "2df220e9bda94360b0b11c17b174fc3c", "9a998bfe8c8e4b2f90786a7ee738de64", "e78a1378a3374d39968b63c088c206a7", "c1e0ac1370cd433a9ce49634c74d8bee", "f608b15c9f4f411993e2cbff85f37964", "da728996b4784ec6bfccb9b051a4116f", "a0b848e270324c9aa187adea72eeade1", "1f4432d776c045cd9af870de7434f1c5", "61aad21e779a41ba91e28c0f4f07f5eb", "d6018e255b75415f8695c602bbf6316e", "1738f7695d4442218e0a96f2e54087e0", "95738592221a4655a09936bc8ef6d267", "cbf28380c02a4dd588e13c27e08e9165", "e1d1fe16ac7e4f54a7684e3551d22f69", "6eeebf51e344420e93fd878b23170a5e", "34a199dfb8af4bc4a4450a6cb84434fe", "3f989643030941f19a6c59029d19cce2", "ad77cea2398e42879b2c86818291a5ca", "00c3fe311b334582b12d2865b4b54472", "086ff943676e41508450c94094bbba66", "2b0209ab8fb14cd8aa413b297c728b36", "6cd7b346ec6e4414b16ea1bbd5256597", "5d4c06978def4945b38ddfdd23bc6eb9", "c00521673b4f49a2a94ed80978fd5f08", "6750fc5bff734ff5acfc634920aa6ca8", "48878f905607485aaa3f7c13f47c26f4", "137e3613bcfc423b9c6466401864b081", "3172380edff54b5a94fb8061bf39c5d1", "23c2f6e7c0f34bfc8c4b9a0a82e530f8", "a811f2ff237f4fb693d0b7b6e887d9e8", "0ee8001cd0b6433388ab375237222490", "237dad4062014d6185ec00dfd26d6545" ] }, "id": "bcNEJ6perOSs", "outputId": "1f8c4da5-c0fb-44f8-fd8c-c07b4058f25e" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ecf70d4a4dfc4d5e9781659bdb588f5a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading (…)okenizer_config.json: 0%| | 0.00/28.0 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e3a121298df44bd2ad9bb75726bf25f7", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading (…)solve/main/vocab.txt: 0%| | 0.00/232k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1738f7695d4442218e0a96f2e54087e0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading (…)/main/tokenizer.json: 0%| | 0.00/466k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6cd7b346ec6e4414b16ea1bbd5256597", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading (…)lve/main/config.json: 0%| | 0.00/483 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from transformers import DistilBertTokenizerFast\n", "tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')" ] }, { "cell_type": "markdown", "metadata": { "id": "ogQsUAw-Dt-R" }, "source": [ "### Tokenize Data\n", "Tokenizing the training and testing data using the initialized tokenizer." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "-OL3fgLvrXvH" }, "outputs": [], "source": [ "train_encodings = tokenizer(X_train, truncation=True, padding=True)\n", "test_encodings = tokenizer(X_test, truncation=True, padding=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "d9VAEWTcD1YJ" }, "source": [ "### Create TensorFlow Datasets\n", "Creating TensorFlow datasets using the tokenized encodings and labels." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "9B42CTCnrrEx" }, "outputs": [], "source": [ "import tensorflow as tf\n", "\n", "train_dataset = tf.data.Dataset.from_tensor_slices((\n", " dict(train_encodings),\n", " y_train\n", "))\n", "\n", "test_dataset = tf.data.Dataset.from_tensor_slices((\n", " dict(test_encodings),\n", " y_test\n", "))" ] }, { "cell_type": "markdown", "metadata": { "id": "mWu5rzb6D6ah" }, "source": [ "### Define Training Arguments\n", "Defining the training arguments for the TFTrainer from the Hugging Face library." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "NH1dupK0rzfn" }, "outputs": [], "source": [ "from transformers import TFDistilBertForSequenceClassification, TFTrainer, TFTrainingArguments\n", "\n", "# Define training arguments\n", "training_args = TFTrainingArguments(\n", " output_dir='./results', # Directory to save model checkpoints and results\n", " num_train_epochs=2, # Number of training epochs\n", " per_device_train_batch_size=8, # Batch size for training\n", " per_device_eval_batch_size=16, # Batch size for evaluation\n", " warmup_steps=500, # Number of warmup steps for learning rate scheduling\n", " weight_decay=0.01, # Weight decay for regularization\n", " logging_dir='./logs', # Directory for storing logs\n", " logging_steps=10, # Log every specified number of steps\n", " evaluation_strategy=\"steps\", # Evaluation strategy (\"steps\" or \"epoch\")\n", " eval_steps=500, # Number of steps between evaluations\n", " save_total_limit=1, # Limit the number of checkpoints saved\n", " metric_for_best_model=\"eval_accuracy\", # Metric for saving the best model checkpoint\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "J68L0_JeD_L0" }, "source": [ "### Initialize and Train Model\n", "Initializes the DistilBERT model, initializes the TFTrainer, and trains the model." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "PZvTrEcfr7k-", "outputId": "7807a2d0-8f53-4ddd-bcd9-3a4b9726f82a" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFDistilBertForSequenceClassification: ['vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_transform.weight', 'vocab_projector.bias', 'vocab_layer_norm.weight']\n", "- This IS expected if you are initializing TFDistilBertForSequenceClassification from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing TFDistilBertForSequenceClassification from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights or buffers of the TF 2.0 model TFDistilBertForSequenceClassification were not initialized from the PyTorch model and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classifier.weight', 'classifier.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "# Create the DistilBERT model within the strategy scope\n", "with training_args.strategy.scope():\n", " model = TFDistilBertForSequenceClassification.from_pretrained(\"distilbert-base-uncased\")\n", "\n", "# Set the eval_steps value\n", "training_args.eval_steps = 500\n", "\n", "# Initialize the Trainer for training\n", "trainer = TFTrainer(\n", " model=model, # the instantiated 🤗 Transformers model to be trained\n", " args=training_args, # training arguments, defined above\n", " train_dataset=train_dataset, # training dataset\n", " eval_dataset=test_dataset # evaluation dataset\n", ")\n", "\n", "# Train the model\n", "trainer.train()" ] }, { "cell_type": "markdown", "metadata": { "id": "TgJ6_2YbMUBO" }, "source": [ "## Model Evaluation and Inference" ] }, { "cell_type": "markdown", "metadata": { "id": "fUjJkKBxEYRr" }, "source": [ "### Evaluate Model\n", "Evaluating the trained model using the testing dataset." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "R534aDi3xD0s", "outputId": "7f7fd521-69b0-46a5-cb3d-84c91ad6ac75" }, "outputs": [ { "data": { "text/plain": [ "{'eval_loss': 0.007443295970950113}" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.evaluate(test_dataset)" ] }, { "cell_type": "markdown", "metadata": { "id": "Fn336SjyHixq" }, "source": [ "### Save Trained Model\n", "Saving the trained model." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "id": "okD5we1NwhQW" }, "outputs": [], "source": [ "trainer.save_model('spam_detection_model')" ] }, { "cell_type": "markdown", "metadata": { "id": "WyPfTOKaMdem" }, "source": [ "## Interactive Gradio Interface" ] }, { "cell_type": "markdown", "metadata": { "id": "MW9ioqk7HpkN" }, "source": [ "### Inference on Sample Text\n", "Performing inference on a sample text using the trained model." ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "f9MbFGrEyNTS", "outputId": "d5b335c2-7dba-4a2a-c3f6-9976da72807f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sample Message: Hi there, how's it going?\n", "Predicted Label: No need to worry, Not a spam message.\n" ] } ], "source": [ "# Sample text you want to classify\n", "sample_text = \"Hi there, how's it going?\"\n", "\n", "# Preprocess the sample text using the tokenizer\n", "sample_encodings = tokenizer(sample_text, truncation=True, padding=True, return_tensors=\"tf\")\n", "\n", "# Perform inference\n", "with training_args.strategy.scope():\n", " logits = model(sample_encodings.input_ids).logits\n", "\n", "# Convert logits to probabilities using softmax\n", "probabilities = tf.nn.softmax(logits, axis=-1)\n", "\n", "# Get the predicted class\n", "predicted_class = tf.argmax(probabilities, axis=-1).numpy()[0]\n", "\n", "# Map the predicted class to label\n", "label_mapping = {0: \"No need to worry, Not a spam message.\", 1: \"This message has been identified as spam.\"}\n", "predicted_label = label_mapping[predicted_class]\n", "\n", "print(\"Sample Message:\", sample_text)\n", "print(\"Predicted Label:\", predicted_label)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "Rcvb9ADvHuvx" }, "source": [ "### Create Gradio Interface\n", "Creating a Gradio interface for interactive spam detection using the trained model." ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 616 }, "id": "uk3gHaanqq2W", "outputId": "b06ff574-2fb3-4b01-bed7-148030c0d72d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Colab notebook detected. To show errors in colab notebook, set debug=True in launch()\n", "Note: opening Chrome Inspector may crash demo inside Colab notebooks.\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "application/javascript": "(async (port, path, width, height, cache, element) => {\n if (!google.colab.kernel.accessAllowed && !cache) {\n return;\n }\n element.appendChild(document.createTextNode(''));\n const url = await google.colab.kernel.proxyPort(port, {cache});\n\n const external_link = document.createElement('div');\n external_link.innerHTML = `\n