theo commited on
Commit
26742b2
1 Parent(s): 2656c08

refactor multiselect to handle error uniformly

Browse files
Files changed (1) hide show
  1. tagging_app.py +92 -53
tagging_app.py CHANGED
@@ -1,6 +1,6 @@
1
  import json
2
  from pathlib import Path
3
- from typing import List, Tuple
4
 
5
  import streamlit as st
6
  import yaml
@@ -59,10 +59,32 @@ def load_ds_datas():
59
 
60
 
61
  def split_known(vals: List[str], okset: List[str]) -> Tuple[List[str], List[str]]:
 
 
62
  return [v for v in vals if v in okset], [v for v in vals if v not in okset]
63
 
64
 
65
- def new_pre_loaded():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  return {
67
  "task_categories": [],
68
  "task_ids": [],
@@ -76,7 +98,7 @@ def new_pre_loaded():
76
  }
77
 
78
 
79
- pre_loaded = new_pre_loaded()
80
  datasets_md = load_ds_datas()
81
  existing_tag_sets = {name: mds["metadata"] for name, mds in datasets_md.items()}
82
  all_dataset_ids = list(existing_tag_sets.keys())
@@ -112,39 +134,50 @@ preloaded_id = None
112
  did_index = 0
113
  if len(preload) == 1 and preload[0] in all_dataset_ids:
114
  preloaded_id, *_ = preload
115
- pre_loaded = existing_tag_sets[preloaded_id] or new_pre_loaded()
116
  did_index = all_dataset_ids.index(preloaded_id)
117
 
118
  did = st.sidebar.selectbox(label="Choose dataset to load tag set from", options=all_dataset_ids, index=did_index)
119
 
120
  leftbtn, rightbtn = st.sidebar.beta_columns(2)
121
  if leftbtn.button("pre-load tagset"):
122
- pre_loaded = existing_tag_sets[did] or new_pre_loaded()
123
  st.experimental_set_query_params(preload_dataset=did)
124
  if rightbtn.button("flush state"):
125
- pre_loaded = new_pre_loaded()
126
  st.experimental_set_query_params()
127
 
128
  if preloaded_id is not None:
129
- st.sidebar.markdown(f"Took [{preloaded_id}](https://huggingface.co/datasets/{preloaded_id}) as base tagset.")
 
 
 
 
 
 
 
130
 
131
 
132
  leftcol, _, rightcol = st.beta_columns([12, 1, 12])
133
 
134
 
135
  leftcol.markdown("### Supported tasks")
136
- task_categories = leftcol.multiselect(
 
 
137
  "What categories of task does the dataset support?",
138
- options=list(task_set.keys()),
139
- default=pre_loaded["task_categories"],
140
- format_func=lambda tg: f"{tg} : {task_set[tg]['description']}",
141
  )
142
  task_specifics = []
143
  for tg in task_categories:
144
- task_specs = leftcol.multiselect(
 
 
145
  f"What specific *{tg}* tasks does the dataset support?",
146
- options=task_set[tg]["options"],
147
- default=[ts for ts in pre_loaded["task_ids"] if ts in task_set[tg]["options"]],
148
  )
149
  if "other" in task_specs:
150
  other_task = st.text_input(
@@ -157,13 +190,13 @@ for tg in task_categories:
157
 
158
 
159
  leftcol.markdown("### Languages")
160
- filtered_existing_languages = [lgc for lgc in set(pre_loaded["languages"]) if lgc not in language_set_restricted]
161
- pre_loaded["languages"] = [lgc for lgc in set(pre_loaded["languages"]) if lgc in language_set_restricted]
162
 
163
- multilinguality = leftcol.multiselect(
 
 
164
  "Does the dataset contain more than one language?",
165
- options=list(multilinguality_set.keys()),
166
- default=pre_loaded["multilinguality"],
167
  format_func=lambda m: f"{m} : {multilinguality_set[m]}",
168
  )
169
 
@@ -175,41 +208,40 @@ if "other" in multilinguality:
175
  st.write(f"Registering other-{other_multilinguality} multilinguality")
176
  multilinguality[multilinguality.index("other")] = f"other-{other_multilinguality}"
177
 
178
- if len(filtered_existing_languages) > 0:
179
- leftcol.markdown(f"**Found bad language codes in existing tagset**:\n{filtered_existing_languages}")
180
- languages = leftcol.multiselect(
 
181
  "What languages are represented in the dataset?",
182
- options=list(language_set_restricted.keys()),
183
- default=pre_loaded["languages"],
184
  format_func=lambda m: f"{m} : {language_set_restricted[m]}",
185
  )
186
 
187
 
188
  leftcol.markdown("### Dataset creators")
189
- ok, nonok = split_known(pre_loaded["language_creators"], creator_set["language"])
190
- if len(nonok) > 0:
191
- leftcol.markdown(f"**Found bad codes in existing tagset**:\n{nonok}")
192
- language_creators = leftcol.multiselect(
193
  "Where does the text in the dataset come from?",
194
- options=creator_set["language"],
195
- default=ok,
196
  )
197
- ok, nonok = split_known(pre_loaded["annotations_creators"], creator_set["annotations"])
198
- if len(nonok) > 0:
199
- leftcol.markdown(f"**Found bad codes in existing tagset**:\n{nonok}")
200
- annotations_creators = leftcol.multiselect(
201
  "Where do the annotations in the dataset come from?",
202
- options=creator_set["annotations"],
203
- default=ok,
204
  )
205
 
206
- ok, nonok = split_known(pre_loaded["licenses"], list(license_set.keys()))
207
- if len(nonok) > 0:
208
- leftcol.markdown(f"**Found bad codes in existing tagset**:\n{nonok}")
209
- licenses = leftcol.multiselect(
210
  "What licenses is the dataset under?",
211
- options=list(license_set.keys()),
212
- default=ok,
213
  format_func=lambda l: f"{l} : {license_set[l]}",
214
  )
215
  if "other" in licenses:
@@ -219,24 +251,31 @@ if "other" in licenses:
219
  )
220
  st.write(f"Registering other-{other_license} license")
221
  licenses[licenses.index("other")] = f"other-{other_license}"
222
- # link ro supported datasets
 
223
  pre_select_ext_a = []
224
- if "original" in pre_loaded["source_datasets"]:
225
  pre_select_ext_a += ["original"]
226
- if any([p.startswith("extended") for p in pre_loaded["source_datasets"]]):
227
  pre_select_ext_a += ["extended"]
228
- extended = leftcol.multiselect(
 
 
229
  "Does the dataset contain original data and/or was it extended from other datasets?",
230
- options=["original", "extended"],
231
- default=pre_select_ext_a,
232
  )
233
  source_datasets = ["original"] if "original" in extended else []
 
 
234
  if "extended" in extended:
235
- pre_select_ext_b = [p.split("|")[1] for p in pre_loaded["source_datasets"] if p.startswith("extended")]
236
- extended_sources = leftcol.multiselect(
 
 
237
  "Which other datasets does this one use data from?",
238
- options=all_dataset_ids + ["other"],
239
- default=pre_select_ext_b,
240
  )
241
  if "other" in extended_sources:
242
  other_extended_sources = st.text_input(
@@ -248,7 +287,7 @@ if "extended" in extended:
248
  source_datasets += [f"extended|{src}" for src in extended_sources]
249
 
250
  size_cats = ["unknown", "n<1K", "1K<n<10K", "10K<n<100K", "100K<n<1M", "n>1M"]
251
- current_size_cats = pre_loaded.get("size_categories") or ["unknown"]
252
  ok, nonok = split_known(current_size_cats, size_cats)
253
  if len(nonok) > 0:
254
  leftcol.markdown(f"**Found bad codes in existing tagset**:\n{nonok}")
 
1
  import json
2
  from pathlib import Path
3
+ from typing import Callable, List, Tuple
4
 
5
  import streamlit as st
6
  import yaml
 
59
 
60
 
61
  def split_known(vals: List[str], okset: List[str]) -> Tuple[List[str], List[str]]:
62
+ if vals is None:
63
+ return [], []
64
  return [v for v in vals if v in okset], [v for v in vals if v not in okset]
65
 
66
 
67
+ def multiselect(
68
+ w: st.delta_generator.DeltaGenerator,
69
+ title: str,
70
+ markdown: str,
71
+ values: List[str],
72
+ valid_set: List[str],
73
+ format_func: Callable = str,
74
+ ):
75
+ valid_values, invalid_values = split_known(values, valid_set)
76
+ w.markdown(
77
+ """
78
+ #### {title}
79
+ {errors}
80
+ """.format(
81
+ title=title, errors="" if len(invalid_values) == 0 else f"_Found invalid values:_ `{invalid_values}`"
82
+ )
83
+ )
84
+ return w.multiselect(markdown, valid_set, default=valid_values, format_func=format_func)
85
+
86
+
87
+ def new_state():
88
  return {
89
  "task_categories": [],
90
  "task_ids": [],
 
98
  }
99
 
100
 
101
+ state = new_state()
102
  datasets_md = load_ds_datas()
103
  existing_tag_sets = {name: mds["metadata"] for name, mds in datasets_md.items()}
104
  all_dataset_ids = list(existing_tag_sets.keys())
 
134
  did_index = 0
135
  if len(preload) == 1 and preload[0] in all_dataset_ids:
136
  preloaded_id, *_ = preload
137
+ state = existing_tag_sets[preloaded_id] or new_state()
138
  did_index = all_dataset_ids.index(preloaded_id)
139
 
140
  did = st.sidebar.selectbox(label="Choose dataset to load tag set from", options=all_dataset_ids, index=did_index)
141
 
142
  leftbtn, rightbtn = st.sidebar.beta_columns(2)
143
  if leftbtn.button("pre-load tagset"):
144
+ state = existing_tag_sets[did] or new_state()
145
  st.experimental_set_query_params(preload_dataset=did)
146
  if rightbtn.button("flush state"):
147
+ state = new_state()
148
  st.experimental_set_query_params()
149
 
150
  if preloaded_id is not None:
151
+ st.sidebar.markdown(
152
+ f"""
153
+ Took [`{preloaded_id}`](https://huggingface.co/datasets/{preloaded_id}) as base tagset:
154
+ ```yaml
155
+ {yaml.dump(state)}
156
+ ```
157
+ """
158
+ )
159
 
160
 
161
  leftcol, _, rightcol = st.beta_columns([12, 1, 12])
162
 
163
 
164
  leftcol.markdown("### Supported tasks")
165
+ task_categories = multiselect(
166
+ leftcol,
167
+ "Task category",
168
  "What categories of task does the dataset support?",
169
+ values=state["task_categories"],
170
+ valid_set=list(task_set.keys()),
171
+ format_func=lambda tg: f"{tg}: {task_set[tg]['description']}",
172
  )
173
  task_specifics = []
174
  for tg in task_categories:
175
+ task_specs = multiselect(
176
+ leftcol,
177
+ "Specific tasks",
178
  f"What specific *{tg}* tasks does the dataset support?",
179
+ values=[ts for ts in state["task_ids"] if ts in task_set[tg]["options"]],
180
+ valid_set=task_set[tg]["options"],
181
  )
182
  if "other" in task_specs:
183
  other_task = st.text_input(
 
190
 
191
 
192
  leftcol.markdown("### Languages")
 
 
193
 
194
+ multilinguality = multiselect(
195
+ leftcol,
196
+ "Monolingual?",
197
  "Does the dataset contain more than one language?",
198
+ values=state["multilinguality"],
199
+ valid_set=list(multilinguality_set.keys()),
200
  format_func=lambda m: f"{m} : {multilinguality_set[m]}",
201
  )
202
 
 
208
  st.write(f"Registering other-{other_multilinguality} multilinguality")
209
  multilinguality[multilinguality.index("other")] = f"other-{other_multilinguality}"
210
 
211
+
212
+ languages = multiselect(
213
+ leftcol,
214
+ "Languages",
215
  "What languages are represented in the dataset?",
216
+ values=state["languages"],
217
+ valid_set=list(language_set_restricted.keys()),
218
  format_func=lambda m: f"{m} : {language_set_restricted[m]}",
219
  )
220
 
221
 
222
  leftcol.markdown("### Dataset creators")
223
+ language_creators = multiselect(
224
+ leftcol,
225
+ "Data origin",
 
226
  "Where does the text in the dataset come from?",
227
+ values=state["language_creators"],
228
+ valid_set=creator_set["language"],
229
  )
230
+ annotations_creators = multiselect(
231
+ leftcol,
232
+ "Annotations origin",
 
233
  "Where do the annotations in the dataset come from?",
234
+ values=state["annotations_creators"],
235
+ valid_set=creator_set["annotations"],
236
  )
237
 
238
+
239
+ licenses = multiselect(
240
+ leftcol,
241
+ "Licenses",
242
  "What licenses is the dataset under?",
243
+ valid_set=list(license_set.keys()),
244
+ values=state["licenses"],
245
  format_func=lambda l: f"{l} : {license_set[l]}",
246
  )
247
  if "other" in licenses:
 
251
  )
252
  st.write(f"Registering other-{other_license} license")
253
  licenses[licenses.index("other")] = f"other-{other_license}"
254
+
255
+ # link to supported datasets
256
  pre_select_ext_a = []
257
+ if "original" in state["source_datasets"]:
258
  pre_select_ext_a += ["original"]
259
+ if any([p.startswith("extended") for p in state["source_datasets"]]):
260
  pre_select_ext_a += ["extended"]
261
+ extended = multiselect(
262
+ leftcol,
263
+ "Relations to existing work",
264
  "Does the dataset contain original data and/or was it extended from other datasets?",
265
+ values=pre_select_ext_a,
266
+ valid_set=["original", "extended"],
267
  )
268
  source_datasets = ["original"] if "original" in extended else []
269
+
270
+ # todo: show bad tags
271
  if "extended" in extended:
272
+ pre_select_ext_b = [p.split("|")[1] for p in state["source_datasets"] if p.startswith("extended")]
273
+ extended_sources = multiselect(
274
+ leftcol,
275
+ "Linked datasets",
276
  "Which other datasets does this one use data from?",
277
+ values=pre_select_ext_b,
278
+ valid_set=all_dataset_ids + ["other"],
279
  )
280
  if "other" in extended_sources:
281
  other_extended_sources = st.text_input(
 
287
  source_datasets += [f"extended|{src}" for src in extended_sources]
288
 
289
  size_cats = ["unknown", "n<1K", "1K<n<10K", "10K<n<100K", "100K<n<1M", "n>1M"]
290
+ current_size_cats = state.get("size_categories") or ["unknown"]
291
  ok, nonok = split_known(current_size_cats, size_cats)
292
  if len(nonok) > 0:
293
  leftcol.markdown(f"**Found bad codes in existing tagset**:\n{nonok}")