|
from nmtscore import NMTScorer |
|
import gradio as gr |
|
|
|
from dataclasses import dataclass |
|
from typing import List, Union, Optional |
|
import numpy as np |
|
|
|
from scipy.special import softmax |
|
from scipy.stats import permutation_test |
|
|
|
|
|
@dataclass |
|
class TranslationDirectionResult: |
|
sentence1: Union[str, List[str]] |
|
sentence2: Union[str, List[str]] |
|
lang1: str |
|
lang2: str |
|
raw_prob_1_to_2: float |
|
raw_prob_2_to_1: float |
|
pvalue: Optional[float] = None |
|
|
|
@property |
|
def num_sentences(self): |
|
return len(self.sentence1) if isinstance(self.sentence1, list) else 1 |
|
|
|
@property |
|
def prob_1_to_2(self): |
|
return softmax([self.raw_prob_1_to_2, self.raw_prob_2_to_1])[0] |
|
|
|
@property |
|
def prob_2_to_1(self): |
|
return softmax([self.raw_prob_1_to_2, self.raw_prob_2_to_1])[1] |
|
|
|
@property |
|
def predicted_direction(self) -> str: |
|
if self.raw_prob_1_to_2 >= self.raw_prob_2_to_1: |
|
return self.lang1 + '→' + self.lang2 |
|
else: |
|
return self.lang2 + '→' + self.lang1 |
|
|
|
def __str__(self): |
|
s = f"""\ |
|
Predicted direction: {self.predicted_direction} |
|
{self.num_sentences} sentence pair{"s" if self.num_sentences > 1 else ""} |
|
{self.lang1}→{self.lang2}: {self.prob_1_to_2:.3f} |
|
{self.lang2}→{self.lang1}: {self.prob_2_to_1:.3f}""" |
|
if self.pvalue is not None: |
|
s += f"\np-value: {self.pvalue}\n" |
|
return s |
|
|
|
|
|
class TranslationDirectionDetector: |
|
|
|
def __init__(self, scorer: NMTScorer = None, use_normalization: bool = False): |
|
self.scorer = scorer or NMTScorer() |
|
self.use_normalization = use_normalization |
|
|
|
def detect(self, |
|
sentence1: Union[str, List[str]], |
|
sentence2: Union[str, List[str]], |
|
lang1: str, |
|
lang2: str, |
|
return_pvalue: bool = False, |
|
pvalue_n_resamples: int = 9999, |
|
score_kwargs: dict = None |
|
) -> TranslationDirectionResult: |
|
if isinstance(sentence1, list) and isinstance(sentence2, list): |
|
if len(sentence1) != len(sentence2): |
|
raise ValueError("Lists sentence1 and sentence2 must have same length") |
|
if len(sentence1) == 0: |
|
raise ValueError("Lists sentence1 and sentence2 must not be empty") |
|
if len(sentence1) == 1 and return_pvalue: |
|
raise ValueError("return_pvalue=True requires the documents to have multiple sentences") |
|
if lang1 == lang2: |
|
raise ValueError("lang1 and lang2 must be different") |
|
|
|
prob_1_to_2 = self.scorer.score_direct( |
|
sentence2, sentence1, |
|
lang2, lang1, |
|
normalize=self.use_normalization, |
|
both_directions=False, |
|
score_kwargs=score_kwargs |
|
) |
|
prob_2_to_1 = self.scorer.score_direct( |
|
sentence1, sentence2, |
|
lang1, lang2, |
|
normalize=self.use_normalization, |
|
both_directions=False, |
|
score_kwargs=score_kwargs |
|
) |
|
pvalue = None |
|
|
|
if isinstance(sentence1, list): |
|
|
|
|
|
log_prob_1_to_2 = np.log2(np.array(prob_1_to_2)) |
|
log_prob_2_to_1 = np.log2(np.array(prob_2_to_1)) |
|
|
|
sentence1_lengths = np.array([self._get_sentence_length(s) for s in sentence1]) |
|
sentence2_lengths = np.array([self._get_sentence_length(s) for s in sentence2]) |
|
log_prob_1_to_2 = sentence2_lengths * log_prob_1_to_2 |
|
log_prob_2_to_1 = sentence1_lengths * log_prob_2_to_1 |
|
|
|
total_log_prob_1_to_2 = log_prob_1_to_2.sum() |
|
total_log_prob_2_to_1 = log_prob_2_to_1.sum() |
|
|
|
avg_log_prob_1_to_2 = total_log_prob_1_to_2 / sum(sentence2_lengths) |
|
avg_log_prob_2_to_1 = total_log_prob_2_to_1 / sum(sentence1_lengths) |
|
|
|
prob_1_to_2 = 2 ** avg_log_prob_1_to_2 |
|
prob_2_to_1 = 2 ** avg_log_prob_2_to_1 |
|
|
|
if return_pvalue: |
|
x = np.vstack([log_prob_1_to_2, sentence2_lengths]).T |
|
y = np.vstack([log_prob_2_to_1, sentence1_lengths]).T |
|
result = permutation_test( |
|
data=(x, y), |
|
statistic=self._statistic_token_mean, |
|
permutation_type="samples", |
|
n_resamples=pvalue_n_resamples, |
|
) |
|
pvalue = result.pvalue |
|
else: |
|
if return_pvalue: |
|
raise ValueError("return_pvalue=True requires sentence1 and sentence2 to be lists of sentences") |
|
|
|
return TranslationDirectionResult( |
|
sentence1=sentence1, |
|
sentence2=sentence2, |
|
lang1=lang1, |
|
lang2=lang2, |
|
raw_prob_1_to_2=prob_1_to_2, |
|
raw_prob_2_to_1=prob_2_to_1, |
|
pvalue=pvalue, |
|
) |
|
|
|
def _get_sentence_length(self, sentence: str) -> int: |
|
tokens = self.scorer.model.tokenizer.tokenize(sentence) |
|
return len(tokens) |
|
|
|
@staticmethod |
|
def _statistic_token_mean(x: np.ndarray, y: np.ndarray, axis: int = -1) -> float: |
|
""" |
|
Statistic for scipy.stats.permutation_test |
|
|
|
:param x: Matrix of shape (2 x num_sentences). The first row contains the unnormalized log probability |
|
for lang1→lang2, the second row contains the sentence lengths in lang2. |
|
:param y: Same as x, but for lang2→lang1 |
|
:return: Difference between lang1→lang2 and lang2→lang1 |
|
""" |
|
if axis != -1: |
|
raise NotImplementedError("Only axis=-1 is supported") |
|
|
|
if x.ndim == 2: |
|
x = x[np.newaxis, ...] |
|
y = y[np.newaxis, ...] |
|
|
|
total_log_prob_1_to_2 = x[:, 0].sum(axis=axis) |
|
total_log_prob_2_to_1 = y[:, 0].sum(axis=axis) |
|
|
|
avg_log_prob_1_to_2 = total_log_prob_1_to_2 / x[:, 1].sum(axis=axis) |
|
avg_log_prob_2_to_1 = total_log_prob_2_to_1 / y[:, 1].sum(axis=axis) |
|
|
|
prob_1_to_2 = 2 ** avg_log_prob_1_to_2 |
|
prob_2_to_1 = 2 ** avg_log_prob_2_to_1 |
|
|
|
return prob_1_to_2 - prob_2_to_1 |
|
|
|
|
|
detector = TranslationDirectionDetector(NMTScorer("m2m100_418M")) |
|
|
|
def translate_direction(text1, lang1, text2, lang2): |
|
lang_to_code = {"English": 'en', |
|
"German": 'de', |
|
"French": 'fr', |
|
"Czech": 'cs', |
|
"Ukrainian": 'uk', |
|
"Chinese": 'zh', |
|
"Russian": 'ru', |
|
"Bengali": 'bn', |
|
"Hindi": 'hi', |
|
"Xhosa": 'xh', |
|
"Zulu": 'zu', |
|
} |
|
if "\n" in text1 or "\n" in text2: |
|
sentence1 = text1.split("\n") |
|
sentence2 = text2.split("\n") |
|
else: |
|
sentence1 = text1 |
|
sentence2 = text2 |
|
result = detector.detect(sentence1, sentence2, lang_to_code[lang1], lang_to_code[lang2]) |
|
return result |
|
|
|
iface = gr.Interface( |
|
fn=translate_direction, |
|
inputs=[ |
|
gr.Textbox(placeholder="Enter a single sentence or multiple sentences separated by a line break (Shift+Enter) in language 1 here.", label="Text 1"), |
|
gr.Dropdown(choices=["English", "German", "French", "Czech", "Ukranian", "Chinese", "Russian", "Bengali", "Hindi", "Xhosa", "Zulu"], label="Language of Text 1"), |
|
gr.Textbox(placeholder="Enter a single sentence or multiple sentences separated by a line break (Shift+Enter) in language 2 here.", label="Text 2"), |
|
gr.Dropdown(choices=["English", "German", "French", "Czech", "Ukranian", "Chinese", "Russian", "Bengali", "Hindi", "Xhosa", "Zulu"], label="Language of Text 2") |
|
], |
|
outputs=gr.Textbox(label="Result"), |
|
title="Translation Direction Detector", |
|
description="Detects the translation direction between two parallel sentences using the M2M100 418M translation model.",) |
|
|
|
iface.launch() |