hynky's picture
hynky HF staff
⚡️ make it faster
276d919
raw
history blame
4.77 kB
from functools import partial
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import gradio as gr
from typing import Dict, List
from src.logic.data_processing import PARTITION_OPTIONS, prepare_for_non_grouped_plotting, prepare_for_group_plotting
from src.logic.graph_settings import Grouping
from src.logic.utils import set_alpha
from datatrove.utils.stats import MetricStatsDict
def plot_scatter(
data: Dict[str, MetricStatsDict],
metric_name: str,
log_scale_x: bool,
log_scale_y: bool,
normalization: bool,
rounding: int,
cumsum: bool,
perc: bool,
progress: gr.Progress,
):
fig = go.Figure()
data = {name: histogram for name, histogram in sorted(data.items())}
for i, (name, histogram) in enumerate(progress.tqdm(data.items(), total=len(data), desc="Plotting...")):
histogram_prepared = prepare_for_non_grouped_plotting(histogram, normalization, rounding)
x = sorted(histogram_prepared.keys())
y = [histogram_prepared[k] for k in x]
if cumsum:
y = np.cumsum(y).tolist()
if perc:
y = (np.array(y) * 100).tolist()
fig.add_trace(
go.Scatter(
x=x,
y=y,
mode="lines",
name=name,
marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5)),
)
)
yaxis_title = "Frequency" if normalization else "Total"
fig.update_layout(
title=f"Line Plots for {metric_name}",
xaxis_title=metric_name,
yaxis_title=yaxis_title,
xaxis_type="log" if log_scale_x and len(x) > 1 else None,
yaxis_type="log" if log_scale_y and len(y) > 1 else None,
width=1200,
height=600,
showlegend=True,
)
return fig
def plot_bars(
data: Dict[str, MetricStatsDict],
metric_name: str,
top_k: int,
direction: PARTITION_OPTIONS,
regex: str | None,
rounding: int,
log_scale_x: bool,
log_scale_y: bool,
show_stds: bool,
progress: gr.Progress,
):
fig = go.Figure()
x = []
y = []
for i, (name, histogram) in enumerate(progress.tqdm(data.items(), total=len(data), desc="Plotting...")):
x, y, stds = prepare_for_group_plotting(histogram, top_k, direction, regex, rounding)
fig.add_trace(go.Bar(
x=x,
y=y,
name=f"{name} Mean",
marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5)),
error_y=dict(type='data', array=stds, visible=show_stds)
))
fig.update_layout(
title=f"Bar Plots for {metric_name}",
xaxis_title=metric_name,
yaxis_title="Avg. value",
xaxis_type="log" if log_scale_x and len(x) > 1 else None,
yaxis_type="log" if log_scale_y and len(y) > 1 else None,
autosize=True,
width=1200,
height=600,
showlegend=True,
)
return fig
# Add any other necessary functions
def plot_data(
metric_data: Dict[str, MetricStatsDict],
metric_name: str,
normalize: bool,
rounding: int,
grouping: Grouping,
top_n: int,
direction: PARTITION_OPTIONS,
group_regex: str,
log_scale_x: bool,
log_scale_y: bool,
cdf: bool,
perc: bool,
show_stds: bool,
) -> tuple[go.Figure, gr.Row, str]:
if grouping == "histogram":
fig = plot_scatter(
metric_data,
metric_name,
log_scale_x,
log_scale_y,
normalize,
rounding,
cdf,
perc,
gr.Progress(),
)
min_max_hist_data = generate_min_max_hist_data(metric_data)
return fig, gr.Row.update(visible=True), min_max_hist_data
else:
fig = plot_bars(
metric_data,
metric_name,
top_n,
direction,
group_regex,
rounding,
log_scale_x,
log_scale_y,
show_stds,
gr.Progress(),
)
return fig, gr.Row.update(visible=True), ""
def generate_min_max_hist_data(data: Dict[str, MetricStatsDict]) -> str:
runs_data = {
run: {
"min": min(map(float, dato.keys())),
"max": max(map(float, dato.keys())),
}
for run, dato in data.items()
}
runs_rows = [
f"| {run} | {values['min']:.4f} | {values['max']:.4f} |"
for run, values in runs_data.items()
]
header = "| Run | Min | Max |\n|-----|-----|-----|\n"
return header + "\n".join(runs_rows)