Spaces:
Running
Running
Commit
·
efa06b4
1
Parent(s):
a74444e
Move filtering and generation of plots to its own file
Browse files- src/data_processing.py +129 -0
src/data_processing.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import plotly.express as px
|
| 3 |
+
|
| 4 |
+
from config.constants import (
|
| 5 |
+
CC_BENCHMARKS,
|
| 6 |
+
LC_BENCHMARKS,
|
| 7 |
+
NON_RTL_METRICS,
|
| 8 |
+
RTL_METRICS,
|
| 9 |
+
S2R_BENCHMARKS,
|
| 10 |
+
SCATTER_PLOT_X_TICKS,
|
| 11 |
+
TYPE_COLORS,
|
| 12 |
+
Y_AXIS_LIMITS,
|
| 13 |
+
)
|
| 14 |
+
from utils import filter_bench, filter_bench_all, filter_RTLRepo, handle_special_cases
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def filter_leaderboard(task, benchmark, model_type, search_query, max_params, state):
|
| 18 |
+
"""Filter leaderboard data based on user selections."""
|
| 19 |
+
subset = state.get_current_df().copy()
|
| 20 |
+
|
| 21 |
+
# Filter by task specific benchmarks when 'All' benchmarks is selected
|
| 22 |
+
if task == "Spec-to-RTL":
|
| 23 |
+
valid_benchmarks = S2R_BENCHMARKS
|
| 24 |
+
if benchmark == "All":
|
| 25 |
+
subset = subset[subset["Benchmark"].isin(valid_benchmarks)]
|
| 26 |
+
elif task == "Code Completion":
|
| 27 |
+
valid_benchmarks = CC_BENCHMARKS
|
| 28 |
+
if benchmark == "All":
|
| 29 |
+
subset = subset[subset["Benchmark"].isin(valid_benchmarks)]
|
| 30 |
+
elif task == "Line Completion †":
|
| 31 |
+
valid_benchmarks = LC_BENCHMARKS
|
| 32 |
+
if benchmark == "All":
|
| 33 |
+
subset = subset[subset["Benchmark"].isin(valid_benchmarks)]
|
| 34 |
+
|
| 35 |
+
if benchmark != "All":
|
| 36 |
+
subset = state.get_current_df()[state.get_current_df()["Benchmark"] == benchmark]
|
| 37 |
+
|
| 38 |
+
if model_type != "All":
|
| 39 |
+
# without emojis
|
| 40 |
+
subset = subset[subset["Model Type"] == model_type.split(" ")[0]]
|
| 41 |
+
if search_query:
|
| 42 |
+
subset = subset[subset["Model"].str.contains(search_query, case=False, na=False)]
|
| 43 |
+
max_params = float(max_params)
|
| 44 |
+
subset = subset[subset["Params"] <= max_params]
|
| 45 |
+
|
| 46 |
+
if benchmark == "All":
|
| 47 |
+
if task == "Spec-to-RTL":
|
| 48 |
+
return filter_bench_all(subset, state.get_current_agg(), agg_column="Agg S2R")
|
| 49 |
+
elif task == "Code Completion":
|
| 50 |
+
return filter_bench_all(subset, state.get_current_agg(), agg_column="Agg MC")
|
| 51 |
+
elif task == "Line Completion †":
|
| 52 |
+
return filter_RTLRepo(subset)
|
| 53 |
+
elif benchmark == "RTL-Repo":
|
| 54 |
+
return filter_RTLRepo(subset)
|
| 55 |
+
else:
|
| 56 |
+
agg_column = None
|
| 57 |
+
if benchmark == "VerilogEval S2R":
|
| 58 |
+
agg_column = "Agg VerilogEval S2R"
|
| 59 |
+
elif benchmark == "VerilogEval MC":
|
| 60 |
+
agg_column = "Agg VerilogEval MC"
|
| 61 |
+
elif benchmark == "RTLLM":
|
| 62 |
+
agg_column = "Agg RTLLM"
|
| 63 |
+
elif benchmark == "VeriGen":
|
| 64 |
+
agg_column = "Agg VeriGen"
|
| 65 |
+
|
| 66 |
+
return filter_bench(subset, state.get_current_agg(), agg_column)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def generate_scatter_plot(benchmark, metric, state):
|
| 70 |
+
"""Generate a scatter plot for the given benchmark and metric."""
|
| 71 |
+
benchmark, metric = handle_special_cases(benchmark, metric)
|
| 72 |
+
|
| 73 |
+
subset = state.get_current_df()[state.get_current_df()["Benchmark"] == benchmark]
|
| 74 |
+
if benchmark == "RTL-Repo":
|
| 75 |
+
subset = subset[subset["Metric"].str.contains("EM", case=False, na=False)]
|
| 76 |
+
detailed_scores = subset.groupby("Model", as_index=False)["Score"].mean()
|
| 77 |
+
detailed_scores.rename(columns={"Score": "Exact Matching (EM)"}, inplace=True)
|
| 78 |
+
else:
|
| 79 |
+
detailed_scores = subset.pivot_table(index="Model", columns="Metric", values="Score").reset_index()
|
| 80 |
+
|
| 81 |
+
details = state.get_current_df()[["Model", "Params", "Model Type"]].drop_duplicates("Model")
|
| 82 |
+
scatter_data = pd.merge(detailed_scores, details, on="Model", how="left").dropna(
|
| 83 |
+
subset=["Params", metric]
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
scatter_data["x"] = scatter_data["Params"]
|
| 87 |
+
scatter_data["y"] = scatter_data[metric]
|
| 88 |
+
scatter_data["size"] = (scatter_data["x"] ** 0.3) * 40
|
| 89 |
+
|
| 90 |
+
scatter_data["color"] = scatter_data["Model Type"].map(TYPE_COLORS).fillna("gray")
|
| 91 |
+
|
| 92 |
+
y_range = Y_AXIS_LIMITS.get(metric, [0, 80])
|
| 93 |
+
|
| 94 |
+
fig = px.scatter(
|
| 95 |
+
scatter_data,
|
| 96 |
+
x="x",
|
| 97 |
+
y="y",
|
| 98 |
+
log_x=True,
|
| 99 |
+
size="size",
|
| 100 |
+
color="Model Type",
|
| 101 |
+
text="Model",
|
| 102 |
+
hover_data={metric: ":.2f"},
|
| 103 |
+
title=f"Params vs. {metric} for {benchmark}",
|
| 104 |
+
labels={"x": "# Params (Log Scale)", "y": metric},
|
| 105 |
+
template="plotly_white",
|
| 106 |
+
height=600,
|
| 107 |
+
width=1200,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
fig.update_traces(
|
| 111 |
+
textposition="top center",
|
| 112 |
+
textfont_size=10,
|
| 113 |
+
marker=dict(opacity=0.8, line=dict(width=0.5, color="black")),
|
| 114 |
+
)
|
| 115 |
+
fig.update_layout(
|
| 116 |
+
xaxis=dict(
|
| 117 |
+
showgrid=True,
|
| 118 |
+
type="log",
|
| 119 |
+
tickmode="array",
|
| 120 |
+
tickvals=SCATTER_PLOT_X_TICKS["tickvals"],
|
| 121 |
+
ticktext=SCATTER_PLOT_X_TICKS["ticktext"],
|
| 122 |
+
),
|
| 123 |
+
showlegend=False,
|
| 124 |
+
yaxis=dict(range=y_range),
|
| 125 |
+
margin=dict(l=50, r=50, t=50, b=50),
|
| 126 |
+
plot_bgcolor="white",
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
return fig
|