{ "cells": [ { "cell_type": "code", "execution_count": 43, "metadata": { "id": "l8Y_Fz5_VKUf" }, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "from random import randint\n", "from joblib import dump, load\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "from sklearn.model_selection import train_test_split\n", "from sklearn import tree\n", "from sklearn.tree import DecisionTreeClassifier\n", "from sklearn.metrics import accuracy_score\n", "\n", "import graphviz" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "id": "mIqh1kxmVQ9o" }, "outputs": [], "source": [ "classifiers = ['DecisionTree']\n", "\n", "models = [DecisionTreeClassifier(random_state=0)]\n", "\n", "def split(df,label):\n", " X_train, X_test, Y_train, Y_test = train_test_split(df, label, test_size=0.25, random_state=42, stratify=label)\n", " return X_train, X_test, Y_train, Y_test\n", "\n", "def acc_score(df,label):\n", " score = pd.DataFrame({\"Classifier\":classifiers})\n", " acc = []\n", " X_train,X_test,Y_train,Y_test = split(df,label)\n", " for i in models:\n", " model = i\n", " model.fit(X_train,Y_train)\n", " predictions = model.predict(X_test)\n", " acc.append(accuracy_score(Y_test,predictions))\n", " score[\"Accuracy\"] = acc\n", " score.sort_values(by=\"Accuracy\", ascending=False,inplace = True)\n", " score.reset_index(drop=True, inplace=True)\n", " return score\n", "\n", "def plot(score,x,y,c = \"b\"):\n", " gen = [1,2,3,4,5]\n", " plt.figure(figsize=(6,4))\n", " ax = sns.pointplot(x=gen, y=score,color = c )\n", " ax.set(xlabel=\"Generation\", ylabel=\"Accuracy\")\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "id": "SYWqktBJVQ7I" }, "outputs": [], "source": [ "def initilization_of_population(size,n_feat):\n", " population = []\n", " for i in range(size):\n", " chromosome = np.ones(n_feat, bool)\n", " chromosome[:int(0.3*n_feat)]=False\n", " np.random.shuffle(chromosome)\n", " population.append(chromosome)\n", " return population\n", "\n", "def fitness_score(population):\n", " scores = []\n", " models = []\n", " for chromosome in population:\n", " logmodel = DecisionTreeClassifier(random_state=0)\n", " logmodel.fit(X_train.iloc[:,chromosome], Y_train)\n", " predictions = logmodel.predict(X_test.iloc[:,chromosome])\n", " scores.append(accuracy_score(Y_test,predictions))\n", " models.append(logmodel)\n", " scores, population, models = np.array(scores), np.array(population), np.array(models)\n", " inds = np.argsort(scores)\n", " return list(scores[inds][::-1]), list(population[inds,:][::-1]), list(models[inds][::-1])\n", "\n", "def selection(pop_after_fit, n_parents):\n", " population_nextgen = []\n", " for i in range(n_parents):\n", " population_nextgen.append(pop_after_fit[i])\n", " return population_nextgen\n", "\n", "def crossover(pop_after_sel):\n", " pop_nextgen = pop_after_sel\n", " for i in range(0,len(pop_after_sel),2):\n", " new_par = []\n", " child_1 , child_2 = pop_nextgen[i] , pop_nextgen[i+1]\n", " new_par = np.concatenate((child_1[:len(child_1)//2],child_2[len(child_1)//2:]))\n", " pop_nextgen.append(new_par)\n", " return pop_nextgen\n", "\n", "def mutation(pop_after_cross, mutation_rate, n_feat):\n", " mutation_range = int(mutation_rate * n_feat)\n", " for n in range(64, len(pop_after_cross)):\n", " chromo = pop_after_cross[n]\n", " rand_posi = []\n", " for i in range(0, mutation_range):\n", " pos = randint(0, n_feat-1)\n", " rand_posi.append(pos)\n", " for j in rand_posi:\n", " chromo[j] = not chromo[j]\n", " pop_after_cross[n] = chromo\n", " return pop_after_cross\n", "\n", "def generations(df, label, size, n_feat, n_parents, mutation_rate, n_gen, X_train, X_test, Y_train, Y_test):\n", " best_chromo = []\n", " best_score = []\n", " best_models = []\n", " population_nextgen=initilization_of_population(size,n_feat)\n", " for i in range(n_gen):\n", " scores, pop_after_fit, models = fitness_score(population_nextgen)\n", " print('Best score in generation',i+1,':',scores[:1])\n", "\n", " pop_after_sel = selection(pop_after_fit, n_parents)\n", " pop_after_cross = crossover(pop_after_sel)\n", " population_nextgen = mutation(pop_after_cross, mutation_rate, n_feat)\n", "\n", " best_score.append(scores[0])\n", " best_chromo.append(pop_after_fit[0])\n", " best_models.append(models[0])\n", "\n", " return best_chromo, best_score, best_models" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 273 }, "id": "X3Ww2R2wVQ44", "outputId": "ccd8b6d6-f6ce-4bf5-c476-8e28c11cc0ee" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(920, 23)\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agetrestbpscholthalacholdpeakcasexfbsexangcp_1.0...restecg_0restecg_1restecg_2slope_1slope_2slope_3thal_3.0thal_6.0thal_7.0num
063.0145.0233.0150.02.30.01.0101.0...0.00.01.00.00.01.00.01.00.00
167.0160.0286.0108.01.53.01.0010.0...0.00.01.00.01.00.01.00.00.01
267.0120.0229.0129.02.62.01.0010.0...0.00.01.00.01.00.00.00.01.01
337.0130.0250.0187.03.50.01.0000.0...1.00.00.00.00.01.01.00.00.00
441.0130.0204.0172.01.40.00.0000.0...0.00.01.01.00.00.01.00.00.00
\n", "

5 rows × 23 columns

\n", "
" ], "text/plain": [ " age trestbps chol thalach oldpeak ca sex fbs exang cp_1.0 ... \\\n", "0 63.0 145.0 233.0 150.0 2.3 0.0 1.0 1 0 1.0 ... \n", "1 67.0 160.0 286.0 108.0 1.5 3.0 1.0 0 1 0.0 ... \n", "2 67.0 120.0 229.0 129.0 2.6 2.0 1.0 0 1 0.0 ... \n", "3 37.0 130.0 250.0 187.0 3.5 0.0 1.0 0 0 0.0 ... \n", "4 41.0 130.0 204.0 172.0 1.4 0.0 0.0 0 0 0.0 ... \n", "\n", " restecg_0 restecg_1 restecg_2 slope_1 slope_2 slope_3 thal_3.0 \\\n", "0 0.0 0.0 1.0 0.0 0.0 1.0 0.0 \n", "1 0.0 0.0 1.0 0.0 1.0 0.0 1.0 \n", "2 0.0 0.0 1.0 0.0 1.0 0.0 0.0 \n", "3 1.0 0.0 0.0 0.0 0.0 1.0 1.0 \n", "4 0.0 0.0 1.0 1.0 0.0 0.0 1.0 \n", "\n", " thal_6.0 thal_7.0 num \n", "0 1.0 0.0 0 \n", "1 0.0 0.0 1 \n", "2 0.0 1.0 1 \n", "3 0.0 0.0 0 \n", "4 0.0 0.0 0 \n", "\n", "[5 rows x 23 columns]" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data_hd = pd.read_csv('./encoded_heart_disease.csv')\n", "\n", "print(data_hd.shape)\n", "data_hd.head()" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 273 }, "id": "Zu6Gr1HGiS6Q", "outputId": "deec92ca-74f3-4e15-c89e-c0bb9858e586" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(920, 23)\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agetrestbpscholthalacholdpeakcasexfbsexangcp_1.0...restecg_0restecg_1restecg_2slope_1slope_2slope_3thal_3.0thal_6.0thal_7.0num
063.0145.0233.0150.02.30.01.0101.0...0.00.01.00.00.01.00.01.00.00
167.0160.0286.0108.01.53.01.0010.0...0.00.01.00.01.00.01.00.00.01
267.0120.0229.0129.02.62.01.0010.0...0.00.01.00.01.00.00.00.01.01
337.0130.0250.0187.03.50.01.0000.0...1.00.00.00.00.01.01.00.00.00
441.0130.0204.0172.01.40.00.0000.0...0.00.01.01.00.00.01.00.00.00
\n", "

5 rows × 23 columns

\n", "
" ], "text/plain": [ " age trestbps chol thalach oldpeak ca sex fbs exang cp_1.0 ... \\\n", "0 63.0 145.0 233.0 150.0 2.3 0.0 1.0 1 0 1.0 ... \n", "1 67.0 160.0 286.0 108.0 1.5 3.0 1.0 0 1 0.0 ... \n", "2 67.0 120.0 229.0 129.0 2.6 2.0 1.0 0 1 0.0 ... \n", "3 37.0 130.0 250.0 187.0 3.5 0.0 1.0 0 0 0.0 ... \n", "4 41.0 130.0 204.0 172.0 1.4 0.0 0.0 0 0 0.0 ... \n", "\n", " restecg_0 restecg_1 restecg_2 slope_1 slope_2 slope_3 thal_3.0 \\\n", "0 0.0 0.0 1.0 0.0 0.0 1.0 0.0 \n", "1 0.0 0.0 1.0 0.0 1.0 0.0 1.0 \n", "2 0.0 0.0 1.0 0.0 1.0 0.0 0.0 \n", "3 1.0 0.0 0.0 0.0 0.0 1.0 1.0 \n", "4 0.0 0.0 1.0 1.0 0.0 0.0 1.0 \n", "\n", " thal_6.0 thal_7.0 num \n", "0 1.0 0.0 0 \n", "1 0.0 0.0 1 \n", "2 0.0 1.0 1 \n", "3 0.0 0.0 0 \n", "4 0.0 0.0 0 \n", "\n", "[5 rows x 23 columns]" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(data_hd.shape)\n", "data_hd.head()" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "LJWVXyIwVQ2Q", "outputId": "72b73b8f-5709-495a-9f81-d4f9926700a5" }, "outputs": [ { "data": { "text/plain": [ "Index(['age', 'trestbps', 'chol', 'thalach', 'oldpeak', 'ca', 'sex', 'fbs',\n", " 'exang', 'cp_1.0', 'cp_2.0', 'cp_3.0', 'cp_4.0', 'restecg_0',\n", " 'restecg_1', 'restecg_2', 'slope_1', 'slope_2', 'slope_3', 'thal_3.0',\n", " 'thal_6.0', 'thal_7.0', 'num'],\n", " dtype='object')" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data_hd.columns" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 443 }, "id": "cbByJdLxkSls", "outputId": "48b3ba1d-8b66-4393-9400-1068e51e780f" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agetrestbpscholthalacholdpeakcasexfbsexangcp_1.0...cp_4.0restecg_0restecg_1restecg_2slope_1slope_2slope_3thal_3.0thal_6.0thal_7.0
063.0145.0233.0150.02.30.01.0101.0...0.00.00.01.00.00.01.00.01.00.0
167.0160.0286.0108.01.53.01.0010.0...1.00.00.01.00.01.00.01.00.00.0
267.0120.0229.0129.02.62.01.0010.0...1.00.00.01.00.01.00.00.00.01.0
337.0130.0250.0187.03.50.01.0000.0...0.01.00.00.00.00.01.01.00.00.0
441.0130.0204.0172.01.40.00.0000.0...0.00.00.01.01.00.00.01.00.00.0
..................................................................
91554.0127.0333.0154.00.00.00.0100.0...1.00.01.00.00.01.00.01.00.00.0
91662.0130.0139.0140.00.50.01.0001.0...0.00.01.00.00.01.00.01.00.00.0
91755.0122.0223.0100.00.00.01.0100.0...1.00.01.00.00.01.00.00.01.00.0
91858.0130.0385.0140.00.50.01.0100.0...1.00.00.01.00.01.00.01.00.00.0
91962.0120.0254.093.00.00.01.0010.0...0.00.00.01.00.01.00.01.00.00.0
\n", "

920 rows × 22 columns

\n", "
" ], "text/plain": [ " age trestbps chol thalach oldpeak ca sex fbs exang cp_1.0 \\\n", "0 63.0 145.0 233.0 150.0 2.3 0.0 1.0 1 0 1.0 \n", "1 67.0 160.0 286.0 108.0 1.5 3.0 1.0 0 1 0.0 \n", "2 67.0 120.0 229.0 129.0 2.6 2.0 1.0 0 1 0.0 \n", "3 37.0 130.0 250.0 187.0 3.5 0.0 1.0 0 0 0.0 \n", "4 41.0 130.0 204.0 172.0 1.4 0.0 0.0 0 0 0.0 \n", ".. ... ... ... ... ... ... ... ... ... ... \n", "915 54.0 127.0 333.0 154.0 0.0 0.0 0.0 1 0 0.0 \n", "916 62.0 130.0 139.0 140.0 0.5 0.0 1.0 0 0 1.0 \n", "917 55.0 122.0 223.0 100.0 0.0 0.0 1.0 1 0 0.0 \n", "918 58.0 130.0 385.0 140.0 0.5 0.0 1.0 1 0 0.0 \n", "919 62.0 120.0 254.0 93.0 0.0 0.0 1.0 0 1 0.0 \n", "\n", " ... cp_4.0 restecg_0 restecg_1 restecg_2 slope_1 slope_2 slope_3 \\\n", "0 ... 0.0 0.0 0.0 1.0 0.0 0.0 1.0 \n", "1 ... 1.0 0.0 0.0 1.0 0.0 1.0 0.0 \n", "2 ... 1.0 0.0 0.0 1.0 0.0 1.0 0.0 \n", "3 ... 0.0 1.0 0.0 0.0 0.0 0.0 1.0 \n", "4 ... 0.0 0.0 0.0 1.0 1.0 0.0 0.0 \n", ".. ... ... ... ... ... ... ... ... \n", "915 ... 1.0 0.0 1.0 0.0 0.0 1.0 0.0 \n", "916 ... 0.0 0.0 1.0 0.0 0.0 1.0 0.0 \n", "917 ... 1.0 0.0 1.0 0.0 0.0 1.0 0.0 \n", "918 ... 1.0 0.0 0.0 1.0 0.0 1.0 0.0 \n", "919 ... 0.0 0.0 0.0 1.0 0.0 1.0 0.0 \n", "\n", " thal_3.0 thal_6.0 thal_7.0 \n", "0 0.0 1.0 0.0 \n", "1 1.0 0.0 0.0 \n", "2 0.0 0.0 1.0 \n", "3 1.0 0.0 0.0 \n", "4 1.0 0.0 0.0 \n", ".. ... ... ... \n", "915 1.0 0.0 0.0 \n", "916 1.0 0.0 0.0 \n", "917 0.0 1.0 0.0 \n", "918 1.0 0.0 0.0 \n", "919 1.0 0.0 0.0 \n", "\n", "[920 rows x 22 columns]" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data_hd.iloc[:, :-1]" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 80 }, "id": "c7JNfRKoVQz4", "outputId": "3d8e474f-eb2a-44ac-e958-5a9afc5f041f" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ClassifierAccuracy
0DecisionTree0.717391
\n", "
" ], "text/plain": [ " Classifier Accuracy\n", "0 DecisionTree 0.717391" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "score1 = acc_score(data_hd.iloc[:, :-1], data_hd['num'])\n", "score1" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "DdSzy-GbX0ER", "outputId": "a98e5841-1d78-4f19-8052-928ba716f992" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(690, 22) (230, 22) (690,) (230,)\n", "Best score in generation 1 : [0.8]\n", "Best score in generation 2 : [0.808695652173913]\n", "Best score in generation 3 : [0.8130434782608695]\n", "Best score in generation 4 : [0.8217391304347826]\n", "Best score in generation 5 : [0.8217391304347826]\n" ] } ], "source": [ "X_train, X_test, Y_train, Y_test = split(data_hd.iloc[:, :-1], data_hd['num'])\n", "print(X_train.shape, X_test.shape, Y_train.shape, Y_test.shape)\n", "chromo_df, score, best_models = generations(data_hd.iloc[:, :-1],\n", " data_hd['num'],\n", " size=96,\n", " n_feat = data_hd.iloc[:, :-1].shape[1],\n", " n_parents=64,\n", " mutation_rate=0.20,\n", " n_gen=5,\n", " X_train = X_train,\n", " X_test = X_test,\n", " Y_train = Y_train,\n", " Y_test = Y_test)" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 388 }, "id": "I5rtmJGvX14N", "outputId": "d97250f0-9eed-4cce-dfb0-3a929c47ac34" }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot(score, 0.9, 1.0,c = \"gold\")" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "id": "HQrzrFeuz0yG" }, "outputs": [], "source": [ "for index, clf in enumerate(best_models):\n", " dump(clf, 'model-{}.joblib'.format(index))" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "id": "fGbUe1WJYbxp" }, "outputs": [], "source": [ "clf = load('model-3.joblib')" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 423 }, "id": "bA8Qf-orbDnu", "outputId": "409bf4c6-5668-4f58-ff16-5f597771fb28" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
sexexangcp_1.0cp_2.0cp_3.0cp_4.0restecg_1slope_1slope_2thal_3.0thal_6.0thal_7.0
2721.010.00.00.01.00.00.01.00.00.01.0
591.011.00.00.00.00.01.00.01.00.00.0
6101.010.00.00.01.00.00.01.01.00.00.0
3281.000.00.00.01.00.00.01.01.00.00.0
8041.010.00.00.01.00.00.00.01.00.00.0
.......................................
3740.000.01.00.00.00.00.01.01.00.00.0
5901.010.00.00.01.01.00.01.01.00.00.0
5731.010.00.00.01.00.00.01.01.00.00.0
5801.010.00.00.01.00.00.01.01.00.00.0
3080.000.01.00.00.00.00.01.01.00.00.0
\n", "

230 rows × 12 columns

\n", "
" ], "text/plain": [ " sex exang cp_1.0 cp_2.0 cp_3.0 cp_4.0 restecg_1 slope_1 slope_2 \\\n", "272 1.0 1 0.0 0.0 0.0 1.0 0.0 0.0 1.0 \n", "59 1.0 1 1.0 0.0 0.0 0.0 0.0 1.0 0.0 \n", "610 1.0 1 0.0 0.0 0.0 1.0 0.0 0.0 1.0 \n", "328 1.0 0 0.0 0.0 0.0 1.0 0.0 0.0 1.0 \n", "804 1.0 1 0.0 0.0 0.0 1.0 0.0 0.0 0.0 \n", ".. ... ... ... ... ... ... ... ... ... \n", "374 0.0 0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 \n", "590 1.0 1 0.0 0.0 0.0 1.0 1.0 0.0 1.0 \n", "573 1.0 1 0.0 0.0 0.0 1.0 0.0 0.0 1.0 \n", "580 1.0 1 0.0 0.0 0.0 1.0 0.0 0.0 1.0 \n", "308 0.0 0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 \n", "\n", " thal_3.0 thal_6.0 thal_7.0 \n", "272 0.0 0.0 1.0 \n", "59 1.0 0.0 0.0 \n", "610 1.0 0.0 0.0 \n", "328 1.0 0.0 0.0 \n", "804 1.0 0.0 0.0 \n", ".. ... ... ... \n", "374 1.0 0.0 0.0 \n", "590 1.0 0.0 0.0 \n", "573 1.0 0.0 0.0 \n", "580 1.0 0.0 0.0 \n", "308 1.0 0.0 0.0 \n", "\n", "[230 rows x 12 columns]" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_test[clf.feature_names_in_]" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "KkBlxHuBY9cM", "outputId": "80420dc6-cd49-4f0c-fa5f-a466df712df7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.8217391304347826\n" ] } ], "source": [ "predictions = clf.predict(X_test[clf.feature_names_in_])\n", "print(accuracy_score(Y_test, predictions))" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
sexexangcp_1.0cp_2.0cp_3.0cp_4.0restecg_1slope_1slope_2thal_3.0thal_6.0thal_7.0
2721.010.00.00.01.00.00.01.00.00.01.0
591.011.00.00.00.00.01.00.01.00.00.0
6101.010.00.00.01.00.00.01.01.00.00.0
3281.000.00.00.01.00.00.01.01.00.00.0
8041.010.00.00.01.00.00.00.01.00.00.0
.......................................
3740.000.01.00.00.00.00.01.01.00.00.0
5901.010.00.00.01.01.00.01.01.00.00.0
5731.010.00.00.01.00.00.01.01.00.00.0
5801.010.00.00.01.00.00.01.01.00.00.0
3080.000.01.00.00.00.00.01.01.00.00.0
\n", "

230 rows × 12 columns

\n", "
" ], "text/plain": [ " sex exang cp_1.0 cp_2.0 cp_3.0 cp_4.0 restecg_1 slope_1 slope_2 \\\n", "272 1.0 1 0.0 0.0 0.0 1.0 0.0 0.0 1.0 \n", "59 1.0 1 1.0 0.0 0.0 0.0 0.0 1.0 0.0 \n", "610 1.0 1 0.0 0.0 0.0 1.0 0.0 0.0 1.0 \n", "328 1.0 0 0.0 0.0 0.0 1.0 0.0 0.0 1.0 \n", "804 1.0 1 0.0 0.0 0.0 1.0 0.0 0.0 0.0 \n", ".. ... ... ... ... ... ... ... ... ... \n", "374 0.0 0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 \n", "590 1.0 1 0.0 0.0 0.0 1.0 1.0 0.0 1.0 \n", "573 1.0 1 0.0 0.0 0.0 1.0 0.0 0.0 1.0 \n", "580 1.0 1 0.0 0.0 0.0 1.0 0.0 0.0 1.0 \n", "308 0.0 0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 \n", "\n", " thal_3.0 thal_6.0 thal_7.0 \n", "272 0.0 0.0 1.0 \n", "59 1.0 0.0 0.0 \n", "610 1.0 0.0 0.0 \n", "328 1.0 0.0 0.0 \n", "804 1.0 0.0 0.0 \n", ".. ... ... ... \n", "374 1.0 0.0 0.0 \n", "590 1.0 0.0 0.0 \n", "573 1.0 0.0 0.0 \n", "580 1.0 0.0 0.0 \n", "308 1.0 0.0 0.0 \n", "\n", "[230 rows x 12 columns]" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_test[clf.feature_names_in_]" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1,\n", " 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1,\n", " 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1,\n", " 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1,\n", " 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1,\n", " 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1,\n", " 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0,\n", " 0, 0, 1, 0, 1, 0, 1, 1, 1, 0], dtype=int64)" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictions" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "Tree\n", "\n", "\n", "\n", "0\n", "\n", "cp_4.0 <= 0.5\n", "gini = 0.494\n", "samples = 690\n", "value = [308.0, 382.0]\n", "\n", "\n", "\n", "1\n", "\n", "sex <= 0.5\n", "gini = 0.415\n", "samples = 320\n", "value = [226, 94]\n", "\n", "\n", "\n", "0->1\n", "\n", "\n", "True\n", "\n", "\n", "\n", "96\n", "\n", "exang <= 0.5\n", "gini = 0.345\n", "samples = 370\n", "value = [82, 288]\n", "\n", "\n", "\n", "0->96\n", "\n", "\n", "False\n", "\n", "\n", "\n", "2\n", "\n", "thal_7.0 <= 0.5\n", "gini = 0.172\n", "samples = 95\n", "value = [86, 9]\n", "\n", "\n", "\n", "1->2\n", "\n", "\n", "\n", "\n", "\n", "31\n", "\n", "cp_2.0 <= 0.5\n", "gini = 0.47\n", "samples = 225\n", "value = [140, 85]\n", "\n", "\n", "\n", "1->31\n", "\n", "\n", "\n", "\n", "\n", "3\n", "\n", "cp_2.0 <= 0.5\n", "gini = 0.126\n", "samples = 89\n", "value = [83, 6]\n", "\n", "\n", "\n", "2->3\n", "\n", "\n", "\n", "\n", "\n", "24\n", "\n", "slope_1 <= 0.5\n", "gini = 0.5\n", "samples = 6\n", "value = [3, 3]\n", "\n", "\n", "\n", "2->24\n", "\n", "\n", "\n", "\n", "\n", "4\n", "\n", "slope_2 <= 0.5\n", "gini = 0.162\n", "samples = 45\n", "value = [41, 4]\n", "\n", "\n", "\n", "3->4\n", "\n", "\n", "\n", "\n", "\n", "13\n", "\n", "slope_1 <= 0.5\n", "gini = 0.087\n", "samples = 44\n", "value = [42, 2]\n", "\n", "\n", "\n", "3->13\n", "\n", "\n", "\n", "\n", "\n", "5\n", "\n", "restecg_1 <= 0.5\n", "gini = 0.087\n", "samples = 22\n", "value = [21, 1]\n", "\n", "\n", "\n", "4->5\n", "\n", "\n", "\n", "\n", "\n", "8\n", "\n", "restecg_1 <= 0.5\n", "gini = 0.227\n", "samples = 23\n", "value = [20, 3]\n", "\n", "\n", "\n", "4->8\n", "\n", "\n", "\n", "\n", "\n", "6\n", "\n", "gini = 0.0\n", "samples = 21\n", "value = [21, 0]\n", "\n", "\n", "\n", "5->6\n", "\n", "\n", "\n", "\n", "\n", "7\n", "\n", "gini = 0.0\n", "samples = 1\n", "value = [0, 1]\n", "\n", "\n", "\n", "5->7\n", "\n", "\n", "\n", "\n", "\n", "9\n", "\n", "cp_1.0 <= 0.5\n", "gini = 0.278\n", "samples = 18\n", "value = [15, 3]\n", "\n", "\n", "\n", "8->9\n", "\n", "\n", "\n", "\n", "\n", "12\n", "\n", "gini = 0.0\n", "samples = 5\n", "value = [5, 0]\n", "\n", "\n", "\n", "8->12\n", "\n", "\n", "\n", "\n", "\n", "10\n", "\n", "gini = 0.245\n", "samples = 14\n", "value = [12, 2]\n", "\n", "\n", "\n", "9->10\n", "\n", "\n", "\n", "\n", "\n", "11\n", "\n", "gini = 0.375\n", "samples = 4\n", "value = [3, 1]\n", "\n", "\n", "\n", "9->11\n", "\n", "\n", "\n", "\n", "\n", "14\n", "\n", "restecg_1 <= 0.5\n", "gini = 0.061\n", "samples = 32\n", "value = [31, 1]\n", "\n", "\n", "\n", "13->14\n", "\n", "\n", "\n", "\n", "\n", "19\n", "\n", "exang <= 0.5\n", "gini = 0.153\n", "samples = 12\n", "value = [11, 1]\n", "\n", "\n", "\n", "13->19\n", "\n", "\n", "\n", "\n", "\n", "15\n", "\n", "exang <= 0.5\n", "gini = 0.077\n", "samples = 25\n", "value = [24, 1]\n", "\n", "\n", "\n", "14->15\n", "\n", "\n", "\n", "\n", "\n", "18\n", "\n", "gini = 0.0\n", "samples = 7\n", "value = [7, 0]\n", "\n", "\n", "\n", "14->18\n", "\n", "\n", "\n", "\n", "\n", "16\n", "\n", "gini = 0.08\n", "samples = 24\n", "value = [23, 1]\n", "\n", "\n", "\n", "15->16\n", "\n", "\n", "\n", "\n", "\n", "17\n", "\n", "gini = 0.0\n", "samples = 1\n", "value = [1, 0]\n", "\n", "\n", "\n", "15->17\n", "\n", "\n", "\n", "\n", "\n", "20\n", "\n", "restecg_1 <= 0.5\n", "gini = 0.18\n", "samples = 10\n", "value = [9, 1]\n", "\n", "\n", "\n", "19->20\n", "\n", "\n", "\n", "\n", "\n", "23\n", "\n", "gini = 0.0\n", "samples = 2\n", "value = [2, 0]\n", "\n", "\n", "\n", "19->23\n", "\n", "\n", "\n", "\n", "\n", "21\n", "\n", "gini = 0.198\n", "samples = 9\n", "value = [8, 1]\n", "\n", "\n", "\n", "20->21\n", "\n", "\n", "\n", "\n", "\n", "22\n", "\n", "gini = 0.0\n", "samples = 1\n", "value = [1, 0]\n", "\n", "\n", "\n", "20->22\n", "\n", "\n", "\n", "\n", "\n", "25\n", "\n", "exang <= 0.5\n", "gini = 0.48\n", "samples = 5\n", "value = [2, 3]\n", "\n", "\n", "\n", "24->25\n", "\n", "\n", "\n", "\n", "\n", "30\n", "\n", "gini = 0.0\n", "samples = 1\n", "value = [1, 0]\n", "\n", "\n", "\n", "24->30\n", "\n", "\n", "\n", "\n", "\n", "26\n", "\n", "gini = 0.5\n", "samples = 2\n", "value = [1, 1]\n", "\n", "\n", "\n", "25->26\n", "\n", "\n", "\n", "\n", "\n", "27\n", "\n", "cp_3.0 <= 0.5\n", "gini = 0.444\n", "samples = 3\n", "value = [1, 2]\n", "\n", "\n", "\n", "25->27\n", "\n", "\n", "\n", "\n", "\n", "28\n", "\n", "gini = 0.5\n", "samples = 2\n", "value = [1, 1]\n", "\n", "\n", "\n", "27->28\n", "\n", "\n", "\n", "\n", "\n", "29\n", "\n", "gini = 0.0\n", "samples = 1\n", "value = [0, 1]\n", "\n", "\n", "\n", "27->29\n", "\n", "\n", "\n", "\n", "\n", "32\n", "\n", "restecg_1 <= 0.5\n", "gini = 0.5\n", "samples = 137\n", "value = [70, 67]\n", "\n", "\n", "\n", "31->32\n", "\n", "\n", "\n", "\n", "\n", "73\n", "\n", "exang <= 0.5\n", "gini = 0.325\n", "samples = 88\n", "value = [70, 18]\n", "\n", "\n", "\n", "31->73\n", "\n", "\n", "\n", "\n", "\n", "33\n", "\n", "slope_2 <= 0.5\n", "gini = 0.489\n", "samples = 110\n", "value = [63, 47]\n", "\n", "\n", "\n", "32->33\n", "\n", "\n", "\n", "\n", "\n", "62\n", "\n", "exang <= 0.5\n", "gini = 0.384\n", "samples = 27\n", "value = [7, 20]\n", "\n", "\n", "\n", "32->62\n", "\n", "\n", "\n", "\n", "\n", "34\n", "\n", "thal_3.0 <= 0.5\n", "gini = 0.401\n", "samples = 36\n", "value = [26, 10]\n", "\n", "\n", "\n", "33->34\n", "\n", "\n", "\n", "\n", "\n", "49\n", "\n", "thal_6.0 <= 0.5\n", "gini = 0.5\n", "samples = 74\n", "value = [37, 37]\n", "\n", "\n", "\n", "33->49\n", "\n", "\n", "\n", "\n", "\n", "35\n", "\n", "cp_1.0 <= 0.5\n", "gini = 0.198\n", "samples = 9\n", "value = [8, 1]\n", "\n", "\n", "\n", "34->35\n", "\n", "\n", "\n", "\n", "\n", "42\n", "\n", "exang <= 0.5\n", "gini = 0.444\n", "samples = 27\n", "value = [18, 9]\n", "\n", "\n", "\n", "34->42\n", "\n", "\n", "\n", "\n", "\n", "36\n", "\n", "exang <= 0.5\n", "gini = 0.278\n", "samples = 6\n", "value = [5, 1]\n", "\n", "\n", "\n", "35->36\n", "\n", "\n", "\n", "\n", "\n", "41\n", "\n", "gini = 0.0\n", "samples = 3\n", "value = [3, 0]\n", "\n", "\n", "\n", "35->41\n", "\n", "\n", "\n", "\n", "\n", "37\n", "\n", "slope_1 <= 0.5\n", "gini = 0.32\n", "samples = 5\n", "value = [4, 1]\n", "\n", "\n", "\n", "36->37\n", "\n", "\n", "\n", "\n", "\n", "40\n", "\n", "gini = 0.0\n", "samples = 1\n", "value = [1, 0]\n", "\n", "\n", "\n", "36->40\n", "\n", "\n", "\n", "\n", "\n", "38\n", "\n", "gini = 0.0\n", "samples = 1\n", "value = [1, 0]\n", "\n", "\n", "\n", "37->38\n", "\n", "\n", "\n", "\n", "\n", "39\n", "\n", "gini = 0.375\n", "samples = 4\n", "value = [3, 1]\n", "\n", "\n", "\n", "37->39\n", "\n", "\n", "\n", "\n", "\n", "43\n", "\n", "cp_1.0 <= 0.5\n", "gini = 0.426\n", "samples = 26\n", "value = [18, 8]\n", "\n", "\n", "\n", "42->43\n", "\n", "\n", "\n", "\n", "\n", "48\n", "\n", "gini = 0.0\n", "samples = 1\n", "value = [0, 1]\n", "\n", "\n", "\n", "42->48\n", "\n", "\n", "\n", "\n", "\n", "44\n", "\n", "slope_1 <= 0.5\n", "gini = 0.363\n", "samples = 21\n", "value = [16, 5]\n", "\n", "\n", "\n", "43->44\n", "\n", "\n", "\n", "\n", "\n", "47\n", "\n", "gini = 0.48\n", "samples = 5\n", "value = [2, 3]\n", "\n", "\n", "\n", "43->47\n", "\n", "\n", "\n", "\n", "\n", "45\n", "\n", "gini = 0.444\n", "samples = 3\n", "value = [2, 1]\n", "\n", "\n", "\n", "44->45\n", "\n", "\n", "\n", "\n", "\n", "46\n", "\n", "gini = 0.346\n", "samples = 18\n", "value = [14, 4]\n", "\n", "\n", "\n", "44->46\n", "\n", "\n", "\n", "\n", "\n", "50\n", "\n", "cp_3.0 <= 0.5\n", "gini = 0.5\n", "samples = 72\n", "value = [37, 35]\n", "\n", "\n", "\n", "49->50\n", "\n", "\n", "\n", "\n", "\n", "61\n", "\n", "gini = 0.0\n", "samples = 2\n", "value = [0, 2]\n", "\n", "\n", "\n", "49->61\n", "\n", "\n", "\n", "\n", "\n", "51\n", "\n", "thal_3.0 <= 0.5\n", "gini = 0.444\n", "samples = 12\n", "value = [4, 8]\n", "\n", "\n", "\n", "50->51\n", "\n", "\n", "\n", "\n", "\n", "54\n", "\n", "thal_3.0 <= 0.5\n", "gini = 0.495\n", "samples = 60\n", "value = [33, 27]\n", "\n", "\n", "\n", "50->54\n", "\n", "\n", "\n", "\n", "\n", "52\n", "\n", "gini = 0.48\n", "samples = 5\n", "value = [3, 2]\n", "\n", "\n", "\n", "51->52\n", "\n", "\n", "\n", "\n", "\n", "53\n", "\n", "gini = 0.245\n", "samples = 7\n", "value = [1, 6]\n", "\n", "\n", "\n", "51->53\n", "\n", "\n", "\n", "\n", "\n", "55\n", "\n", "exang <= 0.5\n", "gini = 0.444\n", "samples = 12\n", "value = [4, 8]\n", "\n", "\n", "\n", "54->55\n", "\n", "\n", "\n", "\n", "\n", "58\n", "\n", "exang <= 0.5\n", "gini = 0.478\n", "samples = 48\n", "value = [29, 19]\n", "\n", "\n", "\n", "54->58\n", "\n", "\n", "\n", "\n", "\n", "56\n", "\n", "gini = 0.408\n", "samples = 7\n", "value = [2, 5]\n", "\n", "\n", "\n", "55->56\n", "\n", "\n", "\n", "\n", "\n", "57\n", "\n", "gini = 0.48\n", "samples = 5\n", "value = [2, 3]\n", "\n", "\n", "\n", "55->57\n", "\n", "\n", "\n", "\n", "\n", "59\n", "\n", "gini = 0.444\n", "samples = 36\n", "value = [24, 12]\n", "\n", "\n", "\n", "58->59\n", "\n", "\n", "\n", "\n", "\n", "60\n", "\n", "gini = 0.486\n", "samples = 12\n", "value = [5, 7]\n", "\n", "\n", "\n", "58->60\n", "\n", "\n", "\n", "\n", "\n", "63\n", "\n", "slope_1 <= 0.5\n", "gini = 0.434\n", "samples = 22\n", "value = [7.0, 15.0]\n", "\n", "\n", "\n", "62->63\n", "\n", "\n", "\n", "\n", "\n", "72\n", "\n", "gini = 0.0\n", "samples = 5\n", "value = [0, 5]\n", "\n", "\n", "\n", "62->72\n", "\n", "\n", "\n", "\n", "\n", "64\n", "\n", "slope_2 <= 0.5\n", "gini = 0.42\n", "samples = 20\n", "value = [6, 14]\n", "\n", "\n", "\n", "63->64\n", "\n", "\n", "\n", "\n", "\n", "71\n", "\n", "gini = 0.5\n", "samples = 2\n", "value = [1, 1]\n", "\n", "\n", "\n", "63->71\n", "\n", "\n", "\n", "\n", "\n", "65\n", "\n", "gini = 0.0\n", "samples = 2\n", "value = [0, 2]\n", "\n", "\n", "\n", "64->65\n", "\n", "\n", "\n", "\n", "\n", "66\n", "\n", "cp_1.0 <= 0.5\n", "gini = 0.444\n", "samples = 18\n", "value = [6, 12]\n", "\n", "\n", "\n", "64->66\n", "\n", "\n", "\n", "\n", "\n", "67\n", "\n", "thal_6.0 <= 0.5\n", "gini = 0.444\n", "samples = 15\n", "value = [5, 10]\n", "\n", "\n", "\n", "66->67\n", "\n", "\n", "\n", "\n", "\n", "70\n", "\n", "gini = 0.444\n", "samples = 3\n", "value = [1, 2]\n", "\n", "\n", "\n", "66->70\n", "\n", "\n", "\n", "\n", "\n", "68\n", "\n", "gini = 0.444\n", "samples = 12\n", "value = [4, 8]\n", "\n", "\n", "\n", "67->68\n", "\n", "\n", "\n", "\n", "\n", "69\n", "\n", "gini = 0.444\n", "samples = 3\n", "value = [1, 2]\n", "\n", "\n", "\n", "67->69\n", "\n", "\n", "\n", "\n", "\n", "74\n", "\n", "thal_7.0 <= 0.5\n", "gini = 0.242\n", "samples = 78\n", "value = [67, 11]\n", "\n", "\n", "\n", "73->74\n", "\n", "\n", "\n", "\n", "\n", "89\n", "\n", "slope_2 <= 0.5\n", "gini = 0.42\n", "samples = 10\n", "value = [3, 7]\n", "\n", "\n", "\n", "73->89\n", "\n", "\n", "\n", "\n", "\n", "75\n", "\n", "restecg_1 <= 0.5\n", "gini = 0.185\n", "samples = 68\n", "value = [61, 7]\n", "\n", "\n", "\n", "74->75\n", "\n", "\n", "\n", "\n", "\n", "84\n", "\n", "slope_2 <= 0.5\n", "gini = 0.48\n", "samples = 10\n", "value = [6, 4]\n", "\n", "\n", "\n", "74->84\n", "\n", "\n", "\n", "\n", "\n", "76\n", "\n", "slope_2 <= 0.5\n", "gini = 0.1\n", "samples = 57\n", "value = [54, 3]\n", "\n", "\n", "\n", "75->76\n", "\n", "\n", "\n", "\n", "\n", "81\n", "\n", "slope_2 <= 0.5\n", "gini = 0.463\n", "samples = 11\n", "value = [7, 4]\n", "\n", "\n", "\n", "75->81\n", "\n", "\n", "\n", "\n", "\n", "77\n", "\n", "gini = 0.0\n", "samples = 16\n", "value = [16, 0]\n", "\n", "\n", "\n", "76->77\n", "\n", "\n", "\n", "\n", "\n", "78\n", "\n", "thal_3.0 <= 0.5\n", "gini = 0.136\n", "samples = 41\n", "value = [38, 3]\n", "\n", "\n", "\n", "76->78\n", "\n", "\n", "\n", "\n", "\n", "79\n", "\n", "gini = 0.0\n", "samples = 2\n", "value = [2, 0]\n", "\n", "\n", "\n", "78->79\n", "\n", "\n", "\n", "\n", "\n", "80\n", "\n", "gini = 0.142\n", "samples = 39\n", "value = [36, 3]\n", "\n", "\n", "\n", "78->80\n", "\n", "\n", "\n", "\n", "\n", "82\n", "\n", "gini = 0.444\n", "samples = 3\n", "value = [2, 1]\n", "\n", "\n", "\n", "81->82\n", "\n", "\n", "\n", "\n", "\n", "83\n", "\n", "gini = 0.469\n", "samples = 8\n", "value = [5, 3]\n", "\n", "\n", "\n", "81->83\n", "\n", "\n", "\n", "\n", "\n", "85\n", "\n", "slope_1 <= 0.5\n", "gini = 0.5\n", "samples = 6\n", "value = [3, 3]\n", "\n", "\n", "\n", "84->85\n", "\n", "\n", "\n", "\n", "\n", "88\n", "\n", "gini = 0.375\n", "samples = 4\n", "value = [3, 1]\n", "\n", "\n", "\n", "84->88\n", "\n", "\n", "\n", "\n", "\n", "86\n", "\n", "gini = 0.0\n", "samples = 1\n", "value = [0, 1]\n", "\n", "\n", "\n", "85->86\n", "\n", "\n", "\n", "\n", "\n", "87\n", "\n", "gini = 0.48\n", "samples = 5\n", "value = [3, 2]\n", "\n", "\n", "\n", "85->87\n", "\n", "\n", "\n", "\n", "\n", "90\n", "\n", "gini = 0.0\n", "samples = 1\n", "value = [1, 0]\n", "\n", "\n", "\n", "89->90\n", "\n", "\n", "\n", "\n", "\n", "91\n", "\n", "restecg_1 <= 0.5\n", "gini = 0.346\n", "samples = 9\n", "value = [2, 7]\n", "\n", "\n", "\n", "89->91\n", "\n", "\n", "\n", "\n", "\n", "92\n", "\n", "thal_6.0 <= 0.5\n", "gini = 0.408\n", "samples = 7\n", "value = [2, 5]\n", "\n", "\n", "\n", "91->92\n", "\n", "\n", "\n", "\n", "\n", "95\n", "\n", "gini = 0.0\n", "samples = 2\n", "value = [0, 2]\n", "\n", "\n", "\n", "91->95\n", "\n", "\n", "\n", "\n", "\n", "93\n", "\n", "gini = 0.444\n", "samples = 6\n", "value = [2, 4]\n", "\n", "\n", "\n", "92->93\n", "\n", "\n", "\n", "\n", "\n", "94\n", "\n", "gini = 0.0\n", "samples = 1\n", "value = [0, 1]\n", "\n", "\n", "\n", "92->94\n", "\n", "\n", "\n", "\n", "\n", "97\n", "\n", "thal_3.0 <= 0.5\n", "gini = 0.462\n", "samples = 160\n", "value = [58, 102]\n", "\n", "\n", "\n", "96->97\n", "\n", "\n", "\n", "\n", "\n", "130\n", "\n", "sex <= 0.5\n", "gini = 0.202\n", "samples = 210\n", "value = [24, 186]\n", "\n", "\n", "\n", "96->130\n", "\n", "\n", "\n", "\n", "\n", "98\n", "\n", "slope_1 <= 0.5\n", "gini = 0.236\n", "samples = 44\n", "value = [6, 38]\n", "\n", "\n", "\n", "97->98\n", "\n", "\n", "\n", "\n", "\n", "111\n", "\n", "sex <= 0.5\n", "gini = 0.495\n", "samples = 116\n", "value = [52, 64]\n", "\n", "\n", "\n", "97->111\n", "\n", "\n", "\n", "\n", "\n", "99\n", "\n", "restecg_1 <= 0.5\n", "gini = 0.111\n", "samples = 34\n", "value = [2, 32]\n", "\n", "\n", "\n", "98->99\n", "\n", "\n", "\n", "\n", "\n", "106\n", "\n", "thal_6.0 <= 0.5\n", "gini = 0.48\n", "samples = 10\n", "value = [4, 6]\n", "\n", "\n", "\n", "98->106\n", "\n", "\n", "\n", "\n", "\n", "100\n", "\n", "thal_6.0 <= 0.5\n", "gini = 0.147\n", "samples = 25\n", "value = [2, 23]\n", "\n", "\n", "\n", "99->100\n", "\n", "\n", "\n", "\n", "\n", "105\n", "\n", "gini = 0.0\n", "samples = 9\n", "value = [0, 9]\n", "\n", "\n", "\n", "99->105\n", "\n", "\n", "\n", "\n", "\n", "101\n", "\n", "sex <= 0.5\n", "gini = 0.095\n", "samples = 20\n", "value = [1, 19]\n", "\n", "\n", "\n", "100->101\n", "\n", "\n", "\n", "\n", "\n", "104\n", "\n", "gini = 0.32\n", "samples = 5\n", "value = [1, 4]\n", "\n", "\n", "\n", "100->104\n", "\n", "\n", "\n", "\n", "\n", "102\n", "\n", "gini = 0.0\n", "samples = 4\n", "value = [0, 4]\n", "\n", "\n", "\n", "101->102\n", "\n", "\n", "\n", "\n", "\n", "103\n", "\n", "gini = 0.117\n", "samples = 16\n", "value = [1, 15]\n", "\n", "\n", "\n", "101->103\n", "\n", "\n", "\n", "\n", "\n", "107\n", "\n", "restecg_1 <= 0.5\n", "gini = 0.444\n", "samples = 9\n", "value = [3, 6]\n", "\n", "\n", "\n", "106->107\n", "\n", "\n", "\n", "\n", "\n", "110\n", "\n", "gini = 0.0\n", "samples = 1\n", "value = [1, 0]\n", "\n", "\n", "\n", "106->110\n", "\n", "\n", "\n", "\n", "\n", "108\n", "\n", "gini = 0.469\n", "samples = 8\n", "value = [3, 5]\n", "\n", "\n", "\n", "107->108\n", "\n", "\n", "\n", "\n", "\n", "109\n", "\n", "gini = 0.0\n", "samples = 1\n", "value = [0, 1]\n", "\n", "\n", "\n", "107->109\n", "\n", "\n", "\n", "\n", "\n", "112\n", "\n", "slope_2 <= 0.5\n", "gini = 0.397\n", "samples = 22\n", "value = [16, 6]\n", "\n", "\n", "\n", "111->112\n", "\n", "\n", "\n", "\n", "\n", "119\n", "\n", "slope_1 <= 0.5\n", "gini = 0.473\n", "samples = 94\n", "value = [36, 58]\n", "\n", "\n", "\n", "111->119\n", "\n", "\n", "\n", "\n", "\n", "113\n", "\n", "slope_1 <= 0.5\n", "gini = 0.48\n", "samples = 10\n", "value = [6, 4]\n", "\n", "\n", "\n", "112->113\n", "\n", "\n", "\n", "\n", "\n", "116\n", "\n", "restecg_1 <= 0.5\n", "gini = 0.278\n", "samples = 12\n", "value = [10, 2]\n", "\n", "\n", "\n", "112->116\n", "\n", "\n", "\n", "\n", "\n", "114\n", "\n", "gini = 0.0\n", "samples = 2\n", "value = [0, 2]\n", "\n", "\n", "\n", "113->114\n", "\n", "\n", "\n", "\n", "\n", "115\n", "\n", "gini = 0.375\n", "samples = 8\n", "value = [6, 2]\n", "\n", "\n", "\n", "113->115\n", "\n", "\n", "\n", "\n", "\n", "117\n", "\n", "gini = 0.198\n", "samples = 9\n", "value = [8, 1]\n", "\n", "\n", "\n", "116->117\n", "\n", "\n", "\n", "\n", "\n", "118\n", "\n", "gini = 0.444\n", "samples = 3\n", "value = [2, 1]\n", "\n", "\n", "\n", "116->118\n", "\n", "\n", "\n", "\n", "\n", "120\n", "\n", "restecg_1 <= 0.5\n", "gini = 0.477\n", "samples = 74\n", "value = [29, 45]\n", "\n", "\n", "\n", "119->120\n", "\n", "\n", "\n", "\n", "\n", "127\n", "\n", "restecg_1 <= 0.5\n", "gini = 0.455\n", "samples = 20\n", "value = [7, 13]\n", "\n", "\n", "\n", "119->127\n", "\n", "\n", "\n", "\n", "\n", "121\n", "\n", "slope_2 <= 0.5\n", "gini = 0.46\n", "samples = 53\n", "value = [19, 34]\n", "\n", "\n", "\n", "120->121\n", "\n", "\n", "\n", "\n", "\n", "124\n", "\n", "slope_2 <= 0.5\n", "gini = 0.499\n", "samples = 21\n", "value = [10, 11]\n", "\n", "\n", "\n", "120->124\n", "\n", "\n", "\n", "\n", "\n", "122\n", "\n", "gini = 0.444\n", "samples = 3\n", "value = [1, 2]\n", "\n", "\n", "\n", "121->122\n", "\n", "\n", "\n", "\n", "\n", "123\n", "\n", "gini = 0.461\n", "samples = 50\n", "value = [18, 32]\n", "\n", "\n", "\n", "121->123\n", "\n", "\n", "\n", "\n", "\n", "125\n", "\n", "gini = 0.0\n", "samples = 1\n", "value = [1, 0]\n", "\n", "\n", "\n", "124->125\n", "\n", "\n", "\n", "\n", "\n", "126\n", "\n", "gini = 0.495\n", "samples = 20\n", "value = [9, 11]\n", "\n", "\n", "\n", "124->126\n", "\n", "\n", "\n", "\n", "\n", "128\n", "\n", "gini = 0.492\n", "samples = 16\n", "value = [7, 9]\n", "\n", "\n", "\n", "127->128\n", "\n", "\n", "\n", "\n", "\n", "129\n", "\n", "gini = 0.0\n", "samples = 4\n", "value = [0, 4]\n", "\n", "\n", "\n", "127->129\n", "\n", "\n", "\n", "\n", "\n", "131\n", "\n", "thal_3.0 <= 0.5\n", "gini = 0.384\n", "samples = 27\n", "value = [7, 20]\n", "\n", "\n", "\n", "130->131\n", "\n", "\n", "\n", "\n", "\n", "140\n", "\n", "restecg_1 <= 0.5\n", "gini = 0.169\n", "samples = 183\n", "value = [17, 166]\n", "\n", "\n", "\n", "130->140\n", "\n", "\n", "\n", "\n", "\n", "132\n", "\n", "gini = 0.0\n", "samples = 9\n", "value = [0, 9]\n", "\n", "\n", "\n", "131->132\n", "\n", "\n", "\n", "\n", "\n", "133\n", "\n", "slope_1 <= 0.5\n", "gini = 0.475\n", "samples = 18\n", "value = [7, 11]\n", "\n", "\n", "\n", "131->133\n", "\n", "\n", "\n", "\n", "\n", "134\n", "\n", "slope_2 <= 0.5\n", "gini = 0.444\n", "samples = 15\n", "value = [5, 10]\n", "\n", "\n", "\n", "133->134\n", "\n", "\n", "\n", "\n", "\n", "139\n", "\n", "gini = 0.444\n", "samples = 3\n", "value = [2, 1]\n", "\n", "\n", "\n", "133->139\n", "\n", "\n", "\n", "\n", "\n", "135\n", "\n", "gini = 0.0\n", "samples = 1\n", "value = [0, 1]\n", "\n", "\n", "\n", "134->135\n", "\n", "\n", "\n", "\n", "\n", "136\n", "\n", "restecg_1 <= 0.5\n", "gini = 0.459\n", "samples = 14\n", "value = [5, 9]\n", "\n", "\n", "\n", "134->136\n", "\n", "\n", "\n", "\n", "\n", "137\n", "\n", "gini = 0.463\n", "samples = 11\n", "value = [4, 7]\n", "\n", "\n", "\n", "136->137\n", "\n", "\n", "\n", "\n", "\n", "138\n", "\n", "gini = 0.444\n", "samples = 3\n", "value = [1, 2]\n", "\n", "\n", "\n", "136->138\n", "\n", "\n", "\n", "\n", "\n", "141\n", "\n", "slope_2 <= 0.5\n", "gini = 0.146\n", "samples = 139\n", "value = [11, 128]\n", "\n", "\n", "\n", "140->141\n", "\n", "\n", "\n", "\n", "\n", "156\n", "\n", "slope_2 <= 0.5\n", "gini = 0.236\n", "samples = 44\n", "value = [6, 38]\n", "\n", "\n", "\n", "140->156\n", "\n", "\n", "\n", "\n", "\n", "142\n", "\n", "slope_1 <= 0.5\n", "gini = 0.206\n", "samples = 43\n", "value = [5, 38]\n", "\n", "\n", "\n", "141->142\n", "\n", "\n", "\n", "\n", "\n", "151\n", "\n", "thal_3.0 <= 0.5\n", "gini = 0.117\n", "samples = 96\n", "value = [6, 90]\n", "\n", "\n", "\n", "141->151\n", "\n", "\n", "\n", "\n", "\n", "143\n", "\n", "thal_7.0 <= 0.5\n", "gini = 0.172\n", "samples = 21\n", "value = [2, 19]\n", "\n", "\n", "\n", "142->143\n", "\n", "\n", "\n", "\n", "\n", "148\n", "\n", "thal_3.0 <= 0.5\n", "gini = 0.236\n", "samples = 22\n", "value = [3, 19]\n", "\n", "\n", "\n", "142->148\n", "\n", "\n", "\n", "\n", "\n", "144\n", "\n", "thal_6.0 <= 0.5\n", "gini = 0.142\n", "samples = 13\n", "value = [1, 12]\n", "\n", "\n", "\n", "143->144\n", "\n", "\n", "\n", "\n", "\n", "147\n", "\n", "gini = 0.219\n", "samples = 8\n", "value = [1, 7]\n", "\n", "\n", "\n", "143->147\n", "\n", "\n", "\n", "\n", "\n", "145\n", "\n", "gini = 0.153\n", "samples = 12\n", "value = [1, 11]\n", "\n", "\n", "\n", "144->145\n", "\n", "\n", "\n", "\n", "\n", "146\n", "\n", "gini = 0.0\n", "samples = 1\n", "value = [0, 1]\n", "\n", "\n", "\n", "144->146\n", "\n", "\n", "\n", "\n", "\n", "149\n", "\n", "gini = 0.231\n", "samples = 15\n", "value = [2, 13]\n", "\n", "\n", "\n", "148->149\n", "\n", "\n", "\n", "\n", "\n", "150\n", "\n", "gini = 0.245\n", "samples = 7\n", "value = [1, 6]\n", "\n", "\n", "\n", "148->150\n", "\n", "\n", "\n", "\n", "\n", "152\n", "\n", "thal_6.0 <= 0.5\n", "gini = 0.046\n", "samples = 42\n", "value = [1, 41]\n", "\n", "\n", "\n", "151->152\n", "\n", "\n", "\n", "\n", "\n", "155\n", "\n", "gini = 0.168\n", "samples = 54\n", "value = [5, 49]\n", "\n", "\n", "\n", "151->155\n", "\n", "\n", "\n", "\n", "\n", "153\n", "\n", "gini = 0.056\n", "samples = 35\n", "value = [1, 34]\n", "\n", "\n", "\n", "152->153\n", "\n", "\n", "\n", "\n", "\n", "154\n", "\n", "gini = 0.0\n", "samples = 7\n", "value = [0, 7]\n", "\n", "\n", "\n", "152->154\n", "\n", "\n", "\n", "\n", "\n", "157\n", "\n", "gini = 0.0\n", "samples = 12\n", "value = [0, 12]\n", "\n", "\n", "\n", "156->157\n", "\n", "\n", "\n", "\n", "\n", "158\n", "\n", "thal_6.0 <= 0.5\n", "gini = 0.305\n", "samples = 32\n", "value = [6, 26]\n", "\n", "\n", "\n", "156->158\n", "\n", "\n", "\n", "\n", "\n", "159\n", "\n", "thal_7.0 <= 0.5\n", "gini = 0.32\n", "samples = 30\n", "value = [6, 24]\n", "\n", "\n", "\n", "158->159\n", "\n", "\n", "\n", "\n", "\n", "162\n", "\n", "gini = 0.0\n", "samples = 2\n", "value = [0, 2]\n", "\n", "\n", "\n", "158->162\n", "\n", "\n", "\n", "\n", "\n", "160\n", "\n", "gini = 0.302\n", "samples = 27\n", "value = [5, 22]\n", "\n", "\n", "\n", "159->160\n", "\n", "\n", "\n", "\n", "\n", "161\n", "\n", "gini = 0.444\n", "samples = 3\n", "value = [1, 2]\n", "\n", "\n", "\n", "159->161\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree.export_graphviz(clf, feature_names=clf.feature_names_in_, rounded=True, out_file='decision.dot')\n", "\n", "graphviz.Source(open('./decision.dot').read())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.8" } }, "nbformat": 4, "nbformat_minor": 0 }