Spaces:
Sleeping
Sleeping
import math | |
import operator as op | |
import itertools as it | |
import functools as ft | |
import collections as cl | |
from pathlib import Path | |
from dataclasses import fields, asdict | |
import pandas as pd | |
import gradio as gr | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
from datasets import load_dataset | |
from scipy.special import expit | |
from hdinterval import HDI, HDInterval | |
TabGroup = cl.namedtuple('TabGroup', 'name, docs, dataset') | |
# | |
# | |
# | |
def load(repo): | |
parameter = 'parameter' | |
model = 'model' | |
items = [ | |
'chain', | |
'sample', | |
parameter, | |
model, | |
'value', | |
] | |
dataset = load_dataset(str(repo)) | |
return (dataset | |
.get('train') | |
.to_pandas() | |
.rename(columns={'element': model}) | |
.filter(items=items) | |
.query(f'{parameter} == "alpha"') | |
.drop(columns=parameter)) | |
def summarize(df, ci=0.95): | |
def _aggregate(i, g): | |
values = g['value'] | |
hdi = HDInterval(values) | |
interval = hdi(ci) | |
agg = { | |
'model': i, | |
'ability': values.median(), | |
'uncertainty': interval.width(), | |
} | |
agg.update(asdict(interval)) | |
return agg | |
groups = df.groupby('model', sort=False) | |
records = it.starmap(_aggregate, groups) | |
return pd.DataFrame.from_records(records) | |
def rank(df, ascending, name='rank'): | |
df = (df | |
.sort_values(by=['ability', 'uncertainty'], | |
ascending=[ascending, not ascending]) | |
.drop(columns='uncertainty') | |
.reset_index(drop=True)) | |
df.index += 1 | |
return df.reset_index(names=name) | |
def compare(df, model_1, model_2): | |
mcol = 'model' | |
models = [ | |
model_1, | |
model_2, | |
] | |
view = (df | |
.query(f'{mcol} in @models') | |
.pivot(index=['chain', 'sample'], | |
columns=mcol, | |
values='value')) | |
return expit(view[model_1] - view[model_2]) | |
# | |
# | |
# | |
class DataPlotter: | |
def __init__(self, df): | |
self.df = df | |
def plot(self): | |
fig = plt.figure(dpi=200) | |
ax = fig.gca() | |
self.draw(ax) | |
ax.grid(visible=True, | |
axis='both', | |
alpha=0.25, | |
linestyle='dotted') | |
fig.tight_layout() | |
return fig | |
def draw(self, ax): | |
raise NotImplementedError() | |
class RankPlotter(DataPlotter): | |
_y = 'y' | |
def y(self): | |
return self.df[self._y] | |
def __init__(self, df, top=10): | |
view = rank(summarize(df), True, self._y) | |
view = (view | |
.tail(top) | |
.sort_values(by=self._y, ascending=False)) | |
super().__init__(view) | |
def draw(self, ax): | |
self.df.plot.scatter('ability', self._y, ax=ax) | |
ax.hlines(self.y, | |
xmin=self.df['lower'], | |
xmax=self.df['upper'], | |
alpha=0.5) | |
ax.set_xlabel(ax.get_xlabel().title()) | |
ax.set_ylabel('') | |
ax.set_yticks(self.y, self.df['model']) | |
class ComparisonPlotter(DataPlotter): | |
def __init__(self, df, model_1, model_2, ci): | |
super().__init__(compare(df, model_1, model_2)) | |
self.hdi = HDInterval(self.df) | |
self.ci = ci | |
def draw(self, ax): | |
interval = self.hdi(self.ci) | |
sns.ecdfplot(self.df, ax=ax) | |
(_, color, *_) = sns.color_palette() | |
ax.axvline(x=self.df.median(), | |
color=color, | |
linestyle='dashed') | |
ax.axvspan(xmin=interval.lower, | |
xmax=interval.upper, | |
alpha=0.15, | |
color=color) | |
ax.set_xlabel('Pr(M$_{1}$ \u003E M$_{2}$)') | |
try: | |
ci_mid = self.hdi.at(0.5) | |
ax.text(x=0.01, | |
y=0.99, | |
s=f'0.5-min HDI: {ci_mid:.0%}', | |
horizontalalignment='left', | |
verticalalignment='top', | |
transform=ax.transAxes) | |
except ArithmeticError: | |
pass | |
class ComparisonMenu: | |
def __init__(self, df, ci=95): | |
self.df = df | |
self.ci = ci | |
def __call__(self, model_1, model_2, ci): | |
ci /= 100 | |
cp = ComparisonPlotter(self.df, model_1, model_2, ci) | |
return cp.plot() | |
def build_and_get(self): | |
models = self.df['model'].unique() | |
choices = sorted(models, key=lambda x: x.lower()) | |
for i in range(1, 3): | |
label = f'Model {i}' | |
yield gr.Dropdown(label=label, choices=choices) | |
yield gr.Number(value=self.ci, | |
label='HDI (%)', | |
minimum=0, | |
maximum=100) | |
# | |
# | |
# | |
class DocumentationReader: | |
_suffix = '.md' | |
def __init__(self, root): | |
self.root = root | |
def __getitem__(self, item): | |
return (self | |
.root | |
.joinpath(item) | |
.with_suffix(self._suffix) | |
.read_text()) | |
# | |
# | |
# | |
def layout(tab): | |
df = load(Path('jerome-white', tab.dataset)) | |
docs = DocumentationReader(Path('docs', t.docs)) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown(docs['readme']) | |
with gr.Column(): | |
plotter = RankPlotter(df) | |
gr.Plot(plotter.plot()) | |
with gr.Row(): | |
view = rank(summarize(df), False) | |
columns = { x.name: f'HDI {x.name}' for x in fields(HDI) } | |
for i in view.columns: | |
columns.setdefault(i, i.title()) | |
view = (view | |
.rename(columns=columns) | |
.style.format(precision=4)) | |
gr.Dataframe(view) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
display = gr.Plot() | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown(''' | |
Probability that Model 1 is preferred to Model 2. The | |
solid blue curve is a CDF of that distribution; | |
formally the inverse logit of the difference in model | |
abilities. The dashed orange vertical line is the | |
median, while the band surrounding it is the [highest | |
density | |
interval](https://cran.r-project.org/package=HDInterval) | |
of your choice (default 95%). | |
''') | |
with gr.Column(): | |
menu = ComparisonMenu(df) | |
inputs = list(menu.build_and_get()) | |
button = gr.Button(value='Compare!') | |
button.click(menu, inputs=inputs, outputs=[display]) | |
with gr.Accordion('Disclaimer', open=False): | |
gr.Markdown(docs['disclaimer']) | |
# | |
# | |
# | |
with gr.Blocks() as demo: | |
tabs = it.starmap(TabGroup, ( | |
('Alpaca', 'alpaca', 'alpaca-bt-stan'), | |
('Chatbot Arena', 'arena', 'arena-bt-stan'), | |
)) | |
for t in tabs: | |
with gr.Tab(t.name): | |
layout(t) | |
demo.launch() | |