|
from functools import partial |
|
import plotly.express as px |
|
import plotly.graph_objects as go |
|
import numpy as np |
|
import gradio as gr |
|
from typing import Dict, List |
|
|
|
from src.logic.data_processing import PARTITION_OPTIONS, prepare_for_non_grouped_plotting, prepare_for_group_plotting |
|
from src.logic.graph_settings import Grouping |
|
from src.logic.utils import set_alpha |
|
from datatrove.utils.stats import MetricStatsDict |
|
|
|
def plot_scatter( |
|
data: Dict[str, MetricStatsDict], |
|
metric_name: str, |
|
log_scale_x: bool, |
|
log_scale_y: bool, |
|
normalization: bool, |
|
rounding: int, |
|
cumsum: bool, |
|
perc: bool, |
|
progress: gr.Progress, |
|
): |
|
fig = go.Figure() |
|
data = {name: histogram for name, histogram in sorted(data.items())} |
|
for i, (name, histogram) in enumerate(progress.tqdm(data.items(), total=len(data), desc="Plotting...")): |
|
histogram_prepared = prepare_for_non_grouped_plotting(histogram, normalization, rounding) |
|
x = sorted(histogram_prepared.keys()) |
|
y = [histogram_prepared[k] for k in x] |
|
if cumsum: |
|
y = np.cumsum(y).tolist() |
|
if perc: |
|
y = (np.array(y) * 100).tolist() |
|
|
|
fig.add_trace( |
|
go.Scatter( |
|
x=x, |
|
y=y, |
|
mode="lines", |
|
name=name, |
|
marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5)), |
|
) |
|
) |
|
|
|
yaxis_title = "Frequency" if normalization else "Total" |
|
|
|
fig.update_layout( |
|
title=f"Line Plots for {metric_name}", |
|
xaxis_title=metric_name, |
|
yaxis_title=yaxis_title, |
|
xaxis_type="log" if log_scale_x and len(x) > 1 else None, |
|
yaxis_type="log" if log_scale_y and len(y) > 1 else None, |
|
width=1200, |
|
height=600, |
|
showlegend=True, |
|
) |
|
|
|
return fig |
|
|
|
def plot_bars( |
|
data: Dict[str, MetricStatsDict], |
|
metric_name: str, |
|
top_k: int, |
|
direction: PARTITION_OPTIONS, |
|
regex: str | None, |
|
rounding: int, |
|
log_scale_x: bool, |
|
log_scale_y: bool, |
|
show_stds: bool, |
|
progress: gr.Progress, |
|
): |
|
fig = go.Figure() |
|
x = [] |
|
y = [] |
|
|
|
for i, (name, histogram) in enumerate(progress.tqdm(data.items(), total=len(data), desc="Plotting...")): |
|
x, y, stds = prepare_for_group_plotting(histogram, top_k, direction, regex, rounding) |
|
|
|
fig.add_trace(go.Bar( |
|
x=x, |
|
y=y, |
|
name=f"{name} Mean", |
|
marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5)), |
|
error_y=dict(type='data', array=stds, visible=show_stds) |
|
)) |
|
|
|
fig.update_layout( |
|
title=f"Bar Plots for {metric_name}", |
|
xaxis_title=metric_name, |
|
yaxis_title="Avg. value", |
|
xaxis_type="log" if log_scale_x and len(x) > 1 else None, |
|
yaxis_type="log" if log_scale_y and len(y) > 1 else None, |
|
autosize=True, |
|
width=1200, |
|
height=600, |
|
showlegend=True, |
|
) |
|
|
|
return fig |
|
|
|
|
|
|
|
|
|
def plot_data( |
|
metric_data: Dict[str, MetricStatsDict], |
|
metric_name: str, |
|
normalize: bool, |
|
rounding: int, |
|
grouping: Grouping, |
|
top_n: int, |
|
direction: PARTITION_OPTIONS, |
|
group_regex: str, |
|
log_scale_x: bool, |
|
log_scale_y: bool, |
|
cdf: bool, |
|
perc: bool, |
|
show_stds: bool, |
|
) -> tuple[go.Figure, gr.Row, str]: |
|
if grouping == "histogram": |
|
fig = plot_scatter( |
|
metric_data, |
|
metric_name, |
|
log_scale_x, |
|
log_scale_y, |
|
normalize, |
|
rounding, |
|
cdf, |
|
perc, |
|
gr.Progress(), |
|
) |
|
min_max_hist_data = generate_min_max_hist_data(metric_data) |
|
return fig, gr.Row.update(visible=True), min_max_hist_data |
|
else: |
|
fig = plot_bars( |
|
metric_data, |
|
metric_name, |
|
top_n, |
|
direction, |
|
group_regex, |
|
rounding, |
|
log_scale_x, |
|
log_scale_y, |
|
show_stds, |
|
gr.Progress(), |
|
) |
|
return fig, gr.Row.update(visible=True), "" |
|
|
|
def generate_min_max_hist_data(data: Dict[str, MetricStatsDict]) -> str: |
|
runs_data = { |
|
run: { |
|
"min": min(map(float, dato.keys())), |
|
"max": max(map(float, dato.keys())), |
|
} |
|
for run, dato in data.items() |
|
} |
|
|
|
runs_rows = [ |
|
f"| {run} | {values['min']:.4f} | {values['max']:.4f} |" |
|
for run, values in runs_data.items() |
|
] |
|
header = "| Run | Min | Max |\n|-----|-----|-----|\n" |
|
return header + "\n".join(runs_rows) |