File size: 4,425 Bytes
0885169 36599ed 9dda31e 0885169 3c6bfb5 d7d2fdd 0885169 d7d2fdd 0885169 d7d2fdd 0885169 c2756e2 0885169 36599ed c2756e2 9dda31e c2756e2 0885169 052b859 0885169 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import string
import numpy as np
import torch
from laser_encoders import LaserEncoderPipeline
from scipy.spatial.distance import cosine
from simalign import SentenceAligner
from transformers import AutoModel, AutoTokenizer
# setup global variables on import (bad practice, but whatever)
# --------------------------------------------------------------
aligner = SentenceAligner(model="xlm-roberta-base", layer=6, from_tf = True)
de_encoder = LaserEncoderPipeline(lang="deu_Latn")
en_encoder = LaserEncoderPipeline(lang="eng_Latn")
def accuracy(src_sentence: str, trg_sentence: str) -> dict:
"""
Calculate the accuracy of a translation by comparing the source and target
sentences.
Parameters:
src_sentence (str): The source sentence.
trg_sentence (str): The target sentence.
Returns:
dict: A dictionary containing the accuracy score and errors.
"""
# Preprocess both sentences
src_sentence = __preprocess_text(src_sentence)
trg_sentence = __preprocess_text(trg_sentence)
r = __get_alignment_score(src_sentence, trg_sentence)
score = __get_bertscore(src_sentence, trg_sentence)
res = {"score": __bertscore_to_percentage(score), "errors": r}
return res
def __preprocess_text(text: str) -> str:
"""
Remove punctuation and convert text to lowercase.
Parameters:
text (str): The text to preprocess.
Returns:
str: The preprocessed text.
"""
# Remove punctuation
text = text.translate(str.maketrans("", "", string.punctuation))
# Convert to lowercase
text = text.lower()
return text
def __get_bertscore(src_sentence: str, trg_sentence: str) -> float:
"""
Get the BERTScore between two sentences.
Parameters:
src_sentence (str): The source sentence.
trg_sentence (str): The target sentence.
Returns:
float: The BERTScore.
"""
# Tokenize and generate embeddings
emb_src = de_encoder.encode_sentences([src_sentence])[0]
emb_tgt = en_encoder.encode_sentences([trg_sentence])[0]
# Calculate cosine similarity (1 - cosine distance)
similarity = 1 - cosine(emb_src, emb_tgt)
return similarity
def __bertscore_to_percentage(similarity: float, debug: bool = False) -> float:
"""
Convert the BERTScore cosine similarity to a percentage score (0-100).
Parameters:
similarity (float): The cosine similarity from BERTScore.
Returns:
int: A score from 0 to 100.
"""
# Scale the similarity score from [-1, 1] range to [0, 100] (rarely negative)
# Logistic function: 100 / (1 + exp(-k * (x - 0.5))), where k controls steepness
# k = 35 # Steepness parameter - higher values create a sharper transition
if debug:
scaled_score = similarity
else:
scaled_score = max(
100 / (1 + np.exp(-11 * (similarity - 0.60))),
100 / (1 + np.exp(-5 * (similarity - 0.60))),
)
# scaled_score = similarity
return round(scaled_score, 2)
def __get_alignment_score(src_sentence: str, trg_sentence: str) -> list:
"""
Get the alignment score between two sentences.
Parameters:
src_sentence (str): The source sentence.
trg_sentence (str): The target sentence.
Returns:
list: Mistranslations
"""
src_list = src_sentence.split()
trg_list = trg_sentence.split()
# The output is a dictionary with different matching methods.
# Each method has a list of pairs indicating the indexes of aligned words (The alignments are zero-indexed).
alignments = aligner.get_word_aligns(src_list, trg_list)
src_aligns = {x[0] for x in alignments["inter"]}
trg_aligns = {x[1] for x in alignments["inter"]}
mistranslations = []
for i in range(len(src_list)):
if i not in src_aligns:
mistranslations.append(
{
"start": i,
"end": i,
"message": f"Word {src_list[i]} possibly mistranslated or omitted",
}
)
for i in range(len(trg_list)):
if i not in trg_aligns:
mistranslations.append(
{
"start": i,
"end": i,
"message": f"Word {trg_list[i]} possibly mistranslated or added erroneously",
}
)
return mistranslations
|