{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Notes to discuss with prof\n", "\n", "1. since the data is very skewed it won't make sense for us to do cross validation where we randomly choose which data to use for training and which data to use for testing, because if do that after duplicating the low count data then most of the data in testing set will be in training set\n", "2. what is the difference between training the network after performing oversampling and just training the network on the same data but more epochs\n", "3. can we train the model with just 64 examples but keep the positive class data same and randomly choose rest of 32 negative class data?" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from mygrad import Layer\n", "from mygrad import Value" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import pickle\n", "with open('data.pckl', 'rb') as file:\n", " data = pickle.load(file)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from sklearn.utils import shuffle\n", "data = shuffle(data)\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1024" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X = [list(number) for number in data['number']]\n", "Y = [label for label in data['label']]\n", "\n", "len(X)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "for ix, row in enumerate(X):\n", " X[ix] = [Value(float(item)) for item in row]\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "Xtrain, Xtest, Ytrain, Ytest = X[:int(len(X)*0.8)], X[int(len(X)*0.8):], Y[:int(len(X)*0.8)], Y[int(len(X)*0.8):]" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[Value(data=-0.4269173833626292), Value(data=1.5598219958582367), Value(data=-0.9060772972552846), Value(data=-1.7986536298166897), Value(data=-1.6629470204105083), Value(data=1.033417037918746), Value(data=-0.6271032437628579), Value(data=-0.08921159991615646), Value(data=-0.9441026687737217), Value(data=-1.1184262058721397), Value(data=0.3875233603513344), Value(data=-1.1191317646049086), Value(data=0.10773546057187167), Value(data=-0.13759482796954892), Value(data=-0.35568237147978543), Value(data=-0.648468795325317), Value(data=-1.5011307241514515), Value(data=-1.690738455164726), Value(data=1.086651726656151), Value(data=-1.2386293546252176), Value(data=1.3113078129753322), Value(data=-1.0788557253441757), Value(data=-1.8867888938773758)]\n" ] } ], "source": [ "hiddenLayer1 = Layer(10, 1, activation='reLu')\n", "outputLayer = Layer(11, 1, activation='sigmoid')\n", "parameters = outputLayer.parameters() + hiddenLayer1.parameters()\n", "print(parameters)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def predict(x):\n", " x1 = hiddenLayer1(x) \n", " final = outputLayer([x1] + x)\n", " return final" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import accuracy_score, precision_score, f1_score, recall_score\n", "def getAccuracy(X, Y):\n", " predicted = [1 if predict(x).data > 0.5 else 0 for x in X ]\n", " return accuracy_score(Y, predicted)\n", "def getPrecision(X, Y):\n", " predicted = [1 if predict(x).data > 0.5 else 0 for x in X ]\n", " return precision_score(Y, predicted)\n", "def getf1(X, Y):\n", " predicted = [1 if predict(x).data > 0.5 else 0 for x in X ]\n", " return f1_score(Y, predicted)\n", "def getRecall(X, Y):\n", " predicted = [1 if predict(x).data > 0.5 else 0 for x in X ]\n", " return recall_score(Y, predicted)\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.5219512195121951" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "getAccuracy(Xtest, Ytest)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 4 fold cross validation without momentum" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "fold no. 1\n", "loss: 0.37573595553895384 epoch: 0\n", "loss: 0.33397055440340734 epoch: 1\n", "loss: 0.3018161127415282 epoch: 2\n", "loss: 0.27680608399065376 epoch: 3\n", "loss: 0.25879246704249065 epoch: 4\n", "loss: 0.2468074865295087 epoch: 5\n", "loss: 0.23748972817642466 epoch: 6\n", "loss: 0.22921592248421788 epoch: 7\n", "loss: 0.22258545130356772 epoch: 8\n", "loss: 0.21719074030855862 epoch: 9\n", "loss: 0.2124606473917233 epoch: 10\n", "loss: 0.20828183852046642 epoch: 11\n", "loss: 0.20459067797450095 epoch: 12\n", "loss: 0.20129295819299053 epoch: 13\n", "loss: 0.198302944113399 epoch: 14\n", "loss: 0.19554594755074953 epoch: 15\n", "loss: 0.19302387597139112 epoch: 16\n", "loss: 0.19068211124109144 epoch: 17\n", "loss: 0.18846860228271742 epoch: 18\n", "loss: 0.18646297932958714 epoch: 19\n", "loss: 0.1845893025363569 epoch: 20\n", "loss: 0.18288100150665237 epoch: 21\n", "loss: 0.1811183858945435 epoch: 22\n", "loss: 0.17952228728423725 epoch: 23\n", "loss: 0.17798156816842317 epoch: 24\n", "loss: 0.17652089873416205 epoch: 25\n", "loss: 0.17537759548297105 epoch: 26\n", "loss: 0.17435246000303037 epoch: 27\n", "loss: 0.17263594849028185 epoch: 28\n", "loss: 0.17149597721569307 epoch: 29\n", "loss: 0.1706257533615121 epoch: 30\n", "loss: 0.16914430082467521 epoch: 31\n", "loss: 0.16810937216249427 epoch: 32\n", "loss: 0.16716654938204006 epoch: 33\n", "loss: 0.16599156503054782 epoch: 34\n", "loss: 0.16510101354642268 epoch: 35\n", "loss: 0.16419733049043764 epoch: 36\n", "loss: 0.16366629273719377 epoch: 37\n", "loss: 0.16311990931859954 epoch: 38\n", "loss: 0.16207796105830802 epoch: 39\n", "loss: 0.1614108249989092 epoch: 40\n", "loss: 0.1606460448700579 epoch: 41\n", "loss: 0.15974552139859902 epoch: 42\n", "loss: 0.1596599143952281 epoch: 43\n", "loss: 0.15826975912343522 epoch: 44\n", "loss: 0.1584276565803635 epoch: 45\n", "loss: 0.15678038192837696 epoch: 46\n", "loss: 0.15703013795427326 epoch: 47\n", "loss: 0.15566588412309512 epoch: 48\n", "loss: 0.15630263935162528 epoch: 49\n", "loss: 0.15403329987960063 epoch: 50\n", "loss: 0.15405031141825168 epoch: 51\n", "loss: 0.15407512023298076 epoch: 52\n", "loss: 0.15512835889188498 epoch: 53\n", "loss: 0.15200033685381262 epoch: 54\n", "loss: 0.15339022139270628 epoch: 55\n", "loss: 0.15425244464253499 epoch: 56\n", "loss: 0.15118459332737497 epoch: 57\n", "loss: 0.1507747184705248 epoch: 58\n", "loss: 0.15170648332575293 epoch: 59\n", "loss: 0.14880846470120732 epoch: 60\n", "loss: 0.14875118695696332 epoch: 61\n", "loss: 0.15005737402197827 epoch: 62\n", "loss: 0.14734267076488938 epoch: 63\n", "loss: 0.1468527978310063 epoch: 64\n", "loss: 0.14703703896756531 epoch: 65\n", "loss: 0.1505778958303549 epoch: 66\n", "loss: 0.14782963779548441 epoch: 67\n", "loss: 0.14532354052029284 epoch: 68\n", "loss: 0.14572240209885337 epoch: 69\n", "loss: 0.14766963045587533 epoch: 70\n", "loss: 0.149479861279388 epoch: 71\n", "loss: 0.14679302832754168 epoch: 72\n", "loss: 0.14425733859970524 epoch: 73\n", "loss: 0.14597448119384554 epoch: 74\n", "loss: 0.14736700397107766 epoch: 75\n", "loss: 0.14480804167497072 epoch: 76\n", "loss: 0.1424616018303761 epoch: 77\n", "loss: 0.14443726399504458 epoch: 78\n", "loss: 0.1461971091659815 epoch: 79\n", "loss: 0.14373642765048164 epoch: 80\n", "loss: 0.14147056954029877 epoch: 81\n", "loss: 0.1448244152156758 epoch: 82\n", "loss: 0.1448567660467712 epoch: 83\n", "loss: 0.14250982540223295 epoch: 84\n", "loss: 0.14033485944796284 epoch: 85\n", "loss: 0.14191925084633508 epoch: 86\n", "loss: 0.1431900128117084 epoch: 87\n", "loss: 0.14091185774918394 epoch: 88\n", "loss: 0.1394048482512208 epoch: 89\n", "loss: 0.14139970030956792 epoch: 90\n", "loss: 0.13933014344215056 epoch: 91\n", "loss: 0.1410064939645472 epoch: 92\n", "loss: 0.14291859256194253 epoch: 93\n", "loss: 0.1406950597509687 epoch: 94\n", "loss: 0.13868251712845947 epoch: 95\n", "loss: 0.13892985603088473 epoch: 96\n", "loss: 0.14070626917126233 epoch: 97\n", "loss: 0.13865749240896025 epoch: 98\n", "loss: 0.1369015189042515 epoch: 99\n", "loss: 0.13783138418762694 epoch: 100\n", "loss: 0.13720850013853184 epoch: 101\n", "loss: 0.1398483031082388 epoch: 102\n", "loss: 0.1378487782097135 epoch: 103\n", "loss: 0.136012290286936 epoch: 104\n", "loss: 0.13952711293902167 epoch: 105\n", "loss: 0.14034887865495008 epoch: 106\n", "loss: 0.13829917663543997 epoch: 107\n", "loss: 0.13646420483617036 epoch: 108\n", "loss: 0.1351315026838636 epoch: 109\n", "loss: 0.13586202481625942 epoch: 110\n", "loss: 0.1351469802839512 epoch: 111\n", "loss: 0.13740262587432187 epoch: 112\n", "loss: 0.13562316286528917 epoch: 113\n", "loss: 0.13388518163789864 epoch: 114\n", "loss: 0.1346693384612657 epoch: 115\n", "loss: 0.13525973947235753 epoch: 116\n", "loss: 0.13821507146360093 epoch: 117\n", "loss: 0.13621396710238776 epoch: 118\n", "loss: 0.13451159554352363 epoch: 119\n", "loss: 0.13315645758562347 epoch: 120\n", "loss: 0.13563311462583874 epoch: 121\n", "loss: 0.133962562208412 epoch: 122\n", "loss: 0.13265582880747312 epoch: 123\n", "loss: 0.13493762801086215 epoch: 124\n", "loss: 0.13328611763241505 epoch: 125\n", "loss: 0.13260728378220396 epoch: 126\n", "loss: 0.13644543059782474 epoch: 127\n", "loss: 0.13460765233803126 epoch: 128\n", "loss: 0.13299829379307632 epoch: 129\n", "accuracy test 0.94921875 train 0.9557291666666666\n", "f1score test 0.953405017921147 train 0.9570707070707071\n", "precision test 0.910958904109589 train 0.9176755447941889\n", "recall test 1.0 train 1.0\n", "\n", "fold no. 2\n", "loss: 0.4866973252558418 epoch: 0\n", "loss: 0.41644531508789806 epoch: 1\n", "loss: 0.3689243389632629 epoch: 2\n", "loss: 0.3337170714299505 epoch: 3\n", "loss: 0.30742804758322 epoch: 4\n", "loss: 0.2869560810726914 epoch: 5\n", "loss: 0.2704587074814627 epoch: 6\n", "loss: 0.25681894780544334 epoch: 7\n", "loss: 0.245433564499355 epoch: 8\n", "loss: 0.2368568720930045 epoch: 9\n", "loss: 0.2299075838558373 epoch: 10\n", "loss: 0.22469050013447614 epoch: 11\n", "loss: 0.2208513559257093 epoch: 12\n", "loss: 0.21755402456146186 epoch: 13\n", "loss: 0.21456878903262075 epoch: 14\n", "loss: 0.21193787539089207 epoch: 15\n", "loss: 0.20955441719053092 epoch: 16\n", "loss: 0.207363022415751 epoch: 17\n", "loss: 0.20533016150921413 epoch: 18\n", "loss: 0.2034326388404791 epoch: 19\n", "loss: 0.20165280899132856 epoch: 20\n", "loss: 0.19997237003820692 epoch: 21\n", "loss: 0.19834516618470963 epoch: 22\n", "loss: 0.19680338976473316 epoch: 23\n", "loss: 0.19531986761328532 epoch: 24\n", "loss: 0.19384167305750583 epoch: 25\n", "loss: 0.1924368315599832 epoch: 26\n", "loss: 0.19109527632696532 epoch: 27\n", "loss: 0.18980929469795216 epoch: 28\n", "loss: 0.18853189787291996 epoch: 29\n", "loss: 0.18730180582761885 epoch: 30\n", "loss: 0.18607610811117803 epoch: 31\n", "loss: 0.1848951901961053 epoch: 32\n", "loss: 0.18375432785111362 epoch: 33\n", "loss: 0.182649994989443 epoch: 34\n", "loss: 0.1815702415591492 epoch: 35\n", "loss: 0.18048558540364895 epoch: 36\n", "loss: 0.1794326711856165 epoch: 37\n", "loss: 0.178409058406972 epoch: 38\n", "loss: 0.17741280668861395 epoch: 39\n", "loss: 0.17644104308585307 epoch: 40\n", "loss: 0.17545533508163685 epoch: 41\n", "loss: 0.17449485893888061 epoch: 42\n", "loss: 0.17355792450682705 epoch: 43\n", "loss: 0.1726432342110939 epoch: 44\n", "loss: 0.17174970008989127 epoch: 45\n", "loss: 0.17087636595695155 epoch: 46\n", "loss: 0.17002237122859362 epoch: 47\n", "loss: 0.16918693162615422 epoch: 48\n", "loss: 0.1683693270767674 epoch: 49\n", "loss: 0.1675809407759738 epoch: 50\n", "loss: 0.1668237042233324 epoch: 51\n", "loss: 0.16605008732365217 epoch: 52\n", "loss: 0.1653355372867922 epoch: 53\n", "loss: 0.16460301452410614 epoch: 54\n", "loss: 0.16391256773279347 epoch: 55\n", "loss: 0.16321033779463812 epoch: 56\n", "loss: 0.16259276596652725 epoch: 57\n", "loss: 0.1620396720088726 epoch: 58\n", "loss: 0.16128485215605076 epoch: 59\n", "loss: 0.16071922188828108 epoch: 60\n", "loss: 0.16020074341657894 epoch: 61\n", "loss: 0.15954134973882783 epoch: 62\n", "loss: 0.15904574563142063 epoch: 63\n", "loss: 0.15835414773921713 epoch: 64\n", "loss: 0.1577935039307171 epoch: 65\n", "loss: 0.15730158859826893 epoch: 66\n", "loss: 0.1568283297622172 epoch: 67\n", "loss: 0.1563308504311571 epoch: 68\n", "loss: 0.15588183532443461 epoch: 69\n", "loss: 0.15528213311177202 epoch: 70\n", "loss: 0.15485290040377145 epoch: 71\n", "loss: 0.1543431704779687 epoch: 72\n", "loss: 0.15391331959141738 epoch: 73\n", "loss: 0.15341499233645273 epoch: 74\n", "loss: 0.15312342926311223 epoch: 75\n", "loss: 0.1524494194009261 epoch: 76\n", "loss: 0.15209498043537348 epoch: 77\n", "loss: 0.15171328893722327 epoch: 78\n", "loss: 0.15133260958397693 epoch: 79\n", "loss: 0.1508079760237914 epoch: 80\n", "loss: 0.15058659904813881 epoch: 81\n", "loss: 0.15009703157042947 epoch: 82\n", "loss: 0.14994649428833284 epoch: 83\n", "loss: 0.1492487966313337 epoch: 84\n", "loss: 0.14888408836061698 epoch: 85\n", "loss: 0.14873044417577366 epoch: 86\n", "loss: 0.1484341326885224 epoch: 87\n", "loss: 0.1478014692926164 epoch: 88\n", "loss: 0.14768167853283617 epoch: 89\n", "loss: 0.14711927530212318 epoch: 90\n", "loss: 0.1470017556548842 epoch: 91\n", "loss: 0.1464696741361195 epoch: 92\n", "loss: 0.1464472986324771 epoch: 93\n", "loss: 0.1457880117450023 epoch: 94\n", "loss: 0.14562055499047288 epoch: 95\n", "loss: 0.1455201230648576 epoch: 96\n", "loss: 0.14485199019046094 epoch: 97\n", "loss: 0.14466170255789268 epoch: 98\n", "loss: 0.1444513939128172 epoch: 99\n", "loss: 0.1439710386743098 epoch: 100\n", "loss: 0.14374148068729606 epoch: 101\n", "loss: 0.1435144374980435 epoch: 102\n", "loss: 0.1433731830862707 epoch: 103\n", "loss: 0.14284937207690693 epoch: 104\n", "loss: 0.14275798612421567 epoch: 105\n", "loss: 0.14226737996051972 epoch: 106\n", "loss: 0.14220694961005217 epoch: 107\n", "loss: 0.14210183598540216 epoch: 108\n", "loss: 0.14151698430471588 epoch: 109\n", "loss: 0.14146732248006136 epoch: 110\n", "loss: 0.1409934618188285 epoch: 111\n", "loss: 0.14107636218846376 epoch: 112\n", "loss: 0.14074563085701025 epoch: 113\n", "loss: 0.14063090081942706 epoch: 114\n", "loss: 0.14030350105714756 epoch: 115\n", "loss: 0.14017190273107324 epoch: 116\n", "loss: 0.13982933342833243 epoch: 117\n", "loss: 0.13975283340327838 epoch: 118\n", "loss: 0.13933495611994912 epoch: 119\n", "loss: 0.139275796443558 epoch: 120\n", "loss: 0.1389856564091035 epoch: 121\n", "loss: 0.13938876710905834 epoch: 122\n", "loss: 0.1381564195769227 epoch: 123\n", "loss: 0.1383388044703441 epoch: 124\n", "loss: 0.13862408832590462 epoch: 125\n", "loss: 0.13919155736930047 epoch: 126\n", "loss: 0.13753120149186104 epoch: 127\n", "loss: 0.1384924441418665 epoch: 128\n", "loss: 0.13952732265554277 epoch: 129\n", "accuracy test 0.9609375 train 0.9544270833333334\n", "f1score test 0.963235294117647 train 0.9560853199498118\n", "precision test 0.9290780141843972 train 0.9158653846153846\n", "recall test 1.0 train 1.0\n", "\n", "fold no. 3\n", "loss: 1.0275432942083498 epoch: 0\n", "loss: 0.9578575931825666 epoch: 1\n", "loss: 0.9075661213592326 epoch: 2\n", "loss: 0.869464937494952 epoch: 3\n", "loss: 0.8364595174395587 epoch: 4\n", "loss: 0.8076249485763958 epoch: 5\n", "loss: 0.7828773863859919 epoch: 6\n", "loss: 0.7627155341727909 epoch: 7\n", "loss: 0.745109591986264 epoch: 8\n", "loss: 0.7297087779531674 epoch: 9\n", "loss: 0.7157146779566056 epoch: 10\n", "loss: 0.7034230301469935 epoch: 11\n", "loss: 0.6920224860492449 epoch: 12\n", "loss: 0.6817921764838848 epoch: 13\n", "loss: 0.6722606098064483 epoch: 14\n", "loss: 0.6634842272212922 epoch: 15\n", "loss: 0.6555368369402651 epoch: 16\n", "loss: 0.6480733238365858 epoch: 17\n", "loss: 0.6409903339739373 epoch: 18\n", "loss: 0.6342627572035959 epoch: 19\n", "loss: 0.6279952160507387 epoch: 20\n", "loss: 0.6223019337210842 epoch: 21\n", "loss: 0.6168465990618027 epoch: 22\n", "loss: 0.6116367354373334 epoch: 23\n", "loss: 0.606628050833703 epoch: 24\n", "loss: 0.6018563645022069 epoch: 25\n", "loss: 0.59733871975793 epoch: 26\n", "loss: 0.5929704010928328 epoch: 27\n", "loss: 0.5887435343817764 epoch: 28\n", "loss: 0.5846506382151457 epoch: 29\n", "loss: 0.5807060667034362 epoch: 30\n", "loss: 0.5769191114350246 epoch: 31\n", "loss: 0.5732545282920427 epoch: 32\n", "loss: 0.5696936943023108 epoch: 33\n", "loss: 0.5662194440747806 epoch: 34\n", "loss: 0.5628386681226026 epoch: 35\n", "loss: 0.5595590749255314 epoch: 36\n", "loss: 0.5563452400516654 epoch: 37\n", "loss: 0.5531938093094334 epoch: 38\n", "loss: 0.5501015394888932 epoch: 39\n", "loss: 0.5470652917454935 epoch: 40\n", "loss: 0.5440820252286511 epoch: 41\n", "loss: 0.5411487909400784 epoch: 42\n", "loss: 0.5382627258091738 epoch: 43\n", "loss: 0.5354210469740196 epoch: 44\n", "loss: 0.5326210462579407 epoch: 45\n", "loss: 0.529860084832287 epoch: 46\n", "loss: 0.527135588057184 epoch: 47\n", "loss: 0.5244450404928289 epoch: 48\n", "loss: 0.5217859810747116 epoch: 49\n", "loss: 0.5191559984471581 epoch: 50\n", "loss: 0.5165527264505789 epoch: 51\n", "loss: 0.5139738397589606 epoch: 52\n", "loss: 0.5114170496653553 epoch: 53\n", "loss: 0.5088801000146489 epoch: 54\n", "loss: 0.5063607632845634 epoch: 55\n", "loss: 0.503856836817702 epoch: 56\n", "loss: 0.5013661392098921 epoch: 57\n", "loss: 0.49888650686255714 epoch: 58\n", "loss: 0.4964157907101804 epoch: 59\n", "loss: 0.49395185313755563 epoch: 60\n", "loss: 0.4914925651060497 epoch: 61\n", "loss: 0.48903580351341264 epoch: 62\n", "loss: 0.4865794488180946 epoch: 63\n", "loss: 0.4841213829667756 epoch: 64\n", "loss: 0.48165948767309763 epoch: 65\n", "loss: 0.4791916431069109 epoch: 66\n", "loss: 0.4767157270670279 epoch: 67\n", "loss: 0.4742296147271185 epoch: 68\n", "loss: 0.4717311790646303 epoch: 69\n", "loss: 0.4692182921072701 epoch: 70\n", "loss: 0.4666888271616736 epoch: 71\n", "loss: 0.46414066222551276 epoch: 72\n", "loss: 0.46157168482885336 epoch: 73\n", "loss: 0.4589797986049265 epoch: 74\n", "loss: 0.4563629319560573 epoch: 75\n", "loss: 0.45371020021987635 epoch: 76\n", "loss: 0.4510022945294712 epoch: 77\n", "loss: 0.4482432527330932 epoch: 78\n", "loss: 0.44544214469120164 epoch: 79\n", "loss: 0.4425862974426577 epoch: 80\n", "loss: 0.43965910349153453 epoch: 81\n", "loss: 0.43663539248506716 epoch: 82\n", "loss: 0.4335121577125862 epoch: 83\n", "loss: 0.4303190452386826 epoch: 84\n", "loss: 0.4270544304183648 epoch: 85\n", "loss: 0.4237173380613914 epoch: 86\n", "loss: 0.4202598583767015 epoch: 87\n", "loss: 0.4166235441270236 epoch: 88\n", "loss: 0.4128982802818701 epoch: 89\n", "loss: 0.4090731860960545 epoch: 90\n", "loss: 0.4051452915010556 epoch: 91\n", "loss: 0.40107437682276964 epoch: 92\n", "loss: 0.3965781592629817 epoch: 93\n", "loss: 0.39199204237093954 epoch: 94\n", "loss: 0.3873371304277496 epoch: 95\n", "loss: 0.3826394143484092 epoch: 96\n", "loss: 0.377904210129592 epoch: 97\n", "loss: 0.3728845538230497 epoch: 98\n", "loss: 0.36778845778546976 epoch: 99\n", "loss: 0.36261656856940094 epoch: 100\n", "loss: 0.35748903723359776 epoch: 101\n", "loss: 0.35219209544312613 epoch: 102\n", "loss: 0.34705655029551236 epoch: 103\n", "loss: 0.34201302027358943 epoch: 104\n", "loss: 0.33672520439186315 epoch: 105\n", "loss: 0.33158637118168127 epoch: 106\n", "loss: 0.32668808044990133 epoch: 107\n", "loss: 0.3216823375714773 epoch: 108\n", "loss: 0.31666109748925264 epoch: 109\n", "loss: 0.3116808612615113 epoch: 110\n", "loss: 0.30615912826937103 epoch: 111\n", "loss: 0.30078871860094164 epoch: 112\n", "loss: 0.29538778084529943 epoch: 113\n", "loss: 0.2901180113634108 epoch: 114\n", "loss: 0.28518171706049406 epoch: 115\n", "loss: 0.2805460932939522 epoch: 116\n", "loss: 0.2762602928411815 epoch: 117\n", "loss: 0.2722645887042325 epoch: 118\n", "loss: 0.26852635578584627 epoch: 119\n", "loss: 0.2649899296163774 epoch: 120\n", "loss: 0.2613530818985914 epoch: 121\n", "loss: 0.2567660059300666 epoch: 122\n", "loss: 0.25084530739849675 epoch: 123\n", "loss: 0.24496334351865784 epoch: 124\n", "loss: 0.23966478031836 epoch: 125\n", "loss: 0.23489066068189943 epoch: 126\n", "loss: 0.2305873609363517 epoch: 127\n", "loss: 0.22670628419543995 epoch: 128\n", "loss: 0.22320352659354065 epoch: 129\n", "accuracy test 0.91796875 train 0.9296875\n", "f1score test 0.923076923076923 train 0.9346246973365617\n", "precision test 0.8571428571428571 train 0.8772727272727273\n", "recall test 1.0 train 1.0\n", "\n", "fold no. 4\n", "loss: 1.2896150996005096 epoch: 0\n", "loss: 1.1098244781718036 epoch: 1\n", "loss: 0.9998571158744017 epoch: 2\n", "loss: 0.9200168892957966 epoch: 3\n", "loss: 0.8530897731101232 epoch: 4\n", "loss: 0.7930237243172769 epoch: 5\n", "loss: 0.7376364782834288 epoch: 6\n", "loss: 0.6860504908563037 epoch: 7\n", "loss: 0.6377987025341514 epoch: 8\n", "loss: 0.5925062255983626 epoch: 9\n", "loss: 0.5497762995577337 epoch: 10\n", "loss: 0.5091863893128747 epoch: 11\n", "loss: 0.4703712931015841 epoch: 12\n", "loss: 0.43316792980827856 epoch: 13\n", "loss: 0.397768307769488 epoch: 14\n", "loss: 0.36478302017286485 epoch: 15\n", "loss: 0.3350603154790905 epoch: 16\n", "loss: 0.31000020516810467 epoch: 17\n", "loss: 0.2900867053332955 epoch: 18\n", "loss: 0.274451320716229 epoch: 19\n", "loss: 0.26277549542965856 epoch: 20\n", "loss: 0.25393478162001887 epoch: 21\n", "loss: 0.2466950588215587 epoch: 22\n", "loss: 0.2405540945265736 epoch: 23\n", "loss: 0.2352398623954112 epoch: 24\n", "loss: 0.23095061477958415 epoch: 25\n", "loss: 0.22723828935213888 epoch: 26\n", "loss: 0.22394044347611716 epoch: 27\n", "loss: 0.22095195938809176 epoch: 28\n", "loss: 0.21820742605082138 epoch: 29\n", "loss: 0.21566239814288812 epoch: 30\n", "loss: 0.21328456283658495 epoch: 31\n", "loss: 0.21104929031477085 epoch: 32\n", "loss: 0.20893720347265204 epoch: 33\n", "loss: 0.206932730118554 epoch: 34\n", "loss: 0.20502316764158696 epoch: 35\n", "loss: 0.20319803715340626 epoch: 36\n", "loss: 0.20144861491087668 epoch: 37\n", "loss: 0.19976758043063156 epoch: 38\n", "loss: 0.19814874595791884 epoch: 39\n", "loss: 0.1965868450943621 epoch: 40\n", "loss: 0.19506788275250345 epoch: 41\n", "loss: 0.19358164310046627 epoch: 42\n", "loss: 0.19215989854272547 epoch: 43\n", "loss: 0.19081388660177723 epoch: 44\n", "loss: 0.18951545444675028 epoch: 45\n", "loss: 0.18827170665617646 epoch: 46\n", "loss: 0.18703368099038217 epoch: 47\n", "loss: 0.18577167573734796 epoch: 48\n", "loss: 0.1845510124978676 epoch: 49\n", "loss: 0.18338915483269141 epoch: 50\n", "loss: 0.1822559141622845 epoch: 51\n", "loss: 0.18114049953210898 epoch: 52\n", "loss: 0.1800153601507127 epoch: 53\n", "loss: 0.17889632412291268 epoch: 54\n", "loss: 0.17780724821281582 epoch: 55\n", "loss: 0.17674192238224418 epoch: 56\n", "loss: 0.17565505921764607 epoch: 57\n", "loss: 0.17460120080136357 epoch: 58\n", "loss: 0.17356521040369535 epoch: 59\n", "loss: 0.17255047098040383 epoch: 60\n", "loss: 0.17155993470235767 epoch: 61\n", "loss: 0.170584836462083 epoch: 62\n", "loss: 0.16962920647454663 epoch: 63\n", "loss: 0.16869206738322404 epoch: 64\n", "loss: 0.16777269095084252 epoch: 65\n", "loss: 0.16687048787866365 epoch: 66\n", "loss: 0.16599308669804297 epoch: 67\n", "loss: 0.16514099951016217 epoch: 68\n", "loss: 0.16428614031402072 epoch: 69\n", "loss: 0.16346033815551123 epoch: 70\n", "loss: 0.16263388799001327 epoch: 71\n", "loss: 0.1618287421158048 epoch: 72\n", "loss: 0.1610183665012515 epoch: 73\n", "loss: 0.16025629787230405 epoch: 74\n", "loss: 0.1594746189961126 epoch: 75\n", "loss: 0.1587227071121282 epoch: 76\n", "loss: 0.15801892723108416 epoch: 77\n", "loss: 0.1572834736203767 epoch: 78\n", "loss: 0.15658860948581035 epoch: 79\n", "loss: 0.15596102794666303 epoch: 80\n", "loss: 0.1552032770890088 epoch: 81\n", "loss: 0.15452896773911393 epoch: 82\n", "loss: 0.1540579966418598 epoch: 83\n", "loss: 0.15362613212569431 epoch: 84\n", "loss: 0.15265519696812946 epoch: 85\n", "loss: 0.15207420393854093 epoch: 86\n", "loss: 0.15165428181992469 epoch: 87\n", "loss: 0.15126334322263002 epoch: 88\n", "loss: 0.15028359359603352 epoch: 89\n", "loss: 0.14983421700813068 epoch: 90\n", "loss: 0.14932646315469775 epoch: 91\n", "loss: 0.14906663462247666 epoch: 92\n", "loss: 0.1481563007887694 epoch: 93\n", "loss: 0.14788729777647658 epoch: 94\n", "loss: 0.14704809487240297 epoch: 95\n", "loss: 0.14664415968082953 epoch: 96\n", "loss: 0.1463263487725536 epoch: 97\n", "loss: 0.1459896868058639 epoch: 98\n", "loss: 0.1464982002984429 epoch: 99\n", "loss: 0.14511794515497187 epoch: 100\n", "loss: 0.14525330157685254 epoch: 101\n", "loss: 0.14393735568358107 epoch: 102\n", "loss: 0.1438839969678795 epoch: 103\n", "loss: 0.14344525502632716 epoch: 104\n", "loss: 0.14557580454157476 epoch: 105\n", "loss: 0.1423781674380431 epoch: 106\n", "loss: 0.14272257799875493 epoch: 107\n", "loss: 0.14642802629060592 epoch: 108\n", "loss: 0.14261609636982137 epoch: 109\n", "loss: 0.1431645359955052 epoch: 110\n", "loss: 0.1469681751234988 epoch: 111\n", "loss: 0.14307278839163914 epoch: 112\n", "loss: 0.14070026858408585 epoch: 113\n", "loss: 0.14138468806747653 epoch: 114\n", "loss: 0.14185109799342493 epoch: 115\n", "loss: 0.14591713019519778 epoch: 116\n", "loss: 0.1420075832788537 epoch: 117\n", "loss: 0.13938731281250807 epoch: 118\n", "loss: 0.13991096140069806 epoch: 119\n", "loss: 0.14140511548544157 epoch: 120\n", "loss: 0.1448300629646812 epoch: 121\n", "loss: 0.14102406946282525 epoch: 122\n", "loss: 0.138241318549996 epoch: 123\n", "loss: 0.13792448741285335 epoch: 124\n", "loss: 0.13988027533982947 epoch: 125\n", "loss: 0.14380289645608385 epoch: 126\n", "loss: 0.1400994570662787 epoch: 127\n", "loss: 0.13716470828099886 epoch: 128\n", "loss: 0.13734737657946916 epoch: 129\n", "accuracy test 0.9609375 train 0.9557291666666666\n", "f1score test 0.9606299212598425 train 0.9582309582309582\n", "precision test 0.9242424242424242 train 0.9198113207547169\n", "recall test 1.0 train 1.0\n", "\n" ] } ], "source": [ "accuracies = []\n", "f1scores = []\n", "precisionscores = []\n", "recallscores = []\n", "losss = []\n", "for i in range(4):\n", " print('fold no.', i+1)\n", " Xtrain, Xtest = X[:len(X)*i//4] + X[len(X)*(i+1)//4:], X[len(X)*i//4:len(X)*(i+1)//4] \n", " Ytrain, Ytest = Y[:len(X)*i//4] + Y[len(X)*(i+1)//4:], Y[len(X)*i//4:len(X)*(i+1)//4] \n", "\n", " hiddenLayer1 = Layer(10, 1, activation='reLu')\n", " outputLayer = Layer(11, 1, activation='sigmoid')\n", " for i in range(5):\n", " hiddenLayer1.neurons[0].w[i] = Value(5.0-i)\n", " hiddenLayer1.neurons[0].w[9-i] = Value(i-5.0)\n", " for i in range(5):\n", " outputLayer.neurons[0].w[i+1] = Value(5.0-i)\n", " outputLayer.neurons[0].w[9-i+1] = Value(i-5.0)\n", " parameters = outputLayer.parameters() + hiddenLayer1.parameters()\n", " prevchange = [0]*len(parameters)\n", " beta = 0.7 # parameter for momentum update\n", " lr = 0.5\n", " epochs = 130\n", "\n", " for _ in range(epochs):\n", " Y_pred = [predict(x) for x in Xtrain]\n", " loss = Value(0)\n", " for i in range(len(Ytrain)):\n", " if Ytrain[i] == 1:\n", " loss -= Y_pred[i].log()\n", " else:\n", " loss -= (Value(1) - Y_pred[i]).log()\n", " loss = loss/len(X)\n", " loss.backward()\n", "\n", " for ix, p in enumerate(parameters):\n", " change = lr*p.grad\n", " p.data = p.data -change\n", " prevchange[ix] = change\n", "\n", " for p in parameters:\n", " p.grad = 0\n", " losss.append(loss.data)\n", " print('loss:', loss.data,'epoch:', _)\n", " print( 'accuracy','test', getAccuracy(Xtest, Ytest),'train', getAccuracy(Xtrain, Ytrain))\n", " print( 'f1score','test', getf1(Xtest, Ytest),'train', getf1(Xtrain, Ytrain))\n", " print( 'precision','test', getPrecision(Xtest, Ytest),'train', getPrecision(Xtrain, Ytrain))\n", " print( 'recall','test',getRecall(Xtest, Ytest),'train',getRecall(Xtrain, Ytrain))\n", " print()\n", " accuracies.append(getAccuracy(Xtest, Ytest))\n", " f1scores.append(getf1(Xtest, Ytest))\n", " precisionscores.append(getPrecision(Xtest, Ytest))\n", " recallscores.append(getRecall(Xtest, Ytest))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "final accuracy: 0.9521484375\n", "final f1score 0.9539790139345543\n", "final precisionscore 0.920659892424684\n", "final recallscore 0.9903846153846154\n" ] } ], "source": [ "print('final accuracy:', sum(accuracies)/4)\n", "print('final f1score', sum(f1scores)/4)\n", "print('final precisionscore', sum(precisionscores)/4)\n", "print('final recallscore', sum(recallscores)/4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "plt.plot(losss)\n", "plt.xlabel('Epochs')\n", "plt.ylabel('Loss')\n", "plt.title('Loss over Epochs / without momentum term')\n", "plt.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Network Analysis" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "Text(0, 0.5, 'neuron number')" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "neuron1weightsbias = [v.data for v in hiddenLayer1.neurons[0].w] + [hiddenLayer1.neurons[0].b.data]\n", "outputneuronweightsbias = [v.data for v in outputLayer.neurons[0].w] + [outputLayer.neurons[0].b.data]\n", "import matplotlib.pyplot as plt\n", "\n", "plt.imshow([neuron1weightsbias], cmap='hot')\n", "plt.colorbar()\n", "plt.title('Heatmap of Neurons in hidden Layer')\n", "plt.xlabel('neuron number')\n", "plt.ylabel('neuron number')\n", "plt.show()\n", "\n", "plt.imshow([outputneuronweightsbias], cmap='hot')\n", "plt.colorbar()\n", "plt.title('Heatmap of Neurons in Output Layer')\n", "plt.xlabel('neuron number')\n", "plt.ylabel('neuron number')\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[4.730257472922613,\n", " 4.291331755432143,\n", " 3.2573443760750886,\n", " 2.1589434171025923,\n", " 1.0555995955264479,\n", " -1.0670729046272482,\n", " -2.1616096967604537,\n", " -3.274597127580454,\n", " -4.281371961476945,\n", " -4.751506477851369,\n", " -0.0151131751890921]" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "neuron1weightsbias" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.bar(range(len(neuron1weightsbias)), neuron1weightsbias)\n", "plt.xlabel('Neuron Number')\n", "plt.ylabel('Weight')\n", "plt.title('Neuron 1 Weights')\n", "plt.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# with momentum term" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cross no. 1\n", "loss: 0.36868003909601843 epoch: 0\n", "loss: 0.30716981515383796 epoch: 1\n", "loss: 0.23775899611343987 epoch: 2\n", "loss: 0.2762243432739998 epoch: 3\n", "loss: 0.20530656588268192 epoch: 4\n", "loss: 0.20279116917715703 epoch: 5\n", "loss: 0.19290094036838987 epoch: 6\n", "loss: 0.1781452120046021 epoch: 7\n", "loss: 0.17445027116802864 epoch: 8\n", "loss: 0.1680095172815048 epoch: 9\n", "loss: 0.16015565597932327 epoch: 10\n", "loss: 0.15746857189608465 epoch: 11\n", "loss: 0.15072209286440752 epoch: 12\n", "loss: 0.14824549843458112 epoch: 13\n", "loss: 0.1438645786614209 epoch: 14\n", "loss: 0.14154395530516983 epoch: 15\n", "loss: 0.14546039626444338 epoch: 16\n", "loss: 0.14724562106565575 epoch: 17\n", "loss: 0.15168217576472595 epoch: 18\n", "loss: 0.14668469305743176 epoch: 19\n", "loss: 0.13920850995528983 epoch: 20\n", "loss: 0.1400466340618168 epoch: 21\n", "loss: 0.13578913273753201 epoch: 22\n", "loss: 0.13581126633783955 epoch: 23\n", "loss: 0.1339873561400119 epoch: 24\n", "loss: 0.13752733123288272 epoch: 25\n", "loss: 0.1387924765212181 epoch: 26\n", "loss: 0.14231715338456016 epoch: 27\n", "loss: 0.13975786380586355 epoch: 28\n", "loss: 0.13482568716160406 epoch: 29\n", "loss: 0.13124416058707272 epoch: 30\n", "loss: 0.12934479832215603 epoch: 31\n", "loss: 0.12996509594533118 epoch: 32\n", "loss: 0.13375543119294675 epoch: 33\n", "loss: 0.1350035756880144 epoch: 34\n", "loss: 0.13317051923653306 epoch: 35\n", "loss: 0.129370414187251 epoch: 36\n", "loss: 0.13653425853716422 epoch: 37\n", "loss: 0.1343349816737942 epoch: 38\n", "loss: 0.13899113784333378 epoch: 39\n", "loss: 0.13740755339046992 epoch: 40\n", "loss: 0.13229470851706873 epoch: 41\n", "loss: 0.12757312422232836 epoch: 42\n", "loss: 0.13382320622175067 epoch: 43\n", "loss: 0.13619209106958413 epoch: 44\n", "loss: 0.14553472065518536 epoch: 45\n", "loss: 0.14366318138752657 epoch: 46\n", "loss: 0.1352163068425741 epoch: 47\n", "loss: 0.12885727956525114 epoch: 48\n", "loss: 0.13569430384785797 epoch: 49\n", "loss: 0.12635018066965195 epoch: 50\n", "loss: 0.1265760376296105 epoch: 51\n", "loss: 0.12418106545073804 epoch: 52\n", "loss: 0.12773935470054082 epoch: 53\n", "loss: 0.12936879477131555 epoch: 54\n", "loss: 0.1331406099913378 epoch: 55\n", "loss: 0.1313656286249581 epoch: 56\n", "loss: 0.12654761487871613 epoch: 57\n", "loss: 0.12280016884750034 epoch: 58\n", "loss: 0.12105092258384308 epoch: 59\n", "loss: 0.12244958602511412 epoch: 60\n", "loss: 0.12173331538817515 epoch: 61\n", "loss: 0.12154002856765843 epoch: 62\n", "loss: 0.12116815896346547 epoch: 63\n", "loss: 0.11956434345121285 epoch: 64\n", "loss: 0.12651448768672555 epoch: 65\n", "loss: 0.13255248078491216 epoch: 66\n", "loss: 0.14255907602692364 epoch: 67\n", "loss: 0.1403848665688712 epoch: 68\n", "loss: 0.13164292783909343 epoch: 69\n", "loss: 0.1230361136681327 epoch: 70\n", "loss: 0.12785521994530472 epoch: 71\n", "loss: 0.12797962407594643 epoch: 72\n", "loss: 0.13611166071668623 epoch: 73\n", "loss: 0.13610666468084284 epoch: 74\n", "loss: 0.13010062384064133 epoch: 75\n", "loss: 0.12275577968618161 epoch: 76\n", "loss: 0.11877803736309667 epoch: 77\n", "loss: 0.12820777145497536 epoch: 78\n", "loss: 0.13392388926222884 epoch: 79\n", "loss: 0.15205382214401103 epoch: 80\n", "loss: 0.15204467699541 epoch: 81\n", "loss: 0.13967577193648126 epoch: 82\n", "loss: 0.12783521562178088 epoch: 83\n", "loss: 0.11996761102150041 epoch: 84\n", "loss: 0.15272878814399424 epoch: 85\n", "loss: 0.13593027635414032 epoch: 86\n", "loss: 0.16284574232275503 epoch: 87\n", "loss: 0.169022465956416 epoch: 88\n", "loss: 0.1608563004350432 epoch: 89\n", "loss: 0.14536127747444114 epoch: 90\n", "loss: 0.12951075225978212 epoch: 91\n", "loss: 0.12183265500170419 epoch: 92\n", "loss: 0.13448689123430263 epoch: 93\n", "loss: 0.1267437042610601 epoch: 94\n", "loss: 0.13936271820511442 epoch: 95\n", "loss: 0.14195686488368317 epoch: 96\n", "loss: 0.1330611602863246 epoch: 97\n", "loss: 0.12421631602529942 epoch: 98\n", "loss: 0.11758093570096831 epoch: 99\n", "loss: 0.12771624580596683 epoch: 100\n", "loss: 0.12394220252822441 epoch: 101\n", "loss: 0.13362435842048345 epoch: 102\n", "loss: 0.13589971507386636 epoch: 103\n", "loss: 0.12995998864673372 epoch: 104\n", "loss: 0.12363934179205438 epoch: 105\n", "loss: 0.11867928993935122 epoch: 106\n", "loss: 0.11502914272216506 epoch: 107\n", "loss: 0.11464685113251707 epoch: 108\n", "loss: 0.11963012701832111 epoch: 109\n", "loss: 0.12344770173053593 epoch: 110\n", "loss: 0.12408039072068819 epoch: 111\n", "loss: 0.12184518057389364 epoch: 112\n", "loss: 0.11811872628150506 epoch: 113\n", "loss: 0.11379028259141508 epoch: 114\n", "loss: 0.12296666853343105 epoch: 115\n", "loss: 0.12153531160724884 epoch: 116\n", "loss: 0.12994306734636502 epoch: 117\n", "loss: 0.13137041130812116 epoch: 118\n", "loss: 0.1268467252439156 epoch: 119\n", "loss: 0.12141442989585767 epoch: 120\n", "loss: 0.11735373129119131 epoch: 121\n", "loss: 0.11997133181031398 epoch: 122\n", "loss: 0.11337592847936363 epoch: 123\n", "loss: 0.112500407946854 epoch: 124\n", "loss: 0.10992950927058084 epoch: 125\n", "loss: 0.11121358772423992 epoch: 126\n", "loss: 0.11512866813382663 epoch: 127\n", "loss: 0.11795210358477946 epoch: 128\n", "loss: 0.11557040626073901 epoch: 129\n", "accuracy test 0.9609375 train 0.9635416666666666\n", "f1score test 0.9603174603174602 train 0.9654320987654321\n", "precision test 0.9236641221374046 train 0.9331742243436754\n", "recall test 1.0 train 1.0\n", "\n", "cross no. 2\n", "loss: 0.9165482655929494 epoch: 0\n", "loss: 0.72942999904096 epoch: 1\n", "loss: 0.4856930084133062 epoch: 2\n", "loss: 0.3006016218793657 epoch: 3\n", "loss: 0.229297719505487 epoch: 4\n", "loss: 0.2035380933512077 epoch: 5\n", "loss: 0.1884329387905959 epoch: 6\n", "loss: 0.17440150199427282 epoch: 7\n", "loss: 0.16559364265055387 epoch: 8\n", "loss: 0.1671090872026514 epoch: 9\n", "loss: 0.15784783619000395 epoch: 10\n", "loss: 0.15385004711132688 epoch: 11\n", "loss: 0.15056986555805388 epoch: 12\n", "loss: 0.14425287668164297 epoch: 13\n", "loss: 0.144775837388235 epoch: 14\n", "loss: 0.14462250608588476 epoch: 15\n", "loss: 0.14672312506427956 epoch: 16\n", "loss: 0.14095305333634742 epoch: 17\n", "loss: 0.13388962027016418 epoch: 18\n", "loss: 0.142070817415337 epoch: 19\n", "loss: 0.1456456885014303 epoch: 20\n", "loss: 0.15812800996793822 epoch: 21\n", "loss: 0.1544320853629628 epoch: 22\n", "loss: 0.14024559209357193 epoch: 23\n", "loss: 0.1308282922113132 epoch: 24\n", "loss: 0.1362907224150407 epoch: 25\n", "loss: 0.13415669640207478 epoch: 26\n", "loss: 0.14317308854436225 epoch: 27\n", "loss: 0.14236280161740272 epoch: 28\n", "loss: 0.13473532418657544 epoch: 29\n", "loss: 0.12817357150743497 epoch: 30\n", "loss: 0.12162093949378945 epoch: 31\n", "loss: 0.1332858064616225 epoch: 32\n", "loss: 0.1358368217072919 epoch: 33\n", "loss: 0.1517668747382924 epoch: 34\n", "loss: 0.152713284987877 epoch: 35\n", "loss: 0.14354638804842046 epoch: 36\n", "loss: 0.13233323601784902 epoch: 37\n", "loss: 0.12400193311250832 epoch: 38\n", "loss: 0.13161283143789068 epoch: 39\n", "loss: 0.12695252481763505 epoch: 40\n", "loss: 0.133914747443586 epoch: 41\n", "loss: 0.13510942589823066 epoch: 42\n", "loss: 0.1309073752531184 epoch: 43\n", "loss: 0.12341259693926745 epoch: 44\n", "loss: 0.11676911453281041 epoch: 45\n", "loss: 0.12266493179402127 epoch: 46\n", "loss: 0.11827590522585574 epoch: 47\n", "loss: 0.12099977450325256 epoch: 48\n", "loss: 0.11863684734683867 epoch: 49\n", "loss: 0.11597137303394474 epoch: 50\n", "loss: 0.11764848695544085 epoch: 51\n", "loss: 0.11335054883452948 epoch: 52\n", "loss: 0.11281322545904551 epoch: 53\n", "loss: 0.11147328349211215 epoch: 54\n", "loss: 0.11069135969028722 epoch: 55\n", "loss: 0.11221606797395027 epoch: 56\n", "loss: 0.11171134553670575 epoch: 57\n", "loss: 0.11225438168942115 epoch: 58\n", "loss: 0.11188319849153237 epoch: 59\n", "loss: 0.11102479703800383 epoch: 60\n", "loss: 0.1104333539337151 epoch: 61\n", "loss: 0.1103271691769313 epoch: 62\n", "loss: 0.10980031688289973 epoch: 63\n", "loss: 0.11001263488613786 epoch: 64\n", "loss: 0.11050138264209658 epoch: 65\n", "loss: 0.11057652328902944 epoch: 66\n", "loss: 0.10877268437928819 epoch: 67\n", "loss: 0.11111578147893164 epoch: 68\n", "loss: 0.11365620028613219 epoch: 69\n", "loss: 0.1171775698851146 epoch: 70\n", "loss: 0.11565003068863162 epoch: 71\n", "loss: 0.11152269216921619 epoch: 72\n", "loss: 0.109090253588753 epoch: 73\n", "loss: 0.11663692366704477 epoch: 74\n", "loss: 0.11508334155649544 epoch: 75\n", "loss: 0.12259961802559624 epoch: 76\n", "loss: 0.12300297111041628 epoch: 77\n", "loss: 0.11854419265285605 epoch: 78\n", "loss: 0.11215602242236969 epoch: 79\n", "loss: 0.10995708897702192 epoch: 80\n", "loss: 0.10804462385850468 epoch: 81\n", "loss: 0.10847530181936375 epoch: 82\n", "loss: 0.10866036158339008 epoch: 83\n", "loss: 0.10753578820232068 epoch: 84\n", "loss: 0.10768971006867763 epoch: 85\n", "loss: 0.1069712642760346 epoch: 86\n", "loss: 0.10622612228276418 epoch: 87\n", "loss: 0.10757799579253217 epoch: 88\n", "loss: 0.10824967097655418 epoch: 89\n", "loss: 0.10949167707778833 epoch: 90\n", "loss: 0.10839600678839856 epoch: 91\n", "loss: 0.10599216439143208 epoch: 92\n", "loss: 0.10619293231720403 epoch: 93\n", "loss: 0.10472897259122634 epoch: 94\n", "loss: 0.10466338599052498 epoch: 95\n", "loss: 0.10381208764969022 epoch: 96\n", "loss: 0.103263207646071 epoch: 97\n", "loss: 0.10295610606580222 epoch: 98\n", "loss: 0.10520566310944364 epoch: 99\n", "loss: 0.10812127516211963 epoch: 100\n", "loss: 0.11166967401196799 epoch: 101\n", "loss: 0.110904746094558 epoch: 102\n", "loss: 0.10761049472331469 epoch: 103\n", "loss: 0.10384037080175822 epoch: 104\n", "loss: 0.11164487719958505 epoch: 105\n", "loss: 0.10984117426949151 epoch: 106\n", "loss: 0.11693931298549386 epoch: 107\n", "loss: 0.11705036507428887 epoch: 108\n", "loss: 0.11163706418215827 epoch: 109\n", "loss: 0.1059613848329792 epoch: 110\n", "loss: 0.10596521355634281 epoch: 111\n", "loss: 0.10426213398099518 epoch: 112\n", "loss: 0.10302925050203493 epoch: 113\n", "loss: 0.10299099449597665 epoch: 114\n", "loss: 0.10163475861231477 epoch: 115\n", "loss: 0.10876845415404501 epoch: 116\n", "loss: 0.1071947280728179 epoch: 117\n", "loss: 0.11198154422008179 epoch: 118\n", "loss: 0.11261457852825729 epoch: 119\n", "loss: 0.10929629882370472 epoch: 120\n", "loss: 0.10512130260279152 epoch: 121\n", "loss: 0.1023672052320518 epoch: 122\n", "loss: 0.10411206882729991 epoch: 123\n", "loss: 0.10223115353649928 epoch: 124\n", "loss: 0.10279201814677916 epoch: 125\n", "loss: 0.10174845038328884 epoch: 126\n", "loss: 0.10115223916866523 epoch: 127\n", "loss: 0.1003520581942911 epoch: 128\n", "loss: 0.09969450706128662 epoch: 129\n", "accuracy test 0.9609375 train 0.9609375\n", "f1score test 0.962962962962963 train 0.9622166246851386\n", "precision test 0.9285714285714286 train 0.9271844660194175\n", "recall test 1.0 train 1.0\n", "\n", "cross no. 3\n", "loss: 1.243937970628717 epoch: 0\n", "loss: 1.0197823143075653 epoch: 1\n", "loss: 0.7860238518286085 epoch: 2\n", "loss: 0.6428857890003671 epoch: 3\n", "loss: 0.5661858405028479 epoch: 4\n", "loss: 0.5218729846133345 epoch: 5\n", "loss: 0.4920264445189551 epoch: 6\n", "loss: 0.46747648861860275 epoch: 7\n", "loss: 0.44412802537111323 epoch: 8\n", "loss: 0.41986880781072655 epoch: 9\n", "loss: 0.39358283574377495 epoch: 10\n", "loss: 0.36525723352594436 epoch: 11\n", "loss: 0.336899444414848 epoch: 12\n", "loss: 0.3090511658199733 epoch: 13\n", "loss: 0.2826167309764894 epoch: 14\n", "loss: 0.2582322576918214 epoch: 15\n", "loss: 0.23231777846931498 epoch: 16\n", "loss: 0.2067491965341805 epoch: 17\n", "loss: 0.1826216467451588 epoch: 18\n", "loss: 0.15770367263216456 epoch: 19\n", "loss: 0.1426324373512798 epoch: 20\n", "loss: 0.17682915447708708 epoch: 21\n", "loss: 0.1452902249969128 epoch: 22\n", "loss: 0.1576088929486672 epoch: 23\n", "loss: 0.15527571868273707 epoch: 24\n", "loss: 0.14405203357727125 epoch: 25\n", "loss: 0.1353994569170679 epoch: 26\n", "loss: 0.1527236520754517 epoch: 27\n", "loss: 0.13995572057264077 epoch: 28\n", "loss: 0.14722067136375522 epoch: 29\n", "loss: 0.14663008057794108 epoch: 30\n", "loss: 0.14050429946297335 epoch: 31\n", "loss: 0.13431123587618363 epoch: 32\n", "loss: 0.1354441638508038 epoch: 33\n", "loss: 0.1351923587375112 epoch: 34\n", "loss: 0.13860755298460298 epoch: 35\n", "loss: 0.13745137559683546 epoch: 36\n", "loss: 0.13333656524776769 epoch: 37\n", "loss: 0.1281102312419493 epoch: 38\n", "loss: 0.14775331024378383 epoch: 39\n", "loss: 0.13682918547183684 epoch: 40\n", "loss: 0.14788065810636103 epoch: 41\n", "loss: 0.1469301894650251 epoch: 42\n", "loss: 0.13798428764640994 epoch: 43\n", "loss: 0.1300587367770154 epoch: 44\n", "loss: 0.12866181034546598 epoch: 45\n", "loss: 0.12813915599347503 epoch: 46\n", "loss: 0.12922755346944556 epoch: 47\n", "loss: 0.12749376632262518 epoch: 48\n", "loss: 0.12379693635409315 epoch: 49\n", "loss: 0.13236356522600265 epoch: 50\n", "loss: 0.13222525721775655 epoch: 51\n", "loss: 0.1401955624707633 epoch: 52\n", "loss: 0.13864098997041335 epoch: 53\n", "loss: 0.13153681924540897 epoch: 54\n", "loss: 0.12469296843634378 epoch: 55\n", "loss: 0.12875084532916908 epoch: 56\n", "loss: 0.12823531446648095 epoch: 57\n", "loss: 0.13425116386077834 epoch: 58\n", "loss: 0.13352186939286795 epoch: 59\n", "loss: 0.12815213884648194 epoch: 60\n", "loss: 0.12203170106509019 epoch: 61\n", "loss: 0.12599114866979444 epoch: 62\n", "loss: 0.12742049512977371 epoch: 63\n", "loss: 0.1347010131683283 epoch: 64\n", "loss: 0.1336342424578307 epoch: 65\n", "loss: 0.1273152744419998 epoch: 66\n", "loss: 0.12083666131579937 epoch: 67\n", "loss: 0.12328984113389031 epoch: 68\n", "loss: 0.12248910889763288 epoch: 69\n", "loss: 0.12629911517469639 epoch: 70\n", "loss: 0.1253043565603647 epoch: 71\n", "loss: 0.12073610651861404 epoch: 72\n", "loss: 0.11685445407433288 epoch: 73\n", "loss: 0.12343029060907694 epoch: 74\n", "loss: 0.12716080628464543 epoch: 75\n", "loss: 0.13777029997788864 epoch: 76\n", "loss: 0.13584128986370303 epoch: 77\n", "loss: 0.12679299933155674 epoch: 78\n", "loss: 0.11887425397554868 epoch: 79\n", "loss: 0.12168561111435384 epoch: 80\n", "loss: 0.1163907234794487 epoch: 81\n", "loss: 0.117230546266443 epoch: 82\n", "loss: 0.11610551091725406 epoch: 83\n", "loss: 0.11452266544355408 epoch: 84\n", "loss: 0.11519613558312052 epoch: 85\n", "loss: 0.11441959510995409 epoch: 86\n", "loss: 0.11460747171577224 epoch: 87\n", "loss: 0.11379650786448829 epoch: 88\n", "loss: 0.11344271001815405 epoch: 89\n", "loss: 0.1122588884616754 epoch: 90\n", "loss: 0.11133161997743117 epoch: 91\n", "loss: 0.1117323857235364 epoch: 92\n", "loss: 0.1120077427813384 epoch: 93\n", "loss: 0.11206086904733695 epoch: 94\n", "loss: 0.11091961505207532 epoch: 95\n", "loss: 0.1111846663839085 epoch: 96\n", "loss: 0.11134385270269638 epoch: 97\n", "loss: 0.11156065212414144 epoch: 98\n", "loss: 0.11103713660219755 epoch: 99\n", "loss: 0.11012294241451635 epoch: 100\n", "loss: 0.10907474655247146 epoch: 101\n", "loss: 0.10797768794797519 epoch: 102\n", "loss: 0.11581670620554932 epoch: 103\n", "loss: 0.1208964051157461 epoch: 104\n", "loss: 0.13116181000532734 epoch: 105\n", "loss: 0.12813314902338863 epoch: 106\n", "loss: 0.11815182985467738 epoch: 107\n", "loss: 0.10979779703683795 epoch: 108\n", "loss: 0.12726145967455824 epoch: 109\n", "loss: 0.12340343191621314 epoch: 110\n", "loss: 0.13956627014028716 epoch: 111\n", "loss: 0.14168538132408248 epoch: 112\n", "loss: 0.1344182510103709 epoch: 113\n", "loss: 0.12106936303415626 epoch: 114\n", "loss: 0.11130332899760419 epoch: 115\n", "loss: 0.12215656288511659 epoch: 116\n", "loss: 0.11456430601944621 epoch: 117\n", "loss: 0.1225189614176677 epoch: 118\n", "loss: 0.12224122650393168 epoch: 119\n", "loss: 0.11553087513063057 epoch: 120\n", "loss: 0.10818977635916188 epoch: 121\n", "loss: 0.11268851889791999 epoch: 122\n", "loss: 0.11130702874969753 epoch: 123\n", "loss: 0.1160864609441687 epoch: 124\n", "loss: 0.11481989925433321 epoch: 125\n", "loss: 0.1099752177528171 epoch: 126\n", "loss: 0.10534502052241014 epoch: 127\n", "loss: 0.11448328612193452 epoch: 128\n", "loss: 0.11287708846716817 epoch: 129\n", "accuracy test 0.9140625 train 0.9544270833333334\n", "f1score test 0.923076923076923 train 0.9559748427672956\n", "precision test 0.8571428571428571 train 0.9156626506024096\n", "recall test 1.0 train 1.0\n", "\n", "cross no. 4\n", "loss: 0.7034731463075571 epoch: 0\n", "loss: 0.6065281491302504 epoch: 1\n", "loss: 0.4951887887694784 epoch: 2\n", "loss: 0.39699567997550556 epoch: 3\n", "loss: 0.2799209612182886 epoch: 4\n", "loss: 0.23875123344825736 epoch: 5\n", "loss: 0.23318116861046517 epoch: 6\n", "loss: 0.21100254283567652 epoch: 7\n", "loss: 0.18888816833212038 epoch: 8\n", "loss: 0.17922143869818957 epoch: 9\n", "loss: 0.16942878933128672 epoch: 10\n", "loss: 0.16412827276724087 epoch: 11\n", "loss: 0.158906233678726 epoch: 12\n", "loss: 0.15342571485581122 epoch: 13\n", "loss: 0.14831077315542793 epoch: 14\n", "loss: 0.14219897983941915 epoch: 15\n", "loss: 0.13984728357064477 epoch: 16\n", "loss: 0.13636640862980137 epoch: 17\n", "loss: 0.13479930102318155 epoch: 18\n", "loss: 0.134435794896284 epoch: 19\n", "loss: 0.1343688259230099 epoch: 20\n", "loss: 0.13375693379348824 epoch: 21\n", "loss: 0.13070872471823414 epoch: 22\n", "loss: 0.13577641255732942 epoch: 23\n", "loss: 0.13361862667619914 epoch: 24\n", "loss: 0.13654647394189948 epoch: 25\n", "loss: 0.13417348967924048 epoch: 26\n", "loss: 0.12972932898959824 epoch: 27\n", "loss: 0.1291217440104548 epoch: 28\n", "loss: 0.13198162760869697 epoch: 29\n", "loss: 0.135054765832336 epoch: 30\n", "loss: 0.13320357821353757 epoch: 31\n", "loss: 0.12909628769972906 epoch: 32\n", "loss: 0.12573433574303722 epoch: 33\n", "loss: 0.13477334032582067 epoch: 34\n", "loss: 0.1323425662396159 epoch: 35\n", "loss: 0.14041181270878655 epoch: 36\n", "loss: 0.1387750871453092 epoch: 37\n", "loss: 0.13157804233240256 epoch: 38\n", "loss: 0.1259678057586573 epoch: 39\n", "loss: 0.1269320724503436 epoch: 40\n", "loss: 0.12591662427445863 epoch: 41\n", "loss: 0.12797915811837937 epoch: 42\n", "loss: 0.12756404518885575 epoch: 43\n", "loss: 0.12538937735793515 epoch: 44\n", "loss: 0.123002392646119 epoch: 45\n", "loss: 0.12341501828348579 epoch: 46\n", "loss: 0.12435327613205616 epoch: 47\n", "loss: 0.1265953229802795 epoch: 48\n", "loss: 0.12620645993895843 epoch: 49\n", "loss: 0.12403969502258806 epoch: 50\n", "loss: 0.1208232715441185 epoch: 51\n", "loss: 0.12618915811980186 epoch: 52\n", "loss: 0.12273889568581745 epoch: 53\n", "loss: 0.12563103674416173 epoch: 54\n", "loss: 0.12520812739757248 epoch: 55\n", "loss: 0.12245431144712171 epoch: 56\n", "loss: 0.11863307268044157 epoch: 57\n", "loss: 0.12051669177016336 epoch: 58\n", "loss: 0.11815964737139971 epoch: 59\n", "loss: 0.11847109269406123 epoch: 60\n", "loss: 0.11630552895028964 epoch: 61\n", "loss: 0.11735549475993035 epoch: 62\n", "loss: 0.1154355936014345 epoch: 63\n", "loss: 0.11534015710817626 epoch: 64\n", "loss: 0.11371787612327316 epoch: 65\n", "loss: 0.1162879590364654 epoch: 66\n", "loss: 0.1186500563144215 epoch: 67\n", "loss: 0.12255232419740734 epoch: 68\n", "loss: 0.1205502966619044 epoch: 69\n", "loss: 0.11559170084589787 epoch: 70\n", "loss: 0.11200897143957601 epoch: 71\n", "loss: 0.11855099712484697 epoch: 72\n", "loss: 0.11811760873269349 epoch: 73\n", "loss: 0.1248936843608191 epoch: 74\n", "loss: 0.12439209442066561 epoch: 75\n", "loss: 0.11878453326081641 epoch: 76\n", "loss: 0.1126947256644929 epoch: 77\n", "loss: 0.11433384713596555 epoch: 78\n", "loss: 0.1107735585308641 epoch: 79\n", "loss: 0.11122956996723496 epoch: 80\n", "loss: 0.11055775876358186 epoch: 81\n", "loss: 0.10947555009104898 epoch: 82\n", "loss: 0.10941025196097363 epoch: 83\n", "loss: 0.10931555748936736 epoch: 84\n", "loss: 0.10910265386609252 epoch: 85\n", "loss: 0.10829084576390395 epoch: 86\n", "loss: 0.1112498249696613 epoch: 87\n", "loss: 0.11139974865295045 epoch: 88\n", "loss: 0.11436936911233192 epoch: 89\n", "loss: 0.11335535788700257 epoch: 90\n", "loss: 0.11028644795969816 epoch: 91\n", "loss: 0.10775048468457629 epoch: 92\n", "loss: 0.11298200724997584 epoch: 93\n", "loss: 0.11244084941551477 epoch: 94\n", "loss: 0.11916647968906746 epoch: 95\n", "loss: 0.11816646304668112 epoch: 96\n", "loss: 0.11235985159849048 epoch: 97\n", "loss: 0.1078118637173683 epoch: 98\n", "loss: 0.11145819366441052 epoch: 99\n", "loss: 0.1072962394582775 epoch: 100\n", "loss: 0.10849653560807722 epoch: 101\n", "loss: 0.10819595237226295 epoch: 102\n", "loss: 0.10684554318361893 epoch: 103\n", "loss: 0.10542175489508059 epoch: 104\n", "loss: 0.10818050307024418 epoch: 105\n", "loss: 0.10784328295477412 epoch: 106\n", "loss: 0.11100641898539822 epoch: 107\n", "loss: 0.11057081267445168 epoch: 108\n", "loss: 0.10780317381169821 epoch: 109\n", "loss: 0.10531528957710151 epoch: 110\n", "loss: 0.1068801223889791 epoch: 111\n", "loss: 0.10472825251983942 epoch: 112\n", "loss: 0.1051267333169951 epoch: 113\n", "loss: 0.10480307172699281 epoch: 114\n", "loss: 0.10404278473462401 epoch: 115\n", "loss: 0.10331971581441422 epoch: 116\n", "loss: 0.1033296372624746 epoch: 117\n", "loss: 0.10410368328464957 epoch: 118\n", "loss: 0.10480345394765454 epoch: 119\n", "loss: 0.10449037406938307 epoch: 120\n", "loss: 0.1036419131005099 epoch: 121\n", "loss: 0.1039402883442618 epoch: 122\n", "loss: 0.10279433564761 epoch: 123\n", "loss: 0.10252534741848623 epoch: 124\n", "loss: 0.10192324923643555 epoch: 125\n", "loss: 0.10391248552345227 epoch: 126\n", "loss: 0.10679248378189415 epoch: 127\n", "loss: 0.11264800752265242 epoch: 128\n", "loss: 0.11202389350910529 epoch: 129\n", "accuracy test 0.96484375 train 0.9596354166666666\n", "f1score test 0.9662921348314606 train 0.9611041405269761\n", "precision test 0.9347826086956522 train 0.9251207729468599\n", "recall test 1.0 train 1.0\n", "\n" ] } ], "source": [ "accuracies = []\n", "f1scores = []\n", "precisionscores = []\n", "recallscores = []\n", "losss = []\n", "\n", "for i in range(4):\n", " print('cross no.', i+1)\n", " Xtrain, Xtest = X[:len(X)*i//4] + X[len(X)*(i+1)//4:], X[len(X)*i//4:len(X)*(i+1)//4] \n", " Ytrain, Ytest = Y[:len(X)*i//4] + Y[len(X)*(i+1)//4:], Y[len(X)*i//4:len(X)*(i+1)//4] \n", "\n", " hiddenLayer1 = Layer(10, 1, activation='reLu')\n", " outputLayer = Layer(11, 1, activation='sigmoid')\n", " for i in range(5):\n", " hiddenLayer1.neurons[0].w[i] = Value(5.0-i)\n", " hiddenLayer1.neurons[0].w[9-i] = Value(i-5.0)\n", " for i in range(5):\n", " outputLayer.neurons[0].w[i+1] = Value(5.0-i)\n", " outputLayer.neurons[0].w[9-i+1] = Value(i-5.0)\n", " parameters = outputLayer.parameters() + hiddenLayer1.parameters()\n", "\n", " prevchange = [0]*len(parameters)\n", " beta = 0.7 # parameter for momentum update\n", " lr = 1\n", " epochs = 130\n", "\n", " for _ in range(epochs):\n", " Y_pred = [predict(x) for x in Xtrain]\n", " loss = Value(0)\n", " for i in range(len(Ytrain)):\n", " if Ytrain[i] == 1:\n", " loss -= Y_pred[i].log()\n", " else:\n", " loss -= (Value(1) - Y_pred[i]).log()\n", " loss = loss/len(X)\n", " loss.backward()\n", "\n", " for ix, p in enumerate(parameters):\n", " change = lr*p.grad + beta*prevchange[ix]\n", " p.data = p.data -change\n", " prevchange[ix] = change\n", "\n", " for p in parameters:\n", " p.grad = 0\n", " losss.append(loss.data)\n", " print('loss:', loss.data,'epoch:', _)\n", " print( 'accuracy','test', getAccuracy(Xtest, Ytest),'train', getAccuracy(Xtrain, Ytrain))\n", " print( 'f1score','test', getf1(Xtest, Ytest),'train', getf1(Xtrain, Ytrain))\n", " print( 'precision','test', getPrecision(Xtest, Ytest),'train', getPrecision(Xtrain, Ytrain))\n", " print( 'recall','test',getRecall(Xtest, Ytest),'train',getRecall(Xtrain, Ytrain))\n", " print()\n", " accuracies.append(getAccuracy(Xtest, Ytest))\n", " f1scores.append(getf1(Xtest, Ytest))\n", " precisionscores.append(getPrecision(Xtest, Ytest))\n", " recallscores.append(getRecall(Xtest, Ytest))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "final accuracy: 0.9501953125\n", "final f1score 0.9531623702972016\n", "final precisionscore 0.9110402541368356\n", "final recallscore 1.0\n" ] } ], "source": [ "print('final accuracy:', sum(accuracies)/4)\n", "print('final f1score', sum(f1scores)/4)\n", "print('final precisionscore', sum(precisionscores)/4)\n", "print('final recallscore', sum(recallscores)/4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[4.130229565865318,\n", " 5.156892161282381,\n", " 4.022968498866327,\n", " 2.1792732617579484,\n", " 0.9036316169266277,\n", " -1.0918839311741206,\n", " -2.336481262518848,\n", " -4.136072422567298,\n", " -5.19979213705881,\n", " -4.200395429503537,\n", " -0.13505294147048558]" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "[i.data for i in hiddenLayer1.neurons[0].w] + [hiddenLayer1.neurons[0].b.data]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "plt.plot(losss)\n", "plt.xlabel('Epochs')\n", "plt.ylabel('Loss')\n", "plt.title('Loss over Epochs / with momentum term')\n", "plt.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Network Analysis" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "Text(0, 0.5, 'neuron number')" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "neuron1weightsbias = [v.data for v in hiddenLayer1.neurons[0].w] + [hiddenLayer1.neurons[0].b.data]\n", "outputneuronweightsbias = [v.data for v in outputLayer.neurons[0].w] + [outputLayer.neurons[0].b.data]\n", "import matplotlib.pyplot as plt\n", "\n", "plt.imshow([neuron1weightsbias], cmap='hot')\n", "plt.colorbar()\n", "plt.title('Heatmap of Neurons in hidden Layer')\n", "plt.xlabel('neuron number')\n", "plt.ylabel('neuron number')\n", "plt.show()\n", "\n", "plt.imshow([outputneuronweightsbias], cmap='hot')\n", "plt.colorbar()\n", "plt.title('Heatmap of Neurons in Output Layer')\n", "plt.xlabel('neuron number')\n", "plt.ylabel('neuron number')\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.bar(range(len(neuron1weightsbias)), neuron1weightsbias)\n", "plt.xlabel('Neuron Number')\n", "plt.ylabel('Weight')\n", "plt.title('Neuron 1 Weights')\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.bar(range(len(outputneuronweightsbias)), outputneuronweightsbias)\n", "plt.xlabel('Neuron Number')\n", "plt.ylabel('Weight')\n", "plt.title('output neuron 1 Weights')\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# create a confustion matrix\n", "\n", "from sklearn.metrics import confusion_matrix\n", "import seaborn as sns\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "def plot_confusion_matrix(y_true, y_pred, classes, normalize=False, title=None, cmap=plt.cm.Blues):\n", " \"\"\"\n", " This function prints and plots the confusion matrix.\n", " Normalization can be applied by setting `normalize=True`.\n", " \"\"\"\n", " if not title:\n", " if normalize:\n", " title = 'Normalized confusion matrix'\n", " else:\n", " title = 'Confusion matrix, without normalization'\n", "\n", " # Compute confusion matrix\n", " cm = confusion_matrix(y_true, y_pred)\n", " # Only use the labels that appear in the data\n", " classes = classes\n", " if normalize:\n", " cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n", " print(\"Normalized confusion matrix\")\n", " else:\n", " print('Confusion matrix, without normalization')\n", "\n", " print(cm)\n", "\n", " fig, ax = plt.subplots()\n", " im = ax.imshow(cm, interpolation='nearest', cmap=cmap)\n", " ax" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def predictArray(X):\n", " return [1 if predict(x).data > 0.5 else 0 for x in X ]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Confusion matrix" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from sklearn.metrics import confusion_matrix\n", "\n", "# Example confusion matrix data (replace with your own)\n", "y_true = Y\n", "y_pred = predictArray(X)\n", "\n", "# Compute confusion matrix\n", "cm = confusion_matrix(y_true, y_pred)\n", "\n", "# Plot confusion matrix with numbers\n", "plt.figure(figsize=(8, 6))\n", "sns.set(font_scale=1.2) # Adjust font scale if needed\n", "sns.heatmap(cm, annot=True, fmt='g', cmap='Blues', \n", " xticklabels=['Class 1', 'Class 0'], \n", " yticklabels=['Class 1', 'Class 0'])\n", "plt.xlabel('Predicted')\n", "plt.ylabel('Actual')\n", "plt.title('Confusion Matrix')\n", "plt.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Saving parameters" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pickle as pkl \n", "\n", "with open('parameters/neuron1weightsbias_fn_reLu.pckl', 'wb') as file:\n", " pkl.dump(neuron1weightsbias, file)\n", "with open('parameters/outputneuronweightsbias_fn_reLu.pckl', 'wb') as file:\n", " pkl.dump(outputneuronweightsbias, file)" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [], "source": [ "# Load model\n", "\n", "def loadModel():\n", " neuron1weightsbias, outputneuronweightsbias = [], []\n", " with open(f'parameters/neuron1weightsbias_fn_reLu.pckl', 'rb') as file:\n", " neuron1weightsbias = pickle.load(file)\n", " with open('parameters/outputneuronweightsbias_fn_reLu.pckl', 'rb') as file:\n", " outputneuronweightsbias = pickle.load(file)\n", " hiddenLayer1_ = Layer(10, 1, 'reLu')\n", " outputLayer_ = Layer(11, 1, 'sigmoid')\n", "\n", " hiddenLayer1_.neurons[0].w = [Value(i) for i in neuron1weightsbias[:-1]]\n", " hiddenLayer1_.neurons[0].b = Value(neuron1weightsbias[-1])\n", "\n", " outputLayer_.neurons[0].w = [Value(i) for i in outputneuronweightsbias[:-1]]\n", " outputLayer_.neurons[0].b = Value(outputneuronweightsbias[-1])\n", " return hiddenLayer1_, outputLayer_, neuron1weightsbias, outputneuronweightsbias" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "import pickle as pkl \n", "\n", "hiddenLayer1, outputLayer, neuron1weightsbias, outputneuronweightsbias = loadModel()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.4" } }, "nbformat": 4, "nbformat_minor": 2 }