saattrupdan commited on
Commit
81fc601
1 Parent(s): 79f114e

fix: Ensure different colours, remove Faroese

Browse files
Files changed (1) hide show
  1. app.py +61 -23
app.py CHANGED
@@ -123,6 +123,7 @@ paper](https://aclanthology.org/2023.nodalida-1.20):
123
 
124
 
125
  UPDATE_FREQUENCY_MINUTES = 30
 
126
 
127
 
128
  class Task(BaseModel):
@@ -165,11 +166,11 @@ KNOWLEDGE = Task(name="knowledge", metric="mcc")
165
  REASONING = Task(name="reasoning", metric="mcc")
166
  ALL_TASKS = [obj for obj in globals().values() if isinstance(obj, Task)]
167
 
 
168
  DANISH = Language(code="da", name="Danish")
169
  NORWEGIAN = Language(code="no", name="Norwegian")
170
  SWEDISH = Language(code="sv", name="Swedish")
171
  ICELANDIC = Language(code="is", name="Icelandic")
172
- FAROESE = Language(code="fo", name="Faroese")
173
  GERMAN = Language(code="de", name="German")
174
  DUTCH = Language(code="nl", name="Dutch")
175
  ENGLISH = Language(code="en", name="English")
@@ -189,7 +190,6 @@ DATASETS = [
189
  Dataset(name="norne-nb", language=NORWEGIAN, task=INFORMATION_EXTRACTION),
190
  Dataset(name="norne-nn", language=NORWEGIAN, task=INFORMATION_EXTRACTION),
191
  Dataset(name="mim-gold-ner", language=ICELANDIC, task=INFORMATION_EXTRACTION),
192
- Dataset(name="fone", language=FAROESE, task=INFORMATION_EXTRACTION),
193
  Dataset(name="germeval", language=GERMAN, task=INFORMATION_EXTRACTION),
194
  Dataset(name="conll-nl", language=DUTCH, task=INFORMATION_EXTRACTION),
195
  Dataset(name="conll-en", language=ENGLISH, task=INFORMATION_EXTRACTION),
@@ -198,7 +198,6 @@ DATASETS = [
198
  Dataset(name="scala-nb", language=NORWEGIAN, task=GRAMMAR),
199
  Dataset(name="scala-nn", language=NORWEGIAN, task=GRAMMAR),
200
  Dataset(name="scala-is", language=ICELANDIC, task=GRAMMAR),
201
- Dataset(name="scala-fo", language=FAROESE, task=GRAMMAR),
202
  Dataset(name="scala-de", language=GERMAN, task=GRAMMAR),
203
  Dataset(name="scala-nl", language=DUTCH, task=GRAMMAR),
204
  Dataset(name="scala-en", language=ENGLISH, task=GRAMMAR),
@@ -247,13 +246,48 @@ def main() -> None:
247
  results_dfs = fetch_results()
248
  last_fetch = dt.datetime.now()
249
 
250
- all_languages = [
251
- language.name for language in ALL_LANGUAGES.values()
252
- ]
253
- danish_models = list({
254
- model_id
255
- for model_id in results_dfs[DANISH].index
256
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
  with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
259
  gr.Markdown(INTRO_MARKDOWN)
@@ -317,17 +351,18 @@ def main() -> None:
317
  show_scale=show_scale_checkbox.value,
318
  plot_width=plot_width_slider.value,
319
  plot_height=plot_height_slider.value,
 
320
  results_dfs=results_dfs,
321
  ),
322
  )
323
- with gr.Row():
324
- gr.Markdown(
325
- "<center>Made with ❤️ by the <a href=\"https://alexandra.dk\">"
326
- "Alexandra Institute</a>.</center>"
327
- )
328
  with gr.Tab(label="About"):
329
  gr.Markdown(ABOUT_MARKDOWN)
330
 
 
 
 
 
 
331
  language_names_dropdown.change(
332
  fn=partial(update_model_ids_dropdown, results_dfs=results_dfs),
333
  inputs=[language_names_dropdown, model_ids_dropdown],
@@ -336,7 +371,11 @@ def main() -> None:
336
 
337
  # Update plot when anything changes
338
  update_plot_kwargs = dict(
339
- fn=partial(produce_radial_plot, results_dfs=results_dfs),
 
 
 
 
340
  inputs=[
341
  model_ids_dropdown,
342
  language_names_dropdown,
@@ -443,6 +482,7 @@ def produce_radial_plot(
443
  show_scale: bool,
444
  plot_width: int,
445
  plot_height: int,
 
446
  results_dfs: dict[Language, pd.DataFrame] | None,
447
  ) -> go.Figure:
448
  """Produce a radial plot as a plotly figure.
@@ -460,6 +500,8 @@ def produce_radial_plot(
460
  The width of the plot.
461
  plot_height:
462
  The height of the plot.
 
 
463
  results_dfs:
464
  The results dataframes for each language.
465
 
@@ -551,17 +593,13 @@ def produce_radial_plot(
551
  # Add the results to a plotly figure
552
  fig = go.Figure()
553
  for model_id, result_list in zip(model_ids, results):
554
-
555
- # Generate colour for model, as an RGB triplet. The same model will always
556
- # have the same colour
557
- random.seed(model_id)
558
- r, g, b = tuple(random.randint(0, 255) for _ in range(3))
559
-
560
  fig.add_trace(go.Scatterpolar(
561
  r=result_list,
562
  theta=[task.name for task in tasks],
563
- fill='toself',
564
  name=model_id,
 
 
565
  line=dict(color=f'rgb({r}, {g}, {b})'),
566
  ))
567
 
 
123
 
124
 
125
  UPDATE_FREQUENCY_MINUTES = 30
126
+ MIN_COLOUR_DISTANCE_BETWEEN_MODELS = 200
127
 
128
 
129
  class Task(BaseModel):
 
166
  REASONING = Task(name="reasoning", metric="mcc")
167
  ALL_TASKS = [obj for obj in globals().values() if isinstance(obj, Task)]
168
 
169
+
170
  DANISH = Language(code="da", name="Danish")
171
  NORWEGIAN = Language(code="no", name="Norwegian")
172
  SWEDISH = Language(code="sv", name="Swedish")
173
  ICELANDIC = Language(code="is", name="Icelandic")
 
174
  GERMAN = Language(code="de", name="German")
175
  DUTCH = Language(code="nl", name="Dutch")
176
  ENGLISH = Language(code="en", name="English")
 
190
  Dataset(name="norne-nb", language=NORWEGIAN, task=INFORMATION_EXTRACTION),
191
  Dataset(name="norne-nn", language=NORWEGIAN, task=INFORMATION_EXTRACTION),
192
  Dataset(name="mim-gold-ner", language=ICELANDIC, task=INFORMATION_EXTRACTION),
 
193
  Dataset(name="germeval", language=GERMAN, task=INFORMATION_EXTRACTION),
194
  Dataset(name="conll-nl", language=DUTCH, task=INFORMATION_EXTRACTION),
195
  Dataset(name="conll-en", language=ENGLISH, task=INFORMATION_EXTRACTION),
 
198
  Dataset(name="scala-nb", language=NORWEGIAN, task=GRAMMAR),
199
  Dataset(name="scala-nn", language=NORWEGIAN, task=GRAMMAR),
200
  Dataset(name="scala-is", language=ICELANDIC, task=GRAMMAR),
 
201
  Dataset(name="scala-de", language=GERMAN, task=GRAMMAR),
202
  Dataset(name="scala-nl", language=DUTCH, task=GRAMMAR),
203
  Dataset(name="scala-en", language=ENGLISH, task=GRAMMAR),
 
246
  results_dfs = fetch_results()
247
  last_fetch = dt.datetime.now()
248
 
249
+ all_languages = [language.name for language in ALL_LANGUAGES.values()]
250
+ danish_models = list({model_id for model_id in results_dfs[DANISH].index})
251
+
252
+ # Get distinct RGB values for all models
253
+ all_models = list(
254
+ {model_id for df in results_dfs.values() for model_id in df.index}
255
+ )
256
+ colour_mapping: dict[str, tuple[int, int, int]] = dict()
257
+
258
+ for i in it.count():
259
+ min_colour_distance = MIN_COLOUR_DISTANCE_BETWEEN_MODELS - i
260
+
261
+ if i > 0:
262
+ logger.info(
263
+ f"All retries failed. Trying again with min colour distance "
264
+ f"{min_colour_distance}."
265
+ )
266
+
267
+ random.seed(4242 + i)
268
+ retries_left = 10 * len(all_models)
269
+ for model_id in all_models:
270
+ r, g, b = 0, 0, 0
271
+ too_bright, similar_to_other_model = True, True
272
+ while (too_bright or similar_to_other_model) and retries_left > 0:
273
+ r, g, b = tuple(random.randint(0, 255) for _ in range(3))
274
+ too_bright = np.min([r, g, b]) > 200
275
+ similar_to_other_model = any(
276
+ np.abs(
277
+ np.array(colour) - np.array([r, g, b])
278
+ ).sum() < min_colour_distance
279
+ for colour in colour_mapping.values()
280
+ )
281
+ retries_left -= 1
282
+ logger.info(f"Retries left to find a colour mapping: {retries_left}")
283
+ colour_mapping[model_id] = (r, g, b)
284
+
285
+ if retries_left:
286
+ logger.info(
287
+ f"Successfully found a colour mapping with min colour distance "
288
+ f"{min_colour_distance}."
289
+ )
290
+ break
291
 
292
  with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
293
  gr.Markdown(INTRO_MARKDOWN)
 
351
  show_scale=show_scale_checkbox.value,
352
  plot_width=plot_width_slider.value,
353
  plot_height=plot_height_slider.value,
354
+ colour_mapping=colour_mapping,
355
  results_dfs=results_dfs,
356
  ),
357
  )
 
 
 
 
 
358
  with gr.Tab(label="About"):
359
  gr.Markdown(ABOUT_MARKDOWN)
360
 
361
+ gr.Markdown(
362
+ "<center>Made with ❤️ by the <a href=\"https://alexandra.dk\">"
363
+ "Alexandra Institute</a>.</center>"
364
+ )
365
+
366
  language_names_dropdown.change(
367
  fn=partial(update_model_ids_dropdown, results_dfs=results_dfs),
368
  inputs=[language_names_dropdown, model_ids_dropdown],
 
371
 
372
  # Update plot when anything changes
373
  update_plot_kwargs = dict(
374
+ fn=partial(
375
+ produce_radial_plot,
376
+ colour_mapping=colour_mapping,
377
+ results_dfs=results_dfs,
378
+ ),
379
  inputs=[
380
  model_ids_dropdown,
381
  language_names_dropdown,
 
482
  show_scale: bool,
483
  plot_width: int,
484
  plot_height: int,
485
+ colour_mapping: dict[str, tuple[int, int, int]],
486
  results_dfs: dict[Language, pd.DataFrame] | None,
487
  ) -> go.Figure:
488
  """Produce a radial plot as a plotly figure.
 
500
  The width of the plot.
501
  plot_height:
502
  The height of the plot.
503
+ colour_mapping:
504
+ A mapping from model ids to RGB triplets.
505
  results_dfs:
506
  The results dataframes for each language.
507
 
 
593
  # Add the results to a plotly figure
594
  fig = go.Figure()
595
  for model_id, result_list in zip(model_ids, results):
596
+ r, g, b = colour_mapping[model_id]
 
 
 
 
 
597
  fig.add_trace(go.Scatterpolar(
598
  r=result_list,
599
  theta=[task.name for task in tasks],
 
600
  name=model_id,
601
+ fill='toself',
602
+ fillcolor=f'rgba({r}, {g}, {b}, 0.6)',
603
  line=dict(color=f'rgb({r}, {g}, {b})'),
604
  ))
605