ggcristian commited on
Commit
efa06b4
·
1 Parent(s): a74444e

Move filtering and generation of plots to its own file

Browse files
Files changed (1) hide show
  1. src/data_processing.py +129 -0
src/data_processing.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import plotly.express as px
3
+
4
+ from config.constants import (
5
+ CC_BENCHMARKS,
6
+ LC_BENCHMARKS,
7
+ NON_RTL_METRICS,
8
+ RTL_METRICS,
9
+ S2R_BENCHMARKS,
10
+ SCATTER_PLOT_X_TICKS,
11
+ TYPE_COLORS,
12
+ Y_AXIS_LIMITS,
13
+ )
14
+ from utils import filter_bench, filter_bench_all, filter_RTLRepo, handle_special_cases
15
+
16
+
17
+ def filter_leaderboard(task, benchmark, model_type, search_query, max_params, state):
18
+ """Filter leaderboard data based on user selections."""
19
+ subset = state.get_current_df().copy()
20
+
21
+ # Filter by task specific benchmarks when 'All' benchmarks is selected
22
+ if task == "Spec-to-RTL":
23
+ valid_benchmarks = S2R_BENCHMARKS
24
+ if benchmark == "All":
25
+ subset = subset[subset["Benchmark"].isin(valid_benchmarks)]
26
+ elif task == "Code Completion":
27
+ valid_benchmarks = CC_BENCHMARKS
28
+ if benchmark == "All":
29
+ subset = subset[subset["Benchmark"].isin(valid_benchmarks)]
30
+ elif task == "Line Completion †":
31
+ valid_benchmarks = LC_BENCHMARKS
32
+ if benchmark == "All":
33
+ subset = subset[subset["Benchmark"].isin(valid_benchmarks)]
34
+
35
+ if benchmark != "All":
36
+ subset = state.get_current_df()[state.get_current_df()["Benchmark"] == benchmark]
37
+
38
+ if model_type != "All":
39
+ # without emojis
40
+ subset = subset[subset["Model Type"] == model_type.split(" ")[0]]
41
+ if search_query:
42
+ subset = subset[subset["Model"].str.contains(search_query, case=False, na=False)]
43
+ max_params = float(max_params)
44
+ subset = subset[subset["Params"] <= max_params]
45
+
46
+ if benchmark == "All":
47
+ if task == "Spec-to-RTL":
48
+ return filter_bench_all(subset, state.get_current_agg(), agg_column="Agg S2R")
49
+ elif task == "Code Completion":
50
+ return filter_bench_all(subset, state.get_current_agg(), agg_column="Agg MC")
51
+ elif task == "Line Completion †":
52
+ return filter_RTLRepo(subset)
53
+ elif benchmark == "RTL-Repo":
54
+ return filter_RTLRepo(subset)
55
+ else:
56
+ agg_column = None
57
+ if benchmark == "VerilogEval S2R":
58
+ agg_column = "Agg VerilogEval S2R"
59
+ elif benchmark == "VerilogEval MC":
60
+ agg_column = "Agg VerilogEval MC"
61
+ elif benchmark == "RTLLM":
62
+ agg_column = "Agg RTLLM"
63
+ elif benchmark == "VeriGen":
64
+ agg_column = "Agg VeriGen"
65
+
66
+ return filter_bench(subset, state.get_current_agg(), agg_column)
67
+
68
+
69
+ def generate_scatter_plot(benchmark, metric, state):
70
+ """Generate a scatter plot for the given benchmark and metric."""
71
+ benchmark, metric = handle_special_cases(benchmark, metric)
72
+
73
+ subset = state.get_current_df()[state.get_current_df()["Benchmark"] == benchmark]
74
+ if benchmark == "RTL-Repo":
75
+ subset = subset[subset["Metric"].str.contains("EM", case=False, na=False)]
76
+ detailed_scores = subset.groupby("Model", as_index=False)["Score"].mean()
77
+ detailed_scores.rename(columns={"Score": "Exact Matching (EM)"}, inplace=True)
78
+ else:
79
+ detailed_scores = subset.pivot_table(index="Model", columns="Metric", values="Score").reset_index()
80
+
81
+ details = state.get_current_df()[["Model", "Params", "Model Type"]].drop_duplicates("Model")
82
+ scatter_data = pd.merge(detailed_scores, details, on="Model", how="left").dropna(
83
+ subset=["Params", metric]
84
+ )
85
+
86
+ scatter_data["x"] = scatter_data["Params"]
87
+ scatter_data["y"] = scatter_data[metric]
88
+ scatter_data["size"] = (scatter_data["x"] ** 0.3) * 40
89
+
90
+ scatter_data["color"] = scatter_data["Model Type"].map(TYPE_COLORS).fillna("gray")
91
+
92
+ y_range = Y_AXIS_LIMITS.get(metric, [0, 80])
93
+
94
+ fig = px.scatter(
95
+ scatter_data,
96
+ x="x",
97
+ y="y",
98
+ log_x=True,
99
+ size="size",
100
+ color="Model Type",
101
+ text="Model",
102
+ hover_data={metric: ":.2f"},
103
+ title=f"Params vs. {metric} for {benchmark}",
104
+ labels={"x": "# Params (Log Scale)", "y": metric},
105
+ template="plotly_white",
106
+ height=600,
107
+ width=1200,
108
+ )
109
+
110
+ fig.update_traces(
111
+ textposition="top center",
112
+ textfont_size=10,
113
+ marker=dict(opacity=0.8, line=dict(width=0.5, color="black")),
114
+ )
115
+ fig.update_layout(
116
+ xaxis=dict(
117
+ showgrid=True,
118
+ type="log",
119
+ tickmode="array",
120
+ tickvals=SCATTER_PLOT_X_TICKS["tickvals"],
121
+ ticktext=SCATTER_PLOT_X_TICKS["ticktext"],
122
+ ),
123
+ showlegend=False,
124
+ yaxis=dict(range=y_range),
125
+ margin=dict(l=50, r=50, t=50, b=50),
126
+ plot_bgcolor="white",
127
+ )
128
+
129
+ return fig