{ "cells": [ { "cell_type": "markdown", "source": [ "\n", " \"Open\n", "" ], "metadata": { "collapsed": false } }, { "cell_type": "markdown", "source": [ "# Gaussian Maximum Likelihood\n", "\n", "## MLE of a Gaussian $p_{model}(x|w)$\n", "\n", "You are given an array of data points called `data`. Your course site plots the negative log-likelihood function for several candidate hypotheses. Estimate the parameters of the Gaussian $p_{model}$ by coding an implementation that estimates its optimal parameters (15 points) and explaining what it does (10 points). You are free to use any Gradient-based optimization method you like." ], "metadata": { "collapsed": false } }, { "cell_type": "markdown", "source": [ "### Solution Explanation\n", "\n", "Since the dataset is small, we can calculate the batch gradient together\n", "We would run 5000 iteration and update (increment) the optimal parameters (mean and variance) using learning_rate*batch_gradient.\n", "\n", "We would use below reference equation for our optimization method\n", "\n", "\\begin{equation}\n", "w_k_+_1 := w_k - \\eta \\cdot \\nabla_w L(w_k)\n", "\\end{equation}\n", "\n", "The log-likelihood function for the Gaussian distribution, given a set of observations x, and parameters $\\mu$ and $\\sigma^2$ is:\n", "\n", "\\begin{equation}\n", "\\log L(\\mu, \\sigma^2 | x) = \\sum_{i=1}^{n} \\left[ -\\frac{1}{2} \\log(2\\pi\\sigma^2) - \\frac{(x_i - \\mu)^2}{2\\sigma^2} \\right]\n", "\\end{equation}\n", "\n", "where $x_i$ are the observed values and the sum is over all observations.\n", "\n", "The partial derivatives of the log-likelihood function with respect to `μ` and `σ²` are:\n", "\n", "1. With respect to `μ`:\n", "\n", "\\begin{equation}\n", "\\frac{\\partial l(\\mu, \\sigma^2 | X)}{\\partial \\mu} = \\frac{1}{\\sigma^2}\\sum_{i=1}^{N}(x_i - \\mu) = 0\n", "\\end{equation}\n", "\n", "2. With respect to `σ²`:\n", "\n", "\\begin{equation}\n", "\\frac{\\partial l(\\mu, \\sigma^2 | X)}{\\partial \\sigma^2} = -\\frac{1}{2\\sigma^2}( -N + \\frac{1}{\\sigma^2}\\sum_{i=1}^{N}(x_i - \\mu)^2) = 0\n", "\\end{equation}\n" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 122, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Optimal mean: 6.214285714285691, Optimal variance: 5.881910450405248\n" ] } ], "source": [ "import numpy as np\n", "data = [4, 5, 7, 8, 8, 9, 10, 5, 2, 3, 5, 4, 8, 9]\n", "\n", "def ll_derivative_wrt_mean(x, mean, variance):\n", " \"\"\"Calculate the derivative with respect to mean\"\"\"\n", " return np.sum([x_i - mean for x_i in x]) / variance\n", "\n", "def ll_derivative_wrt_variance(x, mean, variance):\n", " \"\"\"Calculate the derivative with respect to variance\"\"\"\n", " N = len(x)\n", " return 1 / (2 * variance) * ( -N + (1/variance)*(np.sum([(x_i - mean)**2 for x_i in x])))\n", "\n", "n_epochs = 5000\n", "t0, t1 = 50, 5000\n", "\n", "def learning_rate(t):\n", " return t0 / (t+t1)\n", "\n", "mean = variance = 1\n", "x = data\n", "\n", "## Running n_epoch iteration for calculating gradient\n", "for epoch in range(n_epochs):\n", "\n", " # Calculate the gradients for the whole batch together\n", " d_mean = ll_derivative_wrt_mean(x, mean, variance)\n", " d_variance = ll_derivative_wrt_variance(x, mean, variance)\n", "\n", " # Update the parameters as explained in the solution\n", " mean += learning_rate(epoch) * d_mean\n", " variance += learning_rate(epoch) * d_variance\n", "\n", "print(f\"Optimal mean: {mean}, Optimal variance: {variance}\")\n" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2023-06-21T05:34:14.953104700Z", "start_time": "2023-06-21T05:34:14.850380200Z" } } }, { "attachments": {}, "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2023-06-21T01:18:52.570706Z", "start_time": "2023-06-21T01:18:52.109912Z" } }, "source": [ "## MLE of a conditional Gaussian $p_{model}(y|x,w)$\n", "\n", "You are given a problem that involves the relationship between $x$ and $y$. Estimate the parameters of a $p_{model}$ that fit the dataset (x,y) shown below. You are free to use any Gradient-based optimization method you like. \n" ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "89.28893452165364\n", "64.70195926288042\n", "55.60233573926638\n", "50.9363541672891\n", "47.9722558093163\n", "45.933405503503145\n", "44.454872059786624\n", "43.269616413281305\n", "42.41128203058434\n", "41.67778283107532\n", "41.13296662549415\n", "40.609054723267576\n", "40.21321475402931\n", "39.827249917135454\n", "39.542706777350624\n", "39.271586518516514\n", "39.03978093973494\n", "38.84414183625036\n", "38.676661555764\n", "38.51948503275972\n", "38.36254041225529\n", "38.25863622121599\n", "38.13485362850052\n", "38.02372288572771\n", "37.94263411326168\n", "37.8575213993063\n", "37.78718970858306\n", "37.70813403089153\n", "37.654377602606985\n", "37.58525391963927\n", "37.54858404828782\n", "37.49826336812336\n", "37.4478015075289\n", "37.42156143367617\n", "37.38150766733916\n", "37.347401906712065\n", "37.31876904727307\n", "37.29066661963057\n", "37.266028752969056\n", "37.2428922003882\n", "37.2309393303583\n", "37.207099309523485\n", "37.186296913353154\n", "37.169801595434336\n", "37.15545970874582\n", "37.143641356755694\n", "37.12925589698881\n", "37.11595189734959\n", "37.111994567464606\n", "37.0967518668676\n", "37.099359578344085\n", "37.08811619307582\n", "37.08275639155973\n", "37.07496131284822\n", "37.07388588673316\n", "37.069007646018015\n", "37.06525778524906\n", "37.06193842513581\n", "37.06061482062083\n", "37.05493803476385\n", "37.057312078735826\n", "37.054860763777334\n", "37.05810422405241\n", "37.05839183783795\n", "37.05073180561897\n", "37.051577328043955\n", "37.05077349506295\n", "37.0588045725697\n", "37.06062467086565\n", "37.059456938423516\n", "37.055101831257176\n", "37.05781884324051\n", "37.057705237621626\n", "37.069452523057286\n", "37.070137221199396\n", "37.06680396332636\n", "37.06850078956805\n", "37.07684974803073\n", "37.07871158904623\n", "37.083155198685105\n", "37.085278820171645\n", "37.08588870830364\n", "37.0928312842881\n", "37.095532898937066\n", "37.1016603706687\n", "37.10447464969258\n", "37.1062982558934\n", "37.10700273159299\n", "37.117988002867506\n", "37.11442630848\n", "37.120430624799546\n", "37.120927941253306\n", "37.12286945985673\n", "37.13573355003726\n", "37.137215961273604\n", "37.140394710116375\n", "37.147077724642905\n", "37.148562935235404\n", "37.150398956964\n", "37.15626916954538\n", "37.16391354446072\n", "37.17390873883237\n", "37.171342643301045\n", "37.172220532325916\n", "37.17654396237348\n", "37.18634414753994\n", "37.18996739785888\n", "37.19294581492866\n", "37.200270873326474\n", "37.204212819847314\n", "37.204255099168684\n", "37.21090398611339\n", "37.21631677657366\n", "37.21625660275569\n", "37.22547955856158\n", "37.22884382305518\n", "37.23202515927683\n", "37.236585499201965\n", "37.24432272640956\n", "37.24986239677905\n", "37.2503854741486\n", "37.25474929693878\n", "37.26005346730489\n", "37.26341239964392\n", "37.26402748285413\n", "37.27438498114548\n", "37.27831497680257\n", "37.27964222326388\n", "37.28501547297845\n", "37.29037373777114\n", "37.29173084167552\n", "37.30007472732321\n", "37.303242883256516\n", "37.30578600476367\n", "37.311049510721276\n", "37.31750996802713\n", "37.319332269843926\n", "37.319709362637504\n", "37.32431594590197\n", "37.3334690917815\n", "37.336393587976175\n", "37.34163768834769\n", "37.341798035920355\n", "37.34911429309716\n", "37.35413172510299\n", "37.35562924998681\n", "37.358955981463225\n", "37.36369839477191\n", "37.367745224419764\n", "37.37004017847073\n", "37.378372031238555\n", "37.38485483899666\n", "37.38308240297587\n", "37.38877947639206\n", "37.39380710928303\n", "37.40210389892444\n", "37.39958532562084\n", "37.40899537211565\n", "37.41169654168377\n", "37.415114694136186\n", "37.42267380755972\n", "37.419438438640796\n", "37.42392951686787\n", "37.426237669267145\n", "37.43119949390055\n", "37.44153952877323\n", "37.43899862955651\n", "37.44622618512648\n", "37.44609820756795\n", "37.45375427342167\n", "37.45423475296385\n", "37.45915322448655\n", "37.467189881631086\n", "37.46433150880923\n", "37.4716079719962\n", "37.473305867636974\n", "37.4779305723961\n", "37.483217027462985\n", "37.482377156670395\n", "37.49025985936812\n", "37.49664912789918\n", "37.49932023846706\n", "37.501281655989104\n", "37.50437129851693\n", "37.51182437312932\n", "37.50952960165603\n", "37.51640970201514\n", "37.518166384938894\n", "37.52169266918752\n", "37.52243383113713\n", "37.53059716101007\n", "37.52840559495212\n", "37.534233411776086\n", "37.53933782300374\n", "37.5479891140508\n", "37.54795575049627\n", "37.54916475982454\n", "37.554287540912135\n", "37.55646535378343\n", "37.55879373201832\n", "37.56435188674137\n", "37.56820764372305\n", "37.57578437905442\n", "37.57423122895438\n", "37.57753198585083\n", "37.584078650038464\n", "37.58858934701678\n", "37.589562460564174\n", "37.59633341509495\n", "37.59335782261189\n", "37.59753700790952\n", "37.59810487732682\n", "37.602973587775416\n", "37.60547839833311\n", "37.613945500711154\n", "37.61568683370317\n", "37.61792090036366\n", "37.62675410956946\n", "37.62605591859893\n", "37.62858134423213\n", "37.63283593022904\n", "37.63506375471227\n", "37.640174964514685\n", "37.63568528011319\n", "37.6432312931449\n", "37.64760833799639\n", "37.64915488891731\n", "37.655528854610296\n", "37.66175190680533\n", "37.6579913737472\n", "37.66776190174034\n", "37.66432196191528\n", "37.66859431571903\n", "37.67121615422476\n", "37.67446761116701\n", "37.67976079703998\n", "37.68315387079407\n", "37.684002344137646\n", "37.68819141279315\n", "37.691747918971515\n", "37.69596454985414\n", "37.69738714427115\n", "37.69989645295434\n", "37.70262344001593\n", "37.70814483356389\n", "37.708668296203655\n", "37.71345641959073\n", "37.71665533895605\n", "37.720348375645656\n", "37.71879106216731\n", "37.72530078801923\n", "37.72513398559647\n", "37.731534342435914\n", "37.73472262395621\n", "37.73573236206785\n", "37.73785227834013\n", "37.74120811562768\n", "37.742867414858715\n", "37.74659769056756\n", "37.74930356830804\n", "37.75014996972478\n", "37.760562357678516\n", "37.7603035063315\n", "37.75842056257282\n", "37.76198893703602\n", "37.76592473857211\n", "37.77174077884769\n", "37.7704383893496\n", "37.77811040802413\n", "37.77979315224723\n", "37.78280207633712\n", "37.78390248942166\n", "37.78503852477755\n", "37.78978182742847\n", "37.7971103234335\n", "37.7950599692843\n", "37.79908200735471\n", "37.79868164881597\n", "37.79948396265749\n", "37.80836159638842\n", "37.812251124462485\n", "37.813697658509504\n", "37.81479918472016\n", "37.817293936978686\n", "37.81723328687448\n", "37.82273727724142\n", "37.824173695466584\n", "37.822336855839936\n", "37.83179675226294\n", "37.831845584115634\n", "37.83694644406093\n", "37.839903275435674\n", "37.83936182131681\n", "37.844633724871755\n", "37.84658826108331\n", "37.843944262332826\n", "37.85396755368022\n", "37.85726553750719\n", "37.85307640853134\n", "37.85722303969935\n", "37.86192150577117\n", "37.85916847498903\n", "37.86580763883501\n", "37.867728106628164\n", "37.872976100045605\n", "37.873765136841975\n", "37.87643371994003\n", "37.87406628765202\n", "37.880830236397316\n", "37.881402282204974\n", "37.886378580793746\n", "37.88765767516382\n", "37.892796939844196\n", "37.893291400238525\n", "37.89248318033624\n", "37.899320710595866\n", "37.89868281878842\n", "37.901520308688504\n", "37.903769707424004\n", "37.90595981461894\n", "37.908738560328366\n", "37.91267647162542\n", "37.91591942245681\n", "37.91958302702013\n", "37.92202927602375\n", "37.92454472740604\n", "37.923869966155415\n", "37.92500047394112\n", "37.929926474115454\n", "37.93193792942851\n", "37.93498631374668\n", "37.935441627068606\n", "37.93738468496874\n", "37.94574891618067\n", "37.94232402987866\n", "37.94890324076989\n", "37.948385775460814\n", "37.949519764764574\n", "37.95214008165747\n", "37.95199117334347\n", "37.955868159173505\n", "37.954492817591394\n", "37.96009904888457\n", "37.96274560103215\n", "37.970234686062405\n", "37.968137264393924\n", "37.9698571833338\n", "37.97105952899917\n", "37.97379348460702\n", "37.977834155849926\n", "37.98286773674074\n", "37.98097787526362\n", "37.98135553995814\n", "37.986540314880926\n", "37.986199783211326\n", "37.987909063349534\n", "37.992730018789786\n", "37.992957896187896\n", "37.994793755195786\n", "38.00014584623671\n", "37.99859543715205\n", "38.00486194095826\n", "38.0065541005318\n", "38.00656952780949\n", "38.01197155981907\n", "38.0088843830096\n", "38.01304856768476\n", "38.01446372450213\n", "38.020911884925596\n", "38.018321488811864\n", "38.018181694135514\n", "38.02378391040668\n", "38.024196707528496\n", "38.02915299416551\n", "38.03009541260644\n", "38.029639389349306\n", "38.0341755777918\n", "38.03203519510118\n", "38.036810490898084\n", "38.03996597406819\n", "38.038412347501236\n", "38.03909310105403\n", "38.04511598055805\n", "38.046244410856154\n", "38.04935181597087\n", "38.05373415565266\n", "38.051090848228164\n", "38.05600559608584\n", "38.059819300455665\n", "38.05691279929723\n", "38.059643507026856\n", "38.064081362874006\n", "38.06515408252748\n", "38.06745364794884\n", "38.06885393608445\n", "38.071922885770164\n", "38.07110368994519\n", "38.076098159948465\n", "38.07278285457068\n", "38.07662184273411\n", "38.07993535363871\n", "38.08009467342327\n", "38.0831658763319\n", "38.085047985013546\n", "38.08634313222836\n", "38.08825337999768\n", "38.09041907490385\n", "38.09370754589239\n", "38.09374608432343\n", "38.09679788647761\n", "38.09854411554552\n", "38.10180639010508\n", "38.10178037722873\n", "38.10596874269105\n", "38.10473741226355\n", "38.106952578889015\n", "38.11216247383932\n", "38.11033695532927\n", "38.11353493221894\n", "38.11556875548842\n", "38.11755812256415\n", "38.11716636105478\n", "38.11786250463326\n", "38.12205141948269\n", "38.12813461778315\n", "38.12617009319371\n", "38.129677117656385\n", "38.12725418003724\n", "38.13068060975955\n", "38.13071211356083\n", "38.133118811453706\n", "38.13982826786633\n", "38.13628913215895\n", "38.14310339018023\n", "38.144839747460466\n", "38.14389478387945\n", "38.142909267306365\n", "38.14868087939942\n", "38.150915299712636\n", "38.14886128591688\n", "38.149983250829735\n", "38.1506923124992\n", "38.15845266278832\n", "38.15747282623915\n", "38.158568604619376\n", "38.16222342120552\n", "38.163599397980114\n", "38.16491434814074\n", "38.16884137098094\n", "38.16620730878482\n", "38.16566429992947\n", "38.171620502858694\n", "38.174581333493336\n", "38.17473577125207\n", "38.17629781186839\n", "38.1801447546722\n", "38.17891321261799\n", "38.176952644936186\n", "38.18541206733295\n", "38.18428773836853\n", "38.18670973775108\n", "38.188301751861545\n", "38.19218783609514\n", "38.18928077691267\n", "38.18960337358621\n", "38.19059175261181\n", "38.19316914637809\n", "38.194617518304966\n", "38.19799012395394\n", "38.19917670172822\n", "38.202255411534445\n", "38.203445676094766\n", "38.206936755084\n", "38.203377421387806\n", "38.210879535082746\n", "38.20837595144653\n", "38.20761031965299\n", "38.2127461689089\n", "38.21630905128599\n", "38.21562069646263\n", "38.214965460591294\n", "38.21514780416804\n", "38.21840697020915\n", "38.222710833259164\n", "38.219596995397666\n", "38.22505580838188\n", "38.22608209204666\n", "38.227450020710904\n", "38.22818185944714\n", "38.22988026206989\n", "38.230712786775825\n", "38.23334316797604\n", "38.23497195073786\n", "38.23505394172659\n", "38.23543232494319\n", "38.239541937736995\n", "38.24205377715747\n", "38.244387776190734\n", "38.24502780272292\n", "38.248173228731204\n", "38.24578330256078\n", "38.24711878876204\n", "38.24821657601037\n", "38.2510717715939\n", "38.25509294950791\n", "38.252695557756226\n", "38.25362833598611\n", "38.256350656512225\n", "38.26009636286414\n", "38.26115454377892\n", "38.2619621208026\n", "38.26037694683066\n", "38.26399595961379\n", "38.26248472212677\n", "38.26305014628721\n", "38.27145170773\n", "38.271397554149\n", "38.27182120822521\n", "38.27202087397604\n", "38.272530931318435\n", "38.27638749246975\n", "38.27687335544082\n", "38.27910934584747\n", "38.279240038340376\n", "38.28179711845097\n", "38.28091932253032\n", "38.28636845831859\n", "38.285697432063145\n", "38.28590851748622\n", "38.28629083392731\n", "38.28808192717461\n", "38.28891540951676\n", "38.29286323469856\n", "38.29351290412249\n", "38.296326665374266\n", "38.298044893214\n", "38.300853887778885\n", "38.29833147819498\n", "38.29724041863437\n", "38.30111145043151\n", "38.30293036922318\n", "38.30463266550496\n", "38.304270187294755\n", "38.30520033777684\n", "38.30550782766978\n", "38.30947861660901\n", "38.31097920766105\n", "38.31346493026808\n", "38.312773204248415\n", "38.31024115099608\n", "38.31531920055216\n", "38.314112964995076\n", "38.31924217834924\n", "38.322767630608894\n", "38.31959237629898\n", "38.32395988272185\n", "38.323859540075006\n", "38.32350138106846\n", "38.32596419674249\n", "38.325167569712505\n", "38.32793439380383\n", "38.331036879667934\n", "38.331014057263026\n", "38.33398996041993\n", "38.332667823699275\n", "38.33095063277313\n", "38.33698749419159\n", "38.33853329999672\n", "38.336769799343706\n", "38.33983177530421\n", "38.33860665973963\n", "38.34257819363279\n", "38.3432529100771\n", "38.341781489823376\n", "38.34745554482322\n", "38.34557467675838\n", "38.34647712112262\n", "38.34815185477367\n", "38.348857799421914\n", "38.349179636579514\n", "38.35033683123516\n", "38.35533307666166\n", "38.353951082194406\n", "38.35574535230627\n", "38.35479814495391\n", "38.35802455759253\n", "38.36020002468972\n", "38.362270057328594\n", "38.3646832836495\n", "38.35999936586673\n", "38.3634914785258\n", "38.364505415456165\n", "38.36482317739117\n", "38.368821647934695\n", "38.367832351301146\n", "38.37124440529196\n", "38.3697081972027\n", "38.372710548135004\n", "38.37302800989242\n", "38.37292097962001\n", "38.37368517405418\n", "38.37468454029369\n", "38.380964681892195\n", "38.378933982545284\n", "38.3828308510415\n", "38.38028193518024\n", "38.38223626549753\n", "38.38397471383444\n", "38.38578645437154\n", "38.38613139790495\n", "38.38775839610928\n", "38.38930499701287\n", "38.38834381517093\n", "38.38849393806518\n", "38.391981057399896\n", "38.391006039948735\n", "38.39493083651788\n", "38.3977475468096\n", "38.396517967942785\n", "38.39630009866791\n", "38.39654771794272\n", "38.399612525947354\n", "38.40098320503061\n", "38.40168385756668\n", "38.39999057847765\n", "38.40429286289179\n", "38.40416782233149\n", "38.405558747461846\n", "38.40395837512993\n", "38.40776864182407\n", "38.40784105492222\n", "38.4090446135536\n", "38.40830513356724\n", "38.412683342875034\n", "38.409674279101786\n", "38.4124593165462\n", "38.41347204797411\n", "38.41586882868767\n", "38.41406242411693\n", "38.417687038753954\n", "38.41734620201599\n", "38.41880949425681\n", "38.421170974435796\n", "38.42324792673134\n", "38.42270186836279\n", "38.425405202793364\n", "38.42538633688749\n", "38.4253306862678\n", "38.42890105549377\n", "38.42756877029988\n", "38.42756356591002\n", "38.42759428682656\n", "38.42837897325665\n", "38.43177807378193\n", "38.434416375758964\n", "38.431691050333136\n", "38.43167816227504\n", "38.43401466241877\n", "38.43408718300626\n", "38.43396968975395\n", "38.4364280400309\n", "38.43878925289008\n", "38.437073935763685\n", "38.44002834425057\n", "38.444196328300684\n", "38.44455345856569\n", "38.442430269675896\n", "38.44147590900694\n", "38.44851987835909\n", "38.44681039536811\n", "38.44934445008357\n", "38.447245049567655\n", "38.44955073308085\n", "38.448696056093524\n", "38.453998934937815\n", "38.453708799966655\n", "38.450789797964546\n", "38.4549790057484\n", "38.454974544991444\n", "38.45899548255964\n", "38.45762604507572\n", "38.458902480508236\n", "38.458136816878124\n", "38.45756677710537\n", "38.46156398795241\n", "38.46001469942616\n", "38.4640365989659\n", "38.46669908927072\n", "38.463663929402145\n", "38.464072717910014\n", "38.46486504628166\n", "38.468952646610106\n", "38.46540049432696\n", "38.47184871867426\n", "38.47000675741746\n", "38.46989428190714\n", "38.47265110978942\n", "38.47619860012059\n", "38.47632058715916\n", "38.47190774185073\n", "38.47357783647937\n", "38.47432671507666\n", "38.4773810472886\n", "38.481108634954\n", "38.48041985675447\n", "38.48190051140749\n", "38.48133784953966\n", "38.48294732041705\n", "38.48046248545538\n", "38.48298818147035\n", "38.48412227494067\n", "38.48585153475273\n", "38.48542817262134\n", "38.48521799612618\n", "38.48784334300182\n", "38.487378823117666\n", "38.486466636530096\n", "38.486687394445184\n", "38.49349228381847\n", "38.49524213807374\n", "38.49326649023582\n", "38.492824214791526\n", "38.49453620783854\n", "38.496803994632124\n", "38.49631660324138\n", "38.49658444620245\n", "38.4958824681442\n", "38.499484318886196\n", "38.4978851670078\n", "38.49936960702576\n", "38.49778782200879\n", "38.50096712535483\n", "38.50229750543886\n", "38.507620489912945\n", "38.50509860204299\n", "38.503886851335345\n", "38.506356802546165\n", "38.50640047836118\n", "38.511678130513125\n", "38.509632341993104\n", "38.50824156298602\n", "38.5104553788435\n", "38.508254635374655\n", "38.511683056756965\n", "38.511465325161836\n", "38.514935901024224\n", "38.5181566009023\n", "38.516211235241784\n", "38.51235364498682\n", "38.519232878471335\n", "38.5201222241715\n", "38.519206738170496\n", "38.52167793468469\n", "38.518892735226\n", "38.52113759401762\n", "38.521025945359234\n", "38.52130621373478\n", "38.52415478282093\n", "38.52421138953707\n", "38.525185795558585\n", "38.52469134059527\n", "38.52427935311869\n", "38.52679260890753\n", "38.52651742149361\n", "38.527584677416606\n", "38.527287817846286\n", "38.5299140883168\n", "38.53194520590629\n", "38.532253726670014\n", "38.5332376638183\n", "38.533659166822176\n", "38.533707952642565\n", "38.536052480215034\n", "38.53786011210413\n", "38.5389866284863\n", "38.53350820617841\n", "38.53543585146987\n", "38.53723931432042\n", "38.537941534870875\n", "38.537926964632554\n", "38.53988528145683\n", "38.5429443473797\n", "38.54097154477061\n", "38.54329195351343\n", "38.54095552395381\n", "38.5436778910876\n", "38.543728213370414\n", "38.54647553809057\n", "38.5456649111206\n", "38.54630524177205\n", "38.549678808957204\n", "38.54988756940118\n", "38.54906056848318\n", "38.54902148601339\n", "38.54701962029451\n", "38.55124510105806\n", "38.55206729329144\n", "38.552876057496896\n", "38.55531958904884\n", "38.552612477469175\n", "38.554462474420376\n", "38.55615685717907\n", "38.55695595147239\n", "38.556081281121635\n", "38.55923191539515\n", "38.55963594147744\n", "38.558525228674505\n", "38.55688451768753\n", "38.56035470435008\n", "38.563929271287904\n", "38.56051833203356\n", "38.56000212968224\n", "38.5609032390877\n", "38.56531681199754\n", "38.56312939829513\n", "38.56936525298487\n", "38.56752465114099\n", "38.56662851175407\n", "38.565639285706006\n", "38.56801851008537\n", "38.568765678215485\n", "38.57079180700778\n", "38.57031396334204\n", "38.57400769593835\n", "38.5731621043828\n", "38.57044346527206\n", "38.57425011201873\n", "38.57166221362127\n", "38.57533133258945\n", "38.57773229874491\n", "38.57631195349808\n", "38.57512172367824\n", "38.574839009503656\n", "38.57542748597975\n", "38.57671992217561\n", "38.580517125091674\n", "38.57856355936766\n", "38.581605508402205\n", "38.57920820900686\n", "38.582457580754436\n", "38.58125255744463\n", "38.582322487424925\n", "38.58167566293191\n", "38.58248189315001\n", "38.580516925490976\n", "38.58591058086986\n", "38.586614625471924\n", "38.58500271532402\n", "38.58492907504222\n", "38.58665921092479\n", "38.585491120891945\n", "38.587170848915754\n", "38.5940321016179\n", "38.59306272122245\n", "38.58909627798376\n", "38.591564150446615\n", "38.59353771231502\n", "38.595805219514176\n", "38.59163966607795\n", "38.59538026418414\n", "38.59552354992038\n", "38.593079919778845\n", "38.59665345034937\n", "38.59588368870979\n", "38.599589924218215\n", "38.596336672140616\n", "38.59916938028869\n", "38.59838532879442\n", "38.601301654059725\n", "38.59526889833903\n", "38.602004844487524\n", "38.59736751600153\n", "38.6038812445821\n", "38.60212040798838\n", "38.60243152029274\n", "38.60377520041412\n", "38.60099812329969\n", "38.6056866932069\n", "38.60828997467465\n", "38.60533226709386\n", "38.60390894557522\n", "38.60473843495603\n", "38.608137098099995\n", "38.60632609557857\n", "38.61115441878469\n", "38.61068009001543\n", "38.607295937240686\n", "38.61064248007465\n", "38.60934988696213\n", "38.61103534552233\n", "38.61428277761836\n", "38.60884701736132\n", "38.61218637047843\n", "38.614734697074475\n", "38.61307696870756\n", "38.61369668813641\n", "38.616495930185394\n", "38.619047241763134\n", "38.61884538965792\n", "38.61808676256523\n", "38.62018632771107\n", "38.61679595283391\n", "38.61762795370665\n", "38.62216418398345\n", "38.62187198448468\n", "38.617867233134284\n", "38.62234820222173\n", "38.62185843000559\n", "38.61972275938531\n", "38.62096128441245\n", "38.622199769383286\n", "38.62659870395813\n", "38.62291820547969\n", "38.623563370830055\n", "38.62811026362446\n", "38.62755564885436\n", "38.62689504352079\n", "38.62937077335416\n", "38.62914093511307\n", "38.62953560752292\n", "38.631893414853764\n", "38.63010757571678\n", "38.628907549197535\n", "38.6337262816153\n", "38.630363419351056\n", "38.63368456674347\n", "38.634182375712015\n", "38.63347519041963\n", "38.635792939730045\n", "38.63593322615845\n", "38.635265576413246\n", "38.63594324916939\n", "38.63323066656973\n", "38.63618751978738\n", "38.63720355032003\n", "38.639838538074365\n", "38.6353042415089\n", "38.639087286268364\n", "38.63410673783434\n", "38.64014881241169\n", "38.643164496283354\n", "38.641669853975166\n", "38.642793976633925\n", "38.64145784742571\n", "38.64229419653077\n", "38.64163508844052\n", "38.642914882961094\n", "38.64404349956798\n", "38.64459869877362\n", "38.64800661912383\n", "38.64563878615512\n", "38.64606860354094\n", "38.643246170381865\n", "38.644033903218975\n", "38.646826231672556\n", "38.6472043634196\n", "38.64858686631784\n", "38.6517739043226\n", "38.64658266322722\n", "38.65188070619616\n", "38.648441269171755\n", "38.64842973881107\n", "38.65207845269054\n", "38.650547139123994\n", "38.6541824011279\n", "38.65248032345883\n", "38.6537152264134\n", "38.65190904937479\n", "38.651909549649005\n", "38.656700060006266\n", "38.653387202131\n", "38.656164973962255\n", "38.65653838543571\n", "38.65351565062675\n", "38.65852135548523\n", "38.65741592624363\n", "38.660420332350476\n", "38.65879413610065\n", "38.6589495428806\n", "38.66004994030596\n", "38.66052577113663\n", "38.658234051420784\n", "38.661570983932975\n", "38.66281723343816\n", "38.66423098470316\n", "38.66072114378979\n", "38.66231912676399\n", "38.663289649289716\n", "38.661157942087264\n", "38.66309577113043\n", "38.665002737374806\n", "38.664441434147236\n", "38.665186672343026\n", "38.66454065314266\n", "38.66438356304117\n", "38.66338237753109\n", "38.667625872264985\n", "38.66285149900965\n", "38.669500806079576\n", "38.668675698129036\n", "Optimal mean: 6.214657211944904, Optimal variance: 5.785533197517658\n" ] }, { "data": { "text/plain": "[]" }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": "
", "image/png": "" }, "metadata": {}, "output_type": "display_data" } ], "execution_count": 15 }, { "cell_type": "markdown", "source": [ "## Solution\n", "\n", "Using below equations for reference in optimizing the conditional gaussian model\n", "\n", "The log-likelihood function for this Gaussian distribution, given a set of observations (x, y), and parameters a, b, and $\\sigma^2$ is:\n", "\n", "\\begin{equation}\n", "\\log L(a, b, \\sigma^2 | x, y) = \\sum_{i=1}^{n} \\left[ -\\frac{1}{2} \\log(2\\pi\\sigma^2) - \\frac{(y_i - (a x_i + b))^2}{2\\sigma^2} \\right]\n", "\\end{equation}\n", "\n", "where $y_i$ are the observed values, $x_i$ are the inputs, and the sum is over all observations.\n", "\n", "The partial derivatives of the log-likelihood function with respect to `a`, `b`, and `σ²` are:\n", "\n", "1. With respect to `a`:\n", "\n", "\\begin{equation}\n", "\\frac{\\partial \\log L}{\\partial a} = \\sum \\frac{x (y - (ax + b))}{\\sigma^2}\n", "\\end{equation}\n", "\n", "2. With respect to `b`:\n", "\n", "\\begin{equation}\n", "\\frac{\\partial \\log L}{\\partial b} = \\sum \\frac{(y - (ax + b))}{\\sigma^2}\n", "\\end{equation}\n", "\n", "3. With respect to `σ²`:\n", "\n", "\\begin{equation}\n", "\\frac{\\partial \\log L}{\\partial \\sigma^2} = \\sum \\left[ -\\frac{1}{2\\sigma^2}( -N + \\frac{1}{\\sigma^2}\\sum_{i=1}^{N}(y - (ax + b))^2) \\right]\n", "\\end{equation}" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 123, "metadata": { "ExecuteTime": { "end_time": "2023-06-21T05:34:20.610461700Z", "start_time": "2023-06-21T05:34:19.323415600Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Optimal a: 1.0366003850912346, Optimal b: -2.5915990246152423, Optimal variance: 34.65335537389535\n" ] } ], "source": [ "import numpy as np\n", "\n", "# Given data\n", "x = np.array([8, 16, 22, 33, 50, 51])\n", "y = np.array([5, 20, 14, 32, 42, 58])\n", "\n", "# Initial values\n", "a = b = variance = 1\n", "# Set learning rate and number of epochs\n", "learning_rate = 0.01\n", "epochs = 50000\n", "\n", "# Define derivative functions\n", "def derivative_a(x, y, a, b, variance):\n", "\n", " t_sum = 0\n", " for i in range(len(x)):\n", " t_sum += x[i] * (y[i] - (a * x[i] + b))\n", "\n", " return t_sum / variance\n", "\n", "def derivative_b(x, y, a, b, variance):\n", "\n", " t_sum = 0\n", " for i in range(len(x)):\n", " t_sum += y[i] - (a * x[i] + b)\n", " return t_sum/variance\n", "\n", "def derivative_variance(x, y, a, b, variance):\n", " t_sum = 0; N=len(x)\n", " for i in range(len(x)):\n", " t_sum += np.power(y[i] - (a * x[i] + b), 2)\n", " ans = (1 / (2 * variance)) * ( -N + (1/variance) * t_sum )\n", " return ans\n", " # return (1 / (2 * variance)) * (-N)\n", "\n", "# Batch Gradient Descent\n", "for _ in range(epochs):\n", " # print(a , b, variance)\n", " variance += learning_rate * derivative_variance(x, y, a, b, variance)\n", " a += learning_rate * derivative_a(x, y, a, b, variance)\n", " b += learning_rate * derivative_b(x, y, a, b, variance)\n", "\n", " # Ensure variance remains positive\n", " variance = max(variance, 1e-6)\n", "\n", "print(f\"Optimal a: {a}, Optimal b: {b}, Optimal variance: {variance}\")\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.8.10" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "7d6993cb2f9ce9a59d5d7380609d9cb5192a9dedd2735a011418ad9e827eb538" } } }, "nbformat": 4, "nbformat_minor": 2 }