from datetime import datetime import numpy as np import json import re import heapq from collections import defaultdict import tempfile from typing import Dict, Tuple, List, Literal import gradio as gr from datatrove.utils.stats import MetricStatsDict from src.logic.graph_settings import Grouping PARTITION_OPTIONS = Literal["Top", "Bottom", "Most frequent (n_docs)"] def prepare_for_non_grouped_plotting(metric: Dict[str, MetricStatsDict], normalization: bool, rounding: int) -> Dict[float, float]: keys = np.array([float(key) for key in metric.keys()]) values = np.array([value.total for value in metric.values()]) rounded_keys = np.round(keys, rounding) unique_keys, indices = np.unique(rounded_keys, return_inverse=True) metrics_rounded = np.zeros_like(unique_keys, dtype=float) np.add.at(metrics_rounded, indices, values) if normalization: normalizer = np.sum(metrics_rounded) metrics_rounded /= normalizer return dict(zip(unique_keys, metrics_rounded)) def prepare_for_group_plotting(metric: Dict[str, MetricStatsDict], top_k: int, direction: PARTITION_OPTIONS, regex: str | None, rounding: int) -> Tuple[List[str], List[float], List[float]]: regex_compiled = re.compile(regex) if regex else None filtered_metric = {key: value for key, value in metric.items() if not regex or regex_compiled.match(key)} keys = np.array(list(filtered_metric.keys())) means = np.array([float(value.mean) for value in filtered_metric.values()]) stds = np.array([value.standard_deviation for value in filtered_metric.values()]) rounded_means = np.round(means, rounding) if direction == "Top": top_indices = np.argsort(rounded_means)[-top_k:][::-1] elif direction == "Most frequent (n_docs)": totals = np.array([int(value.n) for value in filtered_metric.values()]) top_indices = np.argsort(totals)[-top_k:][::-1] else: top_indices = np.argsort(rounded_means)[:top_k] top_keys = keys[top_indices] top_means = rounded_means[top_indices] top_stds = stds[top_indices] return top_keys.tolist(), top_means.tolist(), top_stds.tolist() def export_data(exported_data: Dict[str, MetricStatsDict], metric_name: str, grouping: Grouping): if not exported_data: return None file_name = f"{metric_name}_{grouping}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json" with open(file_name, 'w') as f: json.dump({ name: sorted([{"value": key, **value} for key, value in dt.to_dict().items()], key=lambda x: x["value"]) for name, dt in exported_data.items() }, f, indent=2) return gr.File(value=file_name, visible=True)