saattrupdan commited on
Commit
c34e772
1 Parent(s): 8157f53

feat: Add update colours button

Browse files
Files changed (1) hide show
  1. app.py +44 -27
app.py CHANGED
@@ -232,40 +232,30 @@ DATASETS = [
232
  ]
233
 
234
 
235
- def main() -> None:
236
- """Produce a radial plot."""
237
 
238
- global last_fetch
239
- results_dfs = fetch_results()
240
- last_fetch = dt.datetime.now()
 
 
 
 
241
 
242
- all_languages = sorted(
243
- [language.name for language in ALL_LANGUAGES.values()],
244
- key=lambda language_name: language_name.lower(),
245
- )
246
- danish_models = sorted(
247
- list({model_id for model_id in results_dfs[DANISH].index}),
248
- key=lambda model_id: model_id.lower(),
249
- )
250
 
251
  # Get distinct RGB values for all models
252
  all_models = list(
253
  {model_id for df in results_dfs.values() for model_id in df.index}
254
  )
255
- colour_mapping: dict[str, tuple[int, int, int]] = dict()
256
 
257
  for i in it.count():
258
  min_colour_distance = MIN_COLOUR_DISTANCE_BETWEEN_MODELS - i
259
-
260
- if i > 0:
261
- logger.info(
262
- f"All retries failed. Trying again with min colour distance "
263
- f"{min_colour_distance}."
264
- )
265
-
266
  retries_left = 10 * len(all_models)
267
  for model_id in all_models:
268
- random.seed(hash(model_id) + i)
269
  r, g, b = 0, 0, 0
270
  too_bright, similar_to_other_model = True, True
271
  while (too_bright or similar_to_other_model) and retries_left > 0:
@@ -287,6 +277,28 @@ def main() -> None:
287
  )
288
  break
289
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
291
  gr.Markdown(INTRO_MARKDOWN)
292
 
@@ -340,6 +352,11 @@ def main() -> None:
340
  interactive=True,
341
  scale=1,
342
  )
 
 
 
 
 
343
  with gr.Row():
344
  plot = gr.Plot(
345
  value=produce_radial_plot(
@@ -349,7 +366,6 @@ def main() -> None:
349
  show_scale=show_scale_checkbox.value,
350
  plot_width=plot_width_slider.value,
351
  plot_height=plot_height_slider.value,
352
- colour_mapping=colour_mapping,
353
  results_dfs=results_dfs,
354
  ),
355
  )
@@ -371,7 +387,6 @@ def main() -> None:
371
  update_plot_kwargs = dict(
372
  fn=partial(
373
  produce_radial_plot,
374
- colour_mapping=colour_mapping,
375
  results_dfs=results_dfs,
376
  ),
377
  inputs=[
@@ -391,6 +406,11 @@ def main() -> None:
391
  plot_width_slider.change(**update_plot_kwargs)
392
  plot_height_slider.change(**update_plot_kwargs)
393
 
 
 
 
 
 
394
  demo.launch()
395
 
396
 
@@ -483,7 +503,6 @@ def produce_radial_plot(
483
  show_scale: bool,
484
  plot_width: int,
485
  plot_height: int,
486
- colour_mapping: dict[str, tuple[int, int, int]],
487
  results_dfs: dict[Language, pd.DataFrame] | None,
488
  ) -> go.Figure:
489
  """Produce a radial plot as a plotly figure.
@@ -501,8 +520,6 @@ def produce_radial_plot(
501
  The width of the plot.
502
  plot_height:
503
  The height of the plot.
504
- colour_mapping:
505
- A mapping from model ids to RGB triplets.
506
  results_dfs:
507
  The results dataframes for each language.
508
 
 
232
  ]
233
 
234
 
235
+ def update_colour_mapping(results_dfs: dict[Language, pd.DataFrame]) -> None:
236
+ """Get a mapping from model ids to RGB triplets.
237
 
238
+ Args:
239
+ results_dfs:
240
+ The results dataframes for each language.
241
+ """
242
+ global colour_mapping
243
+ global seed
244
+ seed += 1
245
 
246
+ gr.Info(f"Updating colour mapping...")
 
 
 
 
 
 
 
247
 
248
  # Get distinct RGB values for all models
249
  all_models = list(
250
  {model_id for df in results_dfs.values() for model_id in df.index}
251
  )
252
+ colour_mapping = dict()
253
 
254
  for i in it.count():
255
  min_colour_distance = MIN_COLOUR_DISTANCE_BETWEEN_MODELS - i
 
 
 
 
 
 
 
256
  retries_left = 10 * len(all_models)
257
  for model_id in all_models:
258
+ random.seed(hash(model_id) + i + seed)
259
  r, g, b = 0, 0, 0
260
  too_bright, similar_to_other_model = True, True
261
  while (too_bright or similar_to_other_model) and retries_left > 0:
 
277
  )
278
  break
279
 
280
+
281
+ def main() -> None:
282
+ """Produce a radial plot."""
283
+
284
+ global last_fetch
285
+ results_dfs = fetch_results()
286
+ last_fetch = dt.datetime.now()
287
+
288
+ all_languages = sorted(
289
+ [language.name for language in ALL_LANGUAGES.values()],
290
+ key=lambda language_name: language_name.lower(),
291
+ )
292
+ danish_models = sorted(
293
+ list({model_id for model_id in results_dfs[DANISH].index}),
294
+ key=lambda model_id: model_id.lower(),
295
+ )
296
+
297
+ global colour_mapping
298
+ global seed
299
+ seed = 4242
300
+ update_colour_mapping(results_dfs=results_dfs)
301
+
302
  with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
303
  gr.Markdown(INTRO_MARKDOWN)
304
 
 
352
  interactive=True,
353
  scale=1,
354
  )
355
+ update_colours_button = gr.Button(
356
+ value="Update colours",
357
+ interactive=True,
358
+ scale=1,
359
+ )
360
  with gr.Row():
361
  plot = gr.Plot(
362
  value=produce_radial_plot(
 
366
  show_scale=show_scale_checkbox.value,
367
  plot_width=plot_width_slider.value,
368
  plot_height=plot_height_slider.value,
 
369
  results_dfs=results_dfs,
370
  ),
371
  )
 
387
  update_plot_kwargs = dict(
388
  fn=partial(
389
  produce_radial_plot,
 
390
  results_dfs=results_dfs,
391
  ),
392
  inputs=[
 
406
  plot_width_slider.change(**update_plot_kwargs)
407
  plot_height_slider.change(**update_plot_kwargs)
408
 
409
+ # Update colours when the button is clicked
410
+ update_colours_button.click(
411
+ fn=partial(update_colour_mapping, results_dfs=results_dfs),
412
+ ).then(**update_plot_kwargs)
413
+
414
  demo.launch()
415
 
416
 
 
503
  show_scale: bool,
504
  plot_width: int,
505
  plot_height: int,
 
506
  results_dfs: dict[Language, pd.DataFrame] | None,
507
  ) -> go.Figure:
508
  """Produce a radial plot as a plotly figure.
 
520
  The width of the plot.
521
  plot_height:
522
  The height of the plot.
 
 
523
  results_dfs:
524
  The results dataframes for each language.
525