{ "cells": [ { "cell_type": "code", "id": "initial_id", "metadata": { "collapsed": true, "ExecuteTime": { "end_time": "2024-05-14T09:57:03.097798Z", "start_time": "2024-05-14T09:57:02.853658Z" } }, "source": [ "import pandas as pd" ], "execution_count": 2, "outputs": [] }, { "cell_type": "code", "source": [ "df = pd.read_csv(\"/home/gui/hf_dev/datatrove/blogpost/data/commoncrawl_dumps.csv\")" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-14T09:57:03.110303Z", "start_time": "2024-05-14T09:57:03.098988Z" } }, "id": "157e18836c20793c", "execution_count": 3, "outputs": [] }, { "cell_type": "code", "source": [ "grouped = df.groupby('runname')\n", "\n", "# Define a function to take the top 6 rows of each group\n", "def top_6_avg(group):\n", " # Sort the group by \"steps\" in descending order\n", " sorted_group = group.sort_values(by='steps', ascending=False)\n", " # Take the top 6 rows\n", " top_6 = sorted_group.head(6)\n", " # Calculate the average of \"agg_score\"\n", " avg_score = top_6['agg_score'].mean()\n", " return avg_score\n", "\n", "def top_6_stats(group):\n", " # Sort the group by \"steps\" in descending order\n", " sorted_group = group.sort_values(by='steps', ascending=False)\n", " # Take the top 6 rows\n", " top_6 = sorted_group.head(6)\n", " # Calculate the average of \"agg_score\"\n", " avg_score = top_6['agg_score'].mean()\n", " # Calculate the standard deviation of \"agg_score\"\n", " std_dev = top_6['agg_score'].std()\n", " return pd.Series({'avg': avg_score, 'std_dev': std_dev})\n", "\n", "# Apply the function to each group and aggregate the results\n", "result = grouped.apply(top_6_stats)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-14T09:57:03.227764Z", "start_time": "2024-05-14T09:57:03.183929Z" } }, "id": "af7c0416a6371f9a", "execution_count": 4, "outputs": [] }, { "cell_type": "code", "source": [ "result" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-14T09:57:03.784515Z", "start_time": "2024-05-14T09:57:03.775829Z" } }, "id": "65c0cd58c6f9f9d6", "execution_count": 5, "outputs": [] }, { "cell_type": "code", "source": [ "import numpy as np\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "import matplotlib.colors as mcolors\n", "\n", "# Assuming you have already computed the result DataFrame\n", "\n", "# Sort the result DataFrame by \"runname\"\n", "result_sorted = result.sort_index()\n", "colors = result_sorted.index.str.split('-').str[0].astype(int)\n", "\n", "cmap = plt.cm.tab10\n", "\n", "# Create a new colormap without transparency\n", "new_colors = cmap(np.linspace(0, 1, cmap.N))\n", "new_colors = np.concatenate((new_colors[-2:], new_colors))\n", "new_cmap = mcolors.ListedColormap(new_colors)\n", "rgba_colors = new_cmap(new_colors)\n", "\n", "\n", "# Plotting\n", "plt.figure(figsize=(15, 10))\n", "# Join the points with a line\n", "plt.plot(range(len(result_sorted)), result_sorted[\"avg\"], linestyle='-', color='gray', alpha=0.5, zorder=1)\n", "scatter = plt.scatter(range(len(result_sorted)), result_sorted[\"avg\"], c=colors, cmap=new_cmap, marker='o', s=100, zorder=2)\n", "\n", "norm = plt.Normalize(min(colors), max(colors))\n", "\n", "import matplotlib.cm as cm\n", "# Creating a ScalarMappable object with the tab20 colormap and normalization\n", "sm = cm.ScalarMappable(cmap=new_cmap, norm=norm)\n", "\n", "plt.xlabel('Year', fontsize=18)\n", "plt.ylabel('Average Agg Score', fontsize=18)\n", "plt.title('Score by dump', fontsize=24)\n", "plt.xticks(range(len(result_sorted)), colors, ha='center', fontsize=14)\n", "plt.yticks(fontsize=14)\n", "ax = plt.gca()\n", "\n", "# for i in range(len(result_sorted)):\n", "# plt.errorbar(i, result_sorted.iloc[i]['avg'], yerr=result_sorted.iloc[i]['std_dev'], fmt='o', color=sm.to_rgba(colors[i]), markersize=0, capsize=5)\n", "prev = None\n", "labels = ax.xaxis.get_ticklabels()\n", "# labels[0].set_horizontalalignment('right')\n", "lines = []\n", "for x, name in enumerate(colors.tolist()):\n", " if name != prev:\n", " plt.axvline(x=x, color='grey', linestyle=':')\n", " lines.append(x)\n", " prev = name\n", "\n", "mids = np.floor((np.array(lines[:-1]) + np.array(lines[1:])) / 2)\n", "for x in range(len(colors) - 1):\n", " if x not in mids:\n", " labels[x].set_visible(False)\n", "labels[-1].set_horizontalalignment('left')\n", " \n", "\n", "# plt.grid(True)\n", "plt.savefig(\"/home/gui/hf_dev/datatrove/blogpost/plots/score_by_dump.png\", bbox_inches='tight', dpi=300)\n", "plt.show()" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-14T12:33:41.469562Z", "start_time": "2024-05-14T12:33:40.411105Z" } }, "id": "412ed6b4570d10e9", "execution_count": 98, "outputs": [] }, { "metadata": { "ExecuteTime": { "end_time": "2024-05-14T12:18:06.365519Z", "start_time": "2024-05-14T12:18:06.360995Z" } }, "cell_type": "code", "source": [ " \n", "new_colors = cmap(np.linspace(0, 1, cmap.N))\n", "new_colors = np.concatenate((new_colors[-2:], new_colors))\n", "mcolors.ListedColormap(new_colors)" ], "id": "270bd97983706aee", "execution_count": 85, "outputs": [] }, { "metadata": { "ExecuteTime": { "end_time": "2024-05-14T12:13:03.523524Z", "start_time": "2024-05-14T12:13:03.518910Z" } }, "cell_type": "code", "source": "new_cmap", "id": "ae52ddd47cf306a1", "execution_count": 76, "outputs": [] }, { "metadata": {}, "cell_type": "markdown", "source": "Flipped axis", "id": "dd4bbdf230df5953" }, { "metadata": { "ExecuteTime": { "end_time": "2024-05-14T10:16:00.731056Z", "start_time": "2024-05-14T10:15:59.648467Z" } }, "cell_type": "code", "source": [ "import matplotlib.pyplot as plt\n", "\n", "# Assuming you have already computed the result DataFrame\n", "\n", "# Sort the result DataFrame by \"runname\"\n", "result_sorted = result.sort_index()\n", "colors = result_sorted.index.str.split('-').str[0].astype(int)\n", "\n", "rgba_colors = plt.cm.tab20(colors)\n", "# Plotting\n", "plt.figure(figsize=(10, 20))\n", "scatter = plt.scatter(result_sorted[\"avg\"], range(len(result_sorted)), c=colors, cmap='tab20', marker='o', s=100)\n", "# Join the points with a line\n", "plt.plot(result_sorted[\"avg\"], range(len(result_sorted)), linestyle='-', color='gray', alpha=0.5)\n", "\n", "norm = plt.Normalize(min(colors), max(colors))\n", "\n", "import matplotlib.cm as cm\n", "\n", "# Creating a ScalarMappable object with the tab20 colormap and normalization\n", "sm = cm.ScalarMappable(cmap='tab20', norm=norm)\n", "\n", "plt.xlabel('Dump')\n", "plt.ylabel('Average Agg Score')\n", "plt.title('Score by dump. 3 last checkpoints of each seed avgd')\n", "plt.yticks(range(len(result_sorted)), result_sorted.index, ha='right', rotation_mode='anchor')\n", "ax = plt.gca()\n", "\n", "# for i in range(len(result_sorted)):\n", "# plt.errorbar(i, result_sorted.iloc[i]['avg'], yerr=result_sorted.iloc[i]['std_dev'], fmt='o', color=sm.to_rgba(colors[i]), markersize=0, capsize=5)\n", "# for label in ax.xaxis.get_ticklabels()[1::2]:\n", "# label.set_visible(False)\n", "\n", "plt.grid(True)\n", "plt.savefig(\"/home/gui/hf_dev/datatrove/blogpost/plots/score_by_dump.png\", bbox_inches='tight', dpi=300)\n", "plt.show()\n" ], "id": "49656c68704a55ca", "execution_count": 36, "outputs": [] }, { "metadata": {}, "cell_type": "code", "execution_count": null, "source": "", "id": "1872a68fa04b776d", "outputs": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }