miwytt commited on
Commit
80f72f3
1 Parent(s): 92472dd

Initial commit

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 University of Zurich
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,13 +0,0 @@
1
- ---
2
- title: Translation Direction Detection
3
- emoji: 🦀
4
- colorFrom: red
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 4.14.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -36,8 +36,6 @@ iface = gr.Interface(
36
  ],
37
  outputs=gr.Textbox(label="Result"),
38
  title="Translation Direction Detector",
39
- description="Detects the translation direction between two sentences using the M2M100 418M translation model.",
40
- theme="dark"
41
- )
42
 
43
  iface.launch()
 
36
  ],
37
  outputs=gr.Textbox(label="Result"),
38
  title="Translation Direction Detector",
39
+ description="Detects the translation direction between two parallel sentences using the M2M100 418M translation model.",)
 
 
40
 
41
  iface.launch()
pyproject.toml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "translation_direction_detection"
3
+ version = "0.0.1"
4
+ authors = [
5
+ { name="Michelle Wastl", email="michelle.wastl@uzh.ch" },
6
+ { name="Jannis Vamvas", email="vamvas@cl.uzh.ch" },
7
+ { name="Rico Sennrich", email="sennrich@cl.uzh.ch" },
8
+ ]
9
+ description = "Unsupervised translation direction detection using NMT systems"
10
+ readme = "README.md"
11
+ requires-python = ">=3.11"
12
+ dependencies = [
13
+ "transformers<4.34", # https://github.com/ZurichNLP/nmtscore/issues/7
14
+ "nmtscore",
15
+ "scipy",
16
+ ]
17
+ classifiers = [
18
+ "Programming Language :: Python :: 3",
19
+ "License :: OSI Approved :: MIT License",
20
+ "Operating System :: OS Independent",
21
+ ]
22
+
23
+ [project.urls]
24
+ "Homepage" = "https://github.com/ZurichNLP/translation-direction-detection"
25
+ "Bug Tracker" = "https://github.com/ZurichNLP/translation-direction-detection/issues"
26
+ [build-system]
27
+ requires = ["hatchling"]
28
+ build-backend = "hatchling.build"
src/translation_direction_detection/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from translation_direction_detection.detector import TranslationDirectionDetector
src/translation_direction_detection/detector.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Union, Optional
3
+
4
+ import numpy as np
5
+ from nmtscore import NMTScorer
6
+ from scipy.special import softmax
7
+ from scipy.stats import permutation_test
8
+
9
+
10
+ @dataclass
11
+ class TranslationDirectionResult:
12
+ sentence1: Union[str, List[str]]
13
+ sentence2: Union[str, List[str]]
14
+ lang1: str
15
+ lang2: str
16
+ raw_prob_1_to_2: float
17
+ raw_prob_2_to_1: float
18
+ pvalue: Optional[float] = None
19
+
20
+ @property
21
+ def num_sentences(self):
22
+ return len(self.sentence1) if isinstance(self.sentence1, list) else 1
23
+
24
+ @property
25
+ def prob_1_to_2(self):
26
+ return softmax([self.raw_prob_1_to_2, self.raw_prob_2_to_1])[0]
27
+
28
+ @property
29
+ def prob_2_to_1(self):
30
+ return softmax([self.raw_prob_1_to_2, self.raw_prob_2_to_1])[1]
31
+
32
+ @property
33
+ def predicted_direction(self) -> str:
34
+ if self.raw_prob_1_to_2 >= self.raw_prob_2_to_1:
35
+ return self.lang1 + '→' + self.lang2
36
+ else:
37
+ return self.lang2 + '→' + self.lang1
38
+
39
+ def __str__(self):
40
+ s = f"""\
41
+ Predicted direction: {self.predicted_direction}
42
+ {self.num_sentences} sentence pair{"s" if self.num_sentences > 1 else ""}
43
+ {self.lang1}→{self.lang2}: {self.prob_1_to_2:.3f}
44
+ {self.lang2}→{self.lang1}: {self.prob_2_to_1:.3f}"""
45
+ if self.pvalue is not None:
46
+ s += f"\np-value: {self.pvalue}\n"
47
+ return s
48
+
49
+
50
+ class TranslationDirectionDetector:
51
+
52
+ def __init__(self, scorer: NMTScorer = None, use_normalization: bool = False):
53
+ self.scorer = scorer or NMTScorer()
54
+ self.use_normalization = use_normalization
55
+
56
+ def detect(self,
57
+ sentence1: Union[str, List[str]],
58
+ sentence2: Union[str, List[str]],
59
+ lang1: str,
60
+ lang2: str,
61
+ return_pvalue: bool = False,
62
+ pvalue_n_resamples: int = 9999,
63
+ score_kwargs: dict = None
64
+ ) -> TranslationDirectionResult:
65
+ if isinstance(sentence1, list) and isinstance(sentence2, list):
66
+ if len(sentence1) != len(sentence2):
67
+ raise ValueError("Lists sentence1 and sentence2 must have same length")
68
+ if len(sentence1) == 0:
69
+ raise ValueError("Lists sentence1 and sentence2 must not be empty")
70
+ if len(sentence1) == 1 and return_pvalue:
71
+ raise ValueError("return_pvalue=True requires the documents to have multiple sentences")
72
+ if lang1 == lang2:
73
+ raise ValueError("lang1 and lang2 must be different")
74
+
75
+ prob_1_to_2 = self.scorer.score_direct(
76
+ sentence2, sentence1,
77
+ lang2, lang1,
78
+ normalize=self.use_normalization,
79
+ both_directions=False,
80
+ score_kwargs=score_kwargs
81
+ )
82
+ prob_2_to_1 = self.scorer.score_direct(
83
+ sentence1, sentence2,
84
+ lang1, lang2,
85
+ normalize=self.use_normalization,
86
+ both_directions=False,
87
+ score_kwargs=score_kwargs
88
+ )
89
+ pvalue = None
90
+
91
+ if isinstance(sentence1, list): # document-level
92
+ # Compute the average probability per target token, across the complete document
93
+ # 1. Convert probabilities back to log probabilities
94
+ log_prob_1_to_2 = np.log2(np.array(prob_1_to_2))
95
+ log_prob_2_to_1 = np.log2(np.array(prob_2_to_1))
96
+ # 2. Reverse the sentence-level length normalization
97
+ sentence1_lengths = np.array([self._get_sentence_length(s) for s in sentence1])
98
+ sentence2_lengths = np.array([self._get_sentence_length(s) for s in sentence2])
99
+ log_prob_1_to_2 = sentence2_lengths * log_prob_1_to_2
100
+ log_prob_2_to_1 = sentence1_lengths * log_prob_2_to_1
101
+ # 4. Sum up the log probabilities across the document
102
+ total_log_prob_1_to_2 = log_prob_1_to_2.sum()
103
+ total_log_prob_2_to_1 = log_prob_2_to_1.sum()
104
+ # 3. Document-level length normalization
105
+ avg_log_prob_1_to_2 = total_log_prob_1_to_2 / sum(sentence2_lengths)
106
+ avg_log_prob_2_to_1 = total_log_prob_2_to_1 / sum(sentence1_lengths)
107
+ # 4. Convert back to probabilities
108
+ prob_1_to_2 = 2 ** avg_log_prob_1_to_2
109
+ prob_2_to_1 = 2 ** avg_log_prob_2_to_1
110
+
111
+ if return_pvalue:
112
+ x = np.vstack([log_prob_1_to_2, sentence2_lengths]).T
113
+ y = np.vstack([log_prob_2_to_1, sentence1_lengths]).T
114
+ result = permutation_test(
115
+ data=(x, y),
116
+ statistic=self._statistic_token_mean,
117
+ permutation_type="samples",
118
+ n_resamples=pvalue_n_resamples,
119
+ )
120
+ pvalue = result.pvalue
121
+ else:
122
+ if return_pvalue:
123
+ raise ValueError("return_pvalue=True requires sentence1 and sentence2 to be lists of sentences")
124
+
125
+ return TranslationDirectionResult(
126
+ sentence1=sentence1,
127
+ sentence2=sentence2,
128
+ lang1=lang1,
129
+ lang2=lang2,
130
+ raw_prob_1_to_2=prob_1_to_2,
131
+ raw_prob_2_to_1=prob_2_to_1,
132
+ pvalue=pvalue,
133
+ )
134
+
135
+ def _get_sentence_length(self, sentence: str) -> int:
136
+ tokens = self.scorer.model.tokenizer.tokenize(sentence)
137
+ return len(tokens)
138
+
139
+ @staticmethod
140
+ def _statistic_token_mean(x: np.ndarray, y: np.ndarray, axis: int = -1) -> float:
141
+ """
142
+ Statistic for scipy.stats.permutation_test
143
+
144
+ :param x: Matrix of shape (2 x num_sentences). The first row contains the unnormalized log probability
145
+ for lang1→lang2, the second row contains the sentence lengths in lang2.
146
+ :param y: Same as x, but for lang2→lang1
147
+ :return: Difference between lang1→lang2 and lang2→lang1
148
+ """
149
+ if axis != -1:
150
+ raise NotImplementedError("Only axis=-1 is supported")
151
+ # Add batch dim
152
+ if x.ndim == 2:
153
+ x = x[np.newaxis, ...]
154
+ y = y[np.newaxis, ...]
155
+ # Sum up the log probabilities across the document
156
+ total_log_prob_1_to_2 = x[:, 0].sum(axis=axis)
157
+ total_log_prob_2_to_1 = y[:, 0].sum(axis=axis)
158
+ # Document-level length normalization
159
+ avg_log_prob_1_to_2 = total_log_prob_1_to_2 / x[:, 1].sum(axis=axis)
160
+ avg_log_prob_2_to_1 = total_log_prob_2_to_1 / y[:, 1].sum(axis=axis)
161
+ # Convert to probabilities
162
+ prob_1_to_2 = 2 ** avg_log_prob_1_to_2
163
+ prob_2_to_1 = 2 ** avg_log_prob_2_to_1
164
+ # Compute difference
165
+ return prob_1_to_2 - prob_2_to_1