In [None]:
from collections import defaultdict

import pandas as pd


def get_setting(name):
 if "terminal-punct" in name:
 return {"x": "Fraction of lines ended with punctuation", "ylim": (0, 0.1)}
 
 if "line-dedup" in name:
 return {"x": "Fraction of chars in duplicated lines", "xlim": (0, 0.1), "ylim": (0,0.02)}
 
 if "short-line" in name:
 return {"x": "Fraction of lines shorter than 30 chars", "xlim": (0.4, 1.0), "ylim": (0,0.05)}
 
 if "avg_words_per_line" in name:
 return {"x": "Avg. words per line", "x-log": True, "x-log": True, "round": 0}
 if "avg_line_length" in name:
 return {"x": "Avg. words per line", "x-log": True, "round": 0}
 
 if "global-length.json" == name:
 return {"x": "Num. UTF-8 chars", "x-log": True}
 
 if "global-digit_ratio.json" == name:
 return {"x": "Digit ratio", "xlim": (0, 0.25)}
 
 if "global-avg_word_length.json" == name:
 return {"x": "Avg. word length", "xlim": (2.5, 6.5)}

 
 raise ValueError(f"Unknown dataset name: {name}")


def plot_scatter(data):
 """
 Plot scatter plots with smoothing for each dataset in the data list on a single grid.
 Each dataset is expected to be a dictionary with the first key as the dataset name,
 and the value as another dictionary where keys are data points and values are their counts.
 """
 import matplotlib.pyplot as plt
 import numpy as np

 # Determine the number of plots and create a subplot grid
 num_datasets = len(data)
 cols = 2 # Define number of columns in the grid
 rows = (num_datasets) // cols # Calculate the required number of rows
 fig, axs = plt.subplots(rows, cols, figsize=(8 * cols, 3 * rows), dpi=350)
 if rows * cols > 1:
 axs = axs.flatten() # Flatten the array of axes if more than one subplot
 else:
 axs = [axs] # Encapsulate the single AxesSubplot object into a list for uniform handling

 plot_index = 0
 legend_handles = [] # List to store handles for the legend
 legend_labels = [] # List to store labels for the legend
 for name, dataset in data.items():
 setting = get_setting(name)
 ax = axs[plot_index]
 if "name" in setting:
 ax.set_title(setting["name"])
 if "x" in setting:
 ax.set_xlabel(setting["x"])
 if "xlim" in setting:
 ax.set_xlim(setting["xlim"])
 if "ylim" in setting:
 ax.set_ylim(setting["ylim"])
 if "x-log" in setting:
 ax.set_xscale('log')

 # Use 2 decimal places for the y-axis labels
 ax.yaxis.set_major_formatter('{x:.3f}')


 plot_index += 1
 # Each dataset may contain multiple lines
 for i, (line_name, line_data) in enumerate(dataset.items()):
 if "round" in setting:
 tmp_line_data = defaultdict(list)
 for p, p_v in line_data.items():
 rounded_key = str(round(float(p), setting["round"]))
 tmp_line_data[rounded_key].append(p_v)

 # If you want to sum the values that have the same rounded key
 tmp_line_data = {k: sum(v) for k, v in tmp_line_data.items()}
 line_data = tmp_line_data
 
 # Check that if you sum the values you get 1
 assert sum(line_data.values()) == 1

 # Add smoothing for 4-5 points
 # Implementing smoothing using a rolling window
 line_name = rename_dataset(line_name)
 # Sorting the line data by keys
 sorted_line_data = dict(sorted(line_data.items(), key=lambda item: float(item[0])))

 window_size = setting.get("window_size", 5) # Define the window size for smoothing
 x = np.array(list(sorted_line_data.keys()), dtype=float)
 y = np.array(list(sorted_line_data.values()), dtype=float)
 if len(y) >= window_size: # Ensure there are enough points to apply smoothing
 # Convert y to a pandas Series to use rolling function
 y_series = pd.Series(y)
 # Apply rolling window and mean to smooth the data
 y_smoothed = y_series.rolling(window=window_size).mean()
 # Drop NaN values that result from the rolling mean calculation
 y_smoothed = y_smoothed.dropna()
 # Update x to correspond to the length of the smoothed y
 x = x[len(x) - len(y_smoothed):]
 y = y_smoothed.to_numpy() # Convert back to numpy array for plotting



 # Use the line name as the label to unify same line names across different plots

 line, = ax.plot(x, y, label=line_name) # Use default colors
 if line_name not in legend_labels:
 legend_handles.append(line)
 legend_labels.append(line_name)

 # Place a single shared legend on the top of the figure
 fig.legend(handles=legend_handles, labels=legend_labels, loc='lower center', ncol=1)
 for ax in axs:
 ax.set_ylabel('Document Frequency')

 fig.suptitle("Histograms of selected statistics")
 plt.tight_layout(rect=[0, 0.15, 1, 1]) # Adjust the layout to make room for the legend
 fig.set_size_inches(13, 6) # Set the figure size to 18 inches by 12 inches
 plt.show()

plot_scatter(data)
