{ "cells": [ { "cell_type": "code", "execution_count": 43, "metadata": { "id": "l8Y_Fz5_VKUf" }, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "from random import randint\n", "from joblib import dump, load\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "from sklearn.model_selection import train_test_split\n", "from sklearn import tree\n", "from sklearn.tree import DecisionTreeClassifier\n", "from sklearn.metrics import accuracy_score\n", "\n", "import graphviz" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "id": "mIqh1kxmVQ9o" }, "outputs": [], "source": [ "classifiers = ['DecisionTree']\n", "\n", "models = [DecisionTreeClassifier(random_state=0)]\n", "\n", "def split(df,label):\n", " X_train, X_test, Y_train, Y_test = train_test_split(df, label, test_size=0.25, random_state=42, stratify=label)\n", " return X_train, X_test, Y_train, Y_test\n", "\n", "def acc_score(df,label):\n", " score = pd.DataFrame({\"Classifier\":classifiers})\n", " acc = []\n", " X_train,X_test,Y_train,Y_test = split(df,label)\n", " for i in models:\n", " model = i\n", " model.fit(X_train,Y_train)\n", " predictions = model.predict(X_test)\n", " acc.append(accuracy_score(Y_test,predictions))\n", " score[\"Accuracy\"] = acc\n", " score.sort_values(by=\"Accuracy\", ascending=False,inplace = True)\n", " score.reset_index(drop=True, inplace=True)\n", " return score\n", "\n", "def plot(score,x,y,c = \"b\"):\n", " gen = [1,2,3,4,5]\n", " plt.figure(figsize=(6,4))\n", " ax = sns.pointplot(x=gen, y=score,color = c )\n", " ax.set(xlabel=\"Generation\", ylabel=\"Accuracy\")\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "id": "SYWqktBJVQ7I" }, "outputs": [], "source": [ "def initilization_of_population(size,n_feat):\n", " population = []\n", " for i in range(size):\n", " chromosome = np.ones(n_feat, bool)\n", " chromosome[:int(0.3*n_feat)]=False\n", " np.random.shuffle(chromosome)\n", " population.append(chromosome)\n", " return population\n", "\n", "def fitness_score(population):\n", " scores = []\n", " models = []\n", " for chromosome in population:\n", " logmodel = DecisionTreeClassifier(random_state=0)\n", " logmodel.fit(X_train.iloc[:,chromosome], Y_train)\n", " predictions = logmodel.predict(X_test.iloc[:,chromosome])\n", " scores.append(accuracy_score(Y_test,predictions))\n", " models.append(logmodel)\n", " scores, population, models = np.array(scores), np.array(population), np.array(models)\n", " inds = np.argsort(scores)\n", " return list(scores[inds][::-1]), list(population[inds,:][::-1]), list(models[inds][::-1])\n", "\n", "def selection(pop_after_fit, n_parents):\n", " population_nextgen = []\n", " for i in range(n_parents):\n", " population_nextgen.append(pop_after_fit[i])\n", " return population_nextgen\n", "\n", "def crossover(pop_after_sel):\n", " pop_nextgen = pop_after_sel\n", " for i in range(0,len(pop_after_sel),2):\n", " new_par = []\n", " child_1 , child_2 = pop_nextgen[i] , pop_nextgen[i+1]\n", " new_par = np.concatenate((child_1[:len(child_1)//2],child_2[len(child_1)//2:]))\n", " pop_nextgen.append(new_par)\n", " return pop_nextgen\n", "\n", "def mutation(pop_after_cross, mutation_rate, n_feat):\n", " mutation_range = int(mutation_rate * n_feat)\n", " for n in range(64, len(pop_after_cross)):\n", " chromo = pop_after_cross[n]\n", " rand_posi = []\n", " for i in range(0, mutation_range):\n", " pos = randint(0, n_feat-1)\n", " rand_posi.append(pos)\n", " for j in rand_posi:\n", " chromo[j] = not chromo[j]\n", " pop_after_cross[n] = chromo\n", " return pop_after_cross\n", "\n", "def generations(df, label, size, n_feat, n_parents, mutation_rate, n_gen, X_train, X_test, Y_train, Y_test):\n", " best_chromo = []\n", " best_score = []\n", " best_models = []\n", " population_nextgen=initilization_of_population(size,n_feat)\n", " for i in range(n_gen):\n", " scores, pop_after_fit, models = fitness_score(population_nextgen)\n", " print('Best score in generation',i+1,':',scores[:1])\n", "\n", " pop_after_sel = selection(pop_after_fit, n_parents)\n", " pop_after_cross = crossover(pop_after_sel)\n", " population_nextgen = mutation(pop_after_cross, mutation_rate, n_feat)\n", "\n", " best_score.append(scores[0])\n", " best_chromo.append(pop_after_fit[0])\n", " best_models.append(models[0])\n", "\n", " return best_chromo, best_score, best_models" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 273 }, "id": "X3Ww2R2wVQ44", "outputId": "ccd8b6d6-f6ce-4bf5-c476-8e28c11cc0ee" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(920, 23)\n" ] }, { "data": { "text/html": [ "
\n", " | age | \n", "trestbps | \n", "chol | \n", "thalach | \n", "oldpeak | \n", "ca | \n", "sex | \n", "fbs | \n", "exang | \n", "cp_1.0 | \n", "... | \n", "restecg_0 | \n", "restecg_1 | \n", "restecg_2 | \n", "slope_1 | \n", "slope_2 | \n", "slope_3 | \n", "thal_3.0 | \n", "thal_6.0 | \n", "thal_7.0 | \n", "num | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "63.0 | \n", "145.0 | \n", "233.0 | \n", "150.0 | \n", "2.3 | \n", "0.0 | \n", "1.0 | \n", "1 | \n", "0 | \n", "1.0 | \n", "... | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0 | \n", "
1 | \n", "67.0 | \n", "160.0 | \n", "286.0 | \n", "108.0 | \n", "1.5 | \n", "3.0 | \n", "1.0 | \n", "0 | \n", "1 | \n", "0.0 | \n", "... | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1 | \n", "
2 | \n", "67.0 | \n", "120.0 | \n", "229.0 | \n", "129.0 | \n", "2.6 | \n", "2.0 | \n", "1.0 | \n", "0 | \n", "1 | \n", "0.0 | \n", "... | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "1 | \n", "
3 | \n", "37.0 | \n", "130.0 | \n", "250.0 | \n", "187.0 | \n", "3.5 | \n", "0.0 | \n", "1.0 | \n", "0 | \n", "0 | \n", "0.0 | \n", "... | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0 | \n", "
4 | \n", "41.0 | \n", "130.0 | \n", "204.0 | \n", "172.0 | \n", "1.4 | \n", "0.0 | \n", "0.0 | \n", "0 | \n", "0 | \n", "0.0 | \n", "... | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0 | \n", "
5 rows × 23 columns
\n", "\n", " | age | \n", "trestbps | \n", "chol | \n", "thalach | \n", "oldpeak | \n", "ca | \n", "sex | \n", "fbs | \n", "exang | \n", "cp_1.0 | \n", "... | \n", "restecg_0 | \n", "restecg_1 | \n", "restecg_2 | \n", "slope_1 | \n", "slope_2 | \n", "slope_3 | \n", "thal_3.0 | \n", "thal_6.0 | \n", "thal_7.0 | \n", "num | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "63.0 | \n", "145.0 | \n", "233.0 | \n", "150.0 | \n", "2.3 | \n", "0.0 | \n", "1.0 | \n", "1 | \n", "0 | \n", "1.0 | \n", "... | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0 | \n", "
1 | \n", "67.0 | \n", "160.0 | \n", "286.0 | \n", "108.0 | \n", "1.5 | \n", "3.0 | \n", "1.0 | \n", "0 | \n", "1 | \n", "0.0 | \n", "... | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1 | \n", "
2 | \n", "67.0 | \n", "120.0 | \n", "229.0 | \n", "129.0 | \n", "2.6 | \n", "2.0 | \n", "1.0 | \n", "0 | \n", "1 | \n", "0.0 | \n", "... | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "1 | \n", "
3 | \n", "37.0 | \n", "130.0 | \n", "250.0 | \n", "187.0 | \n", "3.5 | \n", "0.0 | \n", "1.0 | \n", "0 | \n", "0 | \n", "0.0 | \n", "... | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0 | \n", "
4 | \n", "41.0 | \n", "130.0 | \n", "204.0 | \n", "172.0 | \n", "1.4 | \n", "0.0 | \n", "0.0 | \n", "0 | \n", "0 | \n", "0.0 | \n", "... | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0 | \n", "
5 rows × 23 columns
\n", "\n", " | age | \n", "trestbps | \n", "chol | \n", "thalach | \n", "oldpeak | \n", "ca | \n", "sex | \n", "fbs | \n", "exang | \n", "cp_1.0 | \n", "... | \n", "cp_4.0 | \n", "restecg_0 | \n", "restecg_1 | \n", "restecg_2 | \n", "slope_1 | \n", "slope_2 | \n", "slope_3 | \n", "thal_3.0 | \n", "thal_6.0 | \n", "thal_7.0 | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "63.0 | \n", "145.0 | \n", "233.0 | \n", "150.0 | \n", "2.3 | \n", "0.0 | \n", "1.0 | \n", "1 | \n", "0 | \n", "1.0 | \n", "... | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "
1 | \n", "67.0 | \n", "160.0 | \n", "286.0 | \n", "108.0 | \n", "1.5 | \n", "3.0 | \n", "1.0 | \n", "0 | \n", "1 | \n", "0.0 | \n", "... | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "
2 | \n", "67.0 | \n", "120.0 | \n", "229.0 | \n", "129.0 | \n", "2.6 | \n", "2.0 | \n", "1.0 | \n", "0 | \n", "1 | \n", "0.0 | \n", "... | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "
3 | \n", "37.0 | \n", "130.0 | \n", "250.0 | \n", "187.0 | \n", "3.5 | \n", "0.0 | \n", "1.0 | \n", "0 | \n", "0 | \n", "0.0 | \n", "... | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "
4 | \n", "41.0 | \n", "130.0 | \n", "204.0 | \n", "172.0 | \n", "1.4 | \n", "0.0 | \n", "0.0 | \n", "0 | \n", "0 | \n", "0.0 | \n", "... | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
915 | \n", "54.0 | \n", "127.0 | \n", "333.0 | \n", "154.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1 | \n", "0 | \n", "0.0 | \n", "... | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "
916 | \n", "62.0 | \n", "130.0 | \n", "139.0 | \n", "140.0 | \n", "0.5 | \n", "0.0 | \n", "1.0 | \n", "0 | \n", "0 | \n", "1.0 | \n", "... | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "
917 | \n", "55.0 | \n", "122.0 | \n", "223.0 | \n", "100.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "1 | \n", "0 | \n", "0.0 | \n", "... | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "
918 | \n", "58.0 | \n", "130.0 | \n", "385.0 | \n", "140.0 | \n", "0.5 | \n", "0.0 | \n", "1.0 | \n", "1 | \n", "0 | \n", "0.0 | \n", "... | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "
919 | \n", "62.0 | \n", "120.0 | \n", "254.0 | \n", "93.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0 | \n", "1 | \n", "0.0 | \n", "... | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "
920 rows × 22 columns
\n", "\n", " | Classifier | \n", "Accuracy | \n", "
---|---|---|
0 | \n", "DecisionTree | \n", "0.717391 | \n", "
\n", " | sex | \n", "exang | \n", "cp_1.0 | \n", "cp_2.0 | \n", "cp_3.0 | \n", "cp_4.0 | \n", "restecg_1 | \n", "slope_1 | \n", "slope_2 | \n", "thal_3.0 | \n", "thal_6.0 | \n", "thal_7.0 | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|
272 | \n", "1.0 | \n", "1 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "
59 | \n", "1.0 | \n", "1 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "
610 | \n", "1.0 | \n", "1 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "
328 | \n", "1.0 | \n", "0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "
804 | \n", "1.0 | \n", "1 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
374 | \n", "0.0 | \n", "0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "
590 | \n", "1.0 | \n", "1 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "
573 | \n", "1.0 | \n", "1 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "
580 | \n", "1.0 | \n", "1 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "
308 | \n", "0.0 | \n", "0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "
230 rows × 12 columns
\n", "\n", " | sex | \n", "exang | \n", "cp_1.0 | \n", "cp_2.0 | \n", "cp_3.0 | \n", "cp_4.0 | \n", "restecg_1 | \n", "slope_1 | \n", "slope_2 | \n", "thal_3.0 | \n", "thal_6.0 | \n", "thal_7.0 | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|
272 | \n", "1.0 | \n", "1 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "
59 | \n", "1.0 | \n", "1 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "
610 | \n", "1.0 | \n", "1 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "
328 | \n", "1.0 | \n", "0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "
804 | \n", "1.0 | \n", "1 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
374 | \n", "0.0 | \n", "0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "
590 | \n", "1.0 | \n", "1 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "
573 | \n", "1.0 | \n", "1 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "
580 | \n", "1.0 | \n", "1 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "
308 | \n", "0.0 | \n", "0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "
230 rows × 12 columns
\n", "