Ludvig commited on
Commit
87c83ab
β€’
1 Parent(s): 4c94861

Adds application. Progression quickly!

Browse files
Files changed (9) hide show
  1. app.py +535 -0
  2. cvms_version.R +1 -0
  3. data.py +72 -0
  4. generate_data.R +46 -0
  5. plot.R +194 -0
  6. requirements.txt +14 -0
  7. small_example.csv +9 -0
  8. text_sections.py +103 -0
  9. utils.py +25 -0
app.py ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ App for plotting confusion matrix with `cvms::plot_confusion_matrix()`.
3
+
4
+ TODO:
5
+ - IMPORTANT! Allow specifying which class probabilities are of! (See plot prob_of_class)
6
+ - Allow setting threshold - manual, max J, spec/sens
7
+ - Add bg box around confusion matrix plot as text dissappears on dark mode!
8
+ - ggsave does not use dpi??
9
+ - allow svg, pdf?
10
+ - entered count -> counts (upload as well)
11
+ - Add full reset button (empty cache on different files)
12
+
13
+ """
14
+
15
+ import pathlib
16
+ import tempfile
17
+ from PIL import Image
18
+ import streamlit as st # Import last
19
+ import pandas as pd
20
+ from pandas.api.types import is_float_dtype
21
+ from itertools import combinations
22
+ from collections import OrderedDict
23
+
24
+ from utils import call_subprocess, clean_string_for_non_alphanumerics
25
+ from data import read_data, read_data_cached, DownloadHeader, generate_data
26
+ from text_sections import (
27
+ intro_text,
28
+ columns_text,
29
+ upload_predictions_text,
30
+ upload_counts_text,
31
+ generate_data_text,
32
+ design_text,
33
+ enter_count_data_text,
34
+ )
35
+
36
+
37
+ # Create temporary directory
38
+
39
+
40
+ @st.cache_resource
41
+ def set_tmp_dir():
42
+ """
43
+ Must cache to avoid regenerating!
44
+ Must be the same throughout the iterations!
45
+ """
46
+ temp_dir = tempfile.TemporaryDirectory()
47
+ return temp_dir, temp_dir.name
48
+
49
+
50
+ temp_dir, temp_dir_path = set_tmp_dir()
51
+ gen_data_store_path = pathlib.Path(f"{temp_dir_path}/generated_data.csv")
52
+ data_store_path = pathlib.Path(f"{temp_dir_path}/data.csv")
53
+ conf_mat_path = pathlib.Path(f"{temp_dir_path}/confusion_matrix.png")
54
+
55
+
56
+ def input_choice_callback():
57
+ """
58
+ Resets steps to 0.
59
+ Used when switching between input methods.
60
+ """
61
+ st.session_state["step"] = 0
62
+ st.session_state["input_type"] = None
63
+
64
+ # Remove old tmp files
65
+ if gen_data_store_path.exists():
66
+ gen_data_store_path.unlink()
67
+ if data_store_path.exists():
68
+ data_store_path.unlink()
69
+ if conf_mat_path.exists():
70
+ conf_mat_path.unlink()
71
+
72
+
73
+ # Text
74
+ intro_text()
75
+
76
+ # Start step counter
77
+ # Required to make dependent forms work
78
+ if st.session_state.get("step") is None:
79
+ st.session_state["step"] = 0
80
+
81
+ input_choice = st.radio(
82
+ label="Input",
83
+ options=["Upload predictions", "Upload counts", "Generate", "Enter counts"],
84
+ index=0,
85
+ horizontal=True,
86
+ on_change=input_choice_callback,
87
+ )
88
+
89
+ # Check whether the expected output
90
+ if st.session_state.get("input_type") is None:
91
+ if input_choice in ["Upload predictions", "Generate"]:
92
+ st.session_state["input_type"] = "data"
93
+ else:
94
+ st.session_state["input_type"] = "counts"
95
+
96
+ # Load data
97
+ if input_choice == "Upload predictions":
98
+ with st.form(key="data_form"):
99
+ upload_predictions_text()
100
+ data_path = st.file_uploader("Upload a dataset", type=["csv"])
101
+ if st.form_submit_button(label="Use data"):
102
+ if data_path:
103
+ st.session_state["step"] = 1
104
+ else:
105
+ st.session_state["step"] = 0
106
+ st.markdown("Please upload a file first (or **generate** some random data to try the function).")
107
+
108
+ if st.session_state["step"] >= 1:
109
+ # Read and store (tmp) data
110
+ df = read_data_cached(data_path)
111
+ with st.form(key="column_form"):
112
+ columns_text()
113
+ target_col = st.selectbox("Targets column", options=list(df.columns))
114
+ prediction_col = st.selectbox(
115
+ "Predictions column", options=list(df.columns)
116
+ )
117
+
118
+ if st.form_submit_button(label="Set columns"):
119
+ st.session_state["step"] = 2
120
+
121
+ # Load data
122
+ elif input_choice == "Upload counts":
123
+ with st.form(key="data_form"):
124
+ upload_counts_text()
125
+ data_path = st.file_uploader("Upload a dataset", type=["csv"])
126
+ if st.form_submit_button(label="Use data"):
127
+ if data_path:
128
+ st.session_state["step"] = 1
129
+ else:
130
+ st.session_state["step"] = 0
131
+ st.write("Please upload a file first.")
132
+
133
+ if st.session_state["step"] >= 1:
134
+ # Read and store (tmp) data
135
+ df = read_data_cached(data_path)
136
+ with st.form(key="column_form"):
137
+ columns_text()
138
+ target_col = st.selectbox("Targets column", options=list(df.columns))
139
+ prediction_col = st.selectbox(
140
+ "Predictions column", options=list(df.columns)
141
+ )
142
+ n_col = st.selectbox(
143
+ "Counts column", options=list(df.columns)
144
+ )
145
+
146
+ if st.form_submit_button(label="Set columns"):
147
+ st.session_state["step"] = 2
148
+
149
+
150
+ # Generate data
151
+ elif input_choice == "Generate":
152
+
153
+ def reset_generation_callback():
154
+ p = pathlib.Path(gen_data_store_path)
155
+ if p.exists():
156
+ p.unlink()
157
+
158
+ with st.form(key="generate_form"):
159
+ generate_data_text()
160
+ col1, col2, col3 = st.columns(3)
161
+ with col1:
162
+ num_classes = st.number_input(
163
+ "# Classes",
164
+ value=3,
165
+ min_value=2,
166
+ help="Number of classes to generate data for.",
167
+ )
168
+ with col2:
169
+ num_observations = st.number_input(
170
+ "# Observations",
171
+ value=30,
172
+ min_value=2,
173
+ max_value=10000,
174
+ help="Number of observations to generate data for.",
175
+ )
176
+ with col3:
177
+ seed = st.number_input("Random Seed", value=42, min_value=0)
178
+ if st.form_submit_button(
179
+ label="Generate data", on_click=reset_generation_callback
180
+ ):
181
+ st.session_state["step"] = 2
182
+
183
+ if st.session_state["step"] >= 2:
184
+ generate_data(
185
+ out_path=gen_data_store_path,
186
+ num_classes=num_classes,
187
+ num_observations=num_observations,
188
+ seed=seed,
189
+ )
190
+ df = read_data(gen_data_store_path)
191
+ target_col = "Target"
192
+ prediction_col = "Predicted Class"
193
+
194
+ elif input_choice == "Enter counts":
195
+
196
+ def repopulate_matrix_callback():
197
+ if "entered_counts" not in st.session_state:
198
+ if "entered_counts" in st.session_state:
199
+ st.session_state.pop("entered_counts")
200
+
201
+ with st.form(key="enter_classes_form"):
202
+ enter_count_data_text()
203
+ classes_joined = st.text_input("Classes (comma-separated)")
204
+
205
+ if st.form_submit_button(
206
+ label="Populate matrix", on_click=repopulate_matrix_callback
207
+ ):
208
+ # Extract class names from comma-separated list
209
+ st.session_state["classes"] = [
210
+ clean_string_for_non_alphanumerics(s) for s in classes_joined.split(",")
211
+ ]
212
+
213
+ # Calculate all pairs of predictions and targets
214
+ all_pairs = list(combinations(st.session_state["classes"], 2))
215
+ all_pairs += [(pair[1], pair[0]) for pair in all_pairs]
216
+ all_pairs += [(c, c) for c in st.session_state["classes"]]
217
+
218
+ # Prepopulate the matrix
219
+ st.session_state["entered_counts"] = pd.DataFrame(
220
+ all_pairs, columns=["Target", "Prediction"]
221
+ )
222
+
223
+ st.session_state["step"] = 1
224
+
225
+ if st.session_state["step"] >= 1:
226
+ with st.form(key="enter_counts_form"):
227
+ st.write("Fill in the counts for `N(Target, Prediction)` pairs.")
228
+ count_input_fields = OrderedDict()
229
+
230
+ num_cols = 3
231
+ cols = st.columns(num_cols)
232
+ for i, (targ, pred) in enumerate(
233
+ zip(
234
+ st.session_state["entered_counts"]["Target"],
235
+ st.session_state["entered_counts"]["Prediction"],
236
+ )
237
+ ):
238
+ count_input_fields[f"{targ}____{pred}"] = cols[
239
+ i % num_cols
240
+ ].number_input(f"N({targ}, {pred})", step=1)
241
+
242
+ if st.form_submit_button(
243
+ label="Generate data",
244
+ ):
245
+ st.session_state["entered_counts"]["N"] = [
246
+ int(val) for val in count_input_fields.values()
247
+ ]
248
+ st.session_state["step"] = 2
249
+
250
+ if st.session_state["step"] >= 2:
251
+ DownloadHeader.header_and_data_download(
252
+ "Entered counts",
253
+ data=st.session_state["entered_counts"],
254
+ file_name="Confusion_Matrix_Counts.csv",
255
+ help="Download counts",
256
+ )
257
+ st.write(st.session_state["entered_counts"])
258
+
259
+ target_col = "Target"
260
+ prediction_col = "Prediction"
261
+ n_col = "N"
262
+
263
+ if st.session_state["step"] >= 2:
264
+ if st.session_state["input_type"] == "data":
265
+ # Remove unused columns
266
+ df = df.loc[:, [target_col, prediction_col]]
267
+
268
+ # Ensure targets are strings
269
+ df[target_col] = df[target_col].astype(str)
270
+ df[target_col] = df[target_col].apply(lambda x: x.replace(" ", "_"))
271
+
272
+ # Save to tmp directory to allow reading in R script
273
+ df.to_csv(data_store_path)
274
+
275
+ # Extract unique classes
276
+ st.session_state["classes"] = sorted([str(c) for c in df[target_col].unique()])
277
+
278
+ predictions_are_probabilities = is_float_dtype(df[prediction_col])
279
+ if predictions_are_probabilities and len(st.session_state["classes"]) != 2:
280
+ st.error(
281
+ "Predictions can only be probabilities in binary classification. "
282
+ f"Got {len(st.session_state['classes'])} classes."
283
+ )
284
+
285
+ st.subheader("The Data")
286
+ col1, col2, col3 = st.columns([2, 2, 2])
287
+ with col2:
288
+ st.write(df.head(5))
289
+ st.write(f"{df.shape} (first 5 rows).")
290
+
291
+ else:
292
+ st.session_state["entered_counts"].to_csv(data_store_path)
293
+
294
+ # Check the number of classes
295
+ num_classes = len(st.session_state["classes"])
296
+ print(st.session_state["classes"])
297
+ if num_classes < 2:
298
+ # TODO Handle better than throwing error?
299
+ raise ValueError(
300
+ "Uploaded data must contain 2 or more classes in `Targets column`. "
301
+ f"Got {num_classes} target classes."
302
+ )
303
+
304
+ with st.form(key="settings_form"):
305
+ design_text()
306
+ col1, col2 = st.columns(2)
307
+ with col1:
308
+ selected_classes = st.multiselect(
309
+ "Select classes (min=2, order is respected)",
310
+ options=st.session_state["classes"],
311
+ default=st.session_state["classes"],
312
+ help="Select the classes to create the confusion matrix for. "
313
+ "Any observation with either a target or prediction "
314
+ "of another class is excluded.",
315
+ )
316
+ with col2:
317
+ if st.session_state["input_type"] == "data" and predictions_are_probabilities:
318
+ prob_of_class = st.selectbox(
319
+ "Probabilities are of (not working)",
320
+ options=st.session_state["classes"],
321
+ index=1,
322
+ )
323
+ else:
324
+ prob_of_class = None
325
+
326
+ default_elements = [
327
+ "Counts",
328
+ "Normalized Counts (%)",
329
+ "Zero Shading",
330
+ "Arrows",
331
+ ]
332
+ if num_classes < 6:
333
+ # Percentages clutter too much with many classes
334
+ default_elements += [
335
+ "Row Percentages",
336
+ "Column Percentages",
337
+ ]
338
+ elements_to_add = st.multiselect(
339
+ "Add the following elements",
340
+ options=[
341
+ "Sum Tiles",
342
+ "Counts",
343
+ "Normalized Counts (%)",
344
+ "Row Percentages",
345
+ "Column Percentages",
346
+ "Zero Shading",
347
+ "Zero Percentages",
348
+ "Zero Text",
349
+ "Arrows",
350
+ ],
351
+ default=default_elements,
352
+ )
353
+
354
+ col1, col2, col3 = st.columns(3)
355
+ with col1:
356
+ counts_on_top = st.checkbox(
357
+ "Counts on top (not working)",
358
+ help="Whether to switch the positions of the counts and normalized counts (%). "
359
+ "That is, the counts become the big centralized numbers and the "
360
+ "normalized counts go below with a smaller font size.",
361
+ )
362
+ with col2:
363
+ diag_percentages_only = st.checkbox("Diagonal row/column percentages only")
364
+ with col3:
365
+ num_digits = st.number_input(
366
+ "Digits", value=2, help="Number of digits to round percentages to."
367
+ )
368
+
369
+ element_flags = [
370
+ key
371
+ for key, val in {
372
+ "--add_sums": "Sum Tiles" in elements_to_add,
373
+ "--add_counts": "Counts" in elements_to_add,
374
+ "--add_normalized": "Normalized Counts (%)" in elements_to_add,
375
+ "--add_row_percentages": "Row Percentages" in elements_to_add,
376
+ "--add_col_percentages": "Column Percentages" in elements_to_add,
377
+ "--add_zero_percentages": "Zero Percentages" in elements_to_add,
378
+ "--add_zero_text": "Zero Text" in elements_to_add,
379
+ "--add_zero_shading": "Zero Shading" in elements_to_add,
380
+ "--add_arrows": "Arrows" in elements_to_add,
381
+ "--counts_on_top": counts_on_top,
382
+ "--diag_percentages_only": diag_percentages_only,
383
+ }.items()
384
+ if val
385
+ ]
386
+
387
+ palette = st.selectbox(
388
+ "Color Palette",
389
+ options=["Blues", "Greens", "Oranges", "Greys", "Purples", "Reds"],
390
+ )
391
+
392
+ # Ask for output parameters
393
+ # TODO: Set default based on number of classes and sum tiles
394
+ col1, col2, col3 = st.columns(3)
395
+ with col1:
396
+ width = st.number_input("Width (px)", value=1200 + 100 * (num_classes - 2))
397
+ with col2:
398
+ height = st.number_input(
399
+ "Height (px)", value=1200 + 100 * (num_classes - 2)
400
+ )
401
+ with col3:
402
+ dpi = st.number_input("DPI (not working)", value=320)
403
+
404
+ if st.form_submit_button(label="Apply"):
405
+ st.session_state["step"] = 3
406
+
407
+ if st.session_state["step"] >= 3:
408
+ plotting_args = [
409
+ "--data_path",
410
+ f"'{data_store_path}'",
411
+ "--out_path",
412
+ f"'{conf_mat_path}'",
413
+ "--target_col",
414
+ f"'{target_col}'",
415
+ "--prediction_col",
416
+ f"'{prediction_col}'",
417
+ "--width",
418
+ f"{width}",
419
+ "--height",
420
+ f"{height}",
421
+ "--dpi",
422
+ f"{dpi}",
423
+ "--classes",
424
+ f"{','.join(selected_classes)}",
425
+ "--digits",
426
+ f"{num_digits}",
427
+ "--palette",
428
+ f"{palette}",
429
+ ]
430
+
431
+ if st.session_state["input_type"] == "counts":
432
+ # The input data are counts
433
+ plotting_args += ["--n_col", f"{n_col}", "--data_are_counts"]
434
+
435
+ plotting_args += element_flags
436
+
437
+ plotting_args = " ".join(plotting_args)
438
+
439
+ call_subprocess(
440
+ f"Rscript plot.R {plotting_args}",
441
+ message="Plotting script",
442
+ return_output=True,
443
+ encoding="UTF-8",
444
+ )
445
+
446
+ DownloadHeader.header_and_image_download(
447
+ "The confusion matrix plot", filepath=conf_mat_path
448
+ )
449
+ col1, col2, col3 = st.columns([2, 8, 2])
450
+ with col2:
451
+ image = Image.open(conf_mat_path)
452
+ st.image(
453
+ image,
454
+ caption="Confusion Matrix",
455
+ # width=500,
456
+ use_column_width=None,
457
+ clamp=False,
458
+ channels="RGB",
459
+ output_format="auto",
460
+ )
461
+
462
+ # evaluation = dplyr.select(
463
+ # evaluation,
464
+ # "Balanced Accuracy",
465
+ # "Accuracy",
466
+ # "F1",
467
+ # "Sensitivity",
468
+ # "Specificity",
469
+ # "Pos Pred Value",
470
+ # "Neg Pred Value",
471
+ # "AUC",
472
+ # "Kappa",
473
+ # "MCC",
474
+ # )
475
+ # evaluation_py = ro.conversion.rpy2py(evaluation)
476
+ # st.write(evaluation_py)
477
+
478
+ # confusion_matrix_py = ro.conversion.rpy2py(confusion_matrix)
479
+ # st.write(confusion_matrix_py)
480
+
481
+ # evaluation = dplyr.select(
482
+ # evaluation,
483
+ # "Balanced Accuracy",
484
+ # "Accuracy",
485
+ # "F1",
486
+ # "Sensitivity",
487
+ # "Specificity",
488
+ # "Pos Pred Value",
489
+ # "Neg Pred Value",
490
+ # "AUC",
491
+ # "Kappa",
492
+ # "MCC",
493
+ # )
494
+ # evaluation_py = ro.conversion.rpy2py(evaluation)
495
+ # st.write(evaluation_py)
496
+
497
+ # temp_dir.cleanup()
498
+
499
+ else:
500
+ st.write("Please upload data.")
501
+
502
+
503
+ # target_col = "Target",
504
+ # prediction_col = "Prediction",
505
+ # counts_col = "N",
506
+ # class_order = NULL,
507
+ # add_sums = FALSE,
508
+ # add_counts = TRUE,
509
+ # add_normalized = TRUE,
510
+ # add_row_percentages = TRUE,
511
+ # add_col_percentages = TRUE,
512
+ # diag_percentages_only = FALSE,
513
+ # rm_zero_percentages = TRUE,
514
+ # rm_zero_text = TRUE,
515
+ # add_zero_shading = TRUE,
516
+ # add_arrows = TRUE,
517
+ # counts_on_top = FALSE,
518
+ # palette = "Blues",
519
+ # intensity_by = "counts",
520
+ # theme_fn = ggplot2::theme_minimal,
521
+ # place_x_axis_above = TRUE,
522
+ # rotate_y_text = TRUE,
523
+ # digits = 1,
524
+ # font_counts = font(),
525
+ # font_normalized = font(),
526
+ # font_row_percentages = font(),
527
+ # font_col_percentages = font(),
528
+ # arrow_size = 0.048,
529
+ # arrow_nudge_from_text = 0.065,
530
+ # tile_border_color = NA,
531
+ # tile_border_size = 0.1,
532
+ # tile_border_linetype = "solid",
533
+ # sums_settings = sum_tile_settings(),
534
+ # darkness = 0.8
535
+ # )
cvms_version.R ADDED
@@ -0,0 +1 @@
 
 
1
+ print(packageVersion("cvms"))
data.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ import pandas as pd
3
+ import streamlit as st
4
+ from utils import call_subprocess
5
+
6
+
7
+ def read_data(data):
8
+ if data is not None:
9
+ df = pd.read_csv(data)
10
+ return df
11
+ else:
12
+ return None
13
+
14
+
15
+ @st.cache_data
16
+ def read_data_cached(data):
17
+ return read_data(data)
18
+
19
+
20
+ def generate_data(out_path, num_classes, num_observations, seed) -> None:
21
+ call_subprocess(
22
+ f"Rscript generate_data.R --out_path {out_path} --num_classes {num_classes} --num_observations {num_observations} --seed {seed}",
23
+ message="Data generation script",
24
+ return_output=True,
25
+ encoding="UTF-8",
26
+ )
27
+
28
+
29
+ class DownloadHeader:
30
+ """
31
+ Class for showing header and download button (for an image file) in the same row.
32
+ """
33
+
34
+ @staticmethod
35
+ def header_and_image_download(
36
+ header, filepath, key=None, label="Download", help="Download plot"
37
+ ):
38
+ col1, col2 = st.columns([9, 2])
39
+ with col1:
40
+ st.subheader(header)
41
+ with col2:
42
+ st.write("")
43
+ with open(filepath, "rb") as img:
44
+ st.download_button(
45
+ label=label,
46
+ data=img,
47
+ file_name=pathlib.Path(filepath).name,
48
+ mime="image/png",
49
+ key=key,
50
+ help=help,
51
+ )
52
+
53
+ @staticmethod
54
+ def _convert_df_to_csv(data, **kwargs):
55
+ return data.to_csv(**kwargs).encode("utf-8")
56
+
57
+ @staticmethod
58
+ def header_and_data_download(
59
+ header, data, file_name, key=None, label="Download", help="Download data"
60
+ ):
61
+ col1, col2 = st.columns([9, 2])
62
+ with col1:
63
+ st.subheader(header)
64
+ with col2:
65
+ st.write("")
66
+ st.download_button(
67
+ label=label,
68
+ data=DownloadHeader._convert_df_to_csv(data, index=False),
69
+ file_name=file_name,
70
+ key=key,
71
+ help=help,
72
+ )
generate_data.R ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env Rscript
2
+ library(optparse)
3
+ library(cvms)
4
+
5
+ option_list <- list(
6
+ make_option(c("--out_path"), type="character",
7
+ help="Path to save data at."),
8
+ make_option(c("--num_classes"), type="integer",
9
+ help="Number of classes."),
10
+ make_option(c("--num_observations"), type="integer",
11
+ help="Number of observations."),
12
+ make_option(c("--seed"), type="integer",
13
+ help="Number of observations.")
14
+
15
+ )
16
+
17
+ opt_parser <- OptionParser(option_list=option_list)
18
+ opt <- parse_args(opt_parser)
19
+
20
+ print(opt)
21
+
22
+ # Set seed if given
23
+ if (!is.null(opt$seed)){
24
+ set.seed(opt$seed)
25
+ }
26
+
27
+ # Make fairly certain predictions
28
+ rcertain <- function(n) {
29
+ (runif(n, min = 1, max = 100)^1.4) / 100
30
+ }
31
+
32
+ # Generate data
33
+ data <- cvms::multiclass_probability_tibble(
34
+ num_classes=opt$num_classes,
35
+ num_observations=opt$num_observations,
36
+ apply_softmax = TRUE,
37
+ FUN = rcertain,
38
+ class_name = "c",
39
+ add_predicted_classes = TRUE,
40
+ add_targets = TRUE
41
+ )
42
+
43
+ data <- data[, c("Predicted Class", "Target")]
44
+
45
+ # Write to disk
46
+ write.csv(data, file = opt$out_path, row.names=FALSE)
plot.R ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env Rscript
2
+ library(optparse)
3
+ library(cvms)
4
+ library(dplyr)
5
+ library(ggplot2)
6
+
7
+ option_list <- list(
8
+ make_option(c("--data_path"), type="character",
9
+ help="Path to data file (.csv)."),
10
+ make_option(c("--out_path"), type="character",
11
+ help="Path to save confusion matrix plot at."),
12
+ make_option(c("--data_are_counts"), action="store_true", default=FALSE,
13
+ help="Indicates that `--data_path` contains counts, not predictions."),
14
+ make_option(c("--target_col"), type="character",
15
+ help="Target column"),
16
+ make_option(c("--prediction_col"), type="character",
17
+ help="Prediction column"),
18
+ make_option(c("--n_col"), type="character",
19
+ help="Count column (when `--data_are_counts`)."),
20
+ make_option(c("--classes"), type="character",
21
+ help="Comma-separated class names. Only these classes will be used - in the specified order."),
22
+ make_option(c("--prob_of_class"), type="character",
23
+ help="Name of class that probabilities are of."),
24
+ make_option(c("--palette"), type="character",
25
+ help="Color palette."),
26
+ make_option(c("--width"), type="integer",
27
+ help="Width of plot in pixels."),
28
+ make_option(c("--height"), type="integer",
29
+ help="Height of plot in pixels."),
30
+ make_option(c("--dpi"), type="integer",
31
+ help="DPI of plot."),
32
+ make_option(c("--add_sums"), action="store_true", default=FALSE,
33
+ help="Wether to add sum tiles."),
34
+ make_option(c("--add_counts"), action="store_true", default=FALSE,
35
+ help="Wether to add counts."),
36
+ make_option(c("--add_normalized"), action="store_true", default=FALSE,
37
+ help="Wether to add normalized counts (i.e. percentages)."),
38
+ make_option(c("--add_row_percentages"), action="store_true", default=FALSE,
39
+ help="Wether to add row percentages."),
40
+ make_option(c("--add_col_percentages"), action="store_true", default=FALSE,
41
+ help="Wether to add column percentages."),
42
+ make_option(c("--add_zero_percentages"), action="store_true", default=FALSE,
43
+ help="Wether to add percentages to zero-tiles."),
44
+ make_option(c("--add_zero_text"), action="store_true", default=FALSE,
45
+ help="Wether to add text to zero-tiles."),
46
+ make_option(c("--add_zero_shading"), action="store_true", default=FALSE,
47
+ help="Wether to add shading to zero-tiles."),
48
+ make_option(c("--add_arrows"), action="store_true", default=FALSE,
49
+ help="Wether to add arrows to row/sum percentages. Requires additional packages."),
50
+ make_option(c("--counts_on_top"), action="store_true", default=FALSE,
51
+ help="Wether to have the counts on top and normalized counts below."),
52
+ make_option(c("--diag_percentages_only"), action="store_true", default=FALSE,
53
+ help="Wether to only show diagonal row/column percentages."),
54
+ make_option(c("--digits"), type="integer",
55
+ help="Number of digits to show for percentages.")
56
+ )
57
+
58
+ opt_parser <- OptionParser(option_list=option_list)
59
+ opt <- parse_args(opt_parser)
60
+
61
+ print(opt)
62
+
63
+ data_are_counts <- opt$data_are_counts
64
+
65
+ # read.csv turns white space into dots
66
+ target_col <- stringr::str_squish(opt$target_col)
67
+ target_col <- stringr::str_replace_all(target_col, " ", ".")
68
+ prediction_col <- stringr::str_squish(opt$prediction_col)
69
+ prediction_col <- stringr::str_replace_all(prediction_col, " ", ".")
70
+
71
+ n_col <- NULL
72
+ if (!is.null(opt$n_col)){
73
+ n_col <- stringr::str_squish(opt$n_col)
74
+ n_col <- stringr::str_replace_all(n_col, " ", ".")
75
+ }
76
+
77
+ # Read and prepare data frame
78
+ df <- tryCatch({
79
+ read.csv(opt$data_path)
80
+ }, error=function(e){
81
+ print(paste0("Failed to read data from ", opt$data_path))
82
+ print(e)
83
+ stop(e)
84
+ })
85
+ print(df)
86
+
87
+ df <- dplyr::as_tibble(df)
88
+ print(df)
89
+ df[[target_col]] <- as.character(df[[target_col]])
90
+
91
+ if (isTRUE(data_are_counts)){
92
+ df[[prediction_col]] <- as.character(df[[prediction_col]])
93
+ }
94
+
95
+ # Predictions can be either probabilities or
96
+ # hard class predictions
97
+ if (is.integer(df[[prediction_col]]) || !is.numeric(df[[prediction_col]])){
98
+ all_present_classes <- sort(
99
+ c(unique(df[[target_col]]),
100
+ unique(df[[prediction_col]])
101
+ )
102
+ )
103
+ } else {
104
+ all_present_classes <- sort(
105
+ unique(df[[target_col]])
106
+ )
107
+ }
108
+
109
+
110
+ if (!is.null(opt$classes)){
111
+ classes <- as.character(unlist(strsplit(opt$classes,"[,:]")), recursive=TRUE)
112
+ } else {
113
+ classes <- all_present_classes
114
+ }
115
+ print(paste0("Selected Classes: ", paste0(classes, collapse=", ")))
116
+
117
+ if (!isTRUE(data_are_counts)){
118
+ # We remove the unwanted classes from the confusion matrix
119
+ # (easier - possibly slower in edge cases)
120
+ family <- ifelse(length(all_present_classes) == 2, "binomial", "multinomial")
121
+ print(df)
122
+
123
+ # TODO : use prob_of_class to ensure probabilities are interpreted correctly!!
124
+ # Might need to invert them to get it to work!
125
+ evaluation <- tryCatch({
126
+ cvms::evaluate(
127
+ data=df,
128
+ target_col=target_col,
129
+ prediction_cols=prediction_col,
130
+ type=family,
131
+ )
132
+ }, error=function(e){
133
+ print("Failed to evaluate data.")
134
+ print(head(df, 5))
135
+ print(e)
136
+ stop(e)
137
+ })
138
+
139
+ confusion_matrix <- evaluation[["Confusion Matrix"]][[1]]
140
+
141
+ } else {
142
+ confusion_matrix <- dplyr::rename(
143
+ df,
144
+ Target = !!target_col,
145
+ Prediction = !!prediction_col,
146
+ N = !!n_col
147
+ )
148
+ }
149
+
150
+ confusion_matrix <- dplyr::filter(
151
+ confusion_matrix,
152
+ Prediction %in% classes,
153
+ Target %in% classes
154
+ )
155
+
156
+
157
+ confusion_matrix_plot <- tryCatch({
158
+ cvms::plot_confusion_matrix(
159
+ confusion_matrix,
160
+ class_order=classes,
161
+ add_sums=opt$add_sums,
162
+ add_counts=opt$add_counts,
163
+ add_normalized=opt$add_normalized,
164
+ add_row_percentages=opt$add_row_percentages,
165
+ add_col_percentages=opt$add_col_percentages,
166
+ rm_zero_percentages=!opt$add_zero_percentages,
167
+ rm_zero_text=!opt$add_zero_text,
168
+ add_zero_shading=opt$add_zero_shading,
169
+ add_arrows=opt$add_arrows,
170
+ counts_on_top=opt$counts_on_top,
171
+ diag_percentages_only=opt$diag_percentages_only,
172
+ digits=as.integer(opt$digits),
173
+ palette=opt$palette
174
+ )
175
+ }, error=function(e){
176
+ print("Failed to create plot from confusion matrix.")
177
+ print(confusion_matrix)
178
+ print(e)
179
+ stop(e)
180
+ })
181
+
182
+ tryCatch({
183
+ ggplot2::ggsave(
184
+ opt$out_path,
185
+ width=opt$width,
186
+ height=opt$height,
187
+ dpi=opt$dpi,
188
+ units="px"
189
+ )
190
+ }, error=function(e){
191
+ print(paste0("Failed to ggsave plot to: ", opt$out_path))
192
+ print(e)
193
+ stop(e)
194
+ })
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rpy2=3.5.1
2
+ pandas
3
+ lazyeval
4
+ r-cvms
5
+ r-dplyr
6
+ r-ggimage
7
+ r-rsvg # Conda forge?
8
+ r-optparse
9
+ r-ggnewscale
10
+ r-stringr
11
+
12
+ # Needs:
13
+ # conda config --add channels conda-forge
14
+ # conda config --set channel_priority strict
small_example.csv ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ target,prediction,predicted_class
2
+ 1,0.3,1
3
+ 2,0.9,2
4
+ 1,0.2,1
5
+ 2,0.9,2
6
+ 1,0.7,2
7
+ 1,0.8,2
8
+ 2,0.5,2
9
+ 2,0.7,2
text_sections.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from utils import call_subprocess
3
+
4
+
5
+ @st.cache_resource
6
+ def get_cvms_version():
7
+ return (
8
+ str(
9
+ call_subprocess(
10
+ f"Rscript cvms_version.R",
11
+ message="cvms versioning script",
12
+ return_output=True,
13
+ encoding="UTF-8",
14
+ )
15
+ )
16
+ .split("[1]")[-1]
17
+ .replace("β€˜", "")
18
+ .replace("’", "")
19
+ )
20
+
21
+
22
+ def intro_text():
23
+ col1, col2 = st.columns([8, 2])
24
+ with col1:
25
+ st.title("Plot Confusion Matrix")
26
+ st.write(
27
+ "This application allows you to plot a confusion matrix based on your own data. "
28
+ )
29
+ with col2:
30
+ st.image(
31
+ "https://github.com/LudvigOlsen/cvms/raw/master/man/figures/cvms_logo_242x280_250dpi.png",
32
+ width=125,
33
+ )
34
+
35
+ st.write(
36
+ "The plot is created with the [**cvms**](https://github.com/LudvigOlsen/cvms) R package "
37
+ f"(v/{get_cvms_version()}, LR Olsen & HB Zachariae, 2019)."
38
+ )
39
+
40
+ st.write(
41
+ "DATA PRIVACY: In order to transfer the data "
42
+ "between python and R, it is temporarily stored on the servers. "
43
+ "While we, the authors, have no intention of looking at your data, we make "
44
+ "*no guarantees* about the privacy of your data (it is not our servers). "
45
+ "Please do not upload sensitive data. The application "
46
+ "only requires columns with predictions and targets."
47
+ )
48
+
49
+
50
+ def generate_data_text():
51
+ st.subheader("Generate data")
52
+ st.write(
53
+ "If you just want to try out the application, you can generate a dataset with targets and predictions. "
54
+ "Select a number of classes and observations, and you're ready to go! "
55
+ )
56
+
57
+
58
+ def enter_count_data_text():
59
+ st.subheader("Enter counts")
60
+ st.write(
61
+ "If you already have the confusion matrix counts and want to plot them. "
62
+ "Enter the counts and get designing! "
63
+ )
64
+ st.write("Start by entering the names of your classes:")
65
+
66
+
67
+ def upload_counts_text():
68
+ st.subheader("Upload your counts")
69
+ st.write(
70
+ "Plot an existing confusion matrix (counts of target-prediction combinations). "
71
+ "The application expects a `.csv` file with: \n"
72
+ "1) A `target classes` column. \n\n"
73
+ "2) A `predicted classes` column. \n\n"
74
+ "3) A `combination count` column for the "
75
+ "combination frequency of 1 and 2. \n\n"
76
+ "Other columns are currently ignored. "
77
+ "See example of such a .csv file [here] (TODO). "
78
+ )
79
+
80
+ def upload_predictions_text():
81
+ st.subheader("Upload your predictions")
82
+ st.markdown(
83
+ "The application expects a `.csv` file with: \n"
84
+ "1) A `target` column. \n"
85
+ "Targets will be converted into strings. \n\n"
86
+ "2) A `prediction` column. \n"
87
+ "Predictions can be probabilities (binary classification only) or class predictions. \n\n"
88
+ "Other columns are currently ignored. \n\n"
89
+ "You will have the option to select the names of these two columns, so don't "
90
+ "worry too much about the column names in the uploaded data."
91
+ )
92
+
93
+
94
+ def columns_text():
95
+ st.subheader("Specify columns")
96
+ st.write(
97
+ "Please select which of the columns in the data should be used for targets and predictions."
98
+ )
99
+
100
+
101
+ def design_text():
102
+ st.subheader("Design your plot")
103
+ st.write("This is where you customize the design of your confusion matrix plot.")
utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import re, string
3
+
4
+
5
+ def call_subprocess(call_, message, return_output=False, encoding="UTF-8"):
6
+ # With capturing of output
7
+ if return_output:
8
+ try:
9
+ out = subprocess.check_output(call_, shell=True, encoding=encoding)
10
+ except subprocess.CalledProcessError as e:
11
+ print(f"{message}: {call_}")
12
+ raise e
13
+ return out
14
+
15
+ # Without capturing of output
16
+ try:
17
+ subprocess.check_call(call_, shell=True)
18
+ except subprocess.CalledProcessError as e:
19
+ print(f"{message}: {call_}")
20
+ raise e
21
+
22
+
23
+ def clean_string_for_non_alphanumerics(s):
24
+ pattern = re.compile("[\W'_']+")
25
+ return pattern.sub("", s)