agh123 commited on
Commit
5b7d0a1
·
1 Parent(s): 0202f73

feat: add quantization to "Model Size vs Performance" plot

Browse files
Files changed (1) hide show
  1. src/components/visualizations.py +132 -72
src/components/visualizations.py CHANGED
@@ -7,6 +7,7 @@ import plotly.express as px
7
  import pandas as pd
8
  from typing import Optional, Dict, List, Set
9
  import plotly.graph_objects as go
 
10
 
11
 
12
  def clean_device_id(device_id: str) -> str:
@@ -16,6 +17,24 @@ def clean_device_id(device_id: str) -> str:
16
  return device_id
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def create_performance_plot(
20
  df: pd.DataFrame, metric: str, title: str, hover_data: List[str] = None
21
  ):
@@ -113,7 +132,9 @@ def filter_dataframe(df: pd.DataFrame, filters: Dict) -> pd.DataFrame:
113
  return filtered_df
114
 
115
 
116
- def create_model_size_performance_plot(df: pd.DataFrame, device_id: str, title: str):
 
 
117
  """Create a plot showing model size vs performance metrics for a specific device"""
118
  if df.empty:
119
  return None
@@ -123,60 +144,103 @@ def create_model_size_performance_plot(df: pd.DataFrame, device_id: str, title:
123
  if device_df.empty:
124
  return None
125
 
 
 
 
 
 
 
 
 
 
 
126
  # Create a new figure with secondary y-axis
127
  fig = go.Figure()
128
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  # Add Token Generation data (left y-axis)
130
- fig.add_trace(
131
- go.Scatter(
132
- x=device_df["Model Size"],
133
- y=device_df["Token Generation"],
134
- name="Token Generation",
135
- mode="markers",
136
- marker=dict(color="#2ecc71"),
137
- yaxis="y",
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  )
139
- )
140
 
141
  # Add Prompt Processing data (right y-axis)
142
- fig.add_trace(
143
- go.Scatter(
144
- x=device_df["Model Size"],
145
- y=device_df["Prompt Processing"],
146
- name="Prompt Processing",
147
- mode="markers",
148
- marker=dict(color="#e74c3c"),
149
- yaxis="y2",
 
 
 
 
 
 
 
 
 
 
 
 
150
  )
151
- )
152
 
153
  # Add trend lines if enough points
154
  if len(device_df) > 2:
155
  # TG trend line
156
  tg_trend = px.scatter(
157
  device_df, x="Model Size", y="Token Generation", trendline="lowess"
158
- ).data[
159
- 1
160
- ] # Get the trend line trace
161
  tg_trend.update(
162
  line=dict(color="#2ecc71", dash="solid"),
163
- name="TG Trend",
164
- showlegend=False,
165
  yaxis="y",
 
166
  )
167
  fig.add_trace(tg_trend)
168
 
169
  # PP trend line
170
  pp_trend = px.scatter(
171
  device_df, x="Model Size", y="Prompt Processing", trendline="lowess"
172
- ).data[
173
- 1
174
- ] # Get the trend line trace
175
  pp_trend.update(
176
  line=dict(color="#e74c3c", dash="solid"),
177
- name="PP Trend",
178
- showlegend=False,
179
  yaxis="y2",
 
180
  )
181
  fig.add_trace(pp_trend)
182
 
@@ -186,10 +250,7 @@ def create_model_size_performance_plot(df: pd.DataFrame, device_id: str, title:
186
  xaxis=dict(
187
  title="Model Size (B)",
188
  gridcolor="lightgrey",
189
- range=[
190
- 0,
191
- max(device_df["Model Size"]) * 1.05,
192
- ], # Start from 0, add 5% padding to max
193
  ),
194
  yaxis=dict(
195
  title="Token Generation (t/s)",
@@ -197,10 +258,7 @@ def create_model_size_performance_plot(df: pd.DataFrame, device_id: str, title:
197
  tickfont=dict(color="#2ecc71"),
198
  gridcolor="lightgrey",
199
  side="left",
200
- range=[
201
- 0,
202
- max(device_df["Token Generation"]) * 1.05,
203
- ], # Start from 0, add 5% padding to max
204
  ),
205
  yaxis2=dict(
206
  title="Prompt Processing (t/s)",
@@ -209,22 +267,21 @@ def create_model_size_performance_plot(df: pd.DataFrame, device_id: str, title:
209
  anchor="x",
210
  overlaying="y",
211
  side="right",
212
- range=[
213
- 0,
214
- max(device_df["Prompt Processing"]) * 1.05,
215
- ], # Start from 0, add 5% padding to max
216
  ),
217
  height=400,
218
  showlegend=True,
219
  plot_bgcolor="white",
220
  legend=dict(
221
- yanchor="middle",
222
- y=0.8,
223
  xanchor="right",
224
  x=0.99,
225
- bgcolor="rgba(255, 255, 255, 0.8)", # Semi-transparent white background
226
  bordercolor="lightgrey",
227
  borderwidth=1,
 
 
228
  ),
229
  )
230
 
@@ -255,30 +312,48 @@ def render_model_size_performance(df: pd.DataFrame, filters: Dict):
255
  device_id: clean_device_id(device_id) for device_id in device_ids
256
  }
257
 
258
- # Device selector for size vs performance plots
259
- selected_device_id = st.selectbox(
260
- "Select Device",
261
- options=device_ids,
262
- format_func=lambda x: device_display_names[
263
- x
264
- ], # Display clean names in dropdown
265
- help="Select a device to view its performance across different model sizes",
266
- key="size_perf_device_selector",
267
- placeholder="Search for a device...",
268
- index=default_index,
269
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
  # Create and display the model size vs performance plot
272
  size_perf_fig = create_model_size_performance_plot(
273
  size_perf_df,
274
  selected_device_id,
 
275
  f"Model Size vs Performance Metrics for {device_display_names[selected_device_id]}",
276
  )
277
 
278
  if size_perf_fig:
279
  st.plotly_chart(size_perf_fig, use_container_width=True)
280
  else:
281
- st.warning("No data available for the selected device.")
282
 
283
 
284
  def render_performance_plots(df: pd.DataFrame, filters: Dict):
@@ -721,21 +796,6 @@ def render_device_rankings(df: pd.DataFrame):
721
  with rank_tab3:
722
  st.subheader("🔍 Rankings by Quantization")
723
 
724
- # Helper function to get quantization name from factor
725
- def get_quant_name(factor: float) -> str:
726
- if factor >= 1.0:
727
- return "No Quantization (F16/F32)"
728
- quant_map = {
729
- 0.8: "[i]Q8_x",
730
- 0.6: "[i]Q6_x",
731
- 0.5: "[i]Q5_x",
732
- 0.4: "[i]Q4_x",
733
- 0.3: "[i]Q3_x",
734
- 0.2: "[i]Q2_x",
735
- 0.1: "[i]Q1_x",
736
- }
737
- return quant_map.get(factor, f"Q{int(factor*10)}_x")
738
-
739
  # Group by device and quantization level
740
  quant_rankings = df.copy()
741
  quant_summary = (
 
7
  import pandas as pd
8
  from typing import Optional, Dict, List, Set
9
  import plotly.graph_objects as go
10
+ from ..core.scoring import get_quantization_tier
11
 
12
 
13
  def clean_device_id(device_id: str) -> str:
 
17
  return device_id
18
 
19
 
20
+ def get_quant_name(factor: float) -> str:
21
+ """Get human-readable name for quantization factor"""
22
+ if pd.isna(factor):
23
+ return "Unknown"
24
+ if factor >= 1.0:
25
+ return "No Quantization (F16/F32)"
26
+ quant_map = {
27
+ 0.8: "[i]Q8_x",
28
+ 0.6: "[i]Q6_x",
29
+ 0.5: "[i]Q5_x",
30
+ 0.4: "[i]Q4_x",
31
+ 0.3: "[i]Q3_x",
32
+ 0.2: "[i]Q2_x",
33
+ 0.1: "[i]Q1_x",
34
+ }
35
+ return quant_map.get(factor, f"Q{int(factor*10)}_x")
36
+
37
+
38
  def create_performance_plot(
39
  df: pd.DataFrame, metric: str, title: str, hover_data: List[str] = None
40
  ):
 
132
  return filtered_df
133
 
134
 
135
+ def create_model_size_performance_plot(
136
+ df: pd.DataFrame, device_id: str, quant_filter: str, title: str
137
+ ):
138
  """Create a plot showing model size vs performance metrics for a specific device"""
139
  if df.empty:
140
  return None
 
144
  if device_df.empty:
145
  return None
146
 
147
+ # Filter by quantization if specified
148
+ if quant_filter != "All":
149
+ device_df = device_df[
150
+ device_df["Model ID"].apply(
151
+ lambda x: get_quantization_tier(x) == float(quant_filter)
152
+ )
153
+ ]
154
+ if device_df.empty:
155
+ return None
156
+
157
  # Create a new figure with secondary y-axis
158
  fig = go.Figure()
159
 
160
+ # Define shapes for different quantization levels
161
+ quant_shapes = {
162
+ 1.0: "circle", # F16/F32
163
+ 0.8: "square", # Q8
164
+ 0.6: "diamond", # Q6
165
+ 0.5: "triangle-up", # Q5
166
+ 0.4: "triangle-down", # Q4
167
+ 0.3: "star", # Q3
168
+ 0.2: "pentagon", # Q2
169
+ 0.1: "hexagon", # Q1
170
+ }
171
+
172
  # Add Token Generation data (left y-axis)
173
+ for quant in sorted(device_df["quant_factor"].unique()):
174
+ quant_df = device_df[device_df["quant_factor"] == quant]
175
+ if quant_df.empty:
176
+ continue
177
+
178
+ quant_name = get_quant_name(quant)
179
+ fig.add_trace(
180
+ go.Scatter(
181
+ x=quant_df["Model Size"],
182
+ y=quant_df["Token Generation"],
183
+ name=f"{quant_name}",
184
+ mode="markers",
185
+ marker=dict(
186
+ color="#2ecc71",
187
+ symbol=quant_shapes.get(quant, "circle"),
188
+ size=10,
189
+ ),
190
+ yaxis="y",
191
+ legendgroup="quant",
192
+ showlegend=True,
193
+ )
194
  )
 
195
 
196
  # Add Prompt Processing data (right y-axis)
197
+ for quant in sorted(device_df["quant_factor"].unique()):
198
+ quant_df = device_df[device_df["quant_factor"] == quant]
199
+ if quant_df.empty:
200
+ continue
201
+
202
+ fig.add_trace(
203
+ go.Scatter(
204
+ x=quant_df["Model Size"],
205
+ y=quant_df["Prompt Processing"],
206
+ name=f"{quant_name}",
207
+ mode="markers",
208
+ marker=dict(
209
+ color="#e74c3c",
210
+ symbol=quant_shapes.get(quant, "circle"),
211
+ size=10,
212
+ ),
213
+ yaxis="y2",
214
+ legendgroup="quant",
215
+ showlegend=False, # Don't show duplicate quantization entries in legend
216
+ )
217
  )
 
218
 
219
  # Add trend lines if enough points
220
  if len(device_df) > 2:
221
  # TG trend line
222
  tg_trend = px.scatter(
223
  device_df, x="Model Size", y="Token Generation", trendline="lowess"
224
+ ).data[1]
 
 
225
  tg_trend.update(
226
  line=dict(color="#2ecc71", dash="solid"),
227
+ name="Token Generation",
228
+ showlegend=False, # Hide from legend
229
  yaxis="y",
230
+ legendgroup="metric",
231
  )
232
  fig.add_trace(tg_trend)
233
 
234
  # PP trend line
235
  pp_trend = px.scatter(
236
  device_df, x="Model Size", y="Prompt Processing", trendline="lowess"
237
+ ).data[1]
 
 
238
  pp_trend.update(
239
  line=dict(color="#e74c3c", dash="solid"),
240
+ name="Prompt Processing",
241
+ showlegend=False, # Hide from legend
242
  yaxis="y2",
243
+ legendgroup="metric",
244
  )
245
  fig.add_trace(pp_trend)
246
 
 
250
  xaxis=dict(
251
  title="Model Size (B)",
252
  gridcolor="lightgrey",
253
+ range=[0, max(device_df["Model Size"]) * 1.05],
 
 
 
254
  ),
255
  yaxis=dict(
256
  title="Token Generation (t/s)",
 
258
  tickfont=dict(color="#2ecc71"),
259
  gridcolor="lightgrey",
260
  side="left",
261
+ range=[0, max(device_df["Token Generation"]) * 1.05],
 
 
 
262
  ),
263
  yaxis2=dict(
264
  title="Prompt Processing (t/s)",
 
267
  anchor="x",
268
  overlaying="y",
269
  side="right",
270
+ range=[0, max(device_df["Prompt Processing"]) * 1.05],
 
 
 
271
  ),
272
  height=400,
273
  showlegend=True,
274
  plot_bgcolor="white",
275
  legend=dict(
276
+ yanchor="top",
277
+ y=0.99,
278
  xanchor="right",
279
  x=0.99,
280
+ bgcolor="rgba(255, 255, 255, 0.8)",
281
  bordercolor="lightgrey",
282
  borderwidth=1,
283
+ groupclick="togglegroup", # Toggle all traces in the same group
284
+ title="Quantization", # Add legend title
285
  ),
286
  )
287
 
 
312
  device_id: clean_device_id(device_id) for device_id in device_ids
313
  }
314
 
315
+ # Create columns for device and quantization selectors
316
+ col1, col2 = st.columns([2, 1])
317
+
318
+ with col1:
319
+ # Device selector
320
+ selected_device_id = st.selectbox(
321
+ "Select Device",
322
+ options=device_ids,
323
+ format_func=lambda x: device_display_names[x],
324
+ help="Select a device to view its performance across different model sizes",
325
+ key="size_perf_device_selector",
326
+ placeholder="Search for a device...",
327
+ index=default_index,
328
+ )
329
+
330
+ with col2:
331
+ # Quantization filter
332
+ quant_options = ["All"] + [
333
+ str(q) for q in sorted(size_perf_df["quant_factor"].unique())
334
+ ]
335
+ quant_filter = st.selectbox(
336
+ "Filter by Quantization",
337
+ options=quant_options,
338
+ format_func=lambda x: (
339
+ "All Quantizations" if x == "All" else get_quant_name(float(x))
340
+ ),
341
+ help="Filter data points by quantization level",
342
+ key="size_perf_quant_selector",
343
+ )
344
 
345
  # Create and display the model size vs performance plot
346
  size_perf_fig = create_model_size_performance_plot(
347
  size_perf_df,
348
  selected_device_id,
349
+ quant_filter,
350
  f"Model Size vs Performance Metrics for {device_display_names[selected_device_id]}",
351
  )
352
 
353
  if size_perf_fig:
354
  st.plotly_chart(size_perf_fig, use_container_width=True)
355
  else:
356
+ st.warning("No data available for the selected device and quantization level.")
357
 
358
 
359
  def render_performance_plots(df: pd.DataFrame, filters: Dict):
 
796
  with rank_tab3:
797
  st.subheader("🔍 Rankings by Quantization")
798
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
799
  # Group by device and quantization level
800
  quant_rankings = df.copy()
801
  quant_summary = (