{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Sentiment Classification with FHE\n", "\n", "This notebook tackles sentiment classification with Fully Homomorphic Encryption. Let's imagine some client (could be a user or a company) wants to predict whether a specific text (e.g., a tweet) contains positive, neutral or negative feedback using a cloud service provider without actually revealing the text during the process.\n", "\n", "To do this, we use a machine learning model that can predict over encrypted data thanks to the Concrete-ML library available on [GitHub](https://github.com/zama-ai/concrete-ml).\n", "\n", "The dataset we use in this notebook can be found on [Kaggle](https://www.kaggle.com/datasets/crowdflower/twitter-airline-sentiment). \n", " \n", "We present two different ways to encode the text:\n", "1. A basic **TF-IDF** approach, which essentially looks at how often a word appears in the text.\n", "2. An advanced **transformer** embedding of the text using the Huggingface repository.\n", "\n", "The main assumption of this notebook is that clients, who want to have their text analyzed in a privacy preserving manner, can encode the text using a predefined representation before encrypting the data. The FHE-friendly model is thus trained in the clear beforehand for the given task, here classification, over theses representations using a relevant training set." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# Import the required packages\n", "import os\n", "import time\n", "from pathlib import Path\n", "\n", "import numpy\n", "import pandas as pd\n", "from sklearn.metrics import average_precision_score\n", "from sklearn.model_selection import GridSearchCV, train_test_split\n", "\n", "from concrete.ml.sklearn import XGBClassifier" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Proportion of positive examples: 16.14%\n", "Proportion of negative examples: 62.69%\n", "Proportion of neutral examples: 21.17%\n" ] } ], "source": [ "# Download the datasets\n", "# The dataset can be downloaded through the `download_data.sh` script, which requires to set up\n", "# Kaggle's CLI, or manually at https://www.kaggle.com/datasets/crowdflower/twitter-airline-sentiment\n", "if not os.path.isfile(\"local_datasets/twitter-airline-sentiment/Tweets.csv\"):\n", " raise ValueError(\"Please launch the `download_data.sh` script to get datasets\")\n", "\n", "\n", "train = pd.read_csv(\"local_datasets/twitter-airline-sentiment/Tweets.csv\", index_col=0)\n", "text_X = train[\"text\"]\n", "y = train[\"airline_sentiment\"]\n", "y = y.replace([\"negative\", \"neutral\", \"positive\"], [0, 1, 2])\n", "\n", "pos_ratio = y.value_counts()[2] / y.value_counts().sum()\n", "neg_ratio = y.value_counts()[0] / y.value_counts().sum()\n", "neutral_ratio = y.value_counts()[1] / y.value_counts().sum()\n", "print(f\"Proportion of positive examples: {round(pos_ratio * 100, 2)}%\")\n", "print(f\"Proportion of negative examples: {round(neg_ratio * 100, 2)}%\")\n", "print(f\"Proportion of neutral examples: {round(neutral_ratio * 100, 2)}%\")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Split in train test\n", "text_X_train, text_X_test, y_train, y_test = train_test_split(\n", " text_X, y, test_size=0.1, random_state=42\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1. Text representation using TF-IDF\n", "\n", "[Term Frequency-Inverse Document Frequency](https://en.wikipedia.org/wiki/Tf%E2%80%93idf)(TF-IDF) also known as is a numerical statistic that is used to compute the importance of a term in a document. The higher the TF-IDF score, the more important the term is to the document.\n", "\n", "We compute it as follows:\n", "\n", "$$ \\mathsf{TF\\textrm{-}IDF}(t,d,D) = \\mathsf{TF}(t,d) * \\mathsf{IDF}(t,D) $$\n", "\n", "where: $\\mathsf{TF}(t,d)$ is the term frequency of term $t$ in document $d$, $\\mathsf{IDF}(t,D)$ is the inverse document frequency of term $t$ in document collection $D$.\n", "\n", "Here we use the scikit-learn implementation of TF-IDF vectorizer." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Let's first build a representation vector from the text\n", "from sklearn.feature_extraction.text import TfidfVectorizer\n", "\n", "tfidf_vectorizer = TfidfVectorizer(max_features=500, stop_words=\"english\")\n", "X_train = tfidf_vectorizer.fit_transform(text_X_train)\n", "X_test = tfidf_vectorizer.transform(text_X_test)\n", "\n", "# Make our train and test dense array\n", "X_train = X_train.toarray()\n", "X_test = X_test.toarray()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Let's build our model\n", "model = XGBClassifier()\n", "\n", "# A gridsearch to find the best parameters\n", "parameters = {\n", " \"n_bits\": [2, 3],\n", " \"max_depth\": [1],\n", " \"n_estimators\": [10, 30, 50],\n", " # \"n_jobs\": [-1],\n", "}" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
GridSearchCV(cv=3, estimator=XGBClassifier(n_jobs=1),\n", " param_grid={'max_depth': [1], 'n_bits': [2, 3],\n", " 'n_estimators': [10, 30, 50]},\n", " scoring='accuracy')In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
GridSearchCV(cv=3, estimator=XGBClassifier(n_jobs=1),\n", " param_grid={'max_depth': [1], 'n_bits': [2, 3],\n", " 'n_estimators': [10, 30, 50]},\n", " scoring='accuracy')
XGBClassifier(n_jobs=1)
XGBClassifier(n_jobs=1)
GridSearchCV(cv=3, estimator=XGBClassifier(n_jobs=1), n_jobs=1,\n", " param_grid={'max_depth': [1], 'n_bits': [2, 3],\n", " 'n_estimators': [10, 30, 50]},\n", " scoring='accuracy')In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
GridSearchCV(cv=3, estimator=XGBClassifier(n_jobs=1), n_jobs=1,\n", " param_grid={'max_depth': [1], 'n_bits': [2, 3],\n", " 'n_estimators': [10, 30, 50]},\n", " scoring='accuracy')
XGBClassifier(n_jobs=1)
XGBClassifier(n_jobs=1)
\n", " | Accuracy | \n", "Average Precision (positive) | \n", "Average Precision (negative) | \n", "Average Precision (neutral) | \n", "
---|---|---|---|---|
Model | \n", "\n", " | \n", " | \n", " | \n", " |
TF-IDF + XGBoost | \n", "0.711749 | \n", "0.640422 | \n", "0.871891 | \n", "0.43486 | \n", "
Transformer Only | \n", "0.805328 | \n", "0.854827 | \n", "0.954804 | \n", "0.68011 | \n", "
Transformer + XGBoost | \n", "0.846311 | \n", "0.895930 | \n", "0.964674 | \n", "0.74489 | \n", "