jerome-white's picture
Shift less-than 0.5 note further left
4ed3a8f
raw
history blame
7.93 kB
import math
import operator as op
import itertools as it
import functools as ft
import collections as cl
from pathlib import Path
from dataclasses import fields, asdict
import pandas as pd
import gradio as gr
import seaborn as sns
import matplotlib.pyplot as plt
from datasets import load_dataset
from scipy.special import expit
from matplotlib.ticker import FixedLocator, StrMethodFormatter
from hdinterval import HDI, HDInterval
TabGroup = cl.namedtuple('TabGroup', 'name, docs, dataset')
#
#
#
def load(repo):
parameter = 'parameter'
model = 'model'
items = [
'chain',
'sample',
parameter,
model,
'value',
]
dataset = load_dataset(str(repo))
return (dataset
.get('train')
.to_pandas()
.rename(columns={'element': model})
.filter(items=items)
.query(f'{parameter} == "alpha"')
.drop(columns=parameter))
def summarize(df, ci=0.95):
def _aggregate(i, g):
values = g['value']
hdi = HDInterval(values)
interval = hdi(ci)
agg = {
'model': i,
'ability': values.median(),
'uncertainty': interval.width(),
}
agg.update(asdict(interval))
return agg
groups = df.groupby('model', sort=False)
records = it.starmap(_aggregate, groups)
return pd.DataFrame.from_records(records)
def rank(df, ascending, name='rank'):
df = (df
.sort_values(by=['ability', 'uncertainty'],
ascending=[ascending, not ascending])
.drop(columns='uncertainty')
.reset_index(drop=True))
df.index += 1
return df.reset_index(names=name)
def compare(df, model_1, model_2):
mcol = 'model'
models = [
model_1,
model_2,
]
view = (df
.query(f'{mcol} in @models')
.pivot(index=['chain', 'sample'],
columns=mcol,
values='value'))
return expit(view[model_1] - view[model_2])
#
#
#
class DataPlotter:
def __init__(self, df):
self.df = df
def plot(self):
fig = plt.figure(dpi=200)
ax = fig.gca()
self.draw(ax)
ax.grid(visible=True,
axis='both',
alpha=0.25,
linestyle='dotted')
fig.tight_layout()
return fig
def draw(self, ax):
raise NotImplementedError()
class RankPlotter(DataPlotter):
_y = 'y'
@ft.cached_property
def y(self):
return self.df[self._y]
def __init__(self, df, top=10):
view = rank(summarize(df), True, self._y)
view = (view
.tail(top)
.sort_values(by=self._y, ascending=False))
super().__init__(view)
def draw(self, ax):
self.df.plot.scatter('ability', self._y, ax=ax)
ax.hlines(self.y,
xmin=self.df['lower'],
xmax=self.df['upper'],
alpha=0.5)
ax.set_xlabel(ax.get_xlabel().title())
ax.set_ylabel('')
ax.set_yticks(self.y, self.df['model'])
class ComparisonPlotter(DataPlotter):
_uncertain = 0.5
def __init__(self, df, model_1, model_2, ci):
super().__init__(compare(df, model_1, model_2))
self.interval = HDInterval(self.df)
self.ci = ci
def draw(self, ax):
hdi = self.interval(self.ci)
(c_hist, c_hdi) = sns.color_palette('colorblind', n_colors=2)
ax = sns.histplot(data=self.df,
stat='density',
color=c_hist)
ax.set_xlabel('logit$^{-1}$(\u03B1$_{1}$ - \u03B1$_{2}$)')
self.pr(ax, hdi, c_hdi)
try:
ci_min = self.interval.at(self._uncertain)
ax.text(x=0.02,
y=0.975,
s=f'{self._uncertain} \u2208 {ci_min:.0%} HDI',
fontsize='small',
fontstyle='italic',
horizontalalignment='left',
verticalalignment='top',
transform=ax.transAxes)
except ArithmeticError:
pass
def pr(self, ax, hdi, color):
x = self.df.median()
zorder = ax.zorder - 1
(label, *_) = ax.get_xticklabels()
parts = label.get_text().split('.')
decimals = len(parts[-1]) + 1 if parts else 2
fmt = f'Pr(M$_{{{{1}}}}$ \u003E M$_{{{{2}}}}$) = {{x:.{decimals}f}}'
ax.axvline(x=x,
color=color,
linestyle='dashed')
ax.axvspan(xmin=hdi.lower,
xmax=hdi.upper,
alpha=0.15,
color=color,
zorder=zorder)
ax_ = ax.secondary_xaxis('top')
ax_.xaxis.set_major_locator(FixedLocator([x]))
ax_.xaxis.set_major_formatter(StrMethodFormatter(fmt))
#
#
#
class ComparisonMenu:
def __init__(self, df, ci=95):
self.df = df
self.ci = ci
def __call__(self, model_1, model_2, ci):
if model_1 and model_2:
ci /= 100
cp = ComparisonPlotter(self.df, model_1, model_2, ci)
return cp.plot()
def build_and_get(self):
models = self.df['model'].unique()
choices = sorted(models, key=lambda x: x.lower())
for i in range(1, 3):
label = f'Model {i}'
yield gr.Dropdown(label=label, choices=choices)
yield gr.Number(value=self.ci,
label='HDI (%)',
minimum=0,
maximum=100)
#
#
#
class DocumentationReader:
_suffix = '.md'
def __init__(self, root):
self.root = root
def __getitem__(self, item):
return (self
.root
.joinpath(item)
.with_suffix(self._suffix)
.read_text())
#
#
#
def layout(tab):
df = load(Path('jerome-white', tab.dataset))
docs = DocumentationReader(Path('docs', t.docs))
with gr.Row():
with gr.Column():
gr.Markdown(docs['readme'])
with gr.Column():
plotter = RankPlotter(df)
gr.Plot(plotter.plot())
with gr.Row():
view = rank(summarize(df), False)
columns = { x.name: f'HDI {x.name}' for x in fields(HDI) }
for i in view.columns:
columns.setdefault(i, i.title())
view = (view
.rename(columns=columns)
.style.format(precision=4))
gr.Dataframe(view)
with gr.Row():
with gr.Column(scale=3):
display = gr.Plot()
with gr.Row():
with gr.Column():
gr.Markdown(f'''
Probability that Model 1 is preferred to Model 2. The
histogram is represents the distribution of inverse
logit of the difference in model abilities. The dashed
vertical line is its median. The shaded region
demarcates the chosen [highest density
interval](https://cran.r-project.org/package=HDInterval)
(HDI). The note in the upper left denotes the smallest
HDI that is inclusive of
{ComparisonPlotter._uncertain}.
''')
with gr.Column():
menu = ComparisonMenu(df)
inputs = list(menu.build_and_get())
button = gr.Button(value='Compare!')
button.click(menu, inputs=inputs, outputs=[display])
with gr.Accordion('Disclaimer', open=False):
gr.Markdown(docs['disclaimer'])
#
#
#
with gr.Blocks() as demo:
tabs = it.starmap(TabGroup, (
('Chatbot Arena', 'arena', 'arena-bt-stan'),
('Alpaca', 'alpaca', 'alpaca-bt-stan'),
))
for t in tabs:
with gr.Tab(t.name):
layout(t)
demo.launch()