{ "cells": [ { "cell_type": "markdown", "id": "fce8933f-4594-4bb5-bffe-86fcb9ddd684", "metadata": {}, "source": [ "# MLE of a Gaussian $p_{model}(x|w)$" ] }, { "cell_type": "code", "execution_count": 4, "id": "f6cd23f0-e755-48af-be5e-aaee83dda1e7", "metadata": { "tags": [] }, "outputs": [], "source": [ "import numpy as np\n", "\n", "data = [4, 5, 7, 8, 8, 9, 10, 5, 2, 3, 5, 4, 8, 9]\n", "\n", "\n", "## imports\n", "import numpy as np\n", "import pandas as pd\n", "from scipy.optimize import minimize\n", "from scipy.stats import norm\n", "import math\n", "\n", "\n", "## Problem 1\n", "data = [4, 5, 7, 8, 8, 9, 10, 5, 2, 3, 5, 4, 8, 9]\n", "\n", "data_mean = np.mean(data)\n", "data_variance = np.var(data)\n", "\n", "\n", "mu = 0.5\n", "sigma = 0.5\n", "w = np.array([mu, sigma])\n", "\n", "w_star = np.array([data_mean, data_variance])\n", "mu_star = data_mean\n", "sigma_star = np.sqrt(data_variance)\n", "offset = 10 * np.random.random(2)\n", "\n", "w1p = w_star + 0.5 * offset\n", "w1n = w_star - 0.5 * offset\n", "w2p = w_star + 0.25 * offset\n", "w2n = w_star - 0.25 * offset" ] }, { "cell_type": "markdown", "id": "f3d8587b-3862-4e98-bbcc-99d57bb313c1", "metadata": {}, "source": [ "Negative Log Likelihood is defined as follows: $-\\ln(\\frac{1}{\\sqrt{2\\pi\\sigma^2}}\\exp(-\\frac{1}{2}\\frac{(x-\\mu)}{\\sigma}^2))$. Ignoring the contribution of the constant, we find that $\\frac{\\delta}{\\delta \\mu} \\mathcal{N} = \\frac{\\mu-x}{\\sigma^2}$ and $\\frac{\\delta}{\\delta \\sigma} \\mathcal{N} = \\frac{\\sigma^2 + (\\mu-x)^2 - \\sigma^2}{\\sigma^3}$. We apply these as our step functions for our SGD. " ] }, { "cell_type": "code", "execution_count": 5, "id": "27bf27ad-031e-4b65-a44d-53c5c1a09d91", "metadata": { "tags": [] }, "outputs": [], "source": [ "loss = lambda mu, sigma, x: np.sum(\n", " [-np.log(norm.pdf(xi, loc=mu, scale=sigma)) for xi in x]\n", ")\n", "\n", "loss_2_alternative = lambda mu, sigma, x: -len(x) / 2 * np.log(\n", " 2 * np.pi * sigma**2\n", ") - 1 / (2 * sigma**2) * np.sum((x - mu) ** 2)\n", "\n", "\n", "dmu = lambda mu, sigma, x: -np.sum([mu - xi for xi in x]) / (sigma**2)\n", "dsigma = lambda mu, sigma, x: -len(x) / sigma + np.sum([(mu - xi) ** 2 for xi in x]) / (sigma**3)\n", "\n", "log = []\n", "def SGD_problem1(mu, sigma, x, learning_rate=0.01, n_epochs=1000):\n", " global log\n", " log = []\n", " for epoch in range(n_epochs):\n", " mu += learning_rate * dmu(mu, sigma, x)\n", " sigma += learning_rate * dsigma(mu, sigma, x)\n", "\n", " # print(f\"Epoch {epoch}, Loss: {loss(mu, sigma, x)}, New mu: {mu}, New sigma: {sigma}\")\n", " log.append(\n", " {\n", " \"Epoch\": epoch,\n", " \"Loss\": loss(mu, sigma, x),\n", " \"Loss 2 Alternative\": loss_2_alternative(mu, sigma, x),\n", " \"New mu\": mu,\n", " \"New sigma\": sigma,\n", " }\n", " )\n", " return np.array([mu, sigma])\n", "\n", "\n", "def debug_SGD_1(wnn, data):\n", " print(\"SGD Problem 1\")\n", " print(\"wnn\", SGD_problem1(*wnn, data))\n", " dflog = pd.DataFrame(log)\n", " dflog[\"mu_star\"] = mu_star\n", " dflog[\"mu_std\"] = sigma_star\n", " print(f\"mu diff at start {dflog.iloc[0]['New mu'] - dflog.iloc[0]['mu_star']}\")\n", " print(f\"mu diff at end {dflog.iloc[-1]['New mu'] - dflog.iloc[-1]['mu_star']}\")\n", " if np.abs(dflog.iloc[-1][\"New mu\"] - dflog.iloc[-1][\"mu_star\"]) < np.abs(\n", " dflog.iloc[0][\"New mu\"] - dflog.iloc[0][\"mu_star\"]\n", " ):\n", " print(\"mu is improving\")\n", " else:\n", " print(\"mu is not improving\")\n", "\n", " print(f\"sigma diff at start {dflog.iloc[0]['New sigma'] - dflog.iloc[0]['mu_std']}\")\n", " print(f\"sigma diff at end {dflog.iloc[-1]['New sigma'] - dflog.iloc[-1]['mu_std']}\")\n", " if np.abs(dflog.iloc[-1][\"New sigma\"] - dflog.iloc[-1][\"mu_std\"]) < np.abs(\n", " dflog.iloc[0][\"New sigma\"] - dflog.iloc[0][\"mu_std\"]\n", " ):\n", " print(\"sigma is improving\")\n", " else:\n", " print(\"sigma is not improving\")\n", "\n", " return dflog" ] }, { "cell_type": "code", "execution_count": 6, "id": "27dd3bc6-b96e-4f8b-9118-01ad344dfd6a", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SGD Problem 1\n", "wnn [6.2142858 2.42541812]\n", "mu diff at start 0.27610721776969527\n", "mu diff at end 8.978893806244059e-08\n", "mu is improving\n", "sigma diff at start 8.134860851821205\n", "sigma diff at end 1.7124079931818414e-12\n", "sigma is improving\n", "SGD Problem 1\n", "wnn [6.21428571 2.42541812]\n", "mu diff at start -0.24923650064862635\n", "mu diff at end -6.602718372050731e-12\n", "mu is improving\n", "sigma diff at start -0.859536014291925\n", "sigma diff at end -3.552713678800501e-15\n", "sigma is improving\n", "SGD Problem 1\n", "wnn [6.21428572 2.42541812]\n", "mu diff at start 0.13794086144778994\n", "mu diff at end 1.0008935902305893e-09\n", "mu is improving\n", "sigma diff at start 5.786783512688555\n", "sigma diff at end 4.440892098500626e-15\n", "sigma is improving\n", "SGD Problem 1\n", "wnn [6.21428571 2.42541812]\n", "mu diff at start -0.13668036978891251\n", "mu diff at end -8.528289185960602e-12\n", "mu is improving\n", "sigma diff at start 1.091241177336173\n", "sigma diff at end 4.440892098500626e-15\n", "sigma is improving\n" ] } ], "source": [ "_ = debug_SGD_1(w1p, data)\n", "_ = debug_SGD_1(w1n, data)\n", "_ = debug_SGD_1(w2p, data)\n", "_ = debug_SGD_1(w2n, data)" ] }, { "cell_type": "markdown", "id": "30096401-0bd5-4cf6-b093-a688476e16f1", "metadata": { "tags": [] }, "source": [ "# MLE of Conditional Gaussian" ] }, { "cell_type": "markdown", "id": "101a3c5e-1e02-41e6-9eab-aba65c39627a", "metadata": {}, "source": [ "dsigma = $-\\frac{n}{\\sigma}+\\frac{1}{\\sigma^3}\\sum_{i=1}^n(y_i - (mx+c))^2$ \n", "dc = $-\\frac{1}{\\sigma^2}\\sum_{i=1}^n(y_i - (mx+c))$ \n", "dm = $-\\frac{1}{\\sigma^2}\\sum_{i=1}^n(x_i(y_i - (mx+c)))$ " ] }, { "cell_type": "code", "execution_count": 8, "id": "21969012-f81b-43d4-975d-13411e975f8f", "metadata": { "collapsed": true, "jupyter": { "outputs_hidden": true }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0, Loss: 297.82677563555086\n", "Epoch 1, Loss: 297.8267749215061\n", "Epoch 2, Loss: 297.82677420752475\n", "Epoch 3, Loss: 297.82677349360694\n", "Epoch 4, Loss: 297.8267727797526\n", "Epoch 5, Loss: 297.8267720659618\n", "Epoch 6, Loss: 297.82677135223446\n", "Epoch 7, Loss: 297.82677063857074\n", "Epoch 8, Loss: 297.8267699249706\n", "Epoch 9, Loss: 297.826769211434\n", "Epoch 10, Loss: 297.8267684979611\n", "Epoch 11, Loss: 297.82676778455175\n", "Epoch 12, Loss: 297.82676707120623\n", "Epoch 13, Loss: 297.82676635792427\n", "Epoch 14, Loss: 297.8267656447061\n", "Epoch 15, Loss: 297.8267649315517\n", "Epoch 16, Loss: 297.82676421846105\n", "Epoch 17, Loss: 297.82676350543414\n", "Epoch 18, Loss: 297.82676279247113\n", "Epoch 19, Loss: 297.82676207957184\n", "Epoch 20, Loss: 297.82676136673643\n", "Epoch 21, Loss: 297.826760653965\n", "Epoch 22, Loss: 297.82675994125736\n", "Epoch 23, Loss: 297.8267592286137\n", "Epoch 24, Loss: 297.8267585160339\n", "Epoch 25, Loss: 297.8267578035182\n", "Epoch 26, Loss: 297.82675709106655\n", "Epoch 27, Loss: 297.8267563786788\n", "Epoch 28, Loss: 297.82675566635504\n", "Epoch 29, Loss: 297.82675495409546\n", "Epoch 30, Loss: 297.82675424189983\n", "Epoch 31, Loss: 297.82675352976844\n", "Epoch 32, Loss: 297.8267528177011\n", "Epoch 33, Loss: 297.82675210569795\n", "Epoch 34, Loss: 297.826751393759\n", "Epoch 35, Loss: 297.8267506818843\n", "Epoch 36, Loss: 297.8267499700737\n", "Epoch 37, Loss: 297.8267492583274\n", "Epoch 38, Loss: 297.82674854664543\n", "Epoch 39, Loss: 297.8267478350277\n", "Epoch 40, Loss: 297.8267471234743\n", "Epoch 41, Loss: 297.82674641198514\n", "Epoch 42, Loss: 297.8267457005605\n", "Epoch 43, Loss: 297.82674498920017\n", "Epoch 44, Loss: 297.82674427790425\n", "Epoch 45, Loss: 297.82674356667275\n", "Epoch 46, Loss: 297.82674285550564\n", "Epoch 47, Loss: 297.8267421444032\n", "Epoch 48, Loss: 297.82674143336516\n", "Epoch 49, Loss: 297.8267407223916\n", "Epoch 50, Loss: 297.8267400114827\n", "Epoch 51, Loss: 297.82673930063817\n", "Epoch 52, Loss: 297.8267385898584\n", "Epoch 53, Loss: 297.8267378791432\n", "Epoch 54, Loss: 297.8267371684925\n", "Epoch 55, Loss: 297.8267364579067\n", "Epoch 56, Loss: 297.8267357473855\n", "Epoch 57, Loss: 297.82673503692894\n", "Epoch 58, Loss: 297.82673432653723\n", "Epoch 59, Loss: 297.8267336162103\n", "Epoch 60, Loss: 297.826732905948\n", "Epoch 61, Loss: 297.82673219575054\n", "Epoch 62, Loss: 297.826731485618\n", "Epoch 63, Loss: 297.82673077555023\n", "Epoch 64, Loss: 297.82673006554734\n", "Epoch 65, Loss: 297.8267293556093\n", "Epoch 66, Loss: 297.82672864573624\n", "Epoch 67, Loss: 297.82672793592815\n", "Epoch 68, Loss: 297.82672722618497\n", "Epoch 69, Loss: 297.8267265165068\n", "Epoch 70, Loss: 297.8267258068936\n", "Epoch 71, Loss: 297.8267250973455\n", "Epoch 72, Loss: 297.8267243878624\n", "Epoch 73, Loss: 297.8267236784444\n", "Epoch 74, Loss: 297.82672296909146\n", "Epoch 75, Loss: 297.8267222598038\n", "Epoch 76, Loss: 297.8267215505811\n", "Epoch 77, Loss: 297.8267208414237\n", "Epoch 78, Loss: 297.82672013233156\n", "Epoch 79, Loss: 297.8267194233045\n", "Epoch 80, Loss: 297.8267187143427\n", "Epoch 81, Loss: 297.8267180054462\n", "Epoch 82, Loss: 297.82671729661496\n", "Epoch 83, Loss: 297.82671658784903\n", "Epoch 84, Loss: 297.8267158791485\n", "Epoch 85, Loss: 297.82671517051335\n", "Epoch 86, Loss: 297.8267144619435\n", "Epoch 87, Loss: 297.8267137534391\n", "Epoch 88, Loss: 297.82671304500013\n", "Epoch 89, Loss: 297.82671233662654\n", "Epoch 90, Loss: 297.82671162831855\n", "Epoch 91, Loss: 297.8267109200759\n", "Epoch 92, Loss: 297.82671021189896\n", "Epoch 93, Loss: 297.8267095037876\n", "Epoch 94, Loss: 297.8267087957417\n", "Epoch 95, Loss: 297.82670808776135\n", "Epoch 96, Loss: 297.82670737984665\n", "Epoch 97, Loss: 297.8267066719976\n", "Epoch 98, Loss: 297.82670596421417\n", "Epoch 99, Loss: 297.8267052564966\n", "Epoch 100, Loss: 297.82670454884465\n", "Epoch 101, Loss: 297.82670384125834\n", "Epoch 102, Loss: 297.8267031337379\n", "Epoch 103, Loss: 297.8267024262833\n", "Epoch 104, Loss: 297.82670171889436\n", "Epoch 105, Loss: 297.8267010115713\n", "Epoch 106, Loss: 297.82670030431416\n", "Epoch 107, Loss: 297.82669959712285\n", "Epoch 108, Loss: 297.82669888999743\n", "Epoch 109, Loss: 297.8266981829379\n", "Epoch 110, Loss: 297.8266974759444\n", "Epoch 111, Loss: 297.8266967690169\n", "Epoch 112, Loss: 297.8266960621552\n", "Epoch 113, Loss: 297.8266953553597\n", "Epoch 114, Loss: 297.82669464863017\n", "Epoch 115, Loss: 297.82669394196677\n", "Epoch 116, Loss: 297.8266932353694\n", "Epoch 117, Loss: 297.82669252883824\n", "Epoch 118, Loss: 297.8266918223731\n", "Epoch 119, Loss: 297.8266911159742\n", "Epoch 120, Loss: 297.82669040964146\n", "Epoch 121, Loss: 297.8266897033749\n", "Epoch 122, Loss: 297.8266889971746\n", "Epoch 123, Loss: 297.82668829104057\n", "Epoch 124, Loss: 297.8266875849728\n", "Epoch 125, Loss: 297.8266868789714\n", "Epoch 126, Loss: 297.8266861730362\n", "Epoch 127, Loss: 297.8266854671674\n", "Epoch 128, Loss: 297.826684761365\n", "Epoch 129, Loss: 297.82668405562896\n", "Epoch 130, Loss: 297.8266833499594\n", "Epoch 131, Loss: 297.8266826443563\n", "Epoch 132, Loss: 297.8266819388196\n", "Epoch 133, Loss: 297.82668123334935\n", "Epoch 134, Loss: 297.8266805279458\n", "Epoch 135, Loss: 297.8266798226087\n", "Epoch 136, Loss: 297.82667911733813\n", "Epoch 137, Loss: 297.8266784121342\n", "Epoch 138, Loss: 297.82667770699686\n", "Epoch 139, Loss: 297.82667700192616\n", "Epoch 140, Loss: 297.8266762969221\n", "Epoch 141, Loss: 297.8266755919847\n", "Epoch 142, Loss: 297.82667488711405\n", "Epoch 143, Loss: 297.8266741823101\n", "Epoch 144, Loss: 297.82667347757297\n", "Epoch 145, Loss: 297.82667277290255\n", "Epoch 146, Loss: 297.82667206829893\n", "Epoch 147, Loss: 297.8266713637622\n", "Epoch 148, Loss: 297.8266706592923\n", "Epoch 149, Loss: 297.8266699548893\n", "Epoch 150, Loss: 297.8266692505533\n", "Epoch 151, Loss: 297.82666854628405\n", "Epoch 152, Loss: 297.82666784208175\n", "Epoch 153, Loss: 297.8266671379465\n", "Epoch 154, Loss: 297.8266664338782\n", "Epoch 155, Loss: 297.82666572987705\n", "Epoch 156, Loss: 297.8266650259427\n", "Epoch 157, Loss: 297.82666432207566\n", "Epoch 158, Loss: 297.82666361827546\n", "Epoch 159, Loss: 297.8266629145426\n", "Epoch 160, Loss: 297.8266622108768\n", "Epoch 161, Loss: 297.8266615072782\n", "Epoch 162, Loss: 297.8266608037467\n", "Epoch 163, Loss: 297.8266601002824\n", "Epoch 164, Loss: 297.8266593968855\n", "Epoch 165, Loss: 297.8266586935557\n", "Epoch 166, Loss: 297.82665799029326\n", "Epoch 167, Loss: 297.82665728709804\n", "Epoch 168, Loss: 297.82665658397036\n", "Epoch 169, Loss: 297.8266558809098\n", "Epoch 170, Loss: 297.82665517791673\n", "Epoch 171, Loss: 297.8266544749911\n", "Epoch 172, Loss: 297.82665377213283\n", "Epoch 173, Loss: 297.826653069342\n", "Epoch 174, Loss: 297.8266523666187\n", "Epoch 175, Loss: 297.82665166396293\n", "Epoch 176, Loss: 297.82665096137464\n", "Epoch 177, Loss: 297.8266502588538\n", "Epoch 178, Loss: 297.82664955640064\n", "Epoch 179, Loss: 297.82664885401516\n", "Epoch 180, Loss: 297.82664815169716\n", "Epoch 181, Loss: 297.8266474494469\n", "Epoch 182, Loss: 297.8266467472642\n", "Epoch 183, Loss: 297.8266460451493\n", "Epoch 184, Loss: 297.82664534310203\n", "Epoch 185, Loss: 297.82664464112264\n", "Epoch 186, Loss: 297.82664393921095\n", "Epoch 187, Loss: 297.82664323736697\n", "Epoch 188, Loss: 297.8266425355909\n", "Epoch 189, Loss: 297.8266418338827\n", "Epoch 190, Loss: 297.82664113224223\n", "Epoch 191, Loss: 297.8266404306698\n", "Epoch 192, Loss: 297.82663972916515\n", "Epoch 193, Loss: 297.82663902772856\n", "Epoch 194, Loss: 297.82663832635984\n", "Epoch 195, Loss: 297.8266376250591\n", "Epoch 196, Loss: 297.8266369238264\n", "Epoch 197, Loss: 297.82663622266176\n", "Epoch 198, Loss: 297.8266355215652\n", "Epoch 199, Loss: 297.8266348205367\n", "Epoch 200, Loss: 297.8266341195762\n", "Epoch 201, Loss: 297.82663341868397\n", "Epoch 202, Loss: 297.82663271785987\n", "Epoch 203, Loss: 297.82663201710386\n", "Epoch 204, Loss: 297.8266313164161\n", "Epoch 205, Loss: 297.82663061579655\n", "Epoch 206, Loss: 297.8266299152454\n", "Epoch 207, Loss: 297.8266292147624\n", "Epoch 208, Loss: 297.8266285143477\n", "Epoch 209, Loss: 297.82662781400137\n", "Epoch 210, Loss: 297.82662711372336\n", "Epoch 211, Loss: 297.8266264135138\n", "Epoch 212, Loss: 297.8266257133725\n", "Epoch 213, Loss: 297.8266250132998\n", "Epoch 214, Loss: 297.82662431329544\n", "Epoch 215, Loss: 297.8266236133595\n", "Epoch 216, Loss: 297.8266229134922\n", "Epoch 217, Loss: 297.8266222136934\n", "Epoch 218, Loss: 297.8266215139631\n", "Epoch 219, Loss: 297.8266208143014\n", "Epoch 220, Loss: 297.8266201147082\n", "Epoch 221, Loss: 297.8266194151838\n", "Epoch 222, Loss: 297.82661871572793\n", "Epoch 223, Loss: 297.82661801634066\n", "Epoch 224, Loss: 297.82661731702217\n", "Epoch 225, Loss: 297.82661661777234\n", "Epoch 226, Loss: 297.8266159185914\n", "Epoch 227, Loss: 297.8266152194791\n", "Epoch 228, Loss: 297.82661452043567\n", "Epoch 229, Loss: 297.82661382146097\n", "Epoch 230, Loss: 297.8266131225552\n", "Epoch 231, Loss: 297.82661242371825\n", "Epoch 232, Loss: 297.82661172495017\n", "Epoch 233, Loss: 297.82661102625104\n", "Epoch 234, Loss: 297.8266103276209\n", "Epoch 235, Loss: 297.8266096290597\n", "Epoch 236, Loss: 297.82660893056743\n", "Epoch 237, Loss: 297.8266082321442\n", "Epoch 238, Loss: 297.82660753379\n", "Epoch 239, Loss: 297.8266068355049\n", "Epoch 240, Loss: 297.82660613728893\n", "Epoch 241, Loss: 297.82660543914204\n", "Epoch 242, Loss: 297.8266047410642\n", "Epoch 243, Loss: 297.82660404305557\n", "Epoch 244, Loss: 297.82660334511615\n", "Epoch 245, Loss: 297.826602647246\n", "Epoch 246, Loss: 297.826601949445\n", "Epoch 247, Loss: 297.8266012517133\n", "Epoch 248, Loss: 297.82660055405086\n", "Epoch 249, Loss: 297.82659985645773\n", "Epoch 250, Loss: 297.82659915893396\n", "Epoch 251, Loss: 297.8265984614796\n", "Epoch 252, Loss: 297.8265977640945\n", "Epoch 253, Loss: 297.8265970667789\n", "Epoch 254, Loss: 297.8265963695328\n", "Epoch 255, Loss: 297.82659567235606\n", "Epoch 256, Loss: 297.8265949752488\n", "Epoch 257, Loss: 297.8265942782112\n", "Epoch 258, Loss: 297.82659358124295\n", "Epoch 259, Loss: 297.82659288434434\n", "Epoch 260, Loss: 297.82659218751525\n", "Epoch 261, Loss: 297.82659149075585\n", "Epoch 262, Loss: 297.826590794066\n", "Epoch 263, Loss: 297.8265900974459\n", "Epoch 264, Loss: 297.8265894008955\n", "Epoch 265, Loss: 297.82658870441463\n", "Epoch 266, Loss: 297.8265880080037\n", "Epoch 267, Loss: 297.8265873116626\n", "Epoch 268, Loss: 297.82658661539097\n", "Epoch 269, Loss: 297.8265859191893\n", "Epoch 270, Loss: 297.8265852230575\n", "Epoch 271, Loss: 297.8265845269956\n", "Epoch 272, Loss: 297.82658383100346\n", "Epoch 273, Loss: 297.82658313508136\n", "Epoch 274, Loss: 297.8265824392291\n", "Epoch 275, Loss: 297.8265817434468\n", "Epoch 276, Loss: 297.82658104773463\n", "Epoch 277, Loss: 297.82658035209226\n", "Epoch 278, Loss: 297.82657965652004\n", "Epoch 279, Loss: 297.8265789610178\n", "Epoch 280, Loss: 297.82657826558574\n", "Epoch 281, Loss: 297.8265775702237\n", "Epoch 282, Loss: 297.82657687493185\n", "Epoch 283, Loss: 297.8265761797103\n", "Epoch 284, Loss: 297.82657548455876\n", "Epoch 285, Loss: 297.82657478947743\n", "Epoch 286, Loss: 297.82657409446637\n", "Epoch 287, Loss: 297.8265733995255\n", "Epoch 288, Loss: 297.826572704655\n", "Epoch 289, Loss: 297.8265720098549\n", "Epoch 290, Loss: 297.82657131512497\n", "Epoch 291, Loss: 297.8265706204655\n", "Epoch 292, Loss: 297.82656992587636\n", "Epoch 293, Loss: 297.82656923135755\n", "Epoch 294, Loss: 297.82656853690935\n", "Epoch 295, Loss: 297.8265678425316\n", "Epoch 296, Loss: 297.82656714822417\n", "Epoch 297, Loss: 297.82656645398737\n", "Epoch 298, Loss: 297.8265657598211\n", "Epoch 299, Loss: 297.82656506572545\n", "Epoch 300, Loss: 297.8265643717003\n", "Epoch 301, Loss: 297.8265636777458\n", "Epoch 302, Loss: 297.8265629838619\n", "Epoch 303, Loss: 297.8265622900487\n", "Epoch 304, Loss: 297.8265615963061\n", "Epoch 305, Loss: 297.82656090263424\n", "Epoch 306, Loss: 297.8265602090333\n", "Epoch 307, Loss: 297.82655951550294\n", "Epoch 308, Loss: 297.82655882204335\n", "Epoch 309, Loss: 297.8265581286547\n", "Epoch 310, Loss: 297.82655743533684\n", "Epoch 311, Loss: 297.8265567420898\n", "Epoch 312, Loss: 297.8265560489138\n", "Epoch 313, Loss: 297.8265553558085\n", "Epoch 314, Loss: 297.82655466277436\n", "Epoch 315, Loss: 297.8265539698111\n", "Epoch 316, Loss: 297.8265532769188\n", "Epoch 317, Loss: 297.8265525840975\n", "Epoch 318, Loss: 297.8265518913472\n", "Epoch 319, Loss: 297.8265511986681\n", "Epoch 320, Loss: 297.82655050606\n", "Epoch 321, Loss: 297.8265498135231\n", "Epoch 322, Loss: 297.8265491210572\n", "Epoch 323, Loss: 297.82654842866265\n", "Epoch 324, Loss: 297.8265477363392\n", "Epoch 325, Loss: 297.826547044087\n", "Epoch 326, Loss: 297.826546351906\n", "Epoch 327, Loss: 297.82654565979635\n", "Epoch 328, Loss: 297.8265449677579\n", "Epoch 329, Loss: 297.8265442757908\n", "Epoch 330, Loss: 297.8265435838951\n", "Epoch 331, Loss: 297.82654289207085\n", "Epoch 332, Loss: 297.8265422003178\n", "Epoch 333, Loss: 297.82654150863635\n", "Epoch 334, Loss: 297.8265408170263\n", "Epoch 335, Loss: 297.8265401254876\n", "Epoch 336, Loss: 297.8265394340205\n", "Epoch 337, Loss: 297.826538742625\n", "Epoch 338, Loss: 297.8265380513009\n", "Epoch 339, Loss: 297.8265373600486\n", "Epoch 340, Loss: 297.8265366688676\n", "Epoch 341, Loss: 297.82653597775845\n", "Epoch 342, Loss: 297.82653528672085\n", "Epoch 343, Loss: 297.826534595755\n", "Epoch 344, Loss: 297.82653390486087\n", "Epoch 345, Loss: 297.8265332140384\n", "Epoch 346, Loss: 297.8265325232878\n", "Epoch 347, Loss: 297.8265318326088\n", "Epoch 348, Loss: 297.8265311420018\n", "Epoch 349, Loss: 297.8265304514666\n", "Epoch 350, Loss: 297.82652976100314\n", "Epoch 351, Loss: 297.82652907061174\n", "Epoch 352, Loss: 297.82652838029213\n", "Epoch 353, Loss: 297.8265276900445\n", "Epoch 354, Loss: 297.82652699986875\n", "Epoch 355, Loss: 297.82652630976514\n", "Epoch 356, Loss: 297.82652561973345\n", "Epoch 357, Loss: 297.82652492977377\n", "Epoch 358, Loss: 297.82652423988617\n", "Epoch 359, Loss: 297.82652355007065\n", "Epoch 360, Loss: 297.8265228603273\n", "Epoch 361, Loss: 297.8265221706561\n", "Epoch 362, Loss: 297.82652148105706\n", "Epoch 363, Loss: 297.8265207915302\n", "Epoch 364, Loss: 297.8265201020755\n", "Epoch 365, Loss: 297.8265194126932\n", "Epoch 366, Loss: 297.82651872338306\n", "Epoch 367, Loss: 297.8265180341453\n", "Epoch 368, Loss: 297.8265173449798\n", "Epoch 369, Loss: 297.8265166558866\n", "Epoch 370, Loss: 297.8265159668659\n", "Epoch 371, Loss: 297.82651527791745\n", "Epoch 372, Loss: 297.8265145890415\n", "Epoch 373, Loss: 297.82651390023807\n", "Epoch 374, Loss: 297.82651321150695\n", "Epoch 375, Loss: 297.82651252284853\n", "Epoch 376, Loss: 297.8265118342626\n", "Epoch 377, Loss: 297.8265111457491\n", "Epoch 378, Loss: 297.82651045730825\n", "Epoch 379, Loss: 297.8265097689401\n", "Epoch 380, Loss: 297.8265090806445\n", "Epoch 381, Loss: 297.8265083924216\n", "Epoch 382, Loss: 297.82650770427136\n", "Epoch 383, Loss: 297.82650701619394\n", "Epoch 384, Loss: 297.82650632818905\n", "Epoch 385, Loss: 297.8265056402571\n", "Epoch 386, Loss: 297.8265049523977\n", "Epoch 387, Loss: 297.82650426461134\n", "Epoch 388, Loss: 297.82650357689784\n", "Epoch 389, Loss: 297.82650288925714\n", "Epoch 390, Loss: 297.82650220168927\n", "Epoch 391, Loss: 297.82650151419443\n", "Epoch 392, Loss: 297.8265008267725\n", "Epoch 393, Loss: 297.8265001394235\n", "Epoch 394, Loss: 297.8264994521476\n", "Epoch 395, Loss: 297.8264987649447\n", "Epoch 396, Loss: 297.82649807781473\n", "Epoch 397, Loss: 297.82649739075794\n", "Epoch 398, Loss: 297.82649670377424\n", "Epoch 399, Loss: 297.82649601686364\n", "Epoch 400, Loss: 297.8264953300262\n", "Epoch 401, Loss: 297.826494643262\n", "Epoch 402, Loss: 297.82649395657097\n", "Epoch 403, Loss: 297.8264932699532\n", "Epoch 404, Loss: 297.8264925834086\n", "Epoch 405, Loss: 297.8264918969374\n", "Epoch 406, Loss: 297.8264912105394\n", "Epoch 407, Loss: 297.8264905242148\n", "Epoch 408, Loss: 297.82648983796355\n", "Epoch 409, Loss: 297.8264891517858\n", "Epoch 410, Loss: 297.82648846568134\n", "Epoch 411, Loss: 297.8264877796504\n", "Epoch 412, Loss: 297.8264870936928\n", "Epoch 413, Loss: 297.82648640780883\n", "Epoch 414, Loss: 297.8264857219984\n", "Epoch 415, Loss: 297.8264850362614\n", "Epoch 416, Loss: 297.82648435059804\n", "Epoch 417, Loss: 297.8264836650082\n", "Epoch 418, Loss: 297.8264829794921\n", "Epoch 419, Loss: 297.82648229404964\n", "Epoch 420, Loss: 297.8264816086808\n", "Epoch 421, Loss: 297.8264809233857\n", "Epoch 422, Loss: 297.82648023816444\n", "Epoch 423, Loss: 297.8264795530168\n", "Epoch 424, Loss: 297.82647886794297\n", "Epoch 425, Loss: 297.82647818294294\n", "Epoch 426, Loss: 297.82647749801686\n", "Epoch 427, Loss: 297.8264768131645\n", "Epoch 428, Loss: 297.8264761283861\n", "Epoch 429, Loss: 297.82647544368155\n", "Epoch 430, Loss: 297.82647475905094\n", "Epoch 431, Loss: 297.82647407449446\n", "Epoch 432, Loss: 297.8264733900118\n", "Epoch 433, Loss: 297.82647270560335\n", "Epoch 434, Loss: 297.8264720212687\n", "Epoch 435, Loss: 297.82647133700834\n", "Epoch 436, Loss: 297.8264706528221\n", "Epoch 437, Loss: 297.82646996870983\n", "Epoch 438, Loss: 297.8264692846718\n", "Epoch 439, Loss: 297.8264686007079\n", "Epoch 440, Loss: 297.8264679168184\n", "Epoch 441, Loss: 297.8264672330029\n", "Epoch 442, Loss: 297.8264665492618\n", "Epoch 443, Loss: 297.82646586559486\n", "Epoch 444, Loss: 297.8264651820023\n", "Epoch 445, Loss: 297.82646449848403\n", "Epoch 446, Loss: 297.8264638150402\n", "Epoch 447, Loss: 297.8264631316708\n", "Epoch 448, Loss: 297.8264624483758\n", "Epoch 449, Loss: 297.8264617651551\n", "Epoch 450, Loss: 297.8264610820091\n", "Epoch 451, Loss: 297.82646039893746\n", "Epoch 452, Loss: 297.82645971594036\n", "Epoch 453, Loss: 297.8264590330179\n", "Epoch 454, Loss: 297.8264583501699\n", "Epoch 455, Loss: 297.82645766739654\n", "Epoch 456, Loss: 297.82645698469787\n", "Epoch 457, Loss: 297.8264563020737\n", "Epoch 458, Loss: 297.82645561952444\n", "Epoch 459, Loss: 297.8264549370498\n", "Epoch 460, Loss: 297.82645425464983\n", "Epoch 461, Loss: 297.8264535723248\n", "Epoch 462, Loss: 297.82645289007445\n", "Epoch 463, Loss: 297.8264522078989\n", "Epoch 464, Loss: 297.8264515257983\n", "Epoch 465, Loss: 297.8264508437724\n", "Epoch 466, Loss: 297.8264501618215\n", "Epoch 467, Loss: 297.82644947994555\n", "Epoch 468, Loss: 297.82644879814444\n", "Epoch 469, Loss: 297.8264481164186\n", "Epoch 470, Loss: 297.82644743476743\n", "Epoch 471, Loss: 297.82644675319153\n", "Epoch 472, Loss: 297.82644607169055\n", "Epoch 473, Loss: 297.8264453902647\n", "Epoch 474, Loss: 297.8264447089139\n", "Epoch 475, Loss: 297.82644402763833\n", "Epoch 476, Loss: 297.82644334643805\n", "Epoch 477, Loss: 297.82644266531275\n", "Epoch 478, Loss: 297.8264419842628\n", "Epoch 479, Loss: 297.82644130328805\n", "Epoch 480, Loss: 297.8264406223886\n", "Epoch 481, Loss: 297.8264399415644\n", "Epoch 482, Loss: 297.8264392608157\n", "Epoch 483, Loss: 297.8264385801421\n", "Epoch 484, Loss: 297.82643789954403\n", "Epoch 485, Loss: 297.82643721902133\n", "Epoch 486, Loss: 297.8264365385741\n", "Epoch 487, Loss: 297.8264358582022\n", "Epoch 488, Loss: 297.82643517790603\n", "Epoch 489, Loss: 297.8264344976852\n", "Epoch 490, Loss: 297.82643381753996\n", "Epoch 491, Loss: 297.8264331374703\n", "Epoch 492, Loss: 297.8264324574763\n", "Epoch 493, Loss: 297.82643177755784\n", "Epoch 494, Loss: 297.82643109771504\n", "Epoch 495, Loss: 297.82643041794796\n", "Epoch 496, Loss: 297.8264297382566\n", "Epoch 497, Loss: 297.82642905864094\n", "Epoch 498, Loss: 297.82642837910106\n", "Epoch 499, Loss: 297.82642769963695\n", "Epoch 500, Loss: 297.8264270202487\n", "Epoch 501, Loss: 297.8264263409362\n", "Epoch 502, Loss: 297.82642566169966\n", "Epoch 503, Loss: 297.826424982539\n", "Epoch 504, Loss: 297.8264243034543\n", "Epoch 505, Loss: 297.82642362444545\n", "Epoch 506, Loss: 297.8264229455126\n", "Epoch 507, Loss: 297.8264222666558\n", "Epoch 508, Loss: 297.82642158787496\n", "Epoch 509, Loss: 297.82642090917034\n", "Epoch 510, Loss: 297.8264202305416\n", "Epoch 511, Loss: 297.82641955198915\n", "Epoch 512, Loss: 297.82641887351275\n", "Epoch 513, Loss: 297.8264181951125\n", "Epoch 514, Loss: 297.8264175167885\n", "Epoch 515, Loss: 297.82641683854075\n", "Epoch 516, Loss: 297.82641616036915\n", "Epoch 517, Loss: 297.82641548227383\n", "Epoch 518, Loss: 297.82641480425497\n", "Epoch 519, Loss: 297.8264141263123\n", "Epoch 520, Loss: 297.82641344844603\n", "Epoch 521, Loss: 297.8264127706562\n", "Epoch 522, Loss: 297.82641209294275\n", "Epoch 523, Loss: 297.82641141530576\n", "Epoch 524, Loss: 297.8264107377451\n", "Epoch 525, Loss: 297.826410060261\n", "Epoch 526, Loss: 297.82640938285346\n", "Epoch 527, Loss: 297.8264087055225\n", "Epoch 528, Loss: 297.826408028268\n", "Epoch 529, Loss: 297.82640735109027\n", "Epoch 530, Loss: 297.8264066739891\n", "Epoch 531, Loss: 297.8264059969646\n", "Epoch 532, Loss: 297.8264053200167\n", "Epoch 533, Loss: 297.8264046431456\n", "Epoch 534, Loss: 297.82640396635117\n", "Epoch 535, Loss: 297.8264032896337\n", "Epoch 536, Loss: 297.82640261299275\n", "Epoch 537, Loss: 297.8264019364288\n", "Epoch 538, Loss: 297.82640125994163\n", "Epoch 539, Loss: 297.8264005835315\n", "Epoch 540, Loss: 297.8263999071981\n", "Epoch 541, Loss: 297.82639923094166\n", "Epoch 542, Loss: 297.8263985547622\n", "Epoch 543, Loss: 297.82639787865975\n", "Epoch 544, Loss: 297.82639720263427\n", "Epoch 545, Loss: 297.8263965266858\n", "Epoch 546, Loss: 297.82639585081455\n", "Epoch 547, Loss: 297.82639517502025\n", "Epoch 548, Loss: 297.82639449930315\n", "Epoch 549, Loss: 297.8263938236632\n", "Epoch 550, Loss: 297.8263931481004\n", "Epoch 551, Loss: 297.8263924726149\n", "Epoch 552, Loss: 297.8263917972065\n", "Epoch 553, Loss: 297.8263911218754\n", "Epoch 554, Loss: 297.82639044662164\n", "Epoch 555, Loss: 297.8263897714451\n", "Epoch 556, Loss: 297.8263890963461\n", "Epoch 557, Loss: 297.82638842132434\n", "Epoch 558, Loss: 297.82638774638\n", "Epoch 559, Loss: 297.82638707151307\n", "Epoch 560, Loss: 297.82638639672365\n", "Epoch 561, Loss: 297.8263857220117\n", "Epoch 562, Loss: 297.8263850473772\n", "Epoch 563, Loss: 297.8263843728203\n", "Epoch 564, Loss: 297.826383698341\n", "Epoch 565, Loss: 297.82638302393923\n", "Epoch 566, Loss: 297.82638234961513\n", "Epoch 567, Loss: 297.8263816753686\n", "Epoch 568, Loss: 297.82638100119976\n", "Epoch 569, Loss: 297.82638032710867\n", "Epoch 570, Loss: 297.82637965309533\n", "Epoch 571, Loss: 297.82637897915964\n", "Epoch 572, Loss: 297.8263783053019\n", "Epoch 573, Loss: 297.8263776315219\n", "Epoch 574, Loss: 297.82637695781983\n", "Epoch 575, Loss: 297.8263762841955\n", "Epoch 576, Loss: 297.8263756106491\n", "Epoch 577, Loss: 297.8263749371807\n", "Epoch 578, Loss: 297.8263742637902\n", "Epoch 579, Loss: 297.82637359047766\n", "Epoch 580, Loss: 297.8263729172433\n", "Epoch 581, Loss: 297.82637224408677\n", "Epoch 582, Loss: 297.8263715710084\n", "Epoch 583, Loss: 297.826370898008\n", "Epoch 584, Loss: 297.8263702250859\n", "Epoch 585, Loss: 297.82636955224183\n", "Epoch 586, Loss: 297.82636887947604\n", "Epoch 587, Loss: 297.8263682067884\n", "Epoch 588, Loss: 297.826367534179\n", "Epoch 589, Loss: 297.82636686164784\n", "Epoch 590, Loss: 297.8263661891951\n", "Epoch 591, Loss: 297.82636551682054\n", "Epoch 592, Loss: 297.82636484452433\n", "Epoch 593, Loss: 297.8263641723065\n", "Epoch 594, Loss: 297.8263635001672\n", "Epoch 595, Loss: 297.82636282810614\n", "Epoch 596, Loss: 297.8263621561237\n", "Epoch 597, Loss: 297.8263614842196\n", "Epoch 598, Loss: 297.82636081239417\n", "Epoch 599, Loss: 297.8263601406472\n", "Epoch 600, Loss: 297.8263594689787\n", "Epoch 601, Loss: 297.826358797389\n", "Epoch 602, Loss: 297.82635812587785\n", "Epoch 603, Loss: 297.82635745444526\n", "Epoch 604, Loss: 297.82635678309146\n", "Epoch 605, Loss: 297.8263561118164\n", "Epoch 606, Loss: 297.82635544062\n", "Epoch 607, Loss: 297.8263547695023\n", "Epoch 608, Loss: 297.82635409846347\n", "Epoch 609, Loss: 297.8263534275034\n", "Epoch 610, Loss: 297.8263527566223\n", "Epoch 611, Loss: 297.8263520858201\n", "Epoch 612, Loss: 297.82635141509684\n", "Epoch 613, Loss: 297.8263507444523\n", "Epoch 614, Loss: 297.8263500738869\n", "Epoch 615, Loss: 297.8263494034004\n", "Epoch 616, Loss: 297.8263487329929\n", "Epoch 617, Loss: 297.8263480626645\n", "Epoch 618, Loss: 297.8263473924152\n", "Epoch 619, Loss: 297.82634672224503\n", "Epoch 620, Loss: 297.8263460521538\n", "Epoch 621, Loss: 297.82634538214194\n", "Epoch 622, Loss: 297.8263447122093\n", "Epoch 623, Loss: 297.82634404235574\n", "Epoch 624, Loss: 297.8263433725814\n", "Epoch 625, Loss: 297.8263427028865\n", "Epoch 626, Loss: 297.82634203327075\n", "Epoch 627, Loss: 297.8263413637344\n", "Epoch 628, Loss: 297.82634069427735\n", "Epoch 629, Loss: 297.82634002489976\n", "Epoch 630, Loss: 297.82633935560165\n", "Epoch 631, Loss: 297.8263386863829\n", "Epoch 632, Loss: 297.8263380172436\n", "Epoch 633, Loss: 297.82633734818376\n", "Epoch 634, Loss: 297.8263366792035\n", "Epoch 635, Loss: 297.8263360103027\n", "Epoch 636, Loss: 297.8263353414817\n", "Epoch 637, Loss: 297.82633467274024\n", "Epoch 638, Loss: 297.8263340040784\n", "Epoch 639, Loss: 297.8263333354962\n", "Epoch 640, Loss: 297.8263326669936\n", "Epoch 641, Loss: 297.826331998571\n", "Epoch 642, Loss: 297.82633133022796\n", "Epoch 643, Loss: 297.82633066196473\n", "Epoch 644, Loss: 297.82632999378137\n", "Epoch 645, Loss: 297.8263293256778\n", "Epoch 646, Loss: 297.8263286576541\n", "Epoch 647, Loss: 297.8263279897103\n", "Epoch 648, Loss: 297.82632732184646\n", "Epoch 649, Loss: 297.8263266540626\n", "Epoch 650, Loss: 297.8263259863587\n", "Epoch 651, Loss: 297.8263253187347\n", "Epoch 652, Loss: 297.82632465119093\n", "Epoch 653, Loss: 297.82632398372704\n", "Epoch 654, Loss: 297.8263233163434\n", "Epoch 655, Loss: 297.8263226490398\n", "Epoch 656, Loss: 297.82632198181636\n", "Epoch 657, Loss: 297.82632131467324\n", "Epoch 658, Loss: 297.82632064761026\n", "Epoch 659, Loss: 297.82631998062743\n", "Epoch 660, Loss: 297.8263193137249\n", "Epoch 661, Loss: 297.8263186469027\n", "Epoch 662, Loss: 297.82631798016087\n", "Epoch 663, Loss: 297.8263173134994\n", "Epoch 664, Loss: 297.82631664691814\n", "Epoch 665, Loss: 297.82631598041746\n", "Epoch 666, Loss: 297.8263153139972\n", "Epoch 667, Loss: 297.82631464765734\n", "Epoch 668, Loss: 297.8263139813981\n", "Epoch 669, Loss: 297.82631331521924\n", "Epoch 670, Loss: 297.82631264912106\n", "Epoch 671, Loss: 297.82631198310344\n", "Epoch 672, Loss: 297.8263113171664\n", "Epoch 673, Loss: 297.82631065131005\n", "Epoch 674, Loss: 297.8263099855343\n", "Epoch 675, Loss: 297.82630931983937\n", "Epoch 676, Loss: 297.82630865422504\n", "Epoch 677, Loss: 297.8263079886915\n", "Epoch 678, Loss: 297.8263073232387\n", "Epoch 679, Loss: 297.8263066578669\n", "Epoch 680, Loss: 297.82630599257584\n", "Epoch 681, Loss: 297.8263053273656\n", "Epoch 682, Loss: 297.82630466223634\n", "Epoch 683, Loss: 297.8263039971879\n", "Epoch 684, Loss: 297.82630333222056\n", "Epoch 685, Loss: 297.8263026673342\n", "Epoch 686, Loss: 297.8263020025288\n", "Epoch 687, Loss: 297.82630133780435\n", "Epoch 688, Loss: 297.82630067316114\n", "Epoch 689, Loss: 297.826300008599\n", "Epoch 690, Loss: 297.8262993441179\n", "Epoch 691, Loss: 297.8262986797181\n", "Epoch 692, Loss: 297.8262980153994\n", "Epoch 693, Loss: 297.82629735116194\n", "Epoch 694, Loss: 297.82629668700577\n", "Epoch 695, Loss: 297.8262960229308\n", "Epoch 696, Loss: 297.8262953589371\n", "Epoch 697, Loss: 297.82629469502484\n", "Epoch 698, Loss: 297.826294031194\n", "Epoch 699, Loss: 297.8262933674444\n", "Epoch 700, Loss: 297.8262927037763\n", "Epoch 701, Loss: 297.82629204018974\n", "Epoch 702, Loss: 297.8262913766845\n", "Epoch 703, Loss: 297.82629071326073\n", "Epoch 704, Loss: 297.82629004991867\n", "Epoch 705, Loss: 297.82628938665823\n", "Epoch 706, Loss: 297.8262887234792\n", "Epoch 707, Loss: 297.8262880603819\n", "Epoch 708, Loss: 297.82628739736623\n", "Epoch 709, Loss: 297.8262867344323\n", "Epoch 710, Loss: 297.82628607158\n", "Epoch 711, Loss: 297.8262854088095\n", "Epoch 712, Loss: 297.8262847461207\n", "Epoch 713, Loss: 297.82628408351377\n", "Epoch 714, Loss: 297.8262834209886\n", "Epoch 715, Loss: 297.8262827585453\n", "Epoch 716, Loss: 297.8262820961839\n", "Epoch 717, Loss: 297.8262814339045\n", "Epoch 718, Loss: 297.826280771707\n", "Epoch 719, Loss: 297.8262801095915\n", "Epoch 720, Loss: 297.82627944755797\n", "Epoch 721, Loss: 297.8262787856065\n", "Epoch 722, Loss: 297.82627812373715\n", "Epoch 723, Loss: 297.8262774619497\n", "Epoch 724, Loss: 297.82627680024444\n", "Epoch 725, Loss: 297.8262761386215\n", "Epoch 726, Loss: 297.8262754770806\n", "Epoch 727, Loss: 297.826274815622\n", "Epoch 728, Loss: 297.82627415424554\n", "Epoch 729, Loss: 297.82627349295143\n", "Epoch 730, Loss: 297.8262728317395\n", "Epoch 731, Loss: 297.82627217061\n", "Epoch 732, Loss: 297.82627150956284\n", "Epoch 733, Loss: 297.8262708485981\n", "Epoch 734, Loss: 297.82627018771575\n", "Epoch 735, Loss: 297.8262695269158\n", "Epoch 736, Loss: 297.8262688661983\n", "Epoch 737, Loss: 297.82626820556334\n", "Epoch 738, Loss: 297.826267545011\n", "Epoch 739, Loss: 297.8262668845411\n", "Epoch 740, Loss: 297.82626622415387\n", "Epoch 741, Loss: 297.82626556384935\n", "Epoch 742, Loss: 297.82626490362736\n", "Epoch 743, Loss: 297.826264243488\n", "Epoch 744, Loss: 297.82626358343146\n", "Epoch 745, Loss: 297.82626292345765\n", "Epoch 746, Loss: 297.82626226356655\n", "Epoch 747, Loss: 297.8262616037582\n", "Epoch 748, Loss: 297.8262609440329\n", "Epoch 749, Loss: 297.8262602843903\n", "Epoch 750, Loss: 297.82625962483047\n", "Epoch 751, Loss: 297.82625896535376\n", "Epoch 752, Loss: 297.82625830595987\n", "Epoch 753, Loss: 297.8262576466491\n", "Epoch 754, Loss: 297.82625698742123\n", "Epoch 755, Loss: 297.82625632827643\n", "Epoch 756, Loss: 297.8262556692147\n", "Epoch 757, Loss: 297.82625501023597\n", "Epoch 758, Loss: 297.8262543513405\n", "Epoch 759, Loss: 297.8262536925281\n", "Epoch 760, Loss: 297.82625303379893\n", "Epoch 761, Loss: 297.8262523751529\n", "Epoch 762, Loss: 297.82625171659015\n", "Epoch 763, Loss: 297.82625105811076\n", "Epoch 764, Loss: 297.8262503997146\n", "Epoch 765, Loss: 297.82624974140174\n", "Epoch 766, Loss: 297.82624908317234\n", "Epoch 767, Loss: 297.8262484250262\n", "Epoch 768, Loss: 297.82624776696355\n", "Epoch 769, Loss: 297.8262471089844\n", "Epoch 770, Loss: 297.82624645108865\n", "Epoch 771, Loss: 297.8262457932765\n", "Epoch 772, Loss: 297.82624513554777\n", "Epoch 773, Loss: 297.8262444779027\n", "Epoch 774, Loss: 297.8262438203412\n", "Epoch 775, Loss: 297.8262431628633\n", "Epoch 776, Loss: 297.8262425054691\n", "Epoch 777, Loss: 297.82624184815865\n", "Epoch 778, Loss: 297.82624119093174\n", "Epoch 779, Loss: 297.82624053378873\n", "Epoch 780, Loss: 297.82623987672946\n", "Epoch 781, Loss: 297.826239219754\n", "Epoch 782, Loss: 297.8262385628624\n", "Epoch 783, Loss: 297.82623790605464\n", "Epoch 784, Loss: 297.82623724933075\n", "Epoch 785, Loss: 297.82623659269086\n", "Epoch 786, Loss: 297.8262359361348\n", "Epoch 787, Loss: 297.8262352796629\n", "Epoch 788, Loss: 297.8262346232749\n", "Epoch 789, Loss: 297.8262339669709\n", "Epoch 790, Loss: 297.8262333107511\n", "Epoch 791, Loss: 297.82623265461535\n", "Epoch 792, Loss: 297.8262319985638\n", "Epoch 793, Loss: 297.8262313425964\n", "Epoch 794, Loss: 297.8262306867131\n", "Epoch 795, Loss: 297.8262300309141\n", "Epoch 796, Loss: 297.82622937519943\n", "Epoch 797, Loss: 297.826228719569\n", "Epoch 798, Loss: 297.82622806402287\n", "Epoch 799, Loss: 297.82622740856107\n", "Epoch 800, Loss: 297.8262267531837\n", "Epoch 801, Loss: 297.82622609789064\n", "Epoch 802, Loss: 297.826225442682\n", "Epoch 803, Loss: 297.826224787558\n", "Epoch 804, Loss: 297.8262241325183\n", "Epoch 805, Loss: 297.8262234775633\n", "Epoch 806, Loss: 297.8262228226928\n", "Epoch 807, Loss: 297.82622216790685\n", "Epoch 808, Loss: 297.8262215132056\n", "Epoch 809, Loss: 297.8262208585888\n", "Epoch 810, Loss: 297.8262202040569\n", "Epoch 811, Loss: 297.8262195496096\n", "Epoch 812, Loss: 297.82621889524705\n", "Epoch 813, Loss: 297.82621824096935\n", "Epoch 814, Loss: 297.8262175867764\n", "Epoch 815, Loss: 297.8262169326682\n", "Epoch 816, Loss: 297.82621627864495\n", "Epoch 817, Loss: 297.82621562470666\n", "Epoch 818, Loss: 297.8262149708532\n", "Epoch 819, Loss: 297.82621431708463\n", "Epoch 820, Loss: 297.8262136634011\n", "Epoch 821, Loss: 297.82621300980264\n", "Epoch 822, Loss: 297.82621235628915\n", "Epoch 823, Loss: 297.82621170286075\n", "Epoch 824, Loss: 297.8262110495175\n", "Epoch 825, Loss: 297.8262103962594\n", "Epoch 826, Loss: 297.8262097430863\n", "Epoch 827, Loss: 297.8262090899985\n", "Epoch 828, Loss: 297.82620843699596\n", "Epoch 829, Loss: 297.8262077840786\n", "Epoch 830, Loss: 297.8262071312466\n", "Epoch 831, Loss: 297.82620647849984\n", "Epoch 832, Loss: 297.8262058258384\n", "Epoch 833, Loss: 297.82620517326245\n", "Epoch 834, Loss: 297.8262045207719\n", "Epoch 835, Loss: 297.82620386836675\n", "Epoch 836, Loss: 297.82620321604696\n", "Epoch 837, Loss: 297.8262025638128\n", "Epoch 838, Loss: 297.82620191166416\n", "Epoch 839, Loss: 297.8262012596011\n", "Epoch 840, Loss: 297.82620060762355\n", "Epoch 841, Loss: 297.8261999557316\n", "Epoch 842, Loss: 297.8261993039254\n", "Epoch 843, Loss: 297.8261986522048\n", "Epoch 844, Loss: 297.82619800056995\n", "Epoch 845, Loss: 297.8261973490208\n", "Epoch 846, Loss: 297.82619669755746\n", "Epoch 847, Loss: 297.8261960461799\n", "Epoch 848, Loss: 297.82619539488826\n", "Epoch 849, Loss: 297.82619474368244\n", "Epoch 850, Loss: 297.82619409256245\n", "Epoch 851, Loss: 297.8261934415284\n", "Epoch 852, Loss: 297.82619279058036\n", "Epoch 853, Loss: 297.82619213971833\n", "Epoch 854, Loss: 297.8261914889422\n", "Epoch 855, Loss: 297.82619083825216\n", "Epoch 856, Loss: 297.82619018764836\n", "Epoch 857, Loss: 297.8261895371304\n", "Epoch 858, Loss: 297.82618888669873\n", "Epoch 859, Loss: 297.8261882363531\n", "Epoch 860, Loss: 297.8261875860939\n", "Epoch 861, Loss: 297.8261869359208\n", "Epoch 862, Loss: 297.826186285834\n", "Epoch 863, Loss: 297.82618563583344\n", "Epoch 864, Loss: 297.8261849859192\n", "Epoch 865, Loss: 297.8261843360914\n", "Epoch 866, Loss: 297.82618368634996\n", "Epoch 867, Loss: 297.8261830366949\n", "Epoch 868, Loss: 297.82618238712627\n", "Epoch 869, Loss: 297.82618173764416\n", "Epoch 870, Loss: 297.8261810882486\n", "Epoch 871, Loss: 297.82618043893956\n", "Epoch 872, Loss: 297.826179789717\n", "Epoch 873, Loss: 297.8261791405811\n", "Epoch 874, Loss: 297.82617849153183\n", "Epoch 875, Loss: 297.82617784256917\n", "Epoch 876, Loss: 297.8261771936933\n", "Epoch 877, Loss: 297.8261765449042\n", "Epoch 878, Loss: 297.8261758962016\n", "Epoch 879, Loss: 297.826175247586\n", "Epoch 880, Loss: 297.82617459905725\n", "Epoch 881, Loss: 297.82617395061516\n", "Epoch 882, Loss: 297.8261733022601\n", "Epoch 883, Loss: 297.82617265399193\n", "Epoch 884, Loss: 297.82617200581075\n", "Epoch 885, Loss: 297.8261713577164\n", "Epoch 886, Loss: 297.8261707097091\n", "Epoch 887, Loss: 297.82617006178884\n", "Epoch 888, Loss: 297.8261694139557\n", "Epoch 889, Loss: 297.82616876620966\n", "Epoch 890, Loss: 297.82616811855064\n", "Epoch 891, Loss: 297.8261674709788\n", "Epoch 892, Loss: 297.82616682349425\n", "Epoch 893, Loss: 297.8261661760969\n", "Epoch 894, Loss: 297.82616552878676\n", "Epoch 895, Loss: 297.82616488156384\n", "Epoch 896, Loss: 297.8261642344284\n", "Epoch 897, Loss: 297.8261635873801\n", "Epoch 898, Loss: 297.82616294041935\n", "Epoch 899, Loss: 297.8261622935459\n", "Epoch 900, Loss: 297.82616164676\n", "Epoch 901, Loss: 297.8261610000614\n", "Epoch 902, Loss: 297.8261603534504\n", "Epoch 903, Loss: 297.82615970692694\n", "Epoch 904, Loss: 297.826159060491\n", "Epoch 905, Loss: 297.82615841414264\n", "Epoch 906, Loss: 297.82615776788197\n", "Epoch 907, Loss: 297.826157121709\n", "Epoch 908, Loss: 297.82615647562363\n", "Epoch 909, Loss: 297.8261558296259\n", "Epoch 910, Loss: 297.8261551837161\n", "Epoch 911, Loss: 297.826154537894\n", "Epoch 912, Loss: 297.82615389215977\n", "Epoch 913, Loss: 297.82615324651323\n", "Epoch 914, Loss: 297.8261526009547\n", "Epoch 915, Loss: 297.826151955484\n", "Epoch 916, Loss: 297.8261513101013\n", "Epoch 917, Loss: 297.8261506648065\n", "Epoch 918, Loss: 297.8261500195997\n", "Epoch 919, Loss: 297.826149374481\n", "Epoch 920, Loss: 297.82614872945044\n", "Epoch 921, Loss: 297.8261480845078\n", "Epoch 922, Loss: 297.8261474396533\n", "Epoch 923, Loss: 297.8261467948871\n", "Epoch 924, Loss: 297.826146150209\n", "Epoch 925, Loss: 297.82614550561914\n", "Epoch 926, Loss: 297.8261448611175\n", "Epoch 927, Loss: 297.82614421670417\n", "Epoch 928, Loss: 297.8261435723791\n", "Epoch 929, Loss: 297.8261429281424\n", "Epoch 930, Loss: 297.8261422839941\n", "Epoch 931, Loss: 297.82614163993424\n", "Epoch 932, Loss: 297.8261409959628\n", "Epoch 933, Loss: 297.82614035207973\n", "Epoch 934, Loss: 297.82613970828527\n", "Epoch 935, Loss: 297.8261390645793\n", "Epoch 936, Loss: 297.826138420962\n", "Epoch 937, Loss: 297.8261377774331\n", "Epoch 938, Loss: 297.82613713399303\n", "Epoch 939, Loss: 297.82613649064143\n", "Epoch 940, Loss: 297.8261358473787\n", "Epoch 941, Loss: 297.8261352042046\n", "Epoch 942, Loss: 297.8261345611193\n", "Epoch 943, Loss: 297.8261339181227\n", "Epoch 944, Loss: 297.826133275215\n", "Epoch 945, Loss: 297.82613263239614\n", "Epoch 946, Loss: 297.8261319896662\n", "Epoch 947, Loss: 297.82613134702507\n", "Epoch 948, Loss: 297.826130704473\n", "Epoch 949, Loss: 297.8261300620098\n", "Epoch 950, Loss: 297.8261294196356\n", "Epoch 951, Loss: 297.8261287773505\n", "Epoch 952, Loss: 297.8261281351545\n", "Epoch 953, Loss: 297.8261274930475\n", "Epoch 954, Loss: 297.82612685102976\n", "Epoch 955, Loss: 297.8261262091011\n", "Epoch 956, Loss: 297.82612556726167\n", "Epoch 957, Loss: 297.8261249255115\n", "Epoch 958, Loss: 297.8261242838506\n", "Epoch 959, Loss: 297.8261236422789\n", "Epoch 960, Loss: 297.8261230007966\n", "Epoch 961, Loss: 297.8261223594037\n", "Epoch 962, Loss: 297.82612171810007\n", "Epoch 963, Loss: 297.82612107688584\n", "Epoch 964, Loss: 297.8261204357612\n", "Epoch 965, Loss: 297.82611979472597\n", "Epoch 966, Loss: 297.8261191537804\n", "Epoch 967, Loss: 297.82611851292415\n", "Epoch 968, Loss: 297.8261178721576\n", "Epoch 969, Loss: 297.8261172314807\n", "Epoch 970, Loss: 297.8261165908933\n", "Epoch 971, Loss: 297.82611595039566\n", "Epoch 972, Loss: 297.8261153099878\n", "Epoch 973, Loss: 297.82611466966955\n", "Epoch 974, Loss: 297.8261140294412\n", "Epoch 975, Loss: 297.8261133893026\n", "Epoch 976, Loss: 297.82611274925375\n", "Epoch 977, Loss: 297.82611210929485\n", "Epoch 978, Loss: 297.82611146942594\n", "Epoch 979, Loss: 297.8261108296469\n", "Epoch 980, Loss: 297.8261101899578\n", "Epoch 981, Loss: 297.8261095503587\n", "Epoch 982, Loss: 297.8261089108496\n", "Epoch 983, Loss: 297.82610827143054\n", "Epoch 984, Loss: 297.8261076321016\n", "Epoch 985, Loss: 297.8261069928628\n", "Epoch 986, Loss: 297.8261063537141\n", "Epoch 987, Loss: 297.8261057146557\n", "Epoch 988, Loss: 297.8261050756874\n", "Epoch 989, Loss: 297.82610443680943\n", "Epoch 990, Loss: 297.82610379802173\n", "Epoch 991, Loss: 297.82610315932436\n", "Epoch 992, Loss: 297.8261025207173\n", "Epoch 993, Loss: 297.8261018822007\n", "Epoch 994, Loss: 297.8261012437744\n", "Epoch 995, Loss: 297.8261006054387\n", "Epoch 996, Loss: 297.82609996719333\n", "Epoch 997, Loss: 297.8260993290385\n", "Epoch 998, Loss: 297.8260986909742\n", "Epoch 999, Loss: 297.8260980530006\n", "final parameters: m=0.45136980910052144, c=0.49775672565271384, sigma=1562.2616856027405\n" ] } ], "source": [ "## Problem 2\n", "x = np.array([8, 16, 22, 33, 50, 51])\n", "y = np.array([5, 20, 14, 32, 42, 58])\n", "\n", "# $-\\frac{n}{\\sigma}+\\frac{1}{\\sigma^3}\\sum_{i=1}^n(y_i - (mx+c))^2$\n", "dsigma = lambda sigma, c, m, x: -len(x) / sigma + np.sum(\n", " [(xi - (m * x + c)) ** 2 for xi in x]\n", ") / (sigma**3)\n", "# $-\\frac{1}{\\sigma^2}\\sum_{i=1}^n(y_i - (mx+c))$\n", "dc = lambda sigma, c, m, x: -np.sum([xi - (m * x + c) for xi in x]) / (sigma**2)\n", "# $-\\frac{1}{\\sigma^2}\\sum_{i=1}^n(x_i(y_i - (mx+c)))$\n", "dm = lambda sigma, c, m, x: -np.sum([x * (xi - (m * x + c)) for xi in x]) / (sigma**2)\n", "\n", "\n", "log2 = []\n", "\n", "\n", "def SGD_problem2(\n", " sigma: float,\n", " c: float,\n", " m: float,\n", " x: np.array,\n", " y: np.array,\n", " learning_rate=0.01,\n", " n_epochs=1000,\n", "):\n", " global log2\n", " log2 = []\n", " for epoch in range(n_epochs):\n", " sigma += learning_rate * dsigma(sigma, c, m, x)\n", " c += learning_rate * dc(sigma, c, m, x)\n", " m += learning_rate * dm(sigma, c, m, x)\n", "\n", " log2.append(\n", " {\n", " \"Epoch\": epoch,\n", " \"New sigma\": sigma,\n", " \"New c\": c,\n", " \"New m\": m,\n", " \"dc\": dc(sigma, c, m, x),\n", " \"dm\": dm(sigma, c, m, x),\n", " \"dsigma\": dsigma(sigma, c, m, x),\n", " \"Loss\": loss((m * x + c), sigma, y),\n", " }\n", " )\n", " print(f\"Epoch {epoch}, Loss: {loss((m * x + c), sigma, y)}\")\n", " return np.array([sigma, c, m])\n", "\n", "\n", "result = SGD_problem2(0.5, 0.5, 0.5, x, y)\n", "print(f\"final parameters: m={result[2]}, c={result[1]}, sigma={result[0]}\")" ] }, { "cell_type": "markdown", "id": "0562b012-f4ca-47de-bc76-e0eb2bf1e509", "metadata": {}, "source": [ "loss appears to be decreasing. Uncollapse cell for output" ] }, { "cell_type": "markdown", "id": "bed9f3ce-c15c-4f30-8906-26f3e51acf30", "metadata": {}, "source": [ "# Bike Rides and the Poisson Model" ] }, { "cell_type": "markdown", "id": "975e2ef5-f5d5-45a3-b635-8faef035906f", "metadata": {}, "source": [ "Knowing that the poisson pdf is $P(k) = \\frac{\\lambda^k e^{-\\lambda}}{k!}$, we can find the negative log likelihood of the data as $-\\log(\\Pi_{i=1}^n P(k_i)) = -\\sum_{i=1}^n \\log(\\frac{\\lambda^k_i e^{-\\lambda}}{k_i!}) = \\sum_{i=1}^n -\\ln(\\lambda) k_i + \\ln(k_i!) + \\lambda$. Which simplified, gives $n\\lambda + \\sum_{i=1}^n \\ln(k_i!) - \\sum_{i=1}^n k_i \\ln(\\lambda)$. Differentiating with respect to $\\lambda$ gives $n - \\sum_{i=1}^n \\frac{k_i}{\\lambda}$. Which is our desired $\\frac{\\partial L}{\\partial \\lambda}$!" ] }, { "cell_type": "code", "execution_count": 10, "id": "3877723c-179e-4759-bed5-9eb70110ded2", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SGD Problem 3\n", "l: [3215.17703224]\n", "l diff at start 999.4184849065878\n", "l diff at end 535.134976163929\n", "l is improving\n", "SGD Problem 3\n", "l: [2326.70336987]\n", "l diff at start -998.7262223631474\n", "l diff at end -353.33868620734074\n", "l is improving\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochNew lambdadlambdaLossl_star
001681.315834-127.119133-3.899989e+062680.042056
111682.587025-126.861418-3.900150e+062680.042056
221683.855639-126.604614-3.900311e+062680.042056
331685.121685-126.348715-3.900471e+062680.042056
441686.385173-126.093716-3.900631e+062680.042056
..................
9959952325.399976-32.636710-3.948159e+062680.042056
9969962325.726343-32.602100-3.948170e+062680.042056
9979972326.052364-32.567536-3.948180e+062680.042056
9989982326.378040-32.533018-3.948191e+062680.042056
9999992326.703370-32.498547-3.948201e+062680.042056
\n", "

1000 rows × 5 columns

\n", "
" ], "text/plain": [ " Epoch New lambda dlambda Loss l_star\n", "0 0 1681.315834 -127.119133 -3.899989e+06 2680.042056\n", "1 1 1682.587025 -126.861418 -3.900150e+06 2680.042056\n", "2 2 1683.855639 -126.604614 -3.900311e+06 2680.042056\n", "3 3 1685.121685 -126.348715 -3.900471e+06 2680.042056\n", "4 4 1686.385173 -126.093716 -3.900631e+06 2680.042056\n", ".. ... ... ... ... ...\n", "995 995 2325.399976 -32.636710 -3.948159e+06 2680.042056\n", "996 996 2325.726343 -32.602100 -3.948170e+06 2680.042056\n", "997 997 2326.052364 -32.567536 -3.948180e+06 2680.042056\n", "998 998 2326.378040 -32.533018 -3.948191e+06 2680.042056\n", "999 999 2326.703370 -32.498547 -3.948201e+06 2680.042056\n", "\n", "[1000 rows x 5 columns]" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "\n", "df = pd.read_csv(\"../data/01_raw/nyc_bb_bicyclist_counts.csv\")\n", "\n", "dlambda = lambda l, k: len(k) - np.sum([ki / l for ki in k])\n", "\n", "\n", "def SGD_problem3(\n", " l: float,\n", " k: np.array,\n", " learning_rate=0.01,\n", " n_epochs=1000,\n", "):\n", " global log3\n", " log3 = []\n", " for epoch in range(n_epochs):\n", " l -= learning_rate * dlambda(l, k)\n", " # $n\\lambda + \\sum_{i=1}^n \\ln(k_i!) - \\sum_{i=1}^n k_i \\ln(\\lambda)$\n", " # the rest of the loss function is commented out because it's a\n", " # constant and was causing overflows. It is unnecessary, and a useless\n", " # pain.\n", " loss = len(k) * l - np.sum(\n", " [ki * np.log(l) for ki in k]\n", " ) # + np.sum([np.log(np.math.factorial(ki)) for ki in k])\n", "\n", " log3.append(\n", " {\n", " \"Epoch\": epoch,\n", " \"New lambda\": l,\n", " \"dlambda\": dlambda(l, k),\n", " \"Loss\": loss,\n", " }\n", " )\n", " # print(f\"Epoch {epoch}\", f\"Loss: {loss}\")\n", " return np.array([l])\n", "\n", "\n", "l_star = df[\"BB_COUNT\"].mean()\n", "\n", "\n", "def debug_SGD_3(data, l=1000):\n", " print(\"SGD Problem 3\")\n", " print(f\"l: {SGD_problem3(l, data)}\")\n", " dflog = pd.DataFrame(log3)\n", " dflog[\"l_star\"] = l_star\n", " print(f\"l diff at start {dflog.iloc[0]['New lambda'] - dflog.iloc[0]['l_star']}\")\n", " print(f\"l diff at end {dflog.iloc[-1]['New lambda'] - dflog.iloc[-1]['l_star']}\")\n", " if np.abs(dflog.iloc[-1][\"New lambda\"] - dflog.iloc[-1][\"l_star\"]) < np.abs(\n", " dflog.iloc[0][\"New lambda\"] - dflog.iloc[0][\"l_star\"]\n", " ):\n", " print(\"l is improving\")\n", " else:\n", " print(\"l is not improving\")\n", " return dflog\n", "\n", "\n", "debug_SGD_3(data=df[\"BB_COUNT\"].values, l=l_star + 1000)\n", "debug_SGD_3(data=df[\"BB_COUNT\"].values, l=l_star - 1000)" ] }, { "cell_type": "markdown", "id": "c05192f9-78ae-4bdb-9df5-cac91006d79f", "metadata": {}, "source": [ "l approaches the l_star and decreases the loss function." ] }, { "cell_type": "markdown", "id": "4955b868-7f67-4760-bf86-39f6edd55871", "metadata": {}, "source": [ "## Maximum Likelihood II" ] }, { "cell_type": "markdown", "id": "cd7e6e62-3f64-43e5-bf2c-3cb514411446", "metadata": {}, "source": [ "The partial of the poisson was found to be $nE^{w.x}*x - \\sum_{i=1}^{n}k x.x$" ] }, { "cell_type": "code", "execution_count": 18, "id": "7c8b167d-c397-4155-93f3-d826c279fbb2", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SGD Problem 4\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_615396/2481416868.py:22: RuntimeWarning: divide by zero encountered in log\n", " [ki * np.log(l) for ki in k]\n" ] }, { "ename": "ValueError", "evalue": "operands could not be broadcast together with shapes (3,) (214,) ", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[18], line 44\u001b[0m\n\u001b[1;32m 40\u001b[0m dflog \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mDataFrame(log4)\n\u001b[1;32m 41\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m dflog\n\u001b[0;32m---> 44\u001b[0m _ \u001b[38;5;241m=\u001b[39m \u001b[43mdebug_SGD_3\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 45\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdf\u001b[49m\u001b[43m[\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mHIGH_T\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mLOW_T\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mPRECIP\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto_numpy\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 46\u001b[0m \u001b[43m \u001b[49m\u001b[43mw\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43marray\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m1.0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1.0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1.0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 47\u001b[0m \u001b[43m)\u001b[49m\n", "Cell \u001b[0;32mIn[18], line 39\u001b[0m, in \u001b[0;36mdebug_SGD_3\u001b[0;34m(data, w)\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdebug_SGD_3\u001b[39m(data, w\u001b[38;5;241m=\u001b[39mnp\u001b[38;5;241m.\u001b[39marray([\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m1\u001b[39m])):\n\u001b[1;32m 38\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSGD Problem 4\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 39\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mw: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[43mSGD_problem4\u001b[49m\u001b[43m(\u001b[49m\u001b[43mw\u001b[49m\u001b[43m,\u001b[49m\u001b[38;5;250;43m \u001b[39;49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 40\u001b[0m dflog \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mDataFrame(log4)\n\u001b[1;32m 41\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m dflog\n", "Cell \u001b[0;32mIn[18], line 24\u001b[0m, in \u001b[0;36mSGD_problem4\u001b[0;34m(w, x, learning_rate, n_epochs)\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[38;5;66;03m# custom\u001b[39;00m\n\u001b[1;32m 20\u001b[0m \u001b[38;5;66;03m# loss = x.shape[0] * np.exp(np.dot(x, w))\u001b[39;00m\n\u001b[1;32m 21\u001b[0m loss_fn \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mlambda\u001b[39;00m k, l: \u001b[38;5;28mlen\u001b[39m(k) \u001b[38;5;241m*\u001b[39m l \u001b[38;5;241m-\u001b[39m np\u001b[38;5;241m.\u001b[39msum(\n\u001b[1;32m 22\u001b[0m [ki \u001b[38;5;241m*\u001b[39m np\u001b[38;5;241m.\u001b[39mlog(l) \u001b[38;5;28;01mfor\u001b[39;00m ki \u001b[38;5;129;01min\u001b[39;00m k]\n\u001b[1;32m 23\u001b[0m ) \u001b[38;5;66;03m# + np.sum([np.log(np.math.factorial(ki)) for ki in k])\u001b[39;00m\n\u001b[0;32m---> 24\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mloss_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexp\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdot\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mw\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 25\u001b[0m log4\u001b[38;5;241m.\u001b[39mappend(\n\u001b[1;32m 26\u001b[0m {\n\u001b[1;32m 27\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEpoch\u001b[39m\u001b[38;5;124m\"\u001b[39m: epoch,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 31\u001b[0m }\n\u001b[1;32m 32\u001b[0m )\n\u001b[1;32m 33\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEpoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mLoss: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mloss\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", "Cell \u001b[0;32mIn[18], line 22\u001b[0m, in \u001b[0;36mSGD_problem4..\u001b[0;34m(k, l)\u001b[0m\n\u001b[1;32m 18\u001b[0m w \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m=\u001b[39m learning_rate \u001b[38;5;241m*\u001b[39m dw(w, x)\n\u001b[1;32m 19\u001b[0m \u001b[38;5;66;03m# custom\u001b[39;00m\n\u001b[1;32m 20\u001b[0m \u001b[38;5;66;03m# loss = x.shape[0] * np.exp(np.dot(x, w))\u001b[39;00m\n\u001b[1;32m 21\u001b[0m loss_fn \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mlambda\u001b[39;00m k, l: \u001b[38;5;28mlen\u001b[39m(k) \u001b[38;5;241m*\u001b[39m l \u001b[38;5;241m-\u001b[39m np\u001b[38;5;241m.\u001b[39msum(\n\u001b[0;32m---> 22\u001b[0m [ki \u001b[38;5;241m*\u001b[39m np\u001b[38;5;241m.\u001b[39mlog(l) \u001b[38;5;28;01mfor\u001b[39;00m ki \u001b[38;5;129;01min\u001b[39;00m k]\n\u001b[1;32m 23\u001b[0m ) \u001b[38;5;66;03m# + np.sum([np.log(np.math.factorial(ki)) for ki in k])\u001b[39;00m\n\u001b[1;32m 24\u001b[0m loss \u001b[38;5;241m=\u001b[39m loss_fn(x, np\u001b[38;5;241m.\u001b[39mexp(np\u001b[38;5;241m.\u001b[39mdot(x, w)))\n\u001b[1;32m 25\u001b[0m log4\u001b[38;5;241m.\u001b[39mappend(\n\u001b[1;32m 26\u001b[0m {\n\u001b[1;32m 27\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEpoch\u001b[39m\u001b[38;5;124m\"\u001b[39m: epoch,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 31\u001b[0m }\n\u001b[1;32m 32\u001b[0m )\n", "Cell \u001b[0;32mIn[18], line 22\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 18\u001b[0m w \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m=\u001b[39m learning_rate \u001b[38;5;241m*\u001b[39m dw(w, x)\n\u001b[1;32m 19\u001b[0m \u001b[38;5;66;03m# custom\u001b[39;00m\n\u001b[1;32m 20\u001b[0m \u001b[38;5;66;03m# loss = x.shape[0] * np.exp(np.dot(x, w))\u001b[39;00m\n\u001b[1;32m 21\u001b[0m loss_fn \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mlambda\u001b[39;00m k, l: \u001b[38;5;28mlen\u001b[39m(k) \u001b[38;5;241m*\u001b[39m l \u001b[38;5;241m-\u001b[39m np\u001b[38;5;241m.\u001b[39msum(\n\u001b[0;32m---> 22\u001b[0m [\u001b[43mki\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlog\u001b[49m\u001b[43m(\u001b[49m\u001b[43ml\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m ki \u001b[38;5;129;01min\u001b[39;00m k]\n\u001b[1;32m 23\u001b[0m ) \u001b[38;5;66;03m# + np.sum([np.log(np.math.factorial(ki)) for ki in k])\u001b[39;00m\n\u001b[1;32m 24\u001b[0m loss \u001b[38;5;241m=\u001b[39m loss_fn(x, np\u001b[38;5;241m.\u001b[39mexp(np\u001b[38;5;241m.\u001b[39mdot(x, w)))\n\u001b[1;32m 25\u001b[0m log4\u001b[38;5;241m.\u001b[39mappend(\n\u001b[1;32m 26\u001b[0m {\n\u001b[1;32m 27\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEpoch\u001b[39m\u001b[38;5;124m\"\u001b[39m: epoch,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 31\u001b[0m }\n\u001b[1;32m 32\u001b[0m )\n", "\u001b[0;31mValueError\u001b[0m: operands could not be broadcast together with shapes (3,) (214,) " ] } ], "source": [ "## pset 4\n", "\n", "dw = lambda w, x: np.sum([len(x) * np.exp(np.dot(xi, w)) * x - np.sum(np.dot(x.T,x)) for xi in x])\n", "\n", "#primitive = lambda xi, wi: (x.shape[0] * np.exp(wi * xi) * xi) - (xi**2)\n", "#p_dw = lambda w, xi: np.array([primitive(xi, wi) for xi, wi in ])\n", "\n", "\n", "def SGD_problem4(\n", " w: np.array,\n", " x: np.array,\n", " learning_rate=0.01,\n", " n_epochs=1000,\n", "):\n", " global log4\n", " log4 = []\n", " for epoch in range(n_epochs):\n", " w -= learning_rate * dw(w, x)\n", " # custom\n", " # loss = x.shape[0] * np.exp(np.dot(x, w))\n", " loss_fn = lambda k, l: len(k) * l - np.sum(\n", " [ki * np.log(l) for ki in k]\n", " ) # + np.sum([np.log(np.math.factorial(ki)) for ki in k])\n", " loss = loss_fn(x, np.exp(np.dot(x, w)))\n", " log4.append(\n", " {\n", " \"Epoch\": epoch,\n", " \"New w\": w,\n", " \"dw\": dw(w, x),\n", " \"Loss\": loss,\n", " }\n", " )\n", " print(f\"Epoch {epoch}\", f\"Loss: {loss}\")\n", " return w\n", "\n", "\n", "def debug_SGD_3(data, w=np.array([1, 1])):\n", " print(\"SGD Problem 4\")\n", " print(f\"w: {SGD_problem4(w, data)}\")\n", " dflog = pd.DataFrame(log4)\n", " return dflog\n", "\n", "\n", "_ = debug_SGD_3(\n", " data=df[[\"HIGH_T\", \"LOW_T\", \"PRECIP\"]].to_numpy(),\n", " w=np.array([1.0, 1.0, 1.0]),\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "7c00197d-873d-41b0-a458-dc8478b40f52", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" } }, "nbformat": 4, "nbformat_minor": 5 }