{ "cells": [ { "cell_type": "code", "execution_count": 4, "id": "138889b92720ce2e", "metadata": { "ExecuteTime": { "end_time": "2024-05-14T09:06:04.487186Z", "start_time": "2024-05-14T09:06:04.255111Z" }, "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
runnamestepsagg_scorecommonsense_qa/acccommonsense_qa/acc_normhellaswag/acchellaswag/acc_normopenbookqa/accopenbookqa/acc_normpiqa/acc...siqa/accsiqa/acc_normwinogrande/accwinogrande/acc_normsciq/accsciq/acc_normarc/accarc/acc_normmmlu/accmmlu/acc_norm
0C400.3308930.1860.2330.2720.2580.1660.2860.542...0.3670.3620.5160.4970.2080.2020.21950.25100.2302940.250147
1C410000.3551120.2290.2600.2860.2880.1280.2500.614...0.3510.4040.5190.4760.5650.5180.26800.29350.2389510.250399
2C420000.3784350.2680.2780.3120.3300.1220.2760.646...0.3750.4000.5090.5000.6760.5770.30650.32300.2472750.255482
3C430000.3877950.2800.2950.3310.3800.1520.2740.660...0.3760.3870.5120.4960.7250.6210.31750.33400.2545340.267363
4C440000.3993200.2960.2980.3510.4060.1680.2820.676...0.3820.4040.5220.5030.7230.6180.32550.34700.2547620.263563
..................................................................
1171The Pile1630000.4637890.3790.3490.4410.5550.2400.3660.701...0.4050.3880.5850.5600.8750.8200.44750.44500.2993780.326313
1172The Pile1640000.4627580.3690.3440.4380.5520.2480.3480.708...0.3950.4010.5770.5670.8740.8060.44650.43550.3020830.331563
1173The Pile1650000.4650260.3830.3500.4380.5530.2340.3520.707...0.4000.4010.5690.5560.8740.8110.44600.44550.3051930.331708
1174The Pile1660000.4623490.3770.3460.4400.5570.2280.3460.711...0.3980.3980.5720.5580.8770.8110.45250.43850.3019520.331295
1175The Pile1670000.4645390.3860.3540.4340.5570.2320.3560.706...0.4020.4020.5730.5590.8670.8020.44750.43750.3019340.330810
\n", "

1176 rows × 21 columns

\n", "
" ], "text/plain": [ " runname steps agg_score commonsense_qa/acc \\\n", "0 C4 0 0.330893 0.186 \n", "1 C4 1000 0.355112 0.229 \n", "2 C4 2000 0.378435 0.268 \n", "3 C4 3000 0.387795 0.280 \n", "4 C4 4000 0.399320 0.296 \n", "... ... ... ... ... \n", "1171 The Pile 163000 0.463789 0.379 \n", "1172 The Pile 164000 0.462758 0.369 \n", "1173 The Pile 165000 0.465026 0.383 \n", "1174 The Pile 166000 0.462349 0.377 \n", "1175 The Pile 167000 0.464539 0.386 \n", "\n", " commonsense_qa/acc_norm hellaswag/acc hellaswag/acc_norm \\\n", "0 0.233 0.272 0.258 \n", "1 0.260 0.286 0.288 \n", "2 0.278 0.312 0.330 \n", "3 0.295 0.331 0.380 \n", "4 0.298 0.351 0.406 \n", "... ... ... ... \n", "1171 0.349 0.441 0.555 \n", "1172 0.344 0.438 0.552 \n", "1173 0.350 0.438 0.553 \n", "1174 0.346 0.440 0.557 \n", "1175 0.354 0.434 0.557 \n", "\n", " openbookqa/acc openbookqa/acc_norm piqa/acc ... siqa/acc \\\n", "0 0.166 0.286 0.542 ... 0.367 \n", "1 0.128 0.250 0.614 ... 0.351 \n", "2 0.122 0.276 0.646 ... 0.375 \n", "3 0.152 0.274 0.660 ... 0.376 \n", "4 0.168 0.282 0.676 ... 0.382 \n", "... ... ... ... ... ... \n", "1171 0.240 0.366 0.701 ... 0.405 \n", "1172 0.248 0.348 0.708 ... 0.395 \n", "1173 0.234 0.352 0.707 ... 0.400 \n", "1174 0.228 0.346 0.711 ... 0.398 \n", "1175 0.232 0.356 0.706 ... 0.402 \n", "\n", " siqa/acc_norm winogrande/acc winogrande/acc_norm sciq/acc \\\n", "0 0.362 0.516 0.497 0.208 \n", "1 0.404 0.519 0.476 0.565 \n", "2 0.400 0.509 0.500 0.676 \n", "3 0.387 0.512 0.496 0.725 \n", "4 0.404 0.522 0.503 0.723 \n", "... ... ... ... ... \n", "1171 0.388 0.585 0.560 0.875 \n", "1172 0.401 0.577 0.567 0.874 \n", "1173 0.401 0.569 0.556 0.874 \n", "1174 0.398 0.572 0.558 0.877 \n", "1175 0.402 0.573 0.559 0.867 \n", "\n", " sciq/acc_norm arc/acc arc/acc_norm mmlu/acc mmlu/acc_norm \n", "0 0.202 0.2195 0.2510 0.230294 0.250147 \n", "1 0.518 0.2680 0.2935 0.238951 0.250399 \n", "2 0.577 0.3065 0.3230 0.247275 0.255482 \n", "3 0.621 0.3175 0.3340 0.254534 0.267363 \n", "4 0.618 0.3255 0.3470 0.254762 0.263563 \n", "... ... ... ... ... ... \n", "1171 0.820 0.4475 0.4450 0.299378 0.326313 \n", "1172 0.806 0.4465 0.4355 0.302083 0.331563 \n", "1173 0.811 0.4460 0.4455 0.305193 0.331708 \n", "1174 0.811 0.4525 0.4385 0.301952 0.331295 \n", "1175 0.802 0.4475 0.4375 0.301934 0.330810 \n", "\n", "[1176 rows x 21 columns]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "from matplotlib.figure import Figure\n", "\n", "df = pd.read_csv(\"../src_data/eval_results.csv\")\n", "df" ] }, { "cell_type": "code", "execution_count": 2, "id": "b610f43caefdf01", "metadata": { "ExecuteTime": { "end_time": "2024-05-14T09:06:04.563945Z", "start_time": "2024-05-14T09:06:04.562142Z" }, "collapsed": false }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 5, "id": "initial_id", "metadata": { "ExecuteTime": { "end_time": "2024-05-14T09:06:37.927921Z", "start_time": "2024-05-14T09:06:37.588025Z" }, "collapsed": true }, "outputs": [], "source": [ "import json\n", "import os\n", "from matplotlib import pyplot as plt\n", "metrics = ['agg_score', 'commonsense_qa/acc_norm', 'hellaswag/acc_norm', 'openbookqa/acc_norm', 'piqa/acc_norm',\n", " 'siqa/acc_norm', 'winogrande/acc_norm', 'arc/acc_norm', 'mmlu/acc_norm']\n", "\n", "def normalize_runname(runname):\n", " return runname.replace(\"/\", \"_\")\n", "\n", "grouped = (\n", " df.groupby([\"runname\", \"steps\"])\n", " .agg(\n", " {\n", " key: \"mean\" for key in metrics\n", " }\n", " )\n", " .reset_index()\n", ")\n", "\n", "file_id=\"../assets/data/plots/dataset_ablations\"\n", "files = {}\n", "for metric in metrics:\n", " datas = {}\n", " for name, group in grouped.groupby(\"runname\"):\n", " group = group[[\"steps\", metric]].sort_values(by=\"steps\")\n", " group = group.set_index(\"steps\")\n", " rolling_avg = group\n", " # rolling_avg = group.rolling(window=5).mean()\n", " datas[name] = {\n", " \"x\": (rolling_avg.index * 2048 * 1024 * 1e-9).tolist(),\n", " \"y\": rolling_avg[metric].tolist(),\n", " \"label\": name,\n", " }\n", " # Sort the datata based on the steps\n", " datas = {k: v for k, v in sorted(datas.items(), key=lambda x: -x[1][\"y\"][-1])}\n", " # Create a folder\n", " os.makedirs(f\"{file_id}\", exist_ok=True)\n", " with open(f\"{file_id}/{normalize_runname(metric)}.json\", \"w\") as f:\n", " json.dump({\n", " \"data\": datas,\n", " \"layout\": {\n", " \"title\": {\n", " \"text\": \"Dataset ablations\"\n", " },\n", " }\n", " }, f)\n", " files[metric] = {\"file\": f\"{normalize_runname(metric)}.json\"}\n", "# Create index\n", "with open(f\"{file_id}/index.json\", \"w\") as f:\n", " json.dump({\n", " \"files\": files,\n", " \"settings\": {\n", " \"defaultMetric\": \"agg_score\",\n", " \"slider\":{\"min\":0,\"max\":30,\"default\":5}\n", " }\n", " }, f)\n", " " ] }, { "cell_type": "code", "execution_count": 7, "id": "af28ebbd054cdc33", "metadata": { "ExecuteTime": { "end_time": "2024-05-04T22:25:33.206952Z", "start_time": "2024-05-04T22:25:33.205262Z" }, "collapsed": false }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 5 }