|
import pandas as pd |
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
from typing import Tuple |
|
import plotly.express as px |
|
|
|
VOLUME_FACTOR_REGULARIZATION = 0.5 |
|
UNSCALED_WEIGHTED_ACCURACY_INTERVAL = (-0.5, 100.5) |
|
SCALED_WEIGHTED_ACCURACY_INTERVAL = (0, 1) |
|
|
|
|
|
tools_palette = { |
|
"prediction-request-reasoning": "darkorchid", |
|
"claude-prediction-offline": "rebeccapurple", |
|
"prediction-request-reasoning-claude": "slateblue", |
|
"prediction-request-rag-claude": "steelblue", |
|
"prediction-online": "darkcyan", |
|
"prediction-offline": "mediumaquamarine", |
|
"claude-prediction-online": "mediumseagreen", |
|
"prediction-online-sme": "yellowgreen", |
|
"prediction-url-cot-claude": "gold", |
|
"prediction-offline-sme": "orange", |
|
"prediction-request-rag": "chocolate", |
|
} |
|
|
|
HEIGHT = 400 |
|
WIDTH = 1100 |
|
|
|
|
|
def scale_value( |
|
value: float, |
|
min_max_bounds: Tuple[float, float], |
|
scale_bounds: Tuple[float, float] = (0, 1), |
|
) -> float: |
|
"""Perform min-max scaling on a value.""" |
|
min_, max_ = min_max_bounds |
|
current_range = max_ - min_ |
|
|
|
std = (value - min_) / current_range |
|
|
|
min_bound, max_bound = scale_bounds |
|
target_range = max_bound - min_bound |
|
return std * target_range + min_bound |
|
|
|
|
|
def get_weighted_accuracy(row, global_requests: int): |
|
"""Function to compute the weighted accuracy of a tool""" |
|
return scale_value( |
|
( |
|
row["tool_accuracy"] |
|
+ (row["total_requests"] / global_requests) * VOLUME_FACTOR_REGULARIZATION |
|
), |
|
UNSCALED_WEIGHTED_ACCURACY_INTERVAL, |
|
SCALED_WEIGHTED_ACCURACY_INTERVAL, |
|
) |
|
|
|
|
|
def compute_weighted_accuracy(tools_accuracy: pd.DataFrame): |
|
global_requests = tools_accuracy.total_requests.sum() |
|
tools_accuracy["weighted_accuracy"] = tools_accuracy.apply( |
|
lambda x: get_weighted_accuracy(x, global_requests), axis=1 |
|
) |
|
return tools_accuracy |
|
|
|
|
|
def plot_tools_accuracy_graph(tools_accuracy_info: pd.DataFrame): |
|
tools_accuracy_info = tools_accuracy_info.sort_values( |
|
by="tool_accuracy", ascending=False |
|
) |
|
plt.figure(figsize=(25, 10)) |
|
plot = sns.barplot( |
|
tools_accuracy_info, |
|
x="tool_accuracy", |
|
y="tool", |
|
hue="tool", |
|
dodge=False, |
|
palette=tools_palette, |
|
) |
|
plt.xlabel("Mech tool_accuracy (%)", fontsize=20) |
|
plt.ylabel("tool", fontsize=20) |
|
plt.tick_params(axis="y", labelsize=12) |
|
return gr.Plot(value=plot.get_figure()) |
|
|
|
|
|
def plot_tools_accuracy_rotated_graph(tools_accuracy_info: pd.DataFrame): |
|
tools_accuracy_info = tools_accuracy_info.sort_values( |
|
by="tool_accuracy", ascending=False |
|
) |
|
fig = px.bar( |
|
tools_accuracy_info, |
|
x="tool", |
|
y="tool_accuracy", |
|
color="tool", |
|
color_discrete_map=tools_palette, |
|
) |
|
fig.update_layout( |
|
xaxis_title="Tool", |
|
yaxis_title="Mech tool_accuracy (%)", |
|
) |
|
fig.update_layout(width=WIDTH, height=HEIGHT) |
|
|
|
fig.update_xaxes(showticklabels=False) |
|
return gr.Plot( |
|
value=fig, |
|
) |
|
|
|
|
|
def plot_tools_weighted_accuracy_graph(tools_accuracy_info: pd.DataFrame): |
|
tools_accuracy_info = tools_accuracy_info.sort_values( |
|
by="weighted_accuracy", ascending=False |
|
) |
|
|
|
|
|
plt.figure(figsize=(25, 10)) |
|
plot = sns.barplot( |
|
tools_accuracy_info, |
|
x="weighted_accuracy", |
|
y="tool", |
|
hue="tool", |
|
dodge=False, |
|
palette=tools_palette, |
|
) |
|
plt.xlabel("Weighted accuracy metric", fontsize=20) |
|
plt.ylabel("tool", fontsize=20) |
|
plt.tick_params(axis="y", labelsize=12) |
|
return gr.Plot(value=plot.get_figure()) |
|
|
|
|
|
def plot_tools_weighted_accuracy_rotated_graph(tools_accuracy_info: pd.DataFrame): |
|
tools_accuracy_info = tools_accuracy_info.sort_values( |
|
by="weighted_accuracy", ascending=False |
|
) |
|
fig = px.bar( |
|
tools_accuracy_info, |
|
x="tool", |
|
y="weighted_accuracy", |
|
color="tool", |
|
color_discrete_map=tools_palette, |
|
) |
|
fig.update_layout( |
|
xaxis_title="Tool", |
|
yaxis_title="Weighted accuracy metric", |
|
) |
|
fig.update_layout(width=WIDTH, height=HEIGHT) |
|
|
|
fig.update_xaxes(showticklabels=False) |
|
return gr.Plot( |
|
value=fig, |
|
) |
|
|