{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "import logging\n", "import warnings\n", "\n", "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppresses INFO and WARNING messages\n", "\n", "# Configure logging to suppress TensorFlow messages\n", "logging.getLogger('tensorflow').setLevel(logging.ERROR) # Set level to ERROR to suppress INFO and WARNING messages\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "from datasets import load_dataset\n", "\n", "import tensorflow as tf\n", "from tensorflow.keras import layers\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "dataset = load_dataset(\"x-g85/x_g85_fn_dataset\", streaming=True)\n", "\n", "train = pd.DataFrame(dataset[\"train\"])\n", "valid = pd.DataFrame(dataset[\"valid\"])\n", "test = pd.DataFrame(dataset[\"test\"])" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "X_train = train[\"text\"]\n", "y_train = train[\"label\"]\n", "\n", "X_valid = valid[\"text\"]\n", "y_vaild = valid[\"label\"]\n", "\n", "X_test = test[\"text\"]\n", "y_test = test[\"label\"]\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['John',\n", " 'McCain',\n", " 'says',\n", " 'NSA',\n", " 'chief',\n", " 'Keith',\n", " 'Alexander',\n", " \"'should\",\n", " 'resign',\n", " 'or',\n", " 'be',\n", " \"fired'.\",\n", " 'Senator',\n", " 'gives',\n", " 'interview',\n", " 'to',\n", " 'Der',\n", " 'Spiegel,',\n", " 'saying',\n", " 'general',\n", " 'should',\n", " \"'be\",\n", " 'held',\n", " \"accountable'\",\n", " 'for',\n", " 'Edward',\n", " 'Snowden',\n", " 'leaks.']" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train[0].split()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "207" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Calculate the rounded average number of words per sentence in one line\n", "rounded_average_words_per_sentence = round(sum(len(sentence.split()) for sentence in X_train) / len(X_train))\n", "\n", "rounded_average_words_per_sentence" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Text Tokenization | Vectorization Parameters\n", "max_vocab_length = 5000 # how many unique words to use (i.e num rows in embedding vector)\n", "max_length = 300 # max number of words in a comment to use; default = 300\n", "embed_dim = 256 # how big is each word vector" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# Setup Text Vectorization\n", "# Serialization Issue: https://github.com/onnx/tensorflow-onnx/issues/1886\n", "\n", "text_vectorizer = layers.TextVectorization(\n", " max_tokens= max_vocab_length,\n", " output_mode=\"int\",\n", " output_sequence_length=max_length,\n", " name=\"TextVec\",\n", " \n", ")\n", "\n", "text_vectorizer.adapt(X_train, batch_size=32)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# # Setup Text Tokenizer\n", "# from tensorflow.keras.preprocessing.text import Tokenizer\n", "# from tensorflow.keras.preprocessing.sequence import pad_sequences\n", "\n", "# tokenizer = Tokenizer(num_words=max_vocab_length)\n", "# tokenizer.fit_on_texts(X_train)\n", "\n", "\n", "# tokenized_train = tokenizer.texts_to_sequences(X_train)\n", "# tokenized_valid= tokenizer.texts_to_sequences(X_valid)\n", "# tokenized_test = tokenizer.texts_to_sequences(X_test)\n", "\n", "# X_train = pad_sequences(tokenized_train, maxlen=max_length)\n", "# X_valid = pad_sequences(tokenized_valid, maxlen=max_length)\n", "# X_test = pad_sequences(tokenized_test, maxlen=max_length)\n", "\n", "\n", "# # Save the tokenizer to a JSON file\n", "# tokenizer_json = tokenizer.to_json()\n", "# with open('model/tokenizer.json', 'w') as file:\n", "# file.write(tokenizer_json)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "\n", "# Model Creation\n", "# Issue(Serialization): https://github.com/tflearn/tflearn/issues/605\n", "\n", "# Input\n", "inputs = layers.Input(shape=(1,), dtype=tf.string, name=\"InputLayer\") # For TextVectorization\n", "x = text_vectorizer(inputs) # For TextVectorization\n", "\n", "# inputs = layers.Input(shape=(max_length,), name=\"InputLayer\")\n", "\n", "# Embedding layer\n", "\n", "x = layers.Embedding(input_dim=max_vocab_length, output_dim=embed_dim)(x) # For TextVectorization\n", "# x = layers.Embedding(input_dim=max_vocab_length, output_dim=embed_dim)(inputs) \n", "\n", "# LSTM layers\n", "x = layers.LSTM(100, use_cudnn=False)(x) # LSTM layer without return_sequences\n", "x = layers.Dropout(0.5)(x) # Reduce dropout rate slightly\n", "\n", "# Fully connected layers\n", "x = layers.Dense(64, activation=\"relu\")(x)\n", "x = layers.Dropout(0.3)(x)\n", "\n", "x = layers.Dense(32, activation=\"relu\")(x)\n", "x = layers.Dropout(0.2)(x)\n", "\n", "# Output layer\n", "outputs = layers.Dense(1, activation=\"sigmoid\")(x) # Binary classification\n", "\n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "model_01= tf.keras.Model(inputs, outputs, name = \"model_01\")" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Model: \"model_01\"\n",
"
\n"
],
"text/plain": [
"\u001b[1mModel: \"model_01\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", "│ InputLayer (InputLayer) │ (None, 1) │ 0 │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ TextVec (TextVectorization) │ (None, 300) │ 0 │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ embedding (Embedding) │ (None, 300, 256) │ 1,280,000 │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ lstm (LSTM) │ (None, 100) │ 142,800 │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dropout (Dropout) │ (None, 100) │ 0 │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dense (Dense) │ (None, 64) │ 6,464 │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dropout_1 (Dropout) │ (None, 64) │ 0 │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dense_1 (Dense) │ (None, 32) │ 2,080 │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dropout_2 (Dropout) │ (None, 32) │ 0 │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dense_2 (Dense) │ (None, 1) │ 33 │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n", "\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", "│ InputLayer (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ TextVec (\u001b[38;5;33mTextVectorization\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m300\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ embedding (\u001b[38;5;33mEmbedding\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m300\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m1,280,000\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ lstm (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m142,800\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dropout (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m6,464\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dropout_1 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dense_1 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m2,080\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dropout_2 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dense_2 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m33\u001b[0m │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Total params: 1,431,377 (5.46 MB)\n", "\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m1,431,377\u001b[0m (5.46 MB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Trainable params: 1,431,377 (5.46 MB)\n", "\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m1,431,377\u001b[0m (5.46 MB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Non-trainable params: 0 (0.00 B)\n", "\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Get the summary\n", "model_01.summary()" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# Model Compile\n", "from tensorflow.keras.metrics import AUC, Precision \n", "\n", "model_01.compile(loss=\"binary_crossentropy\",\n", " optimizer = tf.keras.optimizers.Adam(learning_rate=0.001),\n", " metrics = [\"accuracy\", Precision(), AUC()])" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Folder already exists at: model_logs\n" ] } ], "source": [ "import os\n", "\n", "model_logs = \"model_logs\"\n", "\n", "# Check if the `model_logs` directory exists, create it if not\n", "if not os.path.exists(model_logs):\n", " os.makedirs(model_logs)\n", " print(f\"Folder created at: {model_logs}\")\n", "else:\n", " print(f\"Folder already exists at: {model_logs}\")" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/10\n", "\u001b[1m2674/2674\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1773s\u001b[0m 659ms/step - accuracy: 0.5918 - auc: 0.6342 - loss: 0.6441 - precision: 0.5722 - val_accuracy: 0.7242 - val_auc: 0.8294 - val_loss: 0.4581 - val_precision: 0.6573\n", "Epoch 2/10\n", "\u001b[1m2674/2674\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1743s\u001b[0m 652ms/step - accuracy: 0.7226 - auc: 0.8324 - loss: 0.4403 - precision: 0.6735 - val_accuracy: 0.7320 - val_auc: 0.8443 - val_loss: 0.4240 - val_precision: 0.6555\n", "Epoch 3/10\n", "\u001b[1m2674/2674\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1795s\u001b[0m 671ms/step - accuracy: 0.7393 - auc: 0.8508 - loss: 0.3992 - precision: 0.6759 - val_accuracy: 0.7465 - val_auc: 0.8536 - val_loss: 0.3860 - val_precision: 0.6696\n", "Epoch 4/10\n", "\u001b[1m2674/2674\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1776s\u001b[0m 664ms/step - accuracy: 0.7488 - auc: 0.8573 - loss: 0.3855 - precision: 0.6789 - val_accuracy: 0.7474 - val_auc: 0.8587 - val_loss: 0.3812 - val_precision: 0.6682\n", "Epoch 5/10\n", "\u001b[1m2674/2674\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1795s\u001b[0m 671ms/step - accuracy: 0.7519 - auc: 0.8631 - loss: 0.3734 - precision: 0.6765 - val_accuracy: 0.7493 - val_auc: 0.8576 - val_loss: 0.3811 - val_precision: 0.6726\n", "Epoch 6/10\n", "\u001b[1m2674/2674\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1810s\u001b[0m 677ms/step - accuracy: 0.7557 - auc: 0.8653 - loss: 0.3688 - precision: 0.6772 - val_accuracy: 0.7472 - val_auc: 0.8545 - val_loss: 0.3816 - val_precision: 0.6679\n", "Epoch 7/10\n", "\u001b[1m2674/2674\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2580s\u001b[0m 965ms/step - accuracy: 0.7593 - auc: 0.8674 - loss: 0.3672 - precision: 0.6814 - val_accuracy: 0.7459 - val_auc: 0.8539 - val_loss: 0.3945 - val_precision: 0.6731\n", "Epoch 8/10\n", "\u001b[1m2674/2674\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1879s\u001b[0m 703ms/step - accuracy: 0.7632 - auc: 0.8724 - loss: 0.3648 - precision: 0.6854 - val_accuracy: 0.7499 - val_auc: 0.8559 - val_loss: 0.3984 - val_precision: 0.6770\n", "Epoch 9/10\n", "\u001b[1m2674/2674\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1897s\u001b[0m 709ms/step - accuracy: 0.7656 - auc: 0.8766 - loss: 0.3627 - precision: 0.6930 - val_accuracy: 0.7480 - val_auc: 0.8540 - val_loss: 0.3958 - val_precision: 0.6733\n", "Epoch 10/10\n", "\u001b[1m2674/2674\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1933s\u001b[0m 723ms/step - accuracy: 0.7743 - auc: 0.8889 - loss: 0.3518 - precision: 0.7123 - val_accuracy: 0.7478 - val_auc: 0.8558 - val_loss: 0.3971 - val_precision: 0.6801\n" ] } ], "source": [ "\n", "# Early Stopping\n", "# from tensorflow.keras.callbacks import EarlyStopping\n", "# early_stopping = EarlyStopping(monitor='val_loss', patience=5)\n", "\n", "# Model Fit\n", "\n", "history_model_01 = model_01.fit(X_train, y_train, epochs=10, batch_size=32,\n", " validation_data = (X_valid, y_vaild),\n", " callbacks = [tf.keras.callbacks.TensorBoard(\"model_logs\")])" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m149/149\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m33s\u001b[0m 211ms/step\n" ] }, { "data": { "text/plain": [ "array([[0.9999999 ],\n", " [0.5088127 ],\n", " [0.51265997],\n", " [0.51331687],\n", " [0.51750326]], dtype=float32)" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Model Prediction\n", "model_01_pred_probs = model_01.predict(X_test)\n", "model_01_pred_probs[:5]" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# Helper Functions\n", "\n", "import itertools\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from sklearn.metrics import confusion_matrix, accuracy_score, precision_recall_fscore_support\n", "\n", "\n", "\n", "def calculate_results(y_true, y_pred):\n", " \"\"\"\n", " Calculates model accuracy, precision, recall and f1 score of a binary classification model.\n", "\n", " Args:\n", " y_true: true labels in the form of a 1D array\n", " y_pred: predicted labels in the form of a 1D array\n", "\n", " Returns a dictionary of accuracy, precision, recall, f1-score.\n", " \"\"\"\n", " # Calculate model accuracy\n", " model_accuracy = accuracy_score(y_true, y_pred) * 100\n", " # Calculate model precision, recall and f1 score using \"weighted average\n", " model_precision, model_recall, model_f1, _ = precision_recall_fscore_support(y_true, y_pred, average=\"weighted\")\n", " model_results = {\"accuracy\": model_accuracy,\n", " \"precision\": model_precision,\n", " \"recall\": model_recall,\n", " \"f1\": model_f1}\n", " return model_results\n", "\n", "\n", "def make_confusion_matrix(y_true, y_pred, classes=None, figsize=(10, 10), text_size=15, norm=False, savefig=False): \n", " \"\"\"Makes a labelled confusion matrix comparing predictions and ground truth labels.\n", "\n", " If classes is passed, confusion matrix will be labelled, if not, integer class values\n", " will be used.\n", "\n", " Args:\n", " y_true: Array of truth labels (must be same shape as y_pred).\n", " y_pred: Array of predicted labels (must be same shape as y_true).\n", " classes: Array of class labels (e.g. string form). If `None`, integer labels are used.\n", " figsize: Size of output figure (default=(10, 10)).\n", " text_size: Size of output figure text (default=15).\n", " norm: normalize values or not (default=False).\n", " savefig: save confusion matrix to file (default=False).\n", " \n", " Returns:\n", " A labelled confusion matrix plot comparing y_true and y_pred.\n", "\n", " Example usage:\n", " make_confusion_matrix(y_true=test_labels, # ground truth test labels\n", " y_pred=y_preds, # predicted labels\n", " classes=class_names, # array of class label names\n", " figsize=(15, 15),\n", " text_size=10)\n", " \"\"\" \n", " # Create the confustion matrix\n", " cm = confusion_matrix(y_true, y_pred)\n", " cm_norm = cm.astype(\"float\") / cm.sum(axis=1)[:, np.newaxis] # normalize it\n", " n_classes = cm.shape[0] # find the number of classes we're dealing with\n", "\n", " # Plot the figure and make it pretty\n", " fig, ax = plt.subplots(figsize=figsize)\n", " cax = ax.matshow(cm, cmap=plt.cm.Blues) # colors will represent how 'correct' a class is, darker == better\n", " fig.colorbar(cax)\n", "\n", " # Are there a list of classes?\n", " if classes:\n", " labels = classes\n", " else:\n", " labels = np.arange(cm.shape[0])\n", " \n", " # Label the axes\n", " ax.set(title=\"Confusion Matrix\",\n", " xlabel=\"Predicted label\",\n", " ylabel=\"True label\",\n", " xticks=np.arange(n_classes), # create enough axis slots for each class\n", " yticks=np.arange(n_classes), \n", " xticklabels=labels, # axes will labeled with class names (if they exist) or ints\n", " yticklabels=labels)\n", " \n", " # Make x-axis labels appear on bottom\n", " ax.xaxis.set_label_position(\"bottom\")\n", " ax.xaxis.tick_bottom()\n", "\n", " # Set the threshold for different colors\n", " threshold = (cm.max() + cm.min()) / 2.\n", "\n", " # Plot the text on each cell\n", " for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n", " if norm:\n", " plt.text(j, i, f\"{cm[i, j]} ({cm_norm[i, j]*100:.1f}%)\",\n", " horizontalalignment=\"center\",\n", " color=\"white\" if cm[i, j] > threshold else \"black\",\n", " size=text_size)\n", " else:\n", " plt.text(j, i, f\"{cm[i, j]}\",\n", " horizontalalignment=\"center\",\n", " color=\"white\" if cm[i, j] > threshold else \"black\",\n", " size=text_size)\n", "\n", " # Save the figure to the current working directory\n", " if savefig:\n", " fig.savefig(\"confusion_matrix.png\")" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "