File size: 6,626 Bytes
9abdbd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
from nmtscore import NMTScorer

from dataclasses import dataclass
from typing import List, Union, Optional
import numpy as np

from scipy.special import softmax
from scipy.stats import permutation_test


@dataclass
class TranslationDirectionResult:
    sentence1: Union[str, List[str]]
    sentence2: Union[str, List[str]]
    lang1: str
    lang2: str
    raw_prob_1_to_2: float
    raw_prob_2_to_1: float
    pvalue: Optional[float] = None

    @property
    def num_sentences(self):
        return len(self.sentence1) if isinstance(self.sentence1, list) else 1

    @property
    def prob_1_to_2(self):
        return softmax([self.raw_prob_1_to_2, self.raw_prob_2_to_1])[0]

    @property
    def prob_2_to_1(self):
        return softmax([self.raw_prob_1_to_2, self.raw_prob_2_to_1])[1]

    @property
    def predicted_direction(self) -> str:
        if self.raw_prob_1_to_2 >= self.raw_prob_2_to_1:
            return self.lang1 + '→' + self.lang2
        else:
            return self.lang2 + '→' + self.lang1

    def __str__(self):
        s = f"""\
Predicted direction: {self.predicted_direction}
{self.num_sentences} sentence pair{"s" if self.num_sentences > 1 else ""}
{self.lang1}{self.lang2}: {self.prob_1_to_2:.3f}
{self.lang2}{self.lang1}: {self.prob_2_to_1:.3f}"""
        if self.pvalue is not None:
            s += f"\np-value: {self.pvalue}\n"
        return s


class TranslationDirectionDetector:

    def __init__(self, scorer: NMTScorer = None, use_normalization: bool = False):
        self.scorer = scorer or NMTScorer()
        self.use_normalization = use_normalization

    def detect(self,
               sentence1: Union[str, List[str]],
               sentence2: Union[str, List[str]],
               lang1: str,
               lang2: str,
               return_pvalue: bool = False,
               pvalue_n_resamples: int = 9999,
               score_kwargs: dict = None
               ) -> TranslationDirectionResult:
        if isinstance(sentence1, list) and isinstance(sentence2, list):
            if len(sentence1) != len(sentence2):
                raise ValueError("Lists sentence1 and sentence2 must have same length")
            if len(sentence1) == 0:
                raise ValueError("Lists sentence1 and sentence2 must not be empty")
            if len(sentence1) == 1 and return_pvalue:
                raise ValueError("return_pvalue=True requires the documents to have multiple sentences")
        if lang1 == lang2:
            raise ValueError("lang1 and lang2 must be different")

        prob_1_to_2 = self.scorer.score_direct(
            sentence2, sentence1,
            lang2, lang1,
            normalize=self.use_normalization,
            both_directions=False,
            score_kwargs=score_kwargs
        )
        prob_2_to_1 = self.scorer.score_direct(
            sentence1, sentence2,
            lang1, lang2,
            normalize=self.use_normalization,
            both_directions=False,
            score_kwargs=score_kwargs
        )
        pvalue = None

        if isinstance(sentence1, list):  # document-level
            # Compute the average probability per target token, across the complete document
            # 1. Convert probabilities back to log probabilities
            log_prob_1_to_2 = np.log2(np.array(prob_1_to_2))
            log_prob_2_to_1 = np.log2(np.array(prob_2_to_1))
            # 2. Reverse the sentence-level length normalization
            sentence1_lengths = np.array([self._get_sentence_length(s) for s in sentence1])
            sentence2_lengths = np.array([self._get_sentence_length(s) for s in sentence2])
            log_prob_1_to_2 = sentence2_lengths * log_prob_1_to_2
            log_prob_2_to_1 = sentence1_lengths * log_prob_2_to_1
            # 4. Sum up the log probabilities across the document
            total_log_prob_1_to_2 = log_prob_1_to_2.sum()
            total_log_prob_2_to_1 = log_prob_2_to_1.sum()
            # 3. Document-level length normalization
            avg_log_prob_1_to_2 = total_log_prob_1_to_2 / sum(sentence2_lengths)
            avg_log_prob_2_to_1 = total_log_prob_2_to_1 / sum(sentence1_lengths)
            # 4. Convert back to probabilities
            prob_1_to_2 = 2 ** avg_log_prob_1_to_2
            prob_2_to_1 = 2 ** avg_log_prob_2_to_1

            if return_pvalue:
                x = np.vstack([log_prob_1_to_2, sentence2_lengths]).T
                y = np.vstack([log_prob_2_to_1, sentence1_lengths]).T
                result = permutation_test(
                    data=(x, y),
                    statistic=self._statistic_token_mean,
                    permutation_type="samples",
                    n_resamples=pvalue_n_resamples,
                )
                pvalue = result.pvalue
        else:
            if return_pvalue:
                raise ValueError("return_pvalue=True requires sentence1 and sentence2 to be lists of sentences")

        return TranslationDirectionResult(
            sentence1=sentence1,
            sentence2=sentence2,
            lang1=lang1,
            lang2=lang2,
            raw_prob_1_to_2=prob_1_to_2,
            raw_prob_2_to_1=prob_2_to_1,
            pvalue=pvalue,
        )

    def _get_sentence_length(self, sentence: str) -> int:
        tokens = self.scorer.model.tokenizer.tokenize(sentence)
        return len(tokens)

    @staticmethod
    def _statistic_token_mean(x: np.ndarray, y: np.ndarray, axis: int = -1) -> float:
        """
        Statistic for scipy.stats.permutation_test

        :param x: Matrix of shape (2 x num_sentences). The first row contains the unnormalized log probability
        for lang1→lang2, the second row contains the sentence lengths in lang2.
        :param y: Same as x, but for lang2→lang1
        :return: Difference between lang1→lang2 and lang2→lang1
        """
        if axis != -1:
            raise NotImplementedError("Only axis=-1 is supported")
        # Add batch dim
        if x.ndim == 2:
            x = x[np.newaxis, ...]
            y = y[np.newaxis, ...]
        # Sum up the log probabilities across the document
        total_log_prob_1_to_2 = x[:, 0].sum(axis=axis)
        total_log_prob_2_to_1 = y[:, 0].sum(axis=axis)
        # Document-level length normalization
        avg_log_prob_1_to_2 = total_log_prob_1_to_2 / x[:, 1].sum(axis=axis)
        avg_log_prob_2_to_1 = total_log_prob_2_to_1 / y[:, 1].sum(axis=axis)
        # Convert to probabilities
        prob_1_to_2 = 2 ** avg_log_prob_1_to_2
        prob_2_to_1 = 2 ** avg_log_prob_2_to_1
        # Compute difference
        return prob_1_to_2 - prob_2_to_1