diff --git "a/pkmn-classifier/nlp.ipynb" "b/pkmn-classifier/nlp.ipynb" new file mode 100644--- /dev/null +++ "b/pkmn-classifier/nlp.ipynb" @@ -0,0 +1,2106 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 85, + "metadata": { + "id": "i-B5sPHELBBj" + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis\n", + "from sklearn.dummy import DummyClassifier\n", + "from sklearn.ensemble import (\n", + " AdaBoostClassifier,\n", + " GradientBoostingClassifier,\n", + " RandomForestClassifier,\n", + ")\n", + "from sklearn.feature_extraction.text import TfidfVectorizer\n", + "from sklearn.gaussian_process import GaussianProcessClassifier\n", + "from sklearn.gaussian_process.kernels import RBF\n", + "from sklearn.metrics import accuracy_score, classification_report, confusion_matrix\n", + "from sklearn.model_selection import GridSearchCV, train_test_split\n", + "from sklearn.naive_bayes import GaussianNB, MultinomialNB\n", + "from sklearn.neighbors import KNeighborsClassifier\n", + "from sklearn.neural_network import MLPClassifier\n", + "from sklearn.preprocessing import StandardScaler\n", + "from sklearn.svm import SVC, LinearSVC\n", + "from sklearn.tree import DecisionTreeClassifier" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "metadata": { + "id": "uxhFjfeHLBBr" + }, + "outputs": [], + "source": [ + "# Read the pokedex we scraped in web_scrape.ipynb into a DataFrame\n", + "pkmn = pd.read_csv(\"pokemon.csv\")\n", + "pkmn.rename(columns={\"Unnamed: 0\": \"wiki_index\"}, inplace=True)\n", + "pkmn = pkmn[pkmn.primary_type != \"Bird\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 908 + }, + "id": "1HfxKWKEPpNU", + "outputId": "b3d9ab9d-7cd2-4010-f060-bb7ac2f843f1" + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "pivot_table = pkmn[[\"primary_type\", \"secondary_type\"]].value_counts().unstack().fillna(0)\n", + "long_form = pkmn[[\"primary_type\", \"secondary_type\"]].value_counts()\n", + "\n", + "ax = sns.heatmap(pivot_table)\n", + "ax.figure.set_size_inches(15, 15)" + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "RTUdYUv3LBBt", + "outputId": "1098fdc6-500c-4c9a-ade1-57b380a0ef84" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(935, 6229) (935,)\n" + ] + }, + { + "data": { + "text/plain": [ + "['000',\n", + " '01',\n", + " '02',\n", + " '03',\n", + " '04',\n", + " '05',\n", + " '10',\n", + " '100',\n", + " '1000',\n", + " '100x',\n", + " '101',\n", + " '108',\n", + " '11',\n", + " '12',\n", + " '120',\n", + " '13',\n", + " '14',\n", + " '148',\n", + " '15',\n", + " '150',\n", + " '16',\n", + " '17',\n", + " '18',\n", + " '180',\n", + " '19',\n", + " '1950s',\n", + " '1970s',\n", + " '1997',\n", + " '1px',\n", + " '20',\n", + " '200',\n", + " '2000',\n", + " '2003',\n", + " '2005',\n", + " '2007',\n", + " '2008',\n", + " '2009',\n", + " '2013',\n", + " '2014',\n", + " '2015',\n", + " '2016',\n", + " '2018',\n", + " '2019',\n", + " '2020',\n", + " '2021',\n", + " '20th',\n", + " '21',\n", + " '22',\n", + " '23',\n", + " '23rd',\n", + " '24',\n", + " '242',\n", + " '25',\n", + " '26',\n", + " '27',\n", + " '28',\n", + " '29',\n", + " '294',\n", + " '296',\n", + " '2nd',\n", + " '30',\n", + " '300',\n", + " '31',\n", + " '310',\n", + " '32',\n", + " '33',\n", + " '34',\n", + " '35',\n", + " '36',\n", + " '360',\n", + " '36th',\n", + " '37',\n", + " '370',\n", + " '38',\n", + " '39',\n", + " '390',\n", + " '3ds',\n", + " '3rd',\n", + " '40',\n", + " '400',\n", + " '41',\n", + " '42',\n", + " '43',\n", + " '44',\n", + " '440',\n", + " '45',\n", + " '46',\n", + " '47',\n", + " '48',\n", + " '49',\n", + " '4ever',\n", + " '50',\n", + " '500',\n", + " '5000',\n", + " '50th',\n", + " '51',\n", + " '52',\n", + " '53',\n", + " '54',\n", + " '55',\n", + " '56',\n", + " '57',\n", + " '58',\n", + " '59',\n", + " '60',\n", + " '600',\n", + " '61',\n", + " '610',\n", + " '62',\n", + " '63',\n", + " '630',\n", + " '64',\n", + " '65',\n", + " '650',\n", + " '66',\n", + " '67',\n", + " '68',\n", + " '69',\n", + " '69th',\n", + " '70',\n", + " '71',\n", + " '719',\n", + " '72',\n", + " '721',\n", + " '73',\n", + " '74',\n", + " '75',\n", + " '76',\n", + " '767',\n", + " '77',\n", + " '78',\n", + " '79',\n", + " '80',\n", + " '82',\n", + " '90',\n", + " '900',\n", + " '95',\n", + " '967',\n", + " 'abandoned',\n", + " 'abbreviation',\n", + " 'abilities',\n", + " 'ability',\n", + " 'able',\n", + " 'abnormalities',\n", + " 'abominable',\n", + " 'abra',\n", + " 'absence',\n", + " 'absolutely',\n", + " 'absorb',\n", + " 'absorbed',\n", + " 'absorbing',\n", + " 'absorbs',\n", + " 'absorption',\n", + " 'abundance',\n", + " 'abundantly',\n", + " 'acanthaster',\n", + " 'accelerate',\n", + " 'accelgor',\n", + " 'accept',\n", + " 'accepted',\n", + " 'accompanies',\n", + " 'according',\n", + " 'account',\n", + " 'accurately',\n", + " 'accustomed',\n", + " 'ace',\n", + " 'acerola',\n", + " 'achieve',\n", + " 'acid',\n", + " 'acknowledge',\n", + " 'acknowledged',\n", + " 'acquire',\n", + " 'acquires',\n", + " 'acrobatic',\n", + " 'act',\n", + " 'acting',\n", + " 'actions',\n", + " 'activate',\n", + " 'activated',\n", + " 'active',\n", + " 'actively',\n", + " 'actor',\n", + " 'acts',\n", + " 'actual',\n", + " 'actually',\n", + " 'added',\n", + " 'adding',\n", + " 'addition',\n", + " 'additional',\n", + " 'additionally',\n", + " 'adept',\n", + " 'adhesive',\n", + " 'adopted',\n", + " 'adorable',\n", + " 'adorn',\n", + " 'adorned',\n", + " 'adornment',\n", + " 'adult',\n", + " 'advance',\n", + " 'advanced',\n", + " 'advantage',\n", + " 'adventure',\n", + " 'aegislash',\n", + " 'aether',\n", + " 'affect',\n", + " 'affecting',\n", + " 'affection',\n", + " 'affects',\n", + " 'affinity',\n", + " 'afford',\n", + " 'aforementioned',\n", + " 'afro',\n", + " 'afterlife',\n", + " 'afterward',\n", + " 'agent',\n", + " 'ages',\n", + " 'aggressive',\n", + " 'aggressively',\n", + " 'agile',\n", + " 'agility',\n", + " 'agitated',\n", + " 'agitates',\n", + " 'ago',\n", + " 'agriculture',\n", + " 'aipom',\n", + " 'air',\n", + " 'airborne',\n", + " 'aircraft',\n", + " 'akala',\n", + " 'akin',\n", + " 'alakazam',\n", + " 'albeit',\n", + " 'alchemic',\n", + " 'alcremie',\n", + " 'alien',\n", + " 'aliens',\n", + " 'alike',\n", + " 'alive',\n", + " 'allies',\n", + " 'alligator',\n", + " 'alligators',\n", + " 'allow',\n", + " 'allowed',\n", + " 'allowing',\n", + " 'allows',\n", + " 'alloy',\n", + " 'alludes',\n", + " 'alola',\n", + " 'alolan',\n", + " 'alomomola',\n", + " 'alongside',\n", + " 'alpha',\n", + " 'alphabet',\n", + " 'alright',\n", + " 'altaria',\n", + " 'altered',\n", + " 'altering',\n", + " 'alternate',\n", + " 'alternative',\n", + " 'alternatively',\n", + " 'altitudes',\n", + " 'amargasaurus',\n", + " 'amazing',\n", + " 'amber',\n", + " 'america',\n", + " 'american',\n", + " 'ammonites',\n", + " 'amoeba',\n", + " 'amped',\n", + " 'amperage',\n", + " 'ampharos',\n", + " 'amphibian',\n", + " 'amusing',\n", + " 'anaconda',\n", + " 'ancestor',\n", + " 'anchor',\n", + " 'anchors',\n", + " 'ancient',\n", + " 'andrew',\n", + " 'angel',\n", + " 'anger',\n", + " 'angered',\n", + " 'anglerfish',\n", + " 'angles',\n", + " 'angry',\n", + " 'animal',\n", + " 'anime',\n", + " 'animon',\n", + " 'ankylosaurus',\n", + " 'announcement',\n", + " 'annoyance',\n", + " 'anomalocaris',\n", + " 'anorith',\n", + " 'answer',\n", + " 'answers',\n", + " 'antagonist',\n", + " 'antagonists',\n", + " 'anteater',\n", + " 'antenna',\n", + " 'antennae',\n", + " 'antimatter',\n", + " 'antique',\n", + " 'antlers',\n", + " 'antlion',\n", + " 'anymore',\n", + " 'aoki',\n", + " 'apart',\n", + " 'apocalypse',\n", + " 'apparent',\n", + " 'apparently',\n", + " 'appealing',\n", + " 'appear',\n", + " 'appearance',\n", + " 'appearances',\n", + " 'appeared',\n", + " 'appearing',\n", + " 'appears',\n", + " 'appendages',\n", + " 'apple',\n", + " 'apples',\n", + " 'appletun',\n", + " 'appliances',\n", + " 'applin',\n", + " 'appreciation',\n", + " 'approaches',\n", + " 'appropriation',\n", + " 'approximately',\n", + " 'april',\n", + " 'aqua',\n", + " 'aquatic',\n", + " 'aramis',\n", + " 'araquanid',\n", + " 'arbok',\n", + " 'arcanine',\n", + " 'arceus',\n", + " 'arch',\n", + " 'archaeopteryx',\n", + " 'archen',\n", + " 'archeops',\n", + " 'archie',\n", + " 'arctic',\n", + " 'arctovish',\n", + " 'arctozolts',\n", + " 'area',\n", + " 'areas',\n", + " 'aren',\n", + " 'aria',\n", + " 'ariados',\n", + " 'arid',\n", + " 'ariga',\n", + " 'arm',\n", + " 'armed',\n", + " 'armless',\n", + " 'armor',\n", + " 'armored',\n", + " 'armour',\n", + " 'arms',\n", + " 'aroma',\n", + " 'aromatisse',\n", + " 'aron',\n", + " 'arranges',\n", + " 'arrive',\n", + " 'arrives',\n", + " 'arrokuda',\n", + " 'arrow',\n", + " 'arrows',\n", + " 'art',\n", + " 'artagnan',\n", + " 'artillery',\n", + " 'artist',\n", + " 'arts',\n", + " 'artwork',\n", + " 'ascend',\n", + " 'ascended',\n", + " 'ash',\n", + " 'asia',\n", + " 'asian',\n", + " 'aside',\n", + " 'asking',\n", + " 'asleep',\n", + " 'aspects',\n", + " 'ass',\n", + " 'assault',\n", + " 'assembly',\n", + " 'assist',\n", + " 'assists',\n", + " 'associated',\n", + " 'association',\n", + " 'assume',\n", + " 'assuming',\n", + " 'astral',\n", + " 'ate',\n", + " 'athos',\n", + " 'atk',\n", + " 'atmosphere',\n", + " 'atop',\n", + " 'atrophied',\n", + " 'atsuko',\n", + " 'attach',\n", + " 'attached',\n", + " 'attaches',\n", + " 'attack',\n", + " 'attacked',\n", + " 'attackers',\n", + " 'attacking',\n", + " 'attacks',\n", + " 'attempt',\n", + " 'attempted',\n", + " 'attempts',\n", + " 'attention',\n", + " 'attitude',\n", + " 'attract',\n", + " 'attracted',\n", + " 'attractive',\n", + " 'attracts',\n", + " 'audience',\n", + " 'audiences',\n", + " 'audino',\n", + " 'augite',\n", + " 'augurite',\n", + " 'august',\n", + " 'aura',\n", + " 'auras',\n", + " 'aurorus',\n", + " 'australia',\n", + " 'authentic',\n", + " 'authenticity',\n", + " 'automatically',\n", + " 'available',\n", + " 'avoid',\n", + " 'avoiding',\n", + " 'avoids',\n", + " 'awaits',\n", + " 'awake',\n", + " 'away',\n", + " 'awe',\n", + " 'awesome',\n", + " 'awful',\n", + " 'awhile',\n", + " 'awkward',\n", + " 'awoken',\n", + " 'az',\n", + " 'azelf',\n", + " 'azumarill',\n", + " 'azure',\n", + " 'azurill',\n", + " 'babies',\n", + " 'baboon',\n", + " 'baby',\n", + " 'backfish',\n", + " 'backstory',\n", + " 'backwards',\n", + " 'bacteria',\n", + " 'bad',\n", + " 'bag',\n", + " 'bagon',\n", + " 'bags',\n", + " 'baile',\n", + " 'baku',\n", + " 'balance',\n", + " 'bald',\n", + " 'ball',\n", + " 'balloon',\n", + " 'balloons',\n", + " 'balls',\n", + " 'bamboo',\n", + " 'band',\n", + " 'bandai',\n", + " 'banette',\n", + " 'bangs',\n", + " 'banished',\n", + " 'banned',\n", + " 'bar',\n", + " 'baragon',\n", + " 'barb',\n", + " 'barbaracle',\n", + " 'barbed',\n", + " 'barboach',\n", + " 'bare',\n", + " 'barely',\n", + " 'bares',\n", + " 'bark',\n", + " 'barks',\n", + " 'barn',\n", + " 'barnacle',\n", + " 'barnacles',\n", + " 'barracuda',\n", + " 'barrage',\n", + " 'barraskewda',\n", + " 'barreled',\n", + " 'barriers',\n", + " 'basculegion',\n", + " 'basculin',\n", + " 'base',\n", + " 'based',\n", + " 'bash',\n", + " 'bashir',\n", + " 'basic',\n", + " 'basis',\n", + " 'basking',\n", + " 'basks',\n", + " 'bass',\n", + " 'bat',\n", + " 'bathed',\n", + " 'bathes',\n", + " 'battery',\n", + " 'battle',\n", + " 'battled',\n", + " 'battles',\n", + " 'battling',\n", + " 'bay',\n", + " 'bayleef',\n", + " 'beach',\n", + " 'beaches',\n", + " 'beady',\n", + " 'beak',\n", + " 'beaks',\n", + " 'beam',\n", + " 'beams',\n", + " 'bean',\n", + " 'beans',\n", + " 'bear',\n", + " 'beard',\n", + " 'bearing',\n", + " 'bears',\n", + " 'beast',\n", + " 'beasts',\n", + " 'beat',\n", + " 'beaten',\n", + " 'beats',\n", + " 'beautifly',\n", + " 'beautiful',\n", + " 'beauty',\n", + " 'beaver',\n", + " 'becalms',\n", + " 'bee',\n", + " 'beedrill',\n", + " 'beefy',\n", + " 'beehive',\n", + " 'bees',\n", + " 'beetle',\n", + " 'befriends',\n", + " 'befuddle',\n", + " 'began',\n", + " 'begging',\n", + " 'begin',\n", + " 'begins',\n", + " 'behavior',\n", + " 'behemoth',\n", + " 'beings',\n", + " 'belch',\n", + " 'beldum',\n", + " 'beldums',\n", + " 'belief',\n", + " 'believe',\n", + " 'believed',\n", + " 'bell',\n", + " 'belligerently',\n", + " 'bellossom',\n", + " 'bells',\n", + " 'bellsprout',\n", + " 'belly',\n", + " 'belonged',\n", + " 'belonging',\n", + " 'belongs',\n", + " 'belt',\n", + " 'belts',\n", + " 'bem',\n", + " 'ben',\n", + " 'beneath',\n", + " 'benefits',\n", + " 'berries',\n", + " 'berry',\n", + " 'berserk',\n", + " 'best',\n", + " 'bestowing',\n", + " 'beta',\n", + " 'betobebi',\n", + " 'betray',\n", + " 'better',\n", + " 'bewear',\n", + " 'bewitching',\n", + " 'bibarel',\n", + " 'bidoof',\n", + " 'big',\n", + " 'bigger',\n", + " 'billions',\n", + " 'binacle',\n", + " 'binnacle',\n", + " 'biped',\n", + " 'bipedal',\n", + " 'bird',\n", + " 'birds',\n", + " 'birthed',\n", + " 'bisharp',\n", + " 'bison',\n", + " 'bit',\n", + " 'bite',\n", + " 'bites',\n", + " 'biting',\n", + " 'bits',\n", + " 'bitterly',\n", + " 'bivalve',\n", + " 'blacephalon',\n", + " 'black',\n", + " 'blackened',\n", + " 'blackface',\n", + " 'blade',\n", + " 'blades',\n", + " 'blaine',\n", + " 'blast',\n", + " 'blaster',\n", + " 'blasting',\n", + " 'blastoise',\n", + " 'blasts',\n", + " 'blaze',\n", + " 'blaziken',\n", + " 'bleaching',\n", + " 'blend',\n", + " 'blessings',\n", + " 'blind',\n", + " 'blinding',\n", + " 'blipbug',\n", + " 'blissey',\n", + " 'blizzard',\n", + " 'blizzards',\n", + " 'blob',\n", + " 'block',\n", + " 'blocking',\n", + " 'blocks',\n", + " 'blocky',\n", + " 'blond',\n", + " 'blood',\n", + " 'bloom',\n", + " 'bloomed',\n", + " 'blooms',\n", + " 'blossom',\n", + " 'blossoms',\n", + " 'blow',\n", + " 'blowing',\n", + " 'blown',\n", + " 'blows',\n", + " 'blubber',\n", + " 'blue',\n", + " 'blueish',\n", + " 'boats',\n", + " 'bodied',\n", + " 'bodies',\n", + " 'bodily',\n", + " 'body',\n", + " 'boeing',\n", + " 'bogos',\n", + " 'boil',\n", + " 'bolts',\n", + " 'boltund',\n", + " 'bomb',\n", + " 'bombs',\n", + " 'bond',\n", + " 'bonded',\n", + " 'bone',\n", + " 'bonemurang',\n", + " 'bones',\n", + " 'bonnets',\n", + " 'bonsly',\n", + " 'bonus',\n", + " 'boomerang',\n", + " 'boost',\n", + " 'boosting',\n", + " 'boosts',\n", + " 'border',\n", + " 'bordering',\n", + " 'born',\n", + " 'borrows',\n", + " 'boss',\n", + " 'bother',\n", + " 'bothered',\n", + " 'bottle',\n", + " 'bottles',\n", + " 'bouffalant',\n", + " 'boulder',\n", + " 'boulders',\n", + " 'bounce',\n", + " 'bounces',\n", + " 'bouncing',\n", + " 'bouncy',\n", + " 'bounding',\n", + " 'bounsweet',\n", + " 'bountiful',\n", + " 'bow',\n", + " 'bowie',\n", + " 'bowl',\n", + " 'bows',\n", + " 'boxer',\n", + " 'boxers',\n", + " 'boxing',\n", + " 'boy',\n", + " 'bracelets',\n", + " 'braids',\n", + " 'brain',\n", + " 'brains',\n", + " 'brainwaves',\n", + " 'braise',\n", + " 'braixen',\n", + " 'bramble',\n", + " 'branch',\n", + " 'branches',\n", + " 'branching',\n", + " 'brass',\n", + " 'brave',\n", + " 'bravery',\n", + " 'brawl',\n", + " 'brawly',\n", + " 'break',\n", + " 'breaking',\n", + " 'breaks',\n", + " 'breath',\n", + " 'breathe',\n", + " 'breathing',\n", + " 'breeding',\n", + " 'bridge',\n", + " 'brief',\n", + " 'brigadier',\n", + " 'bright',\n", + " 'brighter',\n", + " 'brightly',\n", + " 'brightness',\n", + " 'brilliant',\n", + " 'briney',\n", + " 'bring',\n", + " 'bringer',\n", + " 'brings',\n", + " 'brionne',\n", + " 'british',\n", + " 'brittle',\n", + " 'broad',\n", + " 'brock',\n", + " 'broke',\n", + " 'broken',\n", + " 'broom',\n", + " 'bros',\n", + " 'brought',\n", + " 'brown',\n", + " 'browser',\n", + " 'bruce',\n", + " 'brush',\n", + " 'brutal',\n", + " 'bruxish',\n", + " 'bubble',\n", + " 'bubbles',\n", + " 'buck',\n", + " 'bucktoothed',\n", + " 'bud',\n", + " 'budew',\n", + " 'bug',\n", + " 'bugs',\n", + " 'build',\n", + " 'building',\n", + " 'buildings',\n", + " 'builds',\n", + " 'built',\n", + " 'bulb',\n", + " 'bulbasaur',\n", + " 'bulbous',\n", + " 'bulgasari',\n", + " 'bulge',\n", + " 'bulky',\n", + " 'bull',\n", + " 'bullied',\n", + " 'bullies',\n", + " 'bully',\n", + " 'bulu',\n", + " 'bumps',\n", + " 'bunch',\n", + " 'bunnelby',\n", + " 'burglars',\n", + " 'buried',\n", + " 'buries',\n", + " 'burmy',\n", + " 'burn',\n", + " 'burning',\n", + " 'burns',\n", + " 'burnt',\n", + " 'burrow',\n", + " 'burrows',\n", + " 'burst',\n", + " 'bursting',\n", + " 'bushes',\n", + " 'bushy',\n", + " 'busts',\n", + " 'butlers',\n", + " 'butt',\n", + " 'butterfly',\n", + " 'butterflyfish',\n", + " 'buzzwole',\n", + " 'bōzu',\n", + " 'cackles',\n", + " 'cacnea',\n", + " 'cactus',\n", + " 'cafe',\n", + " 'cafes',\n", + " 'cage',\n", + " 'cake',\n", + " 'calculated',\n", + " 'called',\n", + " 'calling',\n", + " 'calls',\n", + " 'calm',\n", + " 'calorie',\n", + " 'calyrex',\n", + " 'came',\n", + " 'camels',\n", + " 'cameo',\n", + " 'camerupt',\n", + " 'camouflage',\n", + " 'camouflaged',\n", + " 'camouflaging',\n", + " 'campers',\n", + " 'canceled',\n", + " 'cancelled',\n", + " 'candies',\n", + " 'candy',\n", + " 'cane',\n", + " 'cannon',\n", + " 'cannonball',\n", + " 'cannons',\n", + " 'canonically',\n", + " 'cantankerous',\n", + " 'canyon',\n", + " 'canyons',\n", + " 'cap',\n", + " 'capabilities',\n", + " 'capable',\n", + " 'caped',\n", + " 'capoeira',\n", + " 'caps',\n", + " 'capture',\n", + " 'captured',\n", + " 'carapace',\n", + " 'caravaggio',\n", + " 'carbink',\n", + " 'carcolh',\n", + " 'card',\n", + " 'cards',\n", + " 'care',\n", + " 'carefree',\n", + " 'careful',\n", + " 'carefully',\n", + " 'carelessly',\n", + " 'cares',\n", + " 'caring',\n", + " 'carkol',\n", + " 'carnotaur',\n", + " 'carp',\n", + " 'carracosta',\n", + " 'carried',\n", + " 'carriers',\n", + " 'carries',\n", + " 'carrot',\n", + " 'carrots',\n", + " 'carry',\n", + " 'carrying',\n", + " 'cars',\n", + " 'cartoonist',\n", + " 'carvanha',\n", + " 'cascoon',\n", + " 'case',\n", + " 'cast',\n", + " 'caste',\n", + " 'castform',\n", + " 'castle',\n", + " 'cat',\n", + " 'catch',\n", + " 'catcher',\n", + " 'catches',\n", + " 'category',\n", + " 'catfish',\n", + " 'caught',\n", + " 'cause',\n", + " 'caused',\n", + " 'causes',\n", + " 'causing',\n", + " 'cautionary',\n", + " 'cave',\n", + " 'cavernous',\n", + " 'caves',\n", + " 'ceilings',\n", + " 'celebi',\n", + " 'celebration',\n", + " 'celesteela',\n", + " 'cell',\n", + " 'cells',\n", + " 'celsius',\n", + " 'cemetery',\n", + " 'center',\n", + " 'centers',\n", + " 'centipede',\n", + " 'centiskorch',\n", + " 'central',\n", + " 'centres',\n", + " 'centrifugal',\n", + " 'ceo',\n", + " 'ceremonies',\n", + " 'certain',\n", + " 'chains',\n", + " 'challenge',\n", + " 'chameleon',\n", + " 'champion',\n", + " 'championship',\n", + " 'championships',\n", + " 'chan',\n", + " 'chance',\n", + " 'chandelure',\n", + " 'change',\n", + " 'changed',\n", + " 'changes',\n", + " 'changing',\n", + " 'chansey',\n", + " 'chaos',\n", + " 'chaplin',\n", + " 'char',\n", + " 'character',\n", + " 'characterized',\n", + " 'characters',\n", + " 'charge',\n", + " 'charged',\n", + " 'charges',\n", + " 'charging',\n", + " 'charizard',\n", + " 'charjabug',\n", + " 'charlie',\n", + " 'charm',\n", + " 'charmander',\n", + " 'charmeleon',\n", + " 'chase',\n", + " 'chases',\n", + " 'chatter',\n", + " 'check',\n", + " 'cheek',\n", + " 'cheeks',\n", + " 'cheer',\n", + " 'cheerful',\n", + " 'chefs',\n", + " 'chemical',\n", + " 'chemically',\n", + " 'cherish',\n", + " 'chesnaught',\n", + " 'chespin',\n", + " 'chess',\n", + " 'chest',\n", + " 'chestnut',\n", + " 'chewable',\n", + " 'chewing',\n", + " 'chewtle',\n", + " 'chick',\n", + " 'chickadee',\n", + " 'chicken',\n", + " 'chicks',\n", + " 'chief',\n", + " 'chikorita',\n", + " 'child',\n", + " 'children',\n", + " 'chilling',\n", + " 'chills',\n", + " 'chimchar',\n", + " 'chime',\n", + " 'chimera',\n", + " 'chimes',\n", + " 'chinchou',\n", + " 'chinese',\n", + " 'chipped',\n", + " 'chipping',\n", + " 'choice',\n", + " 'chomps',\n", + " 'choose',\n", + " 'chooses',\n", + " 'chop',\n", + " 'chose',\n", + " 'chosen',\n", + " 'church',\n", + " 'chōchin',\n", + " 'cicada',\n", + " 'cinccino',\n", + " 'cinderace',\n", + " 'cinders',\n", + " 'circle',\n", + " 'circles',\n", + " 'circulates',\n", + " 'circulating',\n", + " 'circumstances',\n", + " 'citation',\n", + " 'cited',\n", + " 'cites',\n", + " 'cities',\n", + " 'city',\n", + " 'civilization',\n", + " 'civilizations',\n", + " 'claim',\n", + " 'claimed',\n", + " 'claiming',\n", + " 'claims',\n", + " ...]" + ] + }, + "execution_count": 80, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Create a TF-IDF vectorizer to transform the text into a matrix of features\n", + "vectorizer = TfidfVectorizer(stop_words=\"english\")\n", + "X = vectorizer.fit_transform(pkmn[\"Notes\"]).toarray()\n", + "y = pkmn[\"primary_type\"]\n", + "print(X.shape, y.shape)\n", + "list(vectorizer.get_feature_names_out())" + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "metadata": { + "id": "ICefCkn6LBBv" + }, + "outputs": [], + "source": [ + "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "GpuzXYA0LBBw", + "outputId": "2216c89a-d8f8-46e8-c1ee-442a6961aa29" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dummy : 0.12455516014234876\n", + "Nearest Neighbors : 0.3665480427046263\n", + "Linear SVM : 0.12455516014234876\n", + "RBF SVM : 0.17437722419928825\n", + "LinearSVC : 0.47330960854092524\n", + "Decision Tree : 0.16370106761565836\n", + "Random Forest : 0.13167259786476868\n", + "Neural Net : 0.298932384341637\n", + "Gradient Boosting : 0.33451957295373663\n", + "AdaBoost : 0.1494661921708185\n", + "Naive Bayes : 0.3701067615658363\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.7/dist-packages/sklearn/discriminant_analysis.py:878: UserWarning: Variables are collinear\n", + " warnings.warn(\"Variables are collinear\")\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "QDA : 0.05693950177935943\n", + "Best classifier: LinearSVC with score: 0.47330960854092524\n" + ] + } + ], + "source": [ + "names = [\n", + " \"Dummy\",\n", + " \"Nearest Neighbors\",\n", + " \"Linear SVM\",\n", + " \"RBF SVM\",\n", + " \"LinearSVC\",\n", + " # \"Gaussian Process\",\n", + " \"Decision Tree\",\n", + " \"Random Forest\",\n", + " \"Neural Net\",\n", + " \"Gradient Boosting\",\n", + " \"AdaBoost\",\n", + " \"Naive Bayes\",\n", + " \"QDA\",\n", + "]\n", + "\n", + "classifiers = [\n", + " DummyClassifier(),\n", + " KNeighborsClassifier(3),\n", + " SVC(kernel=\"linear\", C=0.025),\n", + " SVC(gamma=2, C=1),\n", + " LinearSVC(),\n", + " # GaussianProcessClassifier(1.0 * RBF(1.0)),\n", + " DecisionTreeClassifier(max_depth=5),\n", + " RandomForestClassifier(max_depth=5, n_estimators=10, max_features=1),\n", + " MLPClassifier(alpha=1, max_iter=1000),\n", + " GradientBoostingClassifier(),\n", + " AdaBoostClassifier(),\n", + " GaussianNB(),\n", + " QuadraticDiscriminantAnalysis(),\n", + "]\n", + "\n", + "\n", + "max_so_far = 0\n", + "for name, clf in zip(names, classifiers):\n", + " clf.fit(X_train, y_train)\n", + " print(name, \": \", clf.score(X_test, y_test))\n", + "\n", + " if clf.score(X_test, y_test) > max_so_far:\n", + " max_so_far = clf.score(X_test, y_test)\n", + " best_clf = clf\n", + " best_clf_name = name\n", + "\n", + "clf = best_clf\n", + "print(\"Best classifier: \", best_clf_name, \" with score: \", max_so_far)" + ] + }, + { + "cell_type": "code", + "execution_count": 89, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "w6v8QJ3xWLfw", + "outputId": "0aae6cf5-f26b-4e7c-b93e-cd5215225d93" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.7/dist-packages/sklearn/model_selection/_validation.py:372: FitFailedWarning: \n", + "20 fits failed out of a total of 80.\n", + "The score on these train-test partitions for these parameters will be set to nan.\n", + "If these failures are not expected, you can try to debug them by setting error_score='raise'.\n", + "\n", + "Below are more details about the failures:\n", + "--------------------------------------------------------------------------------\n", + "20 fits failed with the following error:\n", + "Traceback (most recent call last):\n", + " File \"/usr/local/lib/python3.7/dist-packages/sklearn/model_selection/_validation.py\", line 680, in _fit_and_score\n", + " estimator.fit(X_train, y_train, **fit_params)\n", + " File \"/usr/local/lib/python3.7/dist-packages/sklearn/svm/_classes.py\", line 272, in fit\n", + " sample_weight=sample_weight,\n", + " File \"/usr/local/lib/python3.7/dist-packages/sklearn/svm/_base.py\", line 1185, in _fit_liblinear\n", + " solver_type = _get_liblinear_solver_type(multi_class, penalty, loss, dual)\n", + " File \"/usr/local/lib/python3.7/dist-packages/sklearn/svm/_base.py\", line 1026, in _get_liblinear_solver_type\n", + " % (error_string, penalty, loss, dual)\n", + "ValueError: Unsupported set of arguments: The combination of penalty='l1' and loss='squared_hinge' are not supported when dual=True, Parameters: penalty='l1', loss='squared_hinge', dual=True\n", + "\n", + " warnings.warn(some_fits_failed_message, FitFailedWarning)\n", + "/usr/local/lib/python3.7/dist-packages/sklearn/model_selection/_search.py:972: UserWarning: One or more of the test scores are non-finite: [ nan 0.40510863 0.41276571 0.41276571 nan 0.40971227\n", + " 0.41276571 0.41276571 nan 0.40971227 0.41276571 0.41276571\n", + " nan 0.41122725 0.41276571 0.41276571]\n", + " category=UserWarning,\n", + "/usr/local/lib/python3.7/dist-packages/sklearn/svm/_base.py:1208: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n", + " ConvergenceWarning,\n" + ] + }, + { + "data": { + "text/plain": [ + "0.41276570757486786" + ] + }, + "execution_count": 89, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "parameters = {\n", + " \"penalty\": [\"l1\", \"l2\"],\n", + " \"C\": [1, 10, 100, 1000],\n", + " \"multi_class\": [\"ovr\", \"crammer_singer\"],\n", + "}\n", + "\n", + "gs_clf = GridSearchCV(clf, parameters, cv=5, n_jobs=-1)\n", + "gs_clf.fit(X_train, y_train)\n", + "gs_clf.best_score_" + ] + }, + { + "cell_type": "code", + "execution_count": 90, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ct1IDpeghQan", + "outputId": "1159ea06-a44d-4528-eb9e-4a158d71f05e" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'C': 1, 'multi_class': 'crammer_singer', 'penalty': 'l1'}" + ] + }, + "execution_count": 90, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gs_clf.best_params_" + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "AUubKCsvhuC9", + "outputId": "36862862-bf28-4d05-f8f3-a57b21684c89" + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
mean_fit_timestd_fit_timemean_score_timestd_score_timeparam_Cparam_multi_classparam_penaltyparamssplit0_test_scoresplit1_test_scoresplit2_test_scoresplit3_test_scoresplit4_test_scoremean_test_scorestd_test_scorerank_test_score
00.0225430.0045290.0000000.0000001ovrl1{'C': 1, 'multi_class': 'ovr', 'penalty': 'l1'}NaNNaNNaNNaNNaNNaNNaN13
10.0836520.0016970.0053570.0004841ovrl2{'C': 1, 'multi_class': 'ovr', 'penalty': 'l2'}0.4656490.4045800.4122140.3969470.3461540.4051090.03810012
20.1161120.0147200.0081720.0038831crammer_singerl1{'C': 1, 'multi_class': 'crammer_singer', 'pen...0.4732820.3969470.4198470.4122140.3615380.4127660.0362971
30.1202340.0121570.0072470.0031751crammer_singerl2{'C': 1, 'multi_class': 'crammer_singer', 'pen...0.4732820.3969470.4198470.4122140.3615380.4127660.0362971
40.0150910.0023330.0000000.00000010ovrl1{'C': 10, 'multi_class': 'ovr', 'penalty': 'l1'}NaNNaNNaNNaNNaNNaNNaN14
50.2464300.0180340.0052650.00030110ovrl2{'C': 10, 'multi_class': 'ovr', 'penalty': 'l2'}0.4809160.4122140.4045800.3893130.3615380.4097120.03959710
60.1326180.0054600.0050940.00008010crammer_singerl1{'C': 10, 'multi_class': 'crammer_singer', 'pe...0.4809160.4122140.4122140.3969470.3615380.4127660.0387801
70.1463670.0140930.0050400.00005310crammer_singerl2{'C': 10, 'multi_class': 'crammer_singer', 'pe...0.4809160.4122140.4122140.3969470.3615380.4127660.0387801
80.0141140.0010500.0000000.000000100ovrl1{'C': 100, 'multi_class': 'ovr', 'penalty': 'l1'}NaNNaNNaNNaNNaNNaNNaN15
90.8341560.0308550.0050860.000078100ovrl2{'C': 100, 'multi_class': 'ovr', 'penalty': 'l2'}0.4732820.4045800.4122140.3969470.3615380.4097120.03621410
100.5195920.0219080.0051760.000095100crammer_singerl1{'C': 100, 'multi_class': 'crammer_singer', 'p...0.4809160.4122140.4122140.3969470.3615380.4127660.0387801
110.5132600.0136590.0049890.000057100crammer_singerl2{'C': 100, 'multi_class': 'crammer_singer', 'p...0.4809160.4122140.4122140.3969470.3615380.4127660.0387801
120.0138910.0007940.0000000.0000001000ovrl1{'C': 1000, 'multi_class': 'ovr', 'penalty': '...NaNNaNNaNNaNNaNNaNNaN16
130.6260050.1065140.0050550.0000831000ovrl2{'C': 1000, 'multi_class': 'ovr', 'penalty': '...0.4732820.3969470.4274810.4045800.3538460.4112270.0391309
144.0005730.1463040.0050130.0000321000crammer_singerl1{'C': 1000, 'multi_class': 'crammer_singer', '...0.4809160.4122140.4122140.3969470.3615380.4127660.0387801
153.9543070.1848810.0047610.0006061000crammer_singerl2{'C': 1000, 'multi_class': 'crammer_singer', '...0.4809160.4122140.4122140.3969470.3615380.4127660.0387801
\n", + "
\n", + " \n", + " \n", + " \n", + "\n", + " \n", + "
\n", + "
\n", + " " + ], + "text/plain": [ + " mean_fit_time std_fit_time mean_score_time std_score_time param_C \\\n", + "0 0.022543 0.004529 0.000000 0.000000 1 \n", + "1 0.083652 0.001697 0.005357 0.000484 1 \n", + "2 0.116112 0.014720 0.008172 0.003883 1 \n", + "3 0.120234 0.012157 0.007247 0.003175 1 \n", + "4 0.015091 0.002333 0.000000 0.000000 10 \n", + "5 0.246430 0.018034 0.005265 0.000301 10 \n", + "6 0.132618 0.005460 0.005094 0.000080 10 \n", + "7 0.146367 0.014093 0.005040 0.000053 10 \n", + "8 0.014114 0.001050 0.000000 0.000000 100 \n", + "9 0.834156 0.030855 0.005086 0.000078 100 \n", + "10 0.519592 0.021908 0.005176 0.000095 100 \n", + "11 0.513260 0.013659 0.004989 0.000057 100 \n", + "12 0.013891 0.000794 0.000000 0.000000 1000 \n", + "13 0.626005 0.106514 0.005055 0.000083 1000 \n", + "14 4.000573 0.146304 0.005013 0.000032 1000 \n", + "15 3.954307 0.184881 0.004761 0.000606 1000 \n", + "\n", + " param_multi_class param_penalty \\\n", + "0 ovr l1 \n", + "1 ovr l2 \n", + "2 crammer_singer l1 \n", + "3 crammer_singer l2 \n", + "4 ovr l1 \n", + "5 ovr l2 \n", + "6 crammer_singer l1 \n", + "7 crammer_singer l2 \n", + "8 ovr l1 \n", + "9 ovr l2 \n", + "10 crammer_singer l1 \n", + "11 crammer_singer l2 \n", + "12 ovr l1 \n", + "13 ovr l2 \n", + "14 crammer_singer l1 \n", + "15 crammer_singer l2 \n", + "\n", + " params split0_test_score \\\n", + "0 {'C': 1, 'multi_class': 'ovr', 'penalty': 'l1'} NaN \n", + "1 {'C': 1, 'multi_class': 'ovr', 'penalty': 'l2'} 0.465649 \n", + "2 {'C': 1, 'multi_class': 'crammer_singer', 'pen... 0.473282 \n", + "3 {'C': 1, 'multi_class': 'crammer_singer', 'pen... 0.473282 \n", + "4 {'C': 10, 'multi_class': 'ovr', 'penalty': 'l1'} NaN \n", + "5 {'C': 10, 'multi_class': 'ovr', 'penalty': 'l2'} 0.480916 \n", + "6 {'C': 10, 'multi_class': 'crammer_singer', 'pe... 0.480916 \n", + "7 {'C': 10, 'multi_class': 'crammer_singer', 'pe... 0.480916 \n", + "8 {'C': 100, 'multi_class': 'ovr', 'penalty': 'l1'} NaN \n", + "9 {'C': 100, 'multi_class': 'ovr', 'penalty': 'l2'} 0.473282 \n", + "10 {'C': 100, 'multi_class': 'crammer_singer', 'p... 0.480916 \n", + "11 {'C': 100, 'multi_class': 'crammer_singer', 'p... 0.480916 \n", + "12 {'C': 1000, 'multi_class': 'ovr', 'penalty': '... NaN \n", + "13 {'C': 1000, 'multi_class': 'ovr', 'penalty': '... 0.473282 \n", + "14 {'C': 1000, 'multi_class': 'crammer_singer', '... 0.480916 \n", + "15 {'C': 1000, 'multi_class': 'crammer_singer', '... 0.480916 \n", + "\n", + " split1_test_score split2_test_score split3_test_score \\\n", + "0 NaN NaN NaN \n", + "1 0.404580 0.412214 0.396947 \n", + "2 0.396947 0.419847 0.412214 \n", + "3 0.396947 0.419847 0.412214 \n", + "4 NaN NaN NaN \n", + "5 0.412214 0.404580 0.389313 \n", + "6 0.412214 0.412214 0.396947 \n", + "7 0.412214 0.412214 0.396947 \n", + "8 NaN NaN NaN \n", + "9 0.404580 0.412214 0.396947 \n", + "10 0.412214 0.412214 0.396947 \n", + "11 0.412214 0.412214 0.396947 \n", + "12 NaN NaN NaN \n", + "13 0.396947 0.427481 0.404580 \n", + "14 0.412214 0.412214 0.396947 \n", + "15 0.412214 0.412214 0.396947 \n", + "\n", + " split4_test_score mean_test_score std_test_score rank_test_score \n", + "0 NaN NaN NaN 13 \n", + "1 0.346154 0.405109 0.038100 12 \n", + "2 0.361538 0.412766 0.036297 1 \n", + "3 0.361538 0.412766 0.036297 1 \n", + "4 NaN NaN NaN 14 \n", + "5 0.361538 0.409712 0.039597 10 \n", + "6 0.361538 0.412766 0.038780 1 \n", + "7 0.361538 0.412766 0.038780 1 \n", + "8 NaN NaN NaN 15 \n", + "9 0.361538 0.409712 0.036214 10 \n", + "10 0.361538 0.412766 0.038780 1 \n", + "11 0.361538 0.412766 0.038780 1 \n", + "12 NaN NaN NaN 16 \n", + "13 0.353846 0.411227 0.039130 9 \n", + "14 0.361538 0.412766 0.038780 1 \n", + "15 0.361538 0.412766 0.038780 1 " + ] + }, + "execution_count": 93, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.DataFrame(gs_clf.cv_results_)" + ] + }, + { + "cell_type": "code", + "execution_count": 94, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "KKBzq6-gLBBy", + "outputId": "45199ff5-8d74-4610-f782-ff61d30f05d3" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "'In the green springtime, this creature photosynthesizes under the sun.' => Grass\n", + "'This create is native to caves in the Arctic. They make friends with polar bears, and sometimes drink hot cocoa.' => Ice\n" + ] + } + ], + "source": [ + "docs_new = [\n", + " \"In the green springtime, this creature photosynthesizes under the sun.\",\n", + " \"This create is native to caves in the Arctic. They make friends with polar bears, and sometimes drink hot cocoa.\",\n", + "]\n", + "\n", + "X_new_counts = vectorizer.transform(docs_new)\n", + "predicted = clf.predict(X_new_counts.toarray())\n", + "\n", + "for doc, cat in zip(docs_new, predicted):\n", + " print(f\"{doc!r} => {cat}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 95, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 282 + }, + "id": "aF8K0Ku5LBB1", + "outputId": "b2492636-9610-4606-f976-aff09b460b6d" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 95, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAOYUlEQVR4nO3dfYwd113G8efZXb/UL6F23dolThpTJ1VCRZLiui9AaTCUtLS4SAglUpALFVsh0hZUqXJAEMQ/RFAoSFRFbusmUkOjqCQ0RaGJZSgIUhxvUruJ46RZ4jix45dEEa3jpF6v98cf9wZtnb3xnpkzc+/u+X6kaO+9czLzm737eO7MPXOOI0IA5r+hfhcAoB2EHSgEYQcKQdiBQhB2oBAjbW5sePnSGHndilm3X3TwxQaraYeHmv33NKamGl2/F6T9icTpyYYq6fCihUnt49REQ5W0J+U9eGnyhCamXvJMy1oN+8jrVmjNH39i1u0v+Z3dDVbTjqElSxtd/9TJk42uf2TV6qT2k0ePNVRJx/CF65Lanxk/0FAl7Ul5D+577vaey/gYDxSiVthtX237MdvjtrfmKgpAfpXDbntY0uckvV/SZZKutX1ZrsIA5FXnyL5R0nhEPBERE5Juk7Q5T1kAcqsT9vMlPT3t+aHuaz/C9qjtMdtjZ040ezEJQG+NX6CLiG0RsSEiNgwvb/bKNIDe6oT9sKQLpj1f230NwACqE/bdki62vc72QknXSLorT1kAcqvcqSYiJm1fL+keScOStkfEvmyVAciqVg+6iLhb0t2ZagHQoFa7yy46+GJSF9ihyy9NWv/U3v2pJTWu6e6spTl56euT2i+eB91lc6G7LFAIwg4UgrADhSDsQCEIO1AIwg4UgrADhSDsQCEIO1AIwg4UgrADhWi1b3yq1L7uQ0vTBseg3/q5pQ4N3fR7sGz3waT2zY5iP7dwZAcKQdiBQtQZSvoC2/9m+xHb+2x/MmdhAPKqc84+KelTEfGg7eWSHrC9IyIeyVQbgIwqH9kj4khEPNh9fELSfs0wlDSAwZDlnN32RZKulLQrx/oA5Ff7qzfbyyT9o6Tfj4gfzLB8VNKoJC3WkrqbA1BR3YkdF6gT9Fsj4o6Z2kyfJGKBFtXZHIAa6lyNt6QvSdofEX+dryQATahzZP8ZSb8p6Rds7+n+94FMdQHIrM4kEf8pyRlrAdCgge4bnyq1n/Xkpp9Oaj+y84Gk9vPBoN1vEMvKu8g7deKF2Tc+M9VzEd1lgUIQdqAQhB0oBGEHCkHYgUIQdqAQhB0oBGEHCkHYgUIQdqAQhB0oxLzqG58qta97aj9xafDGph9Zszqp/aCNGx9Hjie1r6Lp31GqlN9RBH3jgeIRdqAQhB0oRO2w2x62/R3b/5yjIADNyHFk/6Q6Y8YDGGB1R5ddK+lXJH0xTzkAmlL3yP43kj4tqef1ftujtsdsj53WqZqbA1BVnaGkPyjpeES86pfVjBsPDIa6Q0n/qu0nJd2mzpDSX8lSFYDs6kzseENErI2IiyRdI+lfI+K6bJUByIrv2YFCZOkbHxHfkvStHOsC0Iyib4RpYwKE4fXrktqfGT+QvI0Ug3TTRiXrL0xrvze9C0jSpAwV9GviDT7GA4Ug7EAhCDtQCMIOFIKwA4Ug7EAhCDtQCMIOFIKwA4Ug7EAhCDtQiFb7xntoSENLZt8veNAmWKgita/7Dz+0Man94m/cn9Q+1aD17ffJHza6/jak/l2nvAd+amHPZRzZgUIQdqAQdUeXfa3tr9l+1PZ+2+/KVRiAvOqes/+tpG9GxK/bXihpSYaaADSgctht/5ik90j6iCRFxISkiTxlAcitzsf4dZKelfTl7vRPX7T9ikvt08eNn4i5fyUVmKvqhH1E0tskfT4irpR0UtLWsxtNHzd+oRfX2ByAOuqE/ZCkQxGxq/v8a+qEH8AAqjNu/FFJT9t+S/elTZIeyVIVgOzqXo3/uKRbu1fin5D0W/VLAtCEWmGPiD2SNmSqBUCDHBGtbew8r4x3eFNr2zuXfo3fDTRlV+zUD+J5z7SM7rJAIQg7UAjCDhSCsAOFIOxAIQg7UAjCDhSCsAOFIOxAIQg7UAjCDhSi6HHj2+jrPrJmdVL7yaPHGqqkI3Uc+DhyPKl907/TNn6f8/WeCY7sQCEIO1CIuuPG/4HtfbYftv1Vm0HmgEFVOey2z5f0CUkbIuKtkoYlXZOrMAB51f0YPyLpNbZH1Jkg4pn6JQFoQp0BJw9L+oykpyQdkfT9iLj37HaMGw8Mhjof41dI2qzOZBE/Lmmp7evObse48cBgqPMx/hclHYiIZyPitKQ7JL07T1kAcqsT9qckvdP2EttWZ9z4/XnKApBbnXP2XerMAvOgpIe669qWqS4AmdUdN/5GSTdmqgVAg1rtG6/hIQ0tXzbr5k33OZ6vfaBfzZnxA0ntx79yZVL79dd9J6l96nuQqun1V9lG6t/R0OWXzrqtH/uv3utJ2iqAOYuwA4Ug7EAhCDtQCMIOFIKwA4Ug7EAhCDtQCMIOFIKwA4Ug7EAhWu0bH6cnGx8XPUUbfd0HaX+l9H7cl3zse2kbSOjHLUlTexPvil5/Ydr6W/j9N96//9jzs288Odl7PRlqATAHEHagEOcMu+3tto/bfnjaaytt77D9ePfnimbLBFDXbI7sN0u6+qzXtkraGREXS9rZfQ5ggJ0z7BHxH5LOvkKwWdIt3ce3SPpw5roAZFb1avzqiDjSfXxUUs+pNW2PShqVpMVaUnFzAOqqfYEuIkJSvMry/x83foEW1d0cgIqqhv2Y7TdKUvdn2iTeAFpXNex3SdrSfbxF0tfzlAOgKbP56u2rkr4t6S22D9n+qKSbJP2S7cfVmRnmpmbLBFDXOS/QRcS1PRZtylwLgAa1O248zmlkTc8vNmaU2ve+8fsBEvu6p4yJLlXoSz8PxLKEb7Ge7/1hne6yQCEIO1AIwg4UgrADhSDsQCEIO1AIwg4UgrADhSDsQCEIO1AIwg4UotW+8V4wopFVs+/73fSY66njfbcxzvzUiRcaXX9q3/up1SvT2if2XU9tf/DP3pXU/k1/8u2k9lU0/Xcx+YbzZt02nhnuuYwjO1AIwg4Uouq48X9p+1Hb37V9p+3XNlsmgLqqjhu/Q9JbI+KnJH1P0g2Z6wKQWaVx4yPi3oh4eQa5/5a0toHaAGSU45z9tyX9S6+Ftkdtj9kem5h6KcPmAFRRK+y2/0jSpKRbe7WZPm78wqHX1NkcgBoqf89u+yOSPihpU3eiCAADrFLYbV8t6dOSfj4iXsxbEoAmVB03/u8kLZe0w/Ye23/fcJ0Aaqo6bvyXGqgFQIPc5un2eV4Z7zBzS6C61PsZtP7C5G2k9tdv+h6LlPsZ7nvudn1/4rhnWkZ3WaAQhB0oBGEHCkHYgUIQdqAQhB0oBGEHCkHYgUIQdqAQhB0oBGEHCkHYgUIUPUnEIBrEiSsGSuKNLak3tUjSoRvendT+oi//T1L7fr1nHNmBQlQaN37ask/ZDturmikPQC5Vx42X7QskvU/SU5lrAtCASuPGd31WnXHoGGwSmAMqnbPb3izpcETszVwPgIYkX423vUTSH6rzEX427UcljUrS4uFlqZsDkEmVI/ubJa2TtNf2k+pM/fSg7TUzNWaSCGAwJB/ZI+IhSW94+Xk38Bsi4rmMdQHIrOq48QDmmKrjxk9fflG2agA0hh50QCFa7RuvoSHFsiWtbrLf5npf95QJCqQW7mcYb74P19o/vy+p/Qsf2pjUftnupOaaWr1y9o3/t3ekObIDhSDsQCEIO1AIwg4UgrADhSDsQCEIO1AIwg4UgrADhSDsQCEIO1CIVvvGx6kJnRk/0OYm+27Q+rqnmjrxQr9L+FGJ48arwrjxqRZ/4/60/yHxfoOk+wFOTfRcxJEdKARhBwpReZII2x+3/ajtfbb/orkSAeRQaZII21dJ2izp8oj4SUmfyV8agJyqThLxu5JuiohT3TbHG6gNQEZVz9kvkfRztnfZ/nfbb+/V0Pao7THbY6d1quLmANRV9au3EUkrJb1T0tsl3W77JyLiFVNBRcQ2Sdsk6TyvZKoooE+qHtkPSbojOu6XNCWJmVyBAVY17P8k6SpJsn2JpIWSmCQCGGDn/BjfnSTivZJW2T4k6UZJ2yVt734dNyFpy0wf4QEMjjqTRFyXuRYADWq1b7yHhjS0ZPbjqM/1fuUlanqc/KFjZ38LfI71J7XubqPhfUgdW/+eZ/bMuu3GX+59LwPdZYFCEHagEIQdKARhBwpB2IFCEHagEIQdKARhBwpB2IFCEHagEIQdKITbvFnN9rOSDs6waJXKukW2tP2Vytvnfu3vmyLi9TMtaDXsvdgei4gN/a6jLaXtr1TePg/i/vIxHigEYQcKMShh39bvAlpW2v5K5e3zwO3vQJyzA2jeoBzZATSMsAOF6GvYbV9t+zHb47a39rOWtth+0vZDtvfYHut3PU2YaTJQ2ytt77D9ePfnin7WmFOP/f1T24e77/Me2x/oZ41SH8Nue1jS5yS9X9Jlkq61fVm/6mnZVRFxxaB9D5vRzTprMlBJWyXtjIiLJe3sPp8vbtYr91eSPtt9n6+IiLtbrukV+nlk3yhpPCKeiIgJSbepMzMs5rgek4FulnRL9/Etkj7calEN6rG/A6efYT9f0tPTnh/qvjbfhaR7bT9ge7TfxbRodUQc6T4+Kml1P4tpyfW2v9v9mN/30xYu0LXvZyPibeqcvvye7ff0u6C2dWcPmu/f+X5e0pslXSHpiKS/6m85/Q37YUkXTHu+tvvavBYRh7s/j0u6U53TmRIcs/1GSer+PN7nehoVEcci4kxETEn6ggbgfe5n2HdLutj2OtsLJV0j6a4+1tM420ttL3/5saT3SXr41f+veeMuSVu6j7dI+nofa2ncy/+wdf2aBuB9bnX6p+kiYtL29ZLukTQsaXtE7OtXPS1ZLelO21Lnd/8PEfHN/paUX4/JQG+SdLvtj6pzm/Nv9K/CvHrs73ttX6HO6cqTkj7WtwK76C4LFIILdEAhCDtQCMIOFIKwA4Ug7EAhCDtQCMIOFOL/ADLR+Kg8vVraAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.imshow(confusion_matrix(y_test, clf.predict(X_test)))" + ] + }, + { + "cell_type": "code", + "execution_count": 96, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "lUVDLT01LBB2", + "outputId": "d4fd1914-ceb1-4c39-d6c6-a8acc794fddf" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " Accuracy: 0.47330960854092524\n", + "\n", + " Classification Report\n", + "==================================\n", + "\n", + " precision recall f1-score support\n", + "\n", + " Bug 0.57 0.64 0.60 25\n", + " Dark 1.00 0.27 0.42 15\n", + " Dragon 1.00 0.17 0.29 6\n", + " Electric 0.55 0.55 0.55 11\n", + " Fairy 1.00 0.14 0.25 7\n", + " Fighting 0.33 0.30 0.32 10\n", + " Fire 0.58 0.37 0.45 19\n", + " Flying 0.00 0.00 0.00 1\n", + " Ghost 0.38 0.30 0.33 10\n", + " Grass 0.49 0.53 0.51 32\n", + " Ground 0.57 0.31 0.40 13\n", + " Ice 0.80 0.40 0.53 10\n", + " Normal 0.23 0.46 0.31 28\n", + " Poison 0.80 0.50 0.62 8\n", + " Psychic 0.63 0.52 0.57 23\n", + " Rock 0.54 0.35 0.42 20\n", + " Steel 0.40 0.25 0.31 8\n", + " Water 0.47 0.83 0.60 35\n", + "\n", + " accuracy 0.47 281\n", + " macro avg 0.57 0.38 0.42 281\n", + "weighted avg 0.55 0.47 0.47 281\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, msg_start, len(result))\n", + "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, msg_start, len(result))\n", + "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, msg_start, len(result))\n" + ] + } + ], + "source": [ + "print(\"\\n Accuracy: \", accuracy_score(y_test, clf.predict(X_test)))\n", + "print(\"\\n Classification Report\")\n", + "print(\"==================================\")\n", + "print(\"\\n\", classification_report(y_test, clf.predict(X_test)))" + ] + }, + { + "cell_type": "code", + "execution_count": 97, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "XamB14MKPJs-", + "outputId": "f758f37c-7945-4139-e067-b19ef8656510" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " Accuracy: 0.49466192170818507\n", + "\n", + " Classification Report\n", + "==================================\n", + "\n", + " precision recall f1-score support\n", + "\n", + " Bug 0.57 0.64 0.60 25\n", + " Dark 1.00 0.40 0.57 15\n", + " Dragon 0.50 0.17 0.25 6\n", + " Electric 0.50 0.55 0.52 11\n", + " Fairy 0.50 0.14 0.22 7\n", + " Fighting 0.50 0.30 0.37 10\n", + " Fire 0.41 0.37 0.39 19\n", + " Flying 0.00 0.00 0.00 1\n", + " Ghost 0.42 0.50 0.45 10\n", + " Grass 0.50 0.50 0.50 32\n", + " Ground 0.67 0.31 0.42 13\n", + " Ice 0.67 0.40 0.50 10\n", + " Normal 0.30 0.50 0.38 28\n", + " Poison 0.80 0.50 0.62 8\n", + " Psychic 0.67 0.61 0.64 23\n", + " Rock 0.47 0.35 0.40 20\n", + " Steel 0.33 0.25 0.29 8\n", + " Water 0.49 0.83 0.62 35\n", + "\n", + " accuracy 0.49 281\n", + " macro avg 0.52 0.41 0.43 281\n", + "weighted avg 0.53 0.49 0.49 281\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, msg_start, len(result))\n", + "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, msg_start, len(result))\n", + "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, msg_start, len(result))\n" + ] + } + ], + "source": [ + "print(\"\\n Accuracy: \", accuracy_score(y_test, gs_clf.predict(X_test)))\n", + "print(\"\\n Classification Report\")\n", + "print(\"==================================\")\n", + "print(\"\\n\", classification_report(y_test, gs_clf.predict(X_test)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In summary, we can see the grid search classifier is better than the best classifier, however, an accuracy around 0.49 still can be improved.\n", + "We're going to pick this analysis up in the next notebook where we'll look at using HuggingFace Transformers to improve the result." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "colab": { + "name": "nlp.ipynb", + "provenance": [] + }, + "interpreter": { + "hash": "45e1260056979d5382785f386f12ee00f44622d9a136ee7663e9a61a67ca2a68" + }, + "kernelspec": { + "display_name": "Python 3.10.0 ('projects-vBrzsZbN-py3.10')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 0 +}