File size: 3,964 Bytes
c8d36ae
 
 
597bf7d
 
 
 
 
 
 
 
 
 
7a75a86
597bf7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e18be25
 
 
597bf7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""
The metrics page contains precision, recall and f-score metrics as well as a confusion matrix over all the classes. By default, the confusion matrix is normalized. There's an option to zero out the diagonal, leaving only prediction errors (here it makes sense to turn off normalization, so you get raw error counts).
"""
import re

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import streamlit as st
from seqeval.metrics import classification_report
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix

from src.subpages.page import Context, Page


def _get_evaluation(df):
    y_true = df.apply(lambda row: [lbl for lbl in row.labels if lbl != "IGN"], axis=1)
    y_pred = df.apply(
        lambda row: [pred for (pred, lbl) in zip(row.preds, row.labels) if lbl != "IGN"],
        axis=1,
    )
    report: str = classification_report(y_true, y_pred, scheme="IOB2", digits=3)  # type: ignore
    return report.replace(
        "precision    recall  f1-score   support",
        "=" * 12 + "  precision    recall  f1-score   support",
    )


def plot_confusion_matrix(y_true, y_preds, labels, normalize=None, zero_diagonal=True):
    cm = confusion_matrix(y_true, y_preds, normalize=normalize, labels=labels)
    if zero_diagonal:
        np.fill_diagonal(cm, 0)

    # st.write(plt.rcParams["font.size"])
    # plt.rcParams.update({'font.size': 10.0})
    fig, ax = plt.subplots(figsize=(10, 10))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
    fmt = "d" if normalize is None else ".3f"
    disp.plot(
        cmap="Blues",
        include_values=True,
        xticks_rotation="vertical",
        values_format=fmt,
        ax=ax,
        colorbar=False,
    )
    return fig


class MetricsPage(Page):
    name = "Metrics"
    icon = "graph-up-arrow"

    def get_widget_defaults(self):
        return {
            "normalize": True,
            "zero_diagonal": False,
        }

    def render(self, context: Context):
        st.title(self.name)
        with st.expander("💡", expanded=True):
            st.write(
                "The metrics page contains precision, recall and f-score metrics as well as a confusion matrix over all the classes. By default, the confusion matrix is normalized. There's an option to zero out the diagonal, leaving only prediction errors (here it makes sense to turn off normalization, so you get raw error counts)."
            )
            st.write(
                "With the confusion matrix, you don't want any of the classes to end up in the bottom right quarter: those are frequent but error-prone."
            )

        eval_results = _get_evaluation(context.df)
        if len(eval_results.splitlines()) < 8:
            col1, _, col2 = st.columns([8, 1, 1])
        else:
            col1 = col2 = st

        col1.subheader("🎯 Evaluation Results")
        col1.code(eval_results)

        results = [re.split(r" +", l.lstrip()) for l in eval_results.splitlines()[2:-4]]
        data = [(r[0], int(r[-1]), float(r[-2])) for r in results]
        df = pd.DataFrame(data, columns="class support f1".split())
        fig = px.scatter(
            df,
            x="support",
            y="f1",
            range_y=(0, 1.05),
            color="class",
        )
        # fig.update_layout(title_text="asdf", title_yanchor="bottom")
        col1.plotly_chart(fig)

        col2.subheader("🔠 Confusion Matrix")
        normalize = None if not col2.checkbox("Normalize", key="normalize") else "true"
        zero_diagonal = col2.checkbox("Zero Diagonal", key="zero_diagonal")
        col2.pyplot(
            plot_confusion_matrix(
                y_true=context.df_tokens_cleaned["labels"],
                y_preds=context.df_tokens_cleaned["preds"],
                labels=context.labels,
                normalize=normalize,
                zero_diagonal=zero_diagonal,
            ),
        )