{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Reproducing some scaling laws results from [Chinchilla](https://arxiv.org/pdf/2203.15556.pdf). Can't get the numbers to match exactly, but can still be used as a rough guide to help determine compute-optimal models. Also contains related utilities for calculating flops and param counts." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "%matplotlib inline" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## params\n", "\n", "First some parameter calculations:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "123.653376" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def gpt_params(seq_len, vocab_size, d_model, num_heads, num_layers):\n", " \"\"\" Given GPT config calculate total number of parameters \"\"\"\n", " ffw_size = 4*d_model # in GPT the number of intermediate features is always 4*d_model\n", " # token and position embeddings\n", " embeddings = d_model * vocab_size + d_model * seq_len\n", " # transformer blocks\n", " attention = 3*d_model**2 + 3*d_model # weights and biases\n", " attproj = d_model**2 + d_model\n", " ffw = d_model*(ffw_size) + ffw_size\n", " ffwproj = ffw_size*d_model + d_model\n", " layernorms = 2*2*d_model\n", " # dense\n", " ln_f = 2*d_model\n", " dense = d_model*vocab_size # note: no bias here\n", " # note: embeddings are not included in the param count!\n", " total_params = num_layers*(attention + attproj + ffw + ffwproj + layernorms) + ln_f + dense\n", " return total_params\n", "\n", "gpt2 = dict(seq_len = 1024, vocab_size = 50257, d_model = 768, num_heads = 12, num_layers = 12)\n", "gpt_params(**gpt2)/1e6" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "OpenAI reports gpt2 (small) as having 124M params, so this is a match. Also, loading the OpenAI weights into nanoGPT and then calling `model.parameters()` exactly matches the above number and verifies the implementation. Now Chinchilla parameters:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def chinchilla_params(seq_len, vocab_size, d_model, num_heads, num_layers, ffw_size):\n", " \"\"\" Parameters in the Chinchilla models. Unlike GPT they use relative positional embeddings. \"\"\"\n", " # token embeddings only\n", " embeddings = d_model * vocab_size\n", " # transformer blocks\n", " attention = 3*d_model**2 + 3*d_model # weights and biases\n", " relative_pos = d_model**2 + 2*d_model # relative keys, content bias, relative bias\n", " attproj = d_model**2 + d_model\n", " ffw = d_model*ffw_size + ffw_size\n", " ffwproj = ffw_size*d_model + d_model\n", " layernorms = 2*2*d_model\n", " # dense\n", " ln_f = 2*d_model\n", " dense = d_model*vocab_size # note: no bias here\n", " # note: embeddings are not included in the param count!\n", " total_params = num_layers*(attention + relative_pos + attproj + ffw + ffwproj + layernorms) + ln_f + dense\n", " return total_params\n", "\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[44000000.0, 512, 2048, 64, 8, 8]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load in all the 50 Chinchilla models on the last page of the paper\n", "import json\n", "chinchilla_models_txt = '[[44000000.0, 512, 2048, 64, 8, 8], [57000000.0, 576, 2304, 64, 9, 9], [74000000.0, 640, 2560, 64, 10, 10], [90000000.0, 640, 2560, 64, 10, 13], [106000000.0, 640, 2560, 64, 10, 16], [117000000.0, 768, 3072, 64, 12, 12], [140000000.0, 768, 3072, 64, 12, 15], [163000000.0, 768, 3072, 64, 12, 18], [175000000.0, 896, 3584, 64, 14, 14], [196000000.0, 896, 3584, 64, 14, 16], [217000000.0, 896, 3584, 64, 14, 18], [251000000.0, 1024, 4096, 64, 16, 16], [278000000.0, 1024, 4096, 64, 16, 18], [306000000.0, 1024, 4096, 64, 16, 20], [425000000.0, 1280, 5120, 128, 10, 18], [489000000.0, 1280, 5120, 128, 10, 21], [509000000.0, 1408, 5632, 128, 11, 18], [552000000.0, 1280, 5120, 128, 10, 24], [587000000.0, 1408, 5632, 128, 11, 21], [632000000.0, 1536, 6144, 128, 12, 19], [664000000.0, 1408, 5632, 128, 11, 24], [724000000.0, 1536, 6144, 128, 12, 22], [816000000.0, 1536, 6144, 128, 12, 25], [893000000.0, 1792, 7168, 128, 14, 20], [1018000000.0, 1792, 7168, 128, 14, 23], [1143000000.0, 1792, 7168, 128, 14, 26], [1266000000.0, 2048, 8192, 128, 16, 22], [1424000000.0, 2176, 8704, 128, 17, 22], [1429000000.0, 2048, 8192, 128, 16, 25], [1593000000.0, 2048, 8192, 128, 16, 28], [1609000000.0, 2176, 8704, 128, 17, 25], [1731000000.0, 2304, 9216, 128, 18, 24], [1794000000.0, 2176, 8704, 128, 17, 28], [2007000000.0, 2304, 9216, 128, 18, 28], [2283000000.0, 2304, 9216, 128, 18, 32], [2298000000.0, 2560, 10240, 128, 20, 26], [2639000000.0, 2560, 10240, 128, 20, 30], [2980000000.0, 2560, 10240, 128, 20, 34], [3530000000.0, 2688, 10752, 128, 22, 36], [3802000000.0, 2816, 11264, 128, 22, 36], [4084000000.0, 2944, 11776, 128, 22, 36], [4516000000.0, 3072, 12288, 128, 24, 36], [6796000000.0, 3584, 14336, 128, 28, 40], [9293000000.0, 4096, 16384, 128, 32, 42], [11452000000.0, 4352, 17408, 128, 32, 47], [12295000000.0, 4608, 18432, 128, 36, 44], [12569000000.0, 4608, 18432, 128, 32, 47], [13735000000.0, 4864, 19456, 128, 32, 47], [14940000000.0, 4992, 19968, 128, 32, 49], [16183000000.0, 5120, 20480, 128, 40, 47]]'\n", "chilchilla_models = json.loads(chinchilla_models_txt) # all 50 models\n", "chilchilla_models[0] # tuples of params, d_model, ffw_size, kv_size, n_heads, n_layers from Table A9" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "our estimated params: 12296.1623M, chinchilla params: 12295.0000M, d_model: 4608, n_heads: 36, n_layers: 44\n", "our estimated params: 13124.4826M, chinchilla params: 12569.0000M, d_model: 4608, n_heads: 32, n_layers: 47\n", "our estimated params: 14614.4279M, chinchilla params: 13735.0000M, d_model: 4864, n_heads: 32, n_layers: 47\n", "our estimated params: 16037.5039M, chinchilla params: 14940.0000M, d_model: 4992, n_heads: 32, n_layers: 49\n", "our estimated params: 16184.4582M, chinchilla params: 16183.0000M, d_model: 5120, n_heads: 40, n_layers: 47\n" ] } ], "source": [ "for m in chilchilla_models[-5:]: # only print last 5 models of the table\n", " p, d, f, k, h, l = m\n", " nparams = chinchilla_params(seq_len = 1024, vocab_size = 32000, d_model = d, num_heads = h, num_layers = l, ffw_size=f)\n", " print(f\"our estimated params: {nparams/1e6:.4f}M, chinchilla params: {p/1e6:.4f}M, d_model: {d}, n_heads: {h}, n_layers: {l}\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We are almost able to reproduce the parameter counts for the Chinchilla models.\n", "\n", "Now turning to FLOPs:" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## flops" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def chinchilla_flops(seq_len, vocab_size, d_model, num_heads, num_layers, ffw_size):\n", " \"\"\" \n", " Calculate total number of FLOPs, see Chinchilla \n", " paper Appendix F as reference: https://arxiv.org/pdf/2203.15556.pdf\n", " \"\"\" \n", " key_size = d_model // num_heads\n", "\n", " # embeddings\n", " embeddings = 2 * seq_len * vocab_size * d_model\n", "\n", " # attention\n", " # key, query, value projections\n", " attention = 2 * 3 * seq_len * d_model * (key_size * num_heads)\n", " # key @ query logits\n", " attlogits = 2 * seq_len * seq_len * (key_size * num_heads)\n", " # softmax\n", " attsoftmax = 3 * num_heads * seq_len * seq_len # 3* is for subtract (max), exp, divide (?)\n", " # softmax @ value reductions\n", " attvalue = 2 * seq_len * seq_len * (key_size * num_heads)\n", " # final linear\n", " attlinear = 2 * seq_len * (key_size * num_heads) * d_model\n", " att = attention + attlogits + attsoftmax + attvalue + attlinear\n", " # feed forward\n", " dense = 2 * seq_len * (d_model * ffw_size + d_model * ffw_size)\n", "\n", " # logits\n", " logits = 2 * seq_len * d_model * vocab_size\n", " \n", " # this is what you'd expect:\n", " # forward_flops = embeddings + num_layers * (att + dense) + logits\n", " # but:\n", " # per author correspondence apparently there is typo in the paper,\n", " # they do not count embeddings and logits to repro table 4. So instead:\n", " forward_flops = num_layers * (att + dense)\n", " backward_flops = 2 * forward_flops # as in Kaplan et al. 2020\n", " total_flops = forward_flops + backward_flops\n", "\n", " return total_flops\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | seq_len | \n", "vocab_size | \n", "d_model | \n", "num_heads | \n", "num_layers | \n", "ffw_size | \n", "N | \n", "F | \n", "approx_flops | \n", "chinch_flops | \n", "ratio | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "2048 | \n", "32000 | \n", "640 | \n", "10 | \n", "10 | \n", "2560 | \n", "73825280 | \n", "929877196800 | \n", "907165040640 | \n", "9.298772e+11 | \n", "1.025036 | \n", "
1 | \n", "2048 | \n", "32000 | \n", "1024 | \n", "16 | \n", "20 | \n", "4096 | \n", "305707008 | \n", "4135248199680 | \n", "3756527714304 | \n", "4.135248e+12 | \n", "1.100817 | \n", "
2 | \n", "2048 | \n", "32000 | \n", "1280 | \n", "10 | \n", "24 | \n", "5120 | \n", "552604160 | \n", "7353453772800 | \n", "6790399918080 | \n", "7.353454e+12 | \n", "1.082919 | \n", "
3 | \n", "2048 | \n", "32000 | \n", "1792 | \n", "14 | \n", "26 | \n", "7168 | \n", "1143453696 | \n", "14670316437504 | \n", "14050759016448 | \n", "1.467032e+13 | \n", "1.044094 | \n", "
4 | \n", "2048 | \n", "32000 | \n", "2048 | \n", "16 | \n", "28 | \n", "8192 | \n", "1593126912 | \n", "20220437594112 | \n", "19576343494656 | \n", "2.022044e+13 | \n", "1.032902 | \n", "
5 | \n", "2048 | \n", "32000 | \n", "3584 | \n", "28 | \n", "40 | \n", "14336 | \n", "6796274688 | \n", "83021046743040 | \n", "83512623366144 | \n", "8.302105e+13 | \n", "0.994114 | \n", "