jerome-white commited on
Commit
77c903f
1 Parent(s): 8569336

Make figures distinct

Browse files
Files changed (2) hide show
  1. app.py +10 -12
  2. requirements.txt +1 -0
app.py CHANGED
@@ -8,6 +8,7 @@ from pathlib import Path
8
  import pandas as pd
9
  import gradio as gr
10
  import seaborn as sns
 
11
  from datasets import load_dataset
12
  from scipy.special import expit
13
 
@@ -104,18 +105,19 @@ class DataPlotter:
104
  self.df = df
105
 
106
  def plot(self):
107
- ax = self.draw()
 
 
 
108
  ax.grid(visible=True,
109
  axis='both',
110
  alpha=0.25,
111
  linestyle='dotted')
112
-
113
- fig = ax.get_figure()
114
  fig.tight_layout()
115
 
116
  return fig
117
 
118
- def draw(self):
119
  raise NotImplementedError()
120
 
121
  class RankPlotter(DataPlotter):
@@ -132,8 +134,8 @@ class RankPlotter(DataPlotter):
132
  .sort_values(by=self._y, ascending=False))
133
  super().__init__(view)
134
 
135
- def draw(self):
136
- ax = self.df.plot.scatter('ability', self._y)
137
  ax.hlines(self.y,
138
  xmin=self.df['lower'],
139
  xmax=self.df['upper'],
@@ -141,15 +143,13 @@ class RankPlotter(DataPlotter):
141
  ax.set_ylabel('')
142
  ax.set_yticks(self.y, self.df['model'])
143
 
144
- return ax
145
-
146
  class ComparisonPlotter(DataPlotter):
147
  def __init__(self, df, model_1, model_2, ci=0.95):
148
  super().__init__(compare(df, model_1, model_2))
149
  self.interval = hdi(self.df, ci)
150
 
151
- def draw(self):
152
- ax = sns.ecdfplot(self.df)
153
 
154
  (_, color, *_) = sns.color_palette()
155
  ax.axvline(x=self.df.median(),
@@ -161,8 +161,6 @@ class ComparisonPlotter(DataPlotter):
161
  color=color)
162
  ax.set_xlabel('Pr(M$_{1}$ \u003E M$_{2}$)')
163
 
164
- return ax
165
-
166
  def cplot(df, ci=0.95):
167
  def _plot(model_1, model_2):
168
  cp = ComparisonPlotter(df, model_1, model_2, ci)
 
8
  import pandas as pd
9
  import gradio as gr
10
  import seaborn as sns
11
+ import matplotlib.pyplot as plt
12
  from datasets import load_dataset
13
  from scipy.special import expit
14
 
 
105
  self.df = df
106
 
107
  def plot(self):
108
+ fig = plt.figure()
109
+
110
+ ax = fig.gca()
111
+ self.draw(ax)
112
  ax.grid(visible=True,
113
  axis='both',
114
  alpha=0.25,
115
  linestyle='dotted')
 
 
116
  fig.tight_layout()
117
 
118
  return fig
119
 
120
+ def draw(self, ax):
121
  raise NotImplementedError()
122
 
123
  class RankPlotter(DataPlotter):
 
134
  .sort_values(by=self._y, ascending=False))
135
  super().__init__(view)
136
 
137
+ def draw(self, ax):
138
+ self.df.plot.scatter('ability', self._y, ax=ax)
139
  ax.hlines(self.y,
140
  xmin=self.df['lower'],
141
  xmax=self.df['upper'],
 
143
  ax.set_ylabel('')
144
  ax.set_yticks(self.y, self.df['model'])
145
 
 
 
146
  class ComparisonPlotter(DataPlotter):
147
  def __init__(self, df, model_1, model_2, ci=0.95):
148
  super().__init__(compare(df, model_1, model_2))
149
  self.interval = hdi(self.df, ci)
150
 
151
+ def draw(self, ax):
152
+ sns.ecdfplot(self.df, ax=ax)
153
 
154
  (_, color, *_) = sns.color_palette()
155
  ax.axvline(x=self.df.median(),
 
161
  color=color)
162
  ax.set_xlabel('Pr(M$_{1}$ \u003E M$_{2}$)')
163
 
 
 
164
  def cplot(df, ci=0.95):
165
  def _plot(model_1, model_2):
166
  cp = ComparisonPlotter(df, model_1, model_2, ci)
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
  datasets
2
  gradio
 
3
  pandas
4
  scipy
5
  seaborn
 
1
  datasets
2
  gradio
3
+ matplotlib
4
  pandas
5
  scipy
6
  seaborn