import math import operator as op import itertools as it import functools as ft import collections as cl from pathlib import Path from dataclasses import fields, asdict import pandas as pd import gradio as gr import seaborn as sns import matplotlib.pyplot as plt from datasets import load_dataset from scipy.special import expit from matplotlib.ticker import FixedLocator, StrMethodFormatter from hdinterval import HDI, HDInterval TabGroup = cl.namedtuple('TabGroup', 'name, docs, dataset') # # # def load(repo): parameter = 'parameter' model = 'model' items = [ 'chain', 'sample', parameter, model, 'value', ] dataset = load_dataset(str(repo)) return (dataset .get('train') .to_pandas() .rename(columns={'element': model}) .filter(items=items) .query(f'{parameter} == "alpha"') .drop(columns=parameter)) def summarize(df, ci=0.95): def _aggregate(i, g): values = g['value'] hdi = HDInterval(values) interval = hdi(ci) agg = { 'model': i, 'ability': values.median(), 'uncertainty': interval.width(), } agg.update(asdict(interval)) return agg groups = df.groupby('model', sort=False) records = it.starmap(_aggregate, groups) return pd.DataFrame.from_records(records) def rank(df, ascending, name='rank'): df = (df .sort_values(by=['ability', 'uncertainty'], ascending=[ascending, not ascending]) .drop(columns='uncertainty') .reset_index(drop=True)) df.index += 1 return df.reset_index(names=name) def compare(df, model_1, model_2): mcol = 'model' models = [ model_1, model_2, ] view = (df .query(f'{mcol} in @models') .pivot(index=['chain', 'sample'], columns=mcol, values='value')) return expit(view[model_1] - view[model_2]) # # # class DataPlotter: def __init__(self, df): self.df = df def plot(self): fig = plt.figure(dpi=200) ax = fig.gca() self.draw(ax) ax.grid(visible=True, axis='both', alpha=0.25, linestyle='dotted') fig.tight_layout() return fig def draw(self, ax): raise NotImplementedError() class RankPlotter(DataPlotter): _y = 'y' @ft.cached_property def y(self): return self.df[self._y] def __init__(self, df, top=10): view = rank(summarize(df), True, self._y) view = (view .tail(top) .sort_values(by=self._y, ascending=False)) super().__init__(view) def draw(self, ax): self.df.plot.scatter('ability', self._y, ax=ax) ax.hlines(self.y, xmin=self.df['lower'], xmax=self.df['upper'], alpha=0.5) ax.set_xlabel(ax.get_xlabel().title()) ax.set_ylabel('') ax.set_yticks(self.y, self.df['model']) class ComparisonPlotter(DataPlotter): _uncertain = 0.5 def __init__(self, df, model_1, model_2, ci): super().__init__(compare(df, model_1, model_2)) self.interval = HDInterval(self.df) self.ci = ci def draw(self, ax): hdi = self.interval(self.ci) (c_hist, c_hdi) = sns.color_palette('colorblind', n_colors=2) ax = sns.histplot(data=self.df, stat='density', color=c_hist) ax.set_xlabel('logit$^{-1}$(\u03B1$_{1}$ - \u03B1$_{2}$)') self.pr(ax, hdi, c_hdi) self.min_inclusive(ax) def min_inclusive(self, ax): try: ci = self.interval.at(self._uncertain) inclusive = '\u2208' except OverflowError: ci = 1 inclusive = '\u2209' except FloatingPointError: return ax.text(x=0.02, y=0.975, s=f'{self._uncertain} {inclusive} {ci:.0%} HDI', fontsize='small', fontstyle='italic', horizontalalignment='left', verticalalignment='top', transform=ax.transAxes) def pr(self, ax, hdi, color): x = self.df.median() zorder = ax.zorder - 1 (label, *_) = ax.get_xticklabels() parts = label.get_text().split('.') decimals = len(parts[-1]) + 1 if parts else 2 fmt = f'Pr(M$_{{{{1}}}}$ \u003E M$_{{{{2}}}}$) = {{x:.{decimals}f}}' ax.axvline(x=x, color=color, linestyle='dashed') ax.axvspan(xmin=hdi.lower, xmax=hdi.upper, alpha=0.15, color=color, zorder=zorder) ax_ = ax.secondary_xaxis('top') ax_.xaxis.set_major_locator(FixedLocator([x])) ax_.xaxis.set_major_formatter(StrMethodFormatter(fmt)) # # # class ComparisonMenu: def __init__(self, df, ci=95): self.df = df self.ci = ci def __call__(self, model_1, model_2, ci): if model_1 and model_2: ci /= 100 cp = ComparisonPlotter(self.df, model_1, model_2, ci) return cp.plot() def build_and_get(self): models = self.df['model'].unique() choices = sorted(models, key=lambda x: x.lower()) for i in range(1, 3): label = f'Model {i}' yield gr.Dropdown(label=label, choices=choices) yield gr.Number(value=self.ci, label='HDI (%)', minimum=0, maximum=100) # # # class DocumentationReader: _suffix = '.md' def __init__(self, root): self.root = root def __getitem__(self, item): return (self .root .joinpath(item) .with_suffix(self._suffix) .read_text()) # # # def layout(tab): df = load(Path('jerome-white', tab.dataset)) docs = DocumentationReader(Path('docs', t.docs)) with gr.Row(): with gr.Column(): gr.Markdown(docs['readme']) with gr.Column(): plotter = RankPlotter(df) gr.Plot(plotter.plot()) with gr.Row(): view = rank(summarize(df), False) columns = { x.name: f'HDI {x.name}' for x in fields(HDI) } for i in view.columns: columns.setdefault(i, i.title()) view = (view .rename(columns=columns) .style.format(precision=4)) gr.Dataframe(view) with gr.Row(): with gr.Column(scale=3): display = gr.Plot() with gr.Row(): with gr.Column(): gr.Markdown(f''' Probability that Model 1 is preferred to Model 2. The histogram is represents the distribution of inverse logit of the difference in model abilities. The dashed vertical line is its median. The shaded region demarcates the chosen [highest density interval](https://cran.r-project.org/package=HDInterval) (HDI). The note in the upper left denotes the smallest HDI that is inclusive of {ComparisonPlotter._uncertain}. ''') with gr.Column(): menu = ComparisonMenu(df) inputs = list(menu.build_and_get()) button = gr.Button(value='Compare!') button.click(menu, inputs=inputs, outputs=[display]) with gr.Accordion('Disclaimer', open=False): gr.Markdown(docs['disclaimer']) # # # with gr.Blocks() as demo: tabs = it.starmap(TabGroup, ( ('Chatbot Arena', 'arena', 'arena-bt-stan'), ('Alpaca', 'alpaca', 'alpaca-bt-stan'), )) for t in tabs: with gr.Tab(t.name): layout(t) demo.launch()