dawood HF staff NimaBoscarino commited on
Commit
924d3bd
0 Parent(s):

Duplicate from society-ethics/disaggregators

Browse files

Co-authored-by: Nima Boscarino <NimaBoscarino@users.noreply.huggingface.co>

Files changed (9) hide show
  1. .gitattributes +34 -0
  2. .gitignore +56 -0
  3. README.md +27 -0
  4. app.py +547 -0
  5. cached_data.pkl +3 -0
  6. generate_datasets.py +40 -0
  7. medmcqa.pkl +3 -0
  8. prep_data.py +151 -0
  9. requirements.txt +5 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Locked files
2
+ *.lock
3
+ !dvc.lock
4
+
5
+ # Extracted dummy data
6
+ datasets/**/dummy_data-zip-extracted/
7
+
8
+ # Compiled python modules.
9
+ *.pyc
10
+
11
+ # Byte-compiled
12
+ _pycache__/
13
+ .cache/
14
+
15
+ # Python egg metadata, regenerated from source files by setuptools.
16
+ *.egg-info
17
+ .eggs/
18
+
19
+ # PyPI distribution artifacts.
20
+ build/
21
+ dist/
22
+
23
+ # Environments
24
+ .env
25
+ .venv
26
+ env/
27
+ venv/
28
+ ENV/
29
+ env.bak/
30
+ venv.bak/
31
+
32
+ # pyenv
33
+ .python-version
34
+
35
+ # Tests
36
+ .pytest_cache/
37
+
38
+ # Other
39
+ *.DS_Store
40
+
41
+ # PyCharm/vscode
42
+ .idea
43
+ .vscode
44
+
45
+ # Vim
46
+ .*.swp
47
+
48
+ # playground
49
+ /playground
50
+
51
+ # Sphinx documentation
52
+ docs/_build/
53
+ docs/source/_build/
54
+
55
+ *.ipynb
56
+ data_measurements_tool/
README.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Disaggregators
3
+ emoji: 🔥
4
+ colorFrom: red
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 3.12.0
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: society-ethics/disaggregators
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+
15
+
16
+ Friction log:
17
+ - (Gradio) the dataset's .click handler doesn't refresh if I chagnge the samples in the dataset, so I have to find a workaround...
18
+ - (DMT) installing it is a nightmare, especially when I just want to use it as a library
19
+
20
+ ```python
21
+ from data_measurements.dataset_statistics import DatasetStatisticsCacheClass as dmt_cls
22
+
23
+ dstats = dmt_cls(dset_name="NimaBoscarino/medmcqa_age_gender", dset_config=None, split_name="train", text_field="question", label_field=(), label_names=[], use_cache=True) # Maybe values for label_field and label_names??
24
+
25
+ label_obj = labels.DMTHelper(dstats, load_only=False, save=dstats.save)
26
+
27
+ ```
app.py ADDED
@@ -0,0 +1,547 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from datasets import load_dataset
3
+ import matplotlib as mpl
4
+ mpl.use('Agg')
5
+ from typing import List
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import joblib
9
+ import itertools
10
+ import pandas as pd
11
+
12
+ cached_artifacts = joblib.load("cached_data.pkl")
13
+
14
+ laion = load_dataset("society-ethics/laion2B-en_continents", split="train").to_pandas()
15
+ medmcqa = load_dataset("society-ethics/medmcqa_age_gender_custom", split="train").to_pandas()
16
+ stack = load_dataset("society-ethics/the-stack-tabs_spaces", split="train").to_pandas()\
17
+ .drop(columns=["max_stars_repo_licenses", "max_issues_repo_licenses", "max_forks_repo_licenses"])
18
+
19
+ cached_artifacts["laion"]["text"] = {
20
+ "title": "Disaggregating by continent with a built-in module",
21
+ "description": """
22
+ The [`laion/laion2b-en` dataset](https://huggingface.co/datasets/laion/laion2B-en), created by [LAION](https://laion.ai), is used to train image generation models such as [Stable Diffusion](https://huggingface.co/spaces/stabilityai/stable-diffusion). The dataset contains pairs of images and captions, but we might also be curious about the distribution of specific topics, such as <u>continents</u>, mentioned in the captions.
23
+
24
+ The original dataset doesn't contain metadata about specific continents, but we can attempt to infer it from the `TEXT` feature with `disaggregators`. Note that several factors contribute to a high incidence of false positives, such as the fact that country and city names are frequently used as names for fashion products.
25
+ """,
26
+ "visualization": """
27
+ This view shows you a visualization of the relative proportion of each label in the disaggregated dataset. For this dataset, we've only disaggregated by one category (continent), but there are many possible values for it. While there are many rows that haven't been flagged with a continent (check "None" and see!), this disaggregator doesn't assign *Multiple* continents.
28
+
29
+ To see examples of individual rows, click over to the "Inspect" tab!
30
+ """,
31
+ "code": """
32
+ ```python
33
+ from disaggregators import Disaggregator
34
+ disaggregator = Disaggregator("continent", column="TEXT")
35
+
36
+ # Note: this demo used a subset of the dataset
37
+ from datasets import load_dataset
38
+ ds = load_dataset("laion/laion2B-en", split="train", streaming=True).map(disaggregator)
39
+ ```
40
+ """
41
+ }
42
+
43
+ cached_artifacts["medmcqa"]["text"] = {
44
+ "title": "Overriding configurations for built-in modules",
45
+ "description": """
46
+ Meta's [Galactica model](https://galactica.org) is trained on a large-scale scientific corpus, which includes the [`medmcqa` dataset](https://huggingface.co/datasets/medmcqa) of medical entrance exam questions. MedMCQA has a `question` feature which often contains a case scenario, where a hypothetical patient presents with a condition.
47
+
48
+ The original dataset doesn't contain metadata about the <u>age</u> and <u>binary gender</u>, but we can infer them with the `age` and `gender` modules. If a module doesn't have the particular label options that you'd like, such as additional genders or specific age buckets, you can override the module's configuration. In this example we've configured the `age` module to use [NIH's MeSH age groups](https://www.ncbi.nlm.nih.gov/mesh/68009273).
49
+ """,
50
+ "visualization": """
51
+ Since we've disaggregated the MedMCQA dataset by *two* categories (age and binary gender), we can click on "Age + Gender" to visualize the proportions of the *intersections* of each group.
52
+
53
+ There are two things to note about this example:
54
+ 1. The disaggregators for age and gender can flag rows as having more than one age or gender, which we've grouped here as "Multiple"
55
+ 2. If you look at the data through the "Inspect" tab, you'll notice that there are some false positives. `disaggregators` is in early development, and these modules are in a very early "proof of concept" stage! Keep an eye out as we develop more sophisticated algorithms for disaggregation, and [join us over on GitHub](https://github.com/huggingface/disaggregators) if you'd like to contribute ideas, documentation, or code.
56
+ """,
57
+ "code": """
58
+ ```python
59
+ from disaggregators import Disaggregator
60
+ from disaggregators.disaggregation_modules.age import Age, AgeLabels, AgeConfig
61
+
62
+ class MeSHAgeLabels(AgeLabels):
63
+ INFANT = "infant"
64
+ CHILD_PRESCHOOL = "child_preschool"
65
+ CHILD = "child"
66
+ ADOLESCENT = "adolescent"
67
+ ADULT = "adult"
68
+ MIDDLE_AGED = "middle_aged"
69
+ AGED = "aged"
70
+ AGED_80_OVER = "aged_80_over"
71
+
72
+ age_config = AgeConfig(
73
+ labels=MeSHAgeLabels,
74
+ ages=[list(MeSHAgeLabels)],
75
+ breakpoints=[0, 2, 5, 12, 18, 44, 64, 79]
76
+ )
77
+
78
+ age = Age(config=age_config, column="question")
79
+
80
+ disaggregator = Disaggregator([age, "gender"], column="question")
81
+
82
+ from datasets import load_dataset
83
+ ds = load_dataset("medmcqa", split="train").map(disaggregator)
84
+ ```
85
+ """
86
+ }
87
+
88
+ cached_artifacts["stack"]["text"] = {
89
+ "title": "Creating custom disaggregators",
90
+ "description": """
91
+ [The BigCode Project](https://www.bigcode-project.org/) recently released [`bigcode/the-stack`](https://huggingface.co/datasets/bigcode/the-stack), which contains contains over 6TB of permissively-licensed source code files covering 358 programming languages. One of the languages included is [JSX](https://reactjs.org/docs/introducing-jsx.html), which is an extension to JavaScript specifically designed for the [React UI library](https://reactjs.org/docs/introducing-jsx.html). Let's ask some questions about the React code in this dataset!
92
+
93
+ 1. React lets developers define UI components [as functions or as classes](https://reactjs.org/docs/components-and-props.html#function-and-class-components). Which style is more popular in this dataset?
94
+ 2. Programmers have long argued over using [tabs or spaces](https://www.youtube.com/watch?v=SsoOG6ZeyUI). Who's winning?
95
+
96
+ `disaggregators` makes it easy to add your own disaggregation modules. See the code snippet below for an example 🤗
97
+ """,
98
+ "visualization": """
99
+ Like the MedMCQA example, this dataset has also been disaggregated by more than one category. Using multiple disaggregation modules lets us get insights into interesting *intersections* of the subpopulations in our datasets.
100
+ """,
101
+ "code": """
102
+ ```python
103
+ from disaggregators import Disaggregator, DisaggregationModuleLabels, CustomDisaggregator
104
+
105
+ class TabsSpacesLabels(DisaggregationModuleLabels):
106
+ TABS = "tabs"
107
+ SPACES = "spaces"
108
+
109
+ class TabsSpaces(CustomDisaggregator):
110
+ module_id = "tabs_spaces"
111
+ labels = TabsSpacesLabels
112
+
113
+ def __call__(self, row, *args, **kwargs):
114
+ if "\\t" in row[self.column]:
115
+ return {self.labels.TABS: True, self.labels.SPACES: False}
116
+ else:
117
+ return {self.labels.TABS: False, self.labels.SPACES: True}
118
+
119
+ class ReactComponentLabels(DisaggregationModuleLabels):
120
+ CLASS = "class"
121
+ FUNCTION = "function"
122
+
123
+
124
+ class ReactComponent(CustomDisaggregator):
125
+ module_id = "react_component"
126
+ labels = ReactComponentLabels
127
+
128
+ def __call__(self, row, *args, **kwargs):
129
+ if "extends React.Component" in row[self.column] or "extends Component" in row[self.column]:
130
+ return {self.labels.CLASS: True, self.labels.FUNCTION: False}
131
+ else:
132
+ return {self.labels.CLASS: False, self.labels.FUNCTION: True}
133
+
134
+ disaggregator = Disaggregator([TabsSpaces, ReactComponent], column="content")
135
+
136
+ # Note: this demo used a subset of the dataset
137
+ from datasets import load_dataset
138
+ ds = load_dataset("bigcode/the-stack", data_dir="data/jsx", split="train", streaming=True).map(disaggregator)
139
+ ```
140
+ """
141
+ }
142
+
143
+
144
+ def create_plot(selected_fields, available_fields, distributions, feature_names, plot=None):
145
+ plt.close('all')
146
+ clean_fields = [field for field in selected_fields if field not in ["Multiple", "None"]]
147
+ extra_options = [field for field in selected_fields if field in ["Multiple", "None"]]
148
+
149
+ distributions = distributions.reorder_levels(
150
+ sorted(list(available_fields)) + [idx for idx in distributions.index.names if idx not in available_fields]
151
+ )
152
+ distributions = distributions.sort_index()
153
+
154
+ def get_tuple(field):
155
+ return tuple(True if field == x else False for x in sorted(available_fields))
156
+
157
+ masks = [get_tuple(field) for field in sorted(clean_fields)]
158
+ data = [distributions.get(mask, 0) for mask in masks]
159
+ data = [x.sum() if type(x) != int else x for x in data]
160
+
161
+ if "Multiple" in extra_options:
162
+ masks_mult = [el for el in itertools.product((True, False), repeat=len(available_fields)) if el.count(True) > 1]
163
+ data = data + [sum([distributions.get(mask, pd.Series(dtype=float)).sum() for mask in masks_mult])]
164
+
165
+ if "None" in extra_options:
166
+ none_mask = tuple(False for x in available_fields)
167
+ data = data + [distributions.get(none_mask, pd.Series(dtype=float)).sum()]
168
+
169
+ fig, ax = plt.subplots()
170
+
171
+ title = "Distribution "
172
+ size = 0.3
173
+
174
+ cmap = plt.colormaps["Set3"]
175
+ outer_colors = cmap(np.arange(len(data)))
176
+
177
+ total_sum = sum(data)
178
+ all_fields = sorted(clean_fields) + sorted(extra_options)
179
+ labels = [f"{feature_names.get(c, c)}\n{round(data[i] / total_sum * 100, 2)}%" for i, c in enumerate(all_fields)]
180
+
181
+ ax.pie(data, radius=1, labels=labels, colors=outer_colors,
182
+ wedgeprops=dict(width=size, edgecolor='w'))
183
+
184
+ ax.set(aspect="equal", title=title)
185
+
186
+ if plot is None:
187
+ return gr.Plot(plt)
188
+ else:
189
+ new_plot = plot.update(plt)
190
+ return new_plot
191
+
192
+
193
+ # TODO: Consolidate with the other plot function...
194
+ def create_nested_plot(selected_outer, available_outer, selected_inner, available_inner, distributions, feature_names, plot=None):
195
+ plt.close('all')
196
+
197
+ clean_outer = [field for field in selected_outer if field not in ["Multiple", "None"]]
198
+ extra_outer = [field for field in selected_outer if field in ["Multiple", "None"]]
199
+
200
+ clean_inner = [field for field in selected_inner if field not in ["Multiple", "None"]]
201
+ extra_inner = [field for field in selected_inner if field in ["Multiple", "None"]]
202
+
203
+ distributions = distributions.reorder_levels(
204
+ sorted(list(available_outer)) + sorted(list(available_inner)) + sorted([idx for idx in distributions.index.names if idx not in (available_outer + available_inner)])
205
+ )
206
+ distributions = distributions.sort_index()
207
+
208
+ def get_tuple(field, field_options):
209
+ return tuple(True if field == x else False for x in sorted(field_options))
210
+
211
+ masks_outer = [get_tuple(field, available_outer) for field in sorted(clean_outer)]
212
+ masks_inner = [get_tuple(field, available_inner) for field in sorted(clean_inner)]
213
+
214
+ data_inner = [[distributions.get(m_o + mask, 0) for mask in masks_inner] for m_o in masks_outer]
215
+
216
+ masks_mult_inner = []
217
+ masks_none_inner = []
218
+
219
+ if "Multiple" in extra_inner:
220
+ masks_mult_inner = [el for el in itertools.product((True, False), repeat=len(available_inner)) if el.count(True) > 1]
221
+ masks_mult = [m_o + m_i for m_i in masks_mult_inner for m_o in masks_outer]
222
+ mult_inner_count = [distributions.get(mask, pd.Series(dtype=float)).sum() for mask in masks_mult]
223
+ data_inner = [di + [mult_inner_count[idx]] for idx, di in enumerate(data_inner)]
224
+
225
+ if "None" in extra_inner:
226
+ masks_none_inner = tuple(False for x in available_inner)
227
+ masks_none = [m_o + masks_none_inner for m_o in masks_outer]
228
+ none_inner_count = [distributions.get(mask, pd.Series(dtype=float)).sum() for mask in masks_none]
229
+ data_inner = [di + [none_inner_count[idx]] for idx, di in enumerate(data_inner)]
230
+ if len(available_inner) > 0:
231
+ masks_none_inner = [masks_none_inner]
232
+
233
+ if "Multiple" in extra_outer:
234
+ masks_mult = [el for el in itertools.product((True, False), repeat=len(available_outer)) if el.count(True) > 1]
235
+ data_inner = data_inner + [[
236
+ sum([distributions.get(mask + mask_inner, pd.Series(dtype=float)).sum() for mask in masks_mult])
237
+ for mask_inner in (masks_inner + masks_mult_inner + masks_none_inner)
238
+ ]]
239
+
240
+ if "None" in extra_outer:
241
+ none_mask_outer = tuple(False for x in available_outer)
242
+ data_inner = data_inner + [[distributions.get(none_mask_outer + mask, pd.Series(dtype=float)).sum() for mask in (masks_inner + masks_mult_inner + masks_none_inner)]]
243
+
244
+ fig, ax = plt.subplots()
245
+
246
+ title = "Distribution "
247
+ size = 0.3
248
+
249
+ cmap = plt.colormaps["Set3"]
250
+ cmap2 = plt.colormaps["Set2"]
251
+ outer_colors = cmap(np.arange(len(data_inner)))
252
+ inner_colors = cmap2(np.arange(len(data_inner[0])))
253
+
254
+ total_sum = sum(sum(data_inner, []))
255
+ data_outer = [sum(x) for x in data_inner]
256
+ all_fields_outer = sorted(clean_outer) + sorted(extra_outer)
257
+
258
+ clean_labels_outer = [f"{feature_names.get(c, c)}\n{round(data_outer[i] / total_sum * 100, 2)}%" for i, c in enumerate(all_fields_outer)]
259
+ clean_labels_inner = [feature_names[c] for c in sorted(clean_inner)]
260
+
261
+ ax.pie(data_outer, radius=1, labels=clean_labels_outer, colors=outer_colors,
262
+ wedgeprops=dict(width=size, edgecolor='w'))
263
+
264
+ patches, _ = ax.pie(list(itertools.chain(*data_inner)), radius=1 - size, colors=inner_colors,
265
+ wedgeprops=dict(width=size, edgecolor='w'))
266
+
267
+ ax.set(aspect="equal", title=title)
268
+ fig.legend(handles=patches, labels=clean_labels_inner + sorted(extra_inner), loc="lower left")
269
+
270
+ if plot is None:
271
+ return gr.Plot(plt)
272
+ else:
273
+ new_plot = plot.update(plt)
274
+ return new_plot
275
+
276
+
277
+ def select_new_base_plot(plot, disagg_check, disagg_by, artifacts):
278
+ if disagg_by == "Both":
279
+ disaggs = sorted(list(artifacts["disaggregators"]))
280
+
281
+ all_choices = sorted([[x for x in artifacts["data_fields"] if x.startswith(d)] for d in disaggs], key=len, reverse=True)
282
+
283
+ selected_choices = list(artifacts["data_fields"])
284
+ choices = selected_choices + [f"{disagg}.{extra}" for disagg in disaggs for extra in ["Multiple", "None"]]
285
+
286
+ # Map feature names to labels
287
+ choices = [artifacts["feature_names"].get(x, x) for x in choices]
288
+ selected_choices = [artifacts["feature_names"].get(x, x) for x in selected_choices]
289
+
290
+ # Choose new options
291
+ new_check = disagg_check.update(choices=sorted(choices), value=selected_choices)
292
+
293
+ # Generate plot
294
+ new_plot = create_nested_plot(
295
+ all_choices[0], all_choices[0],
296
+ all_choices[1], all_choices[1],
297
+ artifacts["distributions"],
298
+ artifacts["feature_names"],
299
+ plot=plot
300
+ )
301
+
302
+ return new_plot, new_check
303
+
304
+ else:
305
+ selected_choices = [field for field in artifacts["data_fields"] if field.startswith(disagg_by)]
306
+ choices = selected_choices + ["Multiple", "None"]
307
+
308
+ # Map feature names to labels
309
+ choices_for_check = [artifacts["feature_names"].get(x, x) for x in choices]
310
+ selected_choices_for_check = [artifacts["feature_names"].get(x, x) for x in selected_choices]
311
+
312
+ # Choose new options
313
+ new_check = disagg_check.update(choices=choices_for_check, value=selected_choices_for_check)
314
+
315
+ # Generate plot
316
+ new_plot = create_plot(
317
+ sorted(selected_choices), sorted(selected_choices), artifacts["distributions"], artifacts["feature_names"],
318
+ plot=plot
319
+ )
320
+
321
+ return new_plot, new_check
322
+
323
+
324
+ def select_new_sub_plot(plot, disagg_check, disagg_by, artifacts):
325
+ if disagg_by == "Both":
326
+ disaggs = sorted(list(artifacts["disaggregators"]))
327
+
328
+ all_choices = sorted([[x for x in artifacts["data_fields"] if x.startswith(d)] for d in disaggs], key=len, reverse=True)
329
+
330
+ choice1 = all_choices[0][0].split(".")[0]
331
+ choice2 = all_choices[1][0].split(".")[0]
332
+
333
+ check1 = [dc for dc in disagg_check if dc.startswith(choice1)]
334
+ check2 = [dc for dc in disagg_check if dc.startswith(choice2)]
335
+
336
+ check1 = ["Multiple" if c == f"{c.split('.')[0]}.Multiple" else c for c in check1]
337
+ check1 = ["None" if c == f"{c.split('.')[0]}.None" else c for c in check1]
338
+ check2 = ["Multiple" if c == f"{c.split('.')[0]}.Multiple" else c for c in check2]
339
+ check2 = ["None" if c == f"{c.split('.')[0]}.None" else c for c in check2]
340
+
341
+ new_plot = create_nested_plot(
342
+ check1, all_choices[0],
343
+ check2, all_choices[1],
344
+ artifacts["distributions"],
345
+ artifacts["feature_names"],
346
+ plot=plot
347
+ )
348
+
349
+ return new_plot
350
+ else:
351
+ selected_choices = [field for field in artifacts["data_fields"] if field.startswith(disagg_by)]
352
+
353
+ # Generate plot
354
+ new_plot = create_plot(
355
+ disagg_check, selected_choices, artifacts["distributions"], artifacts["feature_names"],
356
+ plot=plot
357
+ )
358
+
359
+ return new_plot
360
+
361
+
362
+ def visualization_filter(plot, artifacts, default_value, intersect=False):
363
+ def map_labels_to_fields(labels: List[str]):
364
+ return [list(artifacts["feature_names"].keys())[list(artifacts["feature_names"].values()).index(x)] if not any([extra in x for extra in ["Multiple", "None"]]) else x for x in labels]
365
+
366
+ def map_category_to_disaggregator(category: str): # e.g. Gender, Age, Gender + Age -> gender, age, Both
367
+ return list(artifacts["feature_names"].keys())[list(artifacts["feature_names"].values()).index(category)]
368
+
369
+ choices = sorted(list(artifacts["disaggregators"]))
370
+ if intersect:
371
+ choices = choices + ["Both"]
372
+
373
+ # Map categories to nice names
374
+ choices = [artifacts["feature_names"][c] for c in choices]
375
+
376
+ disagg_radio = gr.Radio(
377
+ label="Disaggregate by...",
378
+ choices=choices,
379
+ value=artifacts["feature_names"][default_value],
380
+ interactive=True
381
+ )
382
+
383
+ selected_choices = [field for field in artifacts["data_fields"] if field.startswith(f"{default_value}.")]
384
+ choices = selected_choices + ["Multiple", "None"]
385
+
386
+ # Map feature names to labels
387
+ choices = [artifacts["feature_names"].get(x, x) for x in choices]
388
+ selected_choices = [artifacts["feature_names"].get(x, x) for x in selected_choices]
389
+
390
+ disagg_check = gr.CheckboxGroup(
391
+ label="Features",
392
+ choices=choices,
393
+ interactive=True,
394
+ value=selected_choices,
395
+ )
396
+
397
+ disagg_radio.change(
398
+ lambda x: select_new_base_plot(plot, disagg_check, map_category_to_disaggregator(x), artifacts),
399
+ inputs=[disagg_radio],
400
+ outputs=[plot, disagg_check]
401
+ )
402
+
403
+ disagg_check.change(
404
+ lambda x, y: select_new_sub_plot(plot, map_labels_to_fields(x), map_category_to_disaggregator(y), artifacts),
405
+ inputs=[disagg_check, disagg_radio],
406
+ outputs=[plot]
407
+ )
408
+
409
+
410
+ def generate_components(dataset, artifacts, intersect=True):
411
+ gr.Markdown(f"### {artifacts['text']['title']}")
412
+ gr.Markdown(artifacts['text']['description'])
413
+
414
+ with gr.Accordion(label="💻 Click me to see the code!", open=False):
415
+ gr.Markdown(artifacts["text"]["code"])
416
+
417
+ with gr.Tab("Visualize"):
418
+ with gr.Row(elem_id="visualization-window"):
419
+ with gr.Column():
420
+ disagg_by = sorted(list(artifacts["disaggregators"]))[0]
421
+ selected_choices = [field for field in artifacts["data_fields"] if field.startswith(disagg_by)]
422
+ plot = create_plot(
423
+ sorted(selected_choices),
424
+ sorted(selected_choices),
425
+ artifacts["distributions"],
426
+ artifacts["feature_names"]
427
+ )
428
+
429
+ with gr.Column():
430
+ gr.Markdown("### Visualization")
431
+ gr.Markdown(artifacts["text"]["visualization"])
432
+ visualization_filter(plot, artifacts, disagg_by, intersect=intersect)
433
+ with gr.Tab("Inspect"):
434
+ with gr.Row():
435
+ with gr.Column(scale=1):
436
+ gr.Markdown("### Data Inspector")
437
+ gr.Markdown("This tab lets you filter the disaggregated dataset and inspect individual elements. Set as many filters as you like, and then click \"Apply filters\" to fetch a random subset of rows that match *all* of the filters you've selected.")
438
+
439
+ filter_groups = gr.CheckboxGroup(choices=sorted(list(artifacts["data_fields"])), label="Filters")
440
+ fetch_subset = gr.Button("Apply filters")
441
+
442
+ sample_dataframe = gr.State(value=dataset.sample(10))
443
+
444
+ def fetch_new_samples(filters):
445
+ if len(filters) == 0:
446
+ new_dataset = dataset.sample(10)
447
+ else:
448
+ filter_query = " & ".join([f"`{f}`" for f in filters])
449
+ new_dataset = dataset.query(filter_query)
450
+ if new_dataset.shape[0] > 0:
451
+ new_dataset = new_dataset.sample(10)
452
+
453
+ new_samples = [[
454
+ x[1][artifacts["column"]],
455
+ ", ".join([col for col in artifacts["data_fields"] if x[1][col]]),
456
+ ] for x in new_dataset.iterrows()]
457
+ return sample_rows.update(samples=new_samples), new_dataset
458
+
459
+ sample_rows = gr.Dataset(
460
+ samples=[[
461
+ x[1][artifacts["column"]],
462
+ ", ".join([col for col in artifacts["data_fields"] if x[1][col]]),
463
+ ] for x in sample_dataframe.value.iterrows()],
464
+ components=[gr.Textbox(visible=False), gr.Textbox(visible=False)],
465
+ type="index"
466
+ )
467
+ with gr.Column(scale=1):
468
+ row_inspector = gr.DataFrame(
469
+ wrap=True,
470
+ visible=False
471
+ )
472
+
473
+ fetch_subset.click(
474
+ fetch_new_samples,
475
+ inputs=[filter_groups],
476
+ outputs=[sample_rows, sample_dataframe],
477
+ )
478
+
479
+ sample_rows.click(
480
+ lambda df, index: row_inspector.update(visible=True, value=df.iloc[index].reset_index()),
481
+ inputs=[sample_dataframe, sample_rows],
482
+ outputs=[row_inspector]
483
+ )
484
+
485
+
486
+ with gr.Blocks(css="#visualization-window {flex-direction: row-reverse;}") as demo:
487
+ gr.Markdown("# Exploring Disaggregated Data with 🤗 Disaggregators")
488
+ with gr.Accordion("About this demo 👀"):
489
+ gr.Markdown("## What's in your dataset?")
490
+ gr.Markdown("""
491
+ Addressing fairness and bias in machine learning models is [more important than ever](https://www.vice.com/en/article/bvm35w/this-tool-lets-anyone-see-the-bias-in-ai-image-generators)!
492
+ One form of fairness is equal performance across different groups or features.
493
+ To measure this, evaluation datasets must be disaggregated across the different groups of interest.
494
+ """)
495
+
496
+ gr.Markdown("The `disaggregators` library ([GitHub](https://github.com/huggingface/disaggregators)) provides an interface and a collection of modules to help you disaggregate datasets by different groups. Click through each of the tabs below to see it in action!")
497
+ gr.Markdown("""
498
+ After tinkering with the demo, you can install 🤗 Disaggregators with:
499
+ ```bash
500
+ pip install disaggregators
501
+ ```
502
+ """)
503
+ gr.Markdown("Each tab below will show you a feature of `disaggregators` used on a different dataset. First, you'll learn about using the built-in disaggregation modules. The second tab will show you how to override the configurations for the existing modules. Finally, the third tab will show you how to incorporate your own custom modules.")
504
+
505
+ with gr.Tab("🐊 LAION: Built-in Modules Example"):
506
+ generate_components(laion, cached_artifacts["laion"], intersect=False)
507
+ with gr.Tab("🔧 MedMCQA: Configuration Example"):
508
+ generate_components(medmcqa, cached_artifacts["medmcqa"])
509
+ with gr.Tab("🎡 The Stack: Custom Disaggregation Example"):
510
+ generate_components(stack, cached_artifacts["stack"])
511
+
512
+ with gr.Accordion(label="💡How is this calculated?", open=False):
513
+ gr.Markdown("""
514
+ ## Continent
515
+
516
+ Continents are inferred by identifying geographic terms and their related countries using [geograpy3](https://github.com/somnathrakshit/geograpy3). The results are then mapped to [their respective continents](https://github.com/bigscience-workshop/data_sourcing/blob/master/sourcing_sprint/resources/country_regions.json).
517
+
518
+ ## Age
519
+
520
+ Ages are inferred by using [spaCy](https://spacy.io) to seek "date" tokens in strings.
521
+
522
+ ## Gender
523
+
524
+ Binary gender is inferred by checking for words against the [md_gender_bias](https://huggingface.co/datasets/md_gender_bias) dataset.
525
+
526
+ ```
527
+ @inproceedings{dinan-etal-2020-multi,
528
+ title = "Multi-Dimensional Gender Bias Classification",
529
+ author = "Dinan, Emily and
530
+ Fan, Angela and
531
+ Wu, Ledell and
532
+ Weston, Jason and
533
+ Kiela, Douwe and
534
+ Williams, Adina",
535
+ year = "2020",
536
+ publisher = "Association for Computational Linguistics",
537
+ url = "https://www.aclweb.org/anthology/2020.emnlp-main.23",
538
+ doi = "10.18653/v1/2020.emnlp-main.23",
539
+ ```
540
+
541
+ ## Learn more!
542
+
543
+ Visit the [GitHub repository](https://github.com/huggingface/disaggregators) to learn about using the `disaggregators` library and to leave feedback 🤗
544
+ """)
545
+
546
+
547
+ demo.launch()
cached_data.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:194b625aa5a5fe55ae5a94fe8ca6e618c90a702ce34cc4aae9040f5c5f66ce54
3
+ size 16181
generate_datasets.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from disaggregators import Disaggregator
3
+ from disaggregators.disaggregation_modules.age import Age, AgeLabels, AgeConfig
4
+
5
+
6
+ class MeSHAgeLabels(AgeLabels):
7
+ INFANT = "infant"
8
+ CHILD_PRESCHOOL = "child_preschool"
9
+ CHILD = "child"
10
+ ADOLESCENT = "adolescent"
11
+ ADULT = "adult"
12
+ MIDDLE_AGED = "middle_aged"
13
+ AGED = "aged"
14
+ AGED_80_OVER = "aged_80_over"
15
+
16
+
17
+ age = Age(
18
+ config=AgeConfig(
19
+ labels=MeSHAgeLabels,
20
+ ages=[
21
+ MeSHAgeLabels.INFANT,
22
+ MeSHAgeLabels.CHILD_PRESCHOOL,
23
+ MeSHAgeLabels.CHILD,
24
+ MeSHAgeLabels.ADOLESCENT,
25
+ MeSHAgeLabels.ADULT,
26
+ MeSHAgeLabels.MIDDLE_AGED,
27
+ MeSHAgeLabels.AGED,
28
+ MeSHAgeLabels.AGED_80_OVER
29
+ ],
30
+ breakpoints=[0, 2, 5, 12, 18, 44, 64, 79]
31
+ ),
32
+ column="question"
33
+ )
34
+
35
+ disaggregator = Disaggregator([age, "gender"], column="question")
36
+
37
+ ds = load_dataset("medmcqa", split="train")
38
+
39
+ ds_mapped = ds.map(disaggregator)
40
+ ds_mapped.push_to_hub("society-ethics/medmcqa_age_gender_custom")
medmcqa.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3cdc5bc914c22ce5cb946b35d585ed1f52ba269b4da43b0b1b9a5d872908c77f
3
+ size 538
prep_data.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from disaggregators import Disaggregator, DisaggregationModuleLabels, CustomDisaggregator
3
+ from disaggregators.disaggregation_modules.age import Age, AgeLabels, AgeConfig
4
+
5
+ import matplotlib
6
+ matplotlib.use('TKAgg')
7
+
8
+ import joblib
9
+ import os
10
+
11
+ cache_file = "cached_data.pkl"
12
+ cache_dict = {}
13
+
14
+ if os.path.exists(cache_file):
15
+ cache_dict = joblib.load("cached_data.pkl")
16
+
17
+ class MeSHAgeLabels(AgeLabels):
18
+ INFANT = "infant"
19
+ CHILD_PRESCHOOL = "child_preschool"
20
+ CHILD = "child"
21
+ ADOLESCENT = "adolescent"
22
+ ADULT = "adult"
23
+ MIDDLE_AGED = "middle_aged"
24
+ AGED = "aged"
25
+ AGED_80_OVER = "aged_80_over"
26
+
27
+
28
+ age = Age(
29
+ config=AgeConfig(
30
+ labels=MeSHAgeLabels,
31
+ ages=[list(MeSHAgeLabels)],
32
+ breakpoints=[0, 2, 5, 12, 18, 44, 64, 79]
33
+ ),
34
+ column="question"
35
+ )
36
+
37
+
38
+ class TabsSpacesLabels(DisaggregationModuleLabels):
39
+ TABS = "tabs"
40
+ SPACES = "spaces"
41
+
42
+
43
+ class TabsSpaces(CustomDisaggregator):
44
+ module_id = "tabs_spaces"
45
+ labels = TabsSpacesLabels
46
+
47
+ def __call__(self, row, *args, **kwargs):
48
+ if "\t" in row[self.column]:
49
+ return {self.labels.TABS: True, self.labels.SPACES: False}
50
+ else:
51
+ return {self.labels.TABS: False, self.labels.SPACES: True}
52
+
53
+
54
+ class ReactComponentLabels(DisaggregationModuleLabels):
55
+ CLASS = "class"
56
+ FUNCTION = "function"
57
+
58
+
59
+ class ReactComponent(CustomDisaggregator):
60
+ module_id = "react_component"
61
+ labels = ReactComponentLabels
62
+
63
+ def __call__(self, row, *args, **kwargs):
64
+ if "extends React.Component" in row[self.column] or "extends Component" in row[self.column]:
65
+ return {self.labels.CLASS: True, self.labels.FUNCTION: False}
66
+ else:
67
+ return {self.labels.CLASS: False, self.labels.FUNCTION: True}
68
+
69
+
70
+ configs = {
71
+ "laion": {
72
+ "disaggregation_modules": ["continent"],
73
+ "dataset_name": "society-ethics/laion2B-en_continents",
74
+ "column": "TEXT",
75
+ "feature_names": {
76
+ "continent.africa": "Africa",
77
+ "continent.americas": "Americas",
78
+ "continent.asia": "Asia",
79
+ "continent.europe": "Europe",
80
+ "continent.oceania": "Oceania",
81
+
82
+ # Parent level
83
+ "continent": "Continent",
84
+ }
85
+ },
86
+ "medmcqa": {
87
+ "disaggregation_modules": [age, "gender"],
88
+ "dataset_name": "society-ethics/medmcqa_age_gender_custom",
89
+ "column": "question",
90
+ "feature_names": {
91
+ "age.infant": "Infant",
92
+ "age.child_preschool": "Preschool",
93
+ "age.child": "Child",
94
+ "age.adolescent": "Adolescent",
95
+ "age.adult": "Adult",
96
+ "age.middle_aged": "Middle Aged",
97
+ "age.aged": "Aged",
98
+ "age.aged_80_over": "Aged 80+",
99
+ "gender.male": "Male",
100
+ "gender.female": "Female",
101
+
102
+ # Parent level
103
+ "gender": "Gender",
104
+ "age": "Age",
105
+ "Both": "Age + Gender",
106
+ }
107
+ },
108
+ "stack": {
109
+ "disaggregation_modules": [TabsSpaces, ReactComponent],
110
+ "dataset_name": "society-ethics/the-stack-tabs_spaces",
111
+ "column": "content",
112
+ "feature_names": {
113
+ "react_component.class": "Class",
114
+ "react_component.function": "Function",
115
+ "tabs_spaces.tabs": "Tabs",
116
+ "tabs_spaces.spaces": "Spaces",
117
+
118
+ # Parent level
119
+ "react_component": "React Component Syntax",
120
+ "tabs_spaces": "Tabs vs. Spaces",
121
+ "Both": "React Component Syntax + Tabs vs. Spaces",
122
+ }
123
+ }
124
+ }
125
+
126
+
127
+ def generate_cached_data(disaggregation_modules, dataset_name, column, feature_names):
128
+ disaggregator = Disaggregator(disaggregation_modules, column=column)
129
+ ds = load_dataset(dataset_name, split="train")
130
+ df = ds.to_pandas()
131
+
132
+ all_fields = {*disaggregator.fields, "None"}
133
+ distributions = df[sorted(list(disaggregator.fields))].value_counts()
134
+
135
+ return {
136
+ "fields": all_fields,
137
+ "data_fields": disaggregator.fields,
138
+ "distributions": distributions,
139
+ "disaggregators": [module.name for module in disaggregator.modules],
140
+ "column": column,
141
+ "feature_names": feature_names,
142
+ }
143
+
144
+
145
+ cache_dict.update({
146
+ "laion": generate_cached_data(**configs["laion"]),
147
+ "medmcqa": generate_cached_data(**configs["medmcqa"]),
148
+ "stack": generate_cached_data(**configs["stack"])
149
+ })
150
+
151
+ joblib.dump(cache_dict, cache_file)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
1
+ gradio==3.12.0
2
+ datasets
3
+ matplotlib
4
+ numpy
5
+ joblib