File size: 3,720 Bytes
e2a0b30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import pandas as pd
from plotnine import (
    labs,
    theme,
    theme_bw,
    guides,
    position_nudge,
    aes,
    geom_violin,
    geom_line,
    geom_jitter,
    scale_x_discrete,
    ggplot,
)
from pathlib import Path
import math


class GrewTSEVisualiser:
    """
    A basic visualisation class that creates a violin plot based on a syntactic evaluation.
    """

    def __init__(self) -> None:
        self.data = None

    def visualise_syntactic_performance(
        self,
        results: pd.DataFrame,
        title: str,
        target_x_label: str,
        alt_x_label: str,
        x_axis_label: str,
        y_axis_label: str,
        filename: str,
    ) -> None:
        """
        Visualise a syntactic performance evaluation result.

        :param results: pass the results DataFrame created by the GrewTSEEvaluator.
        :param title: Give the diagram a main title.
        :param target_x_label: Give the original target word and hence first word in the minimal pair a label e.g. 'Accusative'.
        :param alt_x_label: Give the second element in the minimal pair a label e.g. 'Dative'.
        :param x_axis_label: Give the X Axis a title.
        :param y_axis_label: Give the Y Axis a title.
        :param filename: A filename to save the visualisation.
        :return:
        """

        visualise_slope(
            filename,
            results,
            target_x_label,
            alt_x_label,
            x_axis_label,
            y_axis_label,
            title,
        )


def visualise_slope(
    path: Path,
    results: pd.DataFrame,
    target_x_label: str,
    alt_x_label: str,
    x_axis_label: str,
    y_axis_label: str,
    title: str,
):
    lsize = 0.65
    fill_alpha = 0.7

    # X-axis: Acc, Gen
    # Y-axis: surprisal
    filtered_df = results[
        results["form_ungrammatical"].notna()
        & (results["form_ungrammatical"].str.strip() != "")
    ]

    filtered_df["subject_id"] = filtered_df.index

    # Melt the dataframe
    df_long = pd.melt(
        filtered_df,
        id_vars=["subject_id"],
        value_vars=["p_grammatical", "p_ungrammatical"],
        var_name="source",
        value_name="log_prob",
    )

    # Map source to fixed x-axis labels
    df_long["x_label"] = df_long["source"].map(
        {"p_grammatical": target_x_label, "p_ungrammatical": alt_x_label}
    )

    def surprisal(p: float) -> float:
        return -math.log2(p)

    def confidence(p: float) -> float:
        return math.log2(p)

    df_long["surprisal"] = df_long["log_prob"].apply(surprisal)

    p = (
        ggplot(df_long, aes(x="x_label", y="surprisal", fill="x_label"))
        + scale_x_discrete(limits=[target_x_label, alt_x_label])
        + geom_jitter(aes(color="x_label"), width=0.01, alpha=0.7)
        +
        # geom_text(aes(label='label'), nudge_y=0.1) +
        geom_line(aes(group="subject_id"), color="gray", alpha=0.7, size=0.2)
        + geom_violin(
            df_long[df_long["x_label"] == target_x_label],
            aes(x="x_label", y="surprisal", group="x_label"),
            position=position_nudge(x=-0.2),
            style="left-right",
            alpha=fill_alpha,
            size=lsize,
        )
        + geom_violin(
            df_long[df_long["x_label"] == alt_x_label],
            aes(x="x_label", y="surprisal", group="x_label"),
            position=position_nudge(x=0.2),
            style="right-left",
            alpha=fill_alpha,
            size=lsize,
        )
        + guides(fill=False)
        + theme_bw()
        + theme(figure_size=(8, 4), legend_position="none")
        + labs(x=x_axis_label, y=y_axis_label, title=title)
    )
    p.save(path, width=14, height=8, dpi=300)