import evaluate import json import sys from pathlib import Path import gradio as gr import numpy as np import pandas as pd import ast import matplotlib.pyplot as plt import matplotlib.patches as mpatches plt.rcParams["figure.dpi"] = 300 plt.switch_backend( "agg" ) # ; https://stackoverflow.com/questions/14694408/runtimeerror-main-thread-is-not-in-main-loop def default_plot(): fig = plt.figure() ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2) ax2 = plt.subplot2grid((3, 1), (2, 0)) ranged = np.linspace(0, 1, 10) ax1.plot( ranged, ranged, color="darkgreen", ls="dotted", label="Perfect", ) # Bin differences ax1.set_ylabel("Conditional Expectation") ax1.set_ylim([0, 1.05]) ax1.set_title("Reliability Diagram") ax1.set_xlim([-0.05, 1.05]) # respective to bin range # Bin frequencies ax2.set_xlabel("Confidence") ax2.set_ylabel("Count") ax2.set_xlim([-0.05, 1.05]) # respective to bin range return fig, ax1, ax2 def reliability_plot(results): # DEV: might still need to write tests in case of equal mass binning # DEV: nicer would be to plot like a polygon # see: https://github.com/markus93/fit-on-the-test/blob/main/Experiments_Synthetic/binnings.py fig, ax1, ax2 = default_plot() # Bin differences bins_with_left_edge = np.insert(results["y_bar"], 0, 0, axis=0) bins_with_right_edge = np.insert(results["y_bar"], -1, 1.0, axis=0) bins_with_leftright_edge = np.insert(bins_with_left_edge, -1, 1.0, axis=0) weights = np.nan_to_num(results["p_bar"], copy=True, nan=0) # NOTE: the histogram API is strange _, _, patches = ax1.hist( bins_with_left_edge, weights=weights, bins=bins_with_leftright_edge, ) for b in range(len(patches)): perfect = bins_with_right_edge[b] # if b != n_bins else empirical = weights[b] # patches[b]._height bin_color = ( "limegreen" if perfect == empirical else "dodgerblue" if empirical < perfect else "orangered" ) patches[b].set_facecolor(bin_color) # color based on over/underconfidence ax1handles = [ mpatches.Patch(color="orangered", label="Overconfident"), mpatches.Patch(color="limegreen", label="Perfect", linestyle="dotted"), mpatches.Patch(color="dodgerblue", label="Underconfident"), ] # Bin frequencies anindices = np.where(~np.isnan(results["p_bar"]))[0] bin_freqs = np.zeros(len(results["p_bar"])) bin_freqs[anindices] = results["bin_freq"] ax2.hist( bins_with_left_edge, weights=bin_freqs, color="midnightblue", bins=bins_with_leftright_edge ) acc_plt = ax2.axvline(x=results["accuracy"], ls="solid", lw=3, c="black", label="Accuracy") conf_plt = ax2.axvline( x=results["p_bar_cont"], ls="dotted", lw=3, c="#444", label="Avg. confidence" ) ax1.legend(loc="lower right", handles=ax1handles) ax2.legend(handles=[acc_plt, conf_plt]) ax1.set_xticks(bins_with_left_edge) ax2.set_xticks(bins_with_left_edge) plt.tight_layout() return fig def compute_and_plot(data, n_bins, bin_range, scheme, proxy, p): # DEV: check on invalid datatypes with better warnings if isinstance(data, pd.DataFrame): data.dropna(inplace=True) predictions = [ ast.literal_eval(prediction) if not isinstance(prediction, list) else prediction for prediction in data["predictions"] ] references = [reference for reference in data["references"]] results = metric._compute( predictions, references, n_bins=n_bins, scheme=scheme, proxy=proxy, p=p, detail=True, ) plot = reliability_plot(results) return results["ECE"], plot sliders = [ gr.Slider(0, 100, value=10, label="n_bins"), gr.Slider( 0, 100, value=None, label="bin_range", visible=False ), # DEV: need to have a double slider gr.Dropdown(choices=["equal-range", "equal-mass"], value="equal-range", label="scheme"), gr.Dropdown(choices=["upper-edge", "center"], value="upper-edge", label="proxy"), gr.Dropdown(choices=[1, 2, np.inf], value=1, label="p"), ] slider_defaults = [slider.value for slider in sliders] # example data component = gr.inputs.Dataframe( headers=["predictions", "references"], col_count=2, datatype="number", type="pandas" ) component.value = [ [[0.6, 0.2, 0.2], 0], [[0.7, 0.1, 0.2], 2], [[0, 0.95, 0.05], 1], ] sample_data = [[component] + slider_defaults] local_path = Path(sys.path[0]) metric = evaluate.load("jordyvl/ece") outputs = [gr.outputs.Textbox(label="ECE"), gr.Plot(label="Reliability diagram")] # outputs[1].value = default_plot().__dict__ #DEV: Does not work in gradio; needs to be JSON encoded iface = gr.Interface( fn=compute_and_plot, inputs=[component] + sliders, outputs=outputs, description=metric.info.description, article=evaluate.utils.parse_readme(local_path / "README.md"), title=f"Metric: {metric.name}", # examples=sample_data; #DEV: ValueError: Examples argument must either be a directory or a nested list, where each sublist represents a set of inputs. ).launch()