{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from collections import defaultdict\n", "\n", "import pandas as pd\n", "\n", "\n", "def get_setting(name):\n", " if \"terminal-punct\" in name:\n", " return {\"x\": \"Fraction of lines ended with punctuation\", \"ylim\": (0, 0.1)}\n", " \n", " if \"line-dedup\" in name:\n", " return {\"x\": \"Fraction of chars in duplicated lines\", \"xlim\": (0, 0.1), \"ylim\": (0,0.02)}\n", " \n", " if \"short-line\" in name:\n", " return {\"x\": \"Fraction of lines shorter than 30 chars\", \"xlim\": (0.4, 1.0), \"ylim\": (0,0.05)}\n", " \n", " if \"avg_words_per_line\" in name:\n", " return {\"x\": \"Avg. words per line\", \"x-log\": True, \"x-log\": True, \"round\": 0}\n", " if \"avg_line_length\" in name:\n", " return {\"x\": \"Avg. words per line\", \"x-log\": True, \"round\": 0}\n", " \n", " if \"global-length.json\" == name:\n", " return {\"x\": \"Num. UTF-8 chars\", \"x-log\": True}\n", " \n", " if \"global-digit_ratio.json\" == name:\n", " return {\"x\": \"Digit ratio\", \"xlim\": (0, 0.25)}\n", " \n", " if \"global-avg_word_length.json\" == name:\n", " return {\"x\": \"Avg. word length\", \"xlim\": (2.5, 6.5)}\n", "\n", " \n", " raise ValueError(f\"Unknown dataset name: {name}\")\n", "\n", "\n", "def plot_scatter(data):\n", " \"\"\"\n", " Plot scatter plots with smoothing for each dataset in the data list on a single grid.\n", " Each dataset is expected to be a dictionary with the first key as the dataset name,\n", " and the value as another dictionary where keys are data points and values are their counts.\n", " \"\"\"\n", " import matplotlib.pyplot as plt\n", " import numpy as np\n", "\n", " # Determine the number of plots and create a subplot grid\n", " num_datasets = len(data)\n", " cols = 2 # Define number of columns in the grid\n", " rows = (num_datasets) // cols # Calculate the required number of rows\n", " fig, axs = plt.subplots(rows, cols, figsize=(8 * cols, 3 * rows), dpi=350)\n", " if rows * cols > 1:\n", " axs = axs.flatten() # Flatten the array of axes if more than one subplot\n", " else:\n", " axs = [axs] # Encapsulate the single AxesSubplot object into a list for uniform handling\n", "\n", " plot_index = 0\n", " legend_handles = [] # List to store handles for the legend\n", " legend_labels = [] # List to store labels for the legend\n", " for name, dataset in data.items():\n", " setting = get_setting(name)\n", " ax = axs[plot_index]\n", " if \"name\" in setting:\n", " ax.set_title(setting[\"name\"])\n", " if \"x\" in setting:\n", " ax.set_xlabel(setting[\"x\"])\n", " if \"xlim\" in setting:\n", " ax.set_xlim(setting[\"xlim\"])\n", " if \"ylim\" in setting:\n", " ax.set_ylim(setting[\"ylim\"])\n", " if \"x-log\" in setting:\n", " ax.set_xscale('log')\n", "\n", " # Use 2 decimal places for the y-axis labels\n", " ax.yaxis.set_major_formatter('{x:.3f}')\n", "\n", "\n", " plot_index += 1\n", " # Each dataset may contain multiple lines\n", " for i, (line_name, line_data) in enumerate(dataset.items()):\n", " if \"round\" in setting:\n", " tmp_line_data = defaultdict(list)\n", " for p, p_v in line_data.items():\n", " rounded_key = str(round(float(p), setting[\"round\"]))\n", " tmp_line_data[rounded_key].append(p_v)\n", "\n", " # If you want to sum the values that have the same rounded key\n", " tmp_line_data = {k: sum(v) for k, v in tmp_line_data.items()}\n", " line_data = tmp_line_data\n", " \n", " # Check that if you sum the values you get 1\n", " assert sum(line_data.values()) == 1\n", "\n", " # Add smoothing for 4-5 points\n", " # Implementing smoothing using a rolling window\n", " line_name = rename_dataset(line_name)\n", " # Sorting the line data by keys\n", " sorted_line_data = dict(sorted(line_data.items(), key=lambda item: float(item[0])))\n", "\n", " window_size = setting.get(\"window_size\", 5) # Define the window size for smoothing\n", " x = np.array(list(sorted_line_data.keys()), dtype=float)\n", " y = np.array(list(sorted_line_data.values()), dtype=float)\n", " if len(y) >= window_size: # Ensure there are enough points to apply smoothing\n", " # Convert y to a pandas Series to use rolling function\n", " y_series = pd.Series(y)\n", " # Apply rolling window and mean to smooth the data\n", " y_smoothed = y_series.rolling(window=window_size).mean()\n", " # Drop NaN values that result from the rolling mean calculation\n", " y_smoothed = y_smoothed.dropna()\n", " # Update x to correspond to the length of the smoothed y\n", " x = x[len(x) - len(y_smoothed):]\n", " y = y_smoothed.to_numpy() # Convert back to numpy array for plotting\n", "\n", "\n", "\n", " # Use the line name as the label to unify same line names across different plots\n", "\n", " line, = ax.plot(x, y, label=line_name) # Use default colors\n", " if line_name not in legend_labels:\n", " legend_handles.append(line)\n", " legend_labels.append(line_name)\n", "\n", " # Place a single shared legend on the top of the figure\n", " fig.legend(handles=legend_handles, labels=legend_labels, loc='lower center', ncol=1)\n", " for ax in axs:\n", " ax.set_ylabel('Document Frequency')\n", "\n", " fig.suptitle(\"Histograms of selected statistics\")\n", " plt.tight_layout(rect=[0, 0.15, 1, 1]) # Adjust the layout to make room for the legend\n", " fig.set_size_inches(13, 6) # Set the figure size to 18 inches by 12 inches\n", " plt.show()\n", "\n", "plot_scatter(data)\n" ] } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 2 }