Spaces:
Running
Running
Implement evaluation
Browse files- README.md +22 -4
- dataset.py +33 -0
- eval.py +159 -0
- sample_text/en2es.m2m100_1.2B.json +1 -0
- sample_text/en2es.m2m100_418M.json +1 -0
- sample_text/{en2es.translation.txt → en2es.translation.m2m100_1.2B.txt} +1 -1
- sample_text/en2es.translation.m2m100_418M.txt +0 -0
- translate.py +6 -2
README.md
CHANGED
@@ -68,7 +68,7 @@ Run `python translate.py -h` for more info.
|
|
68 |
```bash
|
69 |
accelerate launch translate.py \
|
70 |
--sentences_path sample_text/en.txt \
|
71 |
-
--output_path sample_text/en2es.translation.txt \
|
72 |
--source_lang en \
|
73 |
--target_lang es \
|
74 |
--model_name facebook/m2m100_1.2B
|
@@ -83,7 +83,7 @@ You can use the Accelerate CLI to configure the Accelerate environment (Run
|
|
83 |
```bash
|
84 |
accelerate launch --multi_gpu --num_processes 2 --num_machines 1 translate.py \
|
85 |
--sentences_path sample_text/en.txt \
|
86 |
-
--output_path sample_text/en2es.translation.txt \
|
87 |
--source_lang en \
|
88 |
--target_lang es \
|
89 |
--model_name facebook/m2m100_1.2B
|
@@ -102,7 +102,7 @@ Use the `--precision` flag to choose the precision of the model. You can choose
|
|
102 |
```bash
|
103 |
accelerate launch translate.py \
|
104 |
--sentences_path sample_text/en.txt \
|
105 |
-
--output_path sample_text/en2es.translation.txt \
|
106 |
--source_lang en \
|
107 |
--target_lang es \
|
108 |
--model_name facebook/m2m100_1.2B \
|
@@ -111,6 +111,24 @@ accelerate launch translate.py \
|
|
111 |
|
112 |
## Evaluate translations
|
113 |
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
|
|
|
68 |
```bash
|
69 |
accelerate launch translate.py \
|
70 |
--sentences_path sample_text/en.txt \
|
71 |
+
--output_path sample_text/en2es.translation.m2m100_1.2B.txt \
|
72 |
--source_lang en \
|
73 |
--target_lang es \
|
74 |
--model_name facebook/m2m100_1.2B
|
|
|
83 |
```bash
|
84 |
accelerate launch --multi_gpu --num_processes 2 --num_machines 1 translate.py \
|
85 |
--sentences_path sample_text/en.txt \
|
86 |
+
--output_path sample_text/en2es.translation.m2m100_1.2B.txt \
|
87 |
--source_lang en \
|
88 |
--target_lang es \
|
89 |
--model_name facebook/m2m100_1.2B
|
|
|
102 |
```bash
|
103 |
accelerate launch translate.py \
|
104 |
--sentences_path sample_text/en.txt \
|
105 |
+
--output_path sample_text/en2es.translation.m2m100_1.2B.txt \
|
106 |
--source_lang en \
|
107 |
--target_lang es \
|
108 |
--model_name facebook/m2m100_1.2B \
|
|
|
111 |
|
112 |
## Evaluate translations
|
113 |
|
114 |
+
To run the evaluation script you need to install [bert_score](https://github.com/Tiiiger/bert_score): `pip install bert_score`
|
115 |
+
|
116 |
+
The evaluation script will calculate the following metrics:
|
117 |
+
* [SacreBLEU](https://github.com/huggingface/datasets/tree/master/metrics/sacrebleu)
|
118 |
+
* [BLEU](https://github.com/huggingface/datasets/tree/master/metrics/bleu)
|
119 |
+
* [ROUGE](https://github.com/huggingface/datasets/tree/master/metrics/rouge)
|
120 |
+
* [METEOR](https://github.com/huggingface/datasets/tree/master/metrics/meteor)
|
121 |
+
* [TER](https://github.com/huggingface/datasets/tree/master/metrics/ter)
|
122 |
+
* [BertScore](https://github.com/huggingface/datasets/tree/master/metrics/bertscore)
|
123 |
+
|
124 |
+
Run the following command to evaluate the translations:
|
125 |
+
|
126 |
+
```bash
|
127 |
+
accelerate launch eval.py \
|
128 |
+
--pred_path sample_text/es.txt \
|
129 |
+
--gold_path sample_text/en2es.translation.m2m100_1.2B.txt
|
130 |
+
```
|
131 |
+
|
132 |
+
If you want to save the results to a file use the `--output_path` flag.
|
133 |
|
134 |
|
dataset.py
CHANGED
@@ -38,3 +38,36 @@ class DatasetReader(IterableDataset):
|
|
38 |
file_itr = open(self.filename, "r")
|
39 |
mapped_itr = map(self.preprocess, file_itr)
|
40 |
return mapped_itr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
file_itr = open(self.filename, "r")
|
39 |
mapped_itr = map(self.preprocess, file_itr)
|
40 |
return mapped_itr
|
41 |
+
|
42 |
+
|
43 |
+
class ParallelTextReader(IterableDataset):
|
44 |
+
def __init__(self, pred_path: str, gold_path: str):
|
45 |
+
self.pred_path = pred_path
|
46 |
+
self.gold_path = gold_path
|
47 |
+
pref_filename_lines = count_lines(pred_path)
|
48 |
+
gold_path_lines = count_lines(gold_path)
|
49 |
+
assert pref_filename_lines == gold_path_lines, (
|
50 |
+
f"Lines in {pred_path} and {gold_path} do not match "
|
51 |
+
f"{pref_filename_lines} vs {gold_path_lines}"
|
52 |
+
)
|
53 |
+
self.num_sentences = gold_path_lines
|
54 |
+
self.current_line = 0
|
55 |
+
|
56 |
+
def preprocess(self, pred: str, gold: str):
|
57 |
+
self.current_line += 1
|
58 |
+
pred = pred.rstrip().strip()
|
59 |
+
gold = gold.rstrip().strip()
|
60 |
+
if len(pred) == 0:
|
61 |
+
print(f"Warning: Pred empty sentence at line {self.current_line}")
|
62 |
+
if len(gold) == 0:
|
63 |
+
print(f"Warning: Gold empty sentence at line {self.current_line}")
|
64 |
+
return pred, [gold]
|
65 |
+
|
66 |
+
def __iter__(self):
|
67 |
+
pred_itr = open(self.pred_path, "r")
|
68 |
+
gold_itr = open(self.gold_path, "r")
|
69 |
+
mapped_itr = map(self.preprocess, pred_itr, gold_itr)
|
70 |
+
return mapped_itr
|
71 |
+
|
72 |
+
def __len__(self):
|
73 |
+
return self.num_sentences
|
eval.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataset import ParallelTextReader
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
from accelerate.memory_utils import find_executable_batch_size
|
4 |
+
from datasets import load_metric
|
5 |
+
from tqdm import tqdm
|
6 |
+
import torch
|
7 |
+
import json
|
8 |
+
import argparse
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
|
12 |
+
def get_dataloader(pred_path: str, gold_path: str, batch_size: int):
|
13 |
+
"""
|
14 |
+
Returns a dataloader for the given files.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def collate_fn(batch):
|
18 |
+
return list(map(list, zip(*batch)))
|
19 |
+
|
20 |
+
reader = ParallelTextReader(pred_path=pred_path, gold_path=gold_path)
|
21 |
+
dataloader = DataLoader(reader, batch_size=batch_size, collate_fn=collate_fn)
|
22 |
+
return dataloader
|
23 |
+
|
24 |
+
|
25 |
+
def eval_files(
|
26 |
+
pred_path: str,
|
27 |
+
gold_path: str,
|
28 |
+
bert_score_model: str,
|
29 |
+
starting_batch_size: int = 128,
|
30 |
+
output_path: str = None,
|
31 |
+
):
|
32 |
+
"""
|
33 |
+
Evaluates the given files.
|
34 |
+
"""
|
35 |
+
if torch.cuda.is_available():
|
36 |
+
device = "cuda:0"
|
37 |
+
print("We will use a GPU to calculate BertScore.")
|
38 |
+
else:
|
39 |
+
device = "cpu"
|
40 |
+
print(
|
41 |
+
f"We will use the CPU to calculate BertScore, this can be slow for large datasets."
|
42 |
+
)
|
43 |
+
|
44 |
+
dataloader = get_dataloader(pred_path, gold_path, starting_batch_size)
|
45 |
+
print("Loading sacrebleu...")
|
46 |
+
sacrebleu = load_metric("sacrebleu")
|
47 |
+
print("Loading rouge...")
|
48 |
+
rouge = load_metric("rouge")
|
49 |
+
print("Loading bleu...")
|
50 |
+
bleu = load_metric("bleu")
|
51 |
+
print("Loading meteor...")
|
52 |
+
meteor = load_metric("meteor")
|
53 |
+
print("Loading ter...")
|
54 |
+
ter = load_metric("ter")
|
55 |
+
print("Loading BertScore...")
|
56 |
+
bert_score = load_metric("bertscore")
|
57 |
+
|
58 |
+
with tqdm(total=len(dataloader.dataset), desc="Loading data...") as pbar:
|
59 |
+
for predictions, references in dataloader:
|
60 |
+
sacrebleu.add_batch(predictions=predictions, references=references)
|
61 |
+
rouge.add_batch(predictions=predictions, references=references)
|
62 |
+
bleu.add_batch(
|
63 |
+
predictions=[p.split() for p in predictions],
|
64 |
+
references=[[r[0].split()] for r in references],
|
65 |
+
)
|
66 |
+
meteor.add_batch(predictions=predictions, references=references)
|
67 |
+
ter.add_batch(predictions=predictions, references=references)
|
68 |
+
bert_score.add_batch(predictions=predictions, references=references)
|
69 |
+
pbar.update(len(predictions))
|
70 |
+
|
71 |
+
result_dictionary = {}
|
72 |
+
print(f"Computing sacrebleu")
|
73 |
+
result_dictionary["sacrebleu"] = sacrebleu.compute()
|
74 |
+
print(f"Computing rouge score")
|
75 |
+
result_dictionary["rouge"] = rouge.compute()
|
76 |
+
print(f"Computing bleu score")
|
77 |
+
result_dictionary["bleu"] = bleu.compute()
|
78 |
+
print(f"Computing meteor score")
|
79 |
+
result_dictionary["meteor"] = meteor.compute()
|
80 |
+
print(f"Computing ter score")
|
81 |
+
result_dictionary["ter"] = ter.compute()
|
82 |
+
|
83 |
+
@find_executable_batch_size(starting_batch_size=starting_batch_size)
|
84 |
+
def inference(batch_size):
|
85 |
+
nonlocal bert_score, bert_score_model
|
86 |
+
print(f"Computing bert score with batch size {batch_size} on {device}")
|
87 |
+
results = bert_score.compute(
|
88 |
+
model_type=bert_score_model,
|
89 |
+
batch_size=batch_size,
|
90 |
+
device=device,
|
91 |
+
use_fast_tokenizer=True,
|
92 |
+
)
|
93 |
+
|
94 |
+
results["precision"] = np.average(results["precision"])
|
95 |
+
results["recall"] = np.average(results["recall"])
|
96 |
+
results["f1"] = np.average(results["f1"])
|
97 |
+
|
98 |
+
return results
|
99 |
+
|
100 |
+
result_dictionary["bert_score"] = inference()
|
101 |
+
|
102 |
+
if output_path is not None:
|
103 |
+
with open(output_path, "w") as f:
|
104 |
+
json.dump(result_dictionary, f, indent=4)
|
105 |
+
|
106 |
+
print(f"Results: {json.dumps(result_dictionary,indent=4)}")
|
107 |
+
|
108 |
+
return result_dictionary
|
109 |
+
|
110 |
+
|
111 |
+
if __name__ == "__main__":
|
112 |
+
parser = argparse.ArgumentParser(
|
113 |
+
description="Run the translation evaluation experiments"
|
114 |
+
)
|
115 |
+
parser.add_argument(
|
116 |
+
"--pred_path",
|
117 |
+
type=str,
|
118 |
+
required=True,
|
119 |
+
help="Path to a txt file containing the predicted sentences.",
|
120 |
+
)
|
121 |
+
|
122 |
+
parser.add_argument(
|
123 |
+
"--gold_path",
|
124 |
+
type=str,
|
125 |
+
required=True,
|
126 |
+
help="Path to a txt file containing the gold sentences.",
|
127 |
+
)
|
128 |
+
|
129 |
+
parser.add_argument(
|
130 |
+
"--starting_batch_size",
|
131 |
+
type=int,
|
132 |
+
default=64,
|
133 |
+
help="Starting batch size for BertScore, we will automatically reduce it if we find an OOM error.",
|
134 |
+
)
|
135 |
+
|
136 |
+
parser.add_argument(
|
137 |
+
"--output_path",
|
138 |
+
type=str,
|
139 |
+
default=None,
|
140 |
+
help="Path to a json file to save the results. If not given, the results will be printed to the console.",
|
141 |
+
)
|
142 |
+
|
143 |
+
parser.add_argument(
|
144 |
+
"--bert_score_model",
|
145 |
+
type=str,
|
146 |
+
default="microsoft/deberta-xlarge-mnli",
|
147 |
+
help="Model to use for BertScore. See: https://github.com/huggingface/datasets/tree/master/metrics/bertscore"
|
148 |
+
"and https://github.com/Tiiiger/bert_score for more details.",
|
149 |
+
)
|
150 |
+
|
151 |
+
args = parser.parse_args()
|
152 |
+
|
153 |
+
eval_files(
|
154 |
+
pred_path=args.pred_path,
|
155 |
+
gold_path=args.gold_path,
|
156 |
+
starting_batch_size=args.starting_batch_size,
|
157 |
+
output_path=args.output_path,
|
158 |
+
bert_score_model=args.bert_score_model,
|
159 |
+
)
|
sample_text/en2es.m2m100_1.2B.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"sacrebleu": {"score": 32.101150640281695, "counts": [19160, 11392, 7558, 5186], "totals": [31477, 30479, 29481, 28485], "precisions": [60.86984147155066, 37.37655434889596, 25.636850853091822, 18.20607337195015], "bp": 1.0, "sys_len": 31477, "ref_len": 30102}, "rouge": {"rouge1": [[0.5852396804366098, 0.6089057437338691, 0.5919486437026797], [0.5964621218261164, 0.6200342221830797, 0.6029705008756368], [0.6068321807422377, 0.6311106822798185, 0.61324805661008]], "rouge2": [[0.3710985389559613, 0.38708055355385995, 0.3761201217327784], [0.3844850790869714, 0.40017782122170353, 0.38920434271970195], [0.3968990790506025, 0.41382310483690327, 0.4022299418726329]], "rougeL": [[0.5351505034410595, 0.5564838960633809, 0.5410602618870524], [0.5457898501195475, 0.5677049056091881, 0.5519189480892548], [0.5575497491149766, 0.5787856637940312, 0.5630101422167583]], "rougeLsum": [[0.5352116089085267, 0.5570236521823667, 0.5415939934790461], [0.5463246235983789, 0.5676427704754348, 0.5522237812823654], [0.5581141358005033, 0.5796683147249665, 0.5630221371759908]]}, "bleu": {"bleu": 0.2842153038526809, "precisions": [0.5535070989616444, 0.33646946844340314, 0.22383069265549602, 0.15653135365661033], "brevity_penalty": 1.0, "length_ratio": 1.0469217970049918, "translation_length": 28314, "reference_length": 27045}, "meteor": {"meteor": 0.4880039569987408}, "ter": {"score": 59.500831946755405, "num_edits": 16092, "ref_length": 27045.0}, "bert_score": {"precision": 0.8192511852383614, "recall": 0.8262866012752056, "f1": 0.8223477345705033, "hashcode": "microsoft/deberta-xlarge-mnli_L40_no-idf_version=0.3.11(hug_trans=4.18.0)_fast-tokenizer"}}
|
sample_text/en2es.m2m100_418M.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"sacrebleu": {"score": 29.035496917461597, "counts": [18582, 10514, 6681, 4387], "totals": [31477, 30479, 29481, 28485], "precisions": [59.033580074339994, 34.49588241084025, 22.662053525999795, 15.401088292083553], "bp": 1.0, "sys_len": 31477, "ref_len": 30388}, "rouge": {"rouge1": [[0.5661701202298134, 0.5806961045770566, 0.5693885562082325], [0.5768745925790656, 0.5926959547911554, 0.5803693779677083], [0.5871085218904836, 0.6035331460243276, 0.5900979805085623]], "rouge2": [[0.34243414046469267, 0.35226400857606666, 0.34469210847048837], [0.3545484183384055, 0.36470783370743065, 0.3569058648048812], [0.36612813327517263, 0.37717476449671, 0.3689653665404565]], "rougeL": [[0.5129704896656746, 0.526995889564155, 0.5162056185006965], [0.523632841460358, 0.5375452284094455, 0.5267080806612512], [0.5350158816319085, 0.5480980981777757, 0.5372302857012781]], "rougeLsum": [[0.5126805856827783, 0.5265189554049317, 0.5155154093959223], [0.5239559133309495, 0.5380410013947112, 0.5271022617246641], [0.5351934954578494, 0.5491115103854219, 0.5381174565735956]]}, "bleu": {"bleu": 0.2546886610724999, "precisions": [0.5339761248852158, 0.30784155806120955, 0.19560013678331242, 0.1308640025272469], "brevity_penalty": 1.0, "length_ratio": 1.0353982300884956, "translation_length": 28314, "reference_length": 27346}, "meteor": {"meteor": 0.4630996837124251}, "ter": {"score": 61.848167922182405, "num_edits": 16913, "ref_length": 27346.0}, "bert_score": {"precision": 0.8128398380875588, "recall": 0.8185442119538784, "f1": 0.8153291321396827, "hashcode": "microsoft/deberta-xlarge-mnli_L40_no-idf_version=0.3.11(hug_trans=4.18.0)_fast-tokenizer"}}
|
sample_text/{en2es.translation.txt → en2es.translation.m2m100_1.2B.txt}
RENAMED
@@ -997,4 +997,4 @@ Quiero felicitarle, lamentablemente en su ausencia, por la forma exhaustiva y ri
|
|
997 |
Él mencionó anteriormente que el informe se llevó a cabo con una mayoría significativa, pero no con mi apoyo.
|
998 |
Por lo tanto, aunque no comparto sus conclusiones, creo que él ha ilustrado en su informe muchas de las cuestiones que la Comisión debe abordar.
|
999 |
La primera es la posibilidad de renacentización de la política de competencia.
|
1000 |
-
Sé que la Comisión se opone a esto, pero el potencial existe.
|
|
|
997 |
Él mencionó anteriormente que el informe se llevó a cabo con una mayoría significativa, pero no con mi apoyo.
|
998 |
Por lo tanto, aunque no comparto sus conclusiones, creo que él ha ilustrado en su informe muchas de las cuestiones que la Comisión debe abordar.
|
999 |
La primera es la posibilidad de renacentización de la política de competencia.
|
1000 |
+
Sé que la Comisión se opone a esto, pero el potencial existe.
|
sample_text/en2es.translation.m2m100_418M.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
translate.py
CHANGED
@@ -122,6 +122,7 @@ def main(
|
|
122 |
total=total_lines, desc="Dataset translation", leave=True, ascii=True
|
123 |
) as pbar, open(output_path, "w", encoding="utf-8") as output_file:
|
124 |
with torch.no_grad():
|
|
|
125 |
for batch in data_loader:
|
126 |
batch["input_ids"] = batch["input_ids"]
|
127 |
batch["attention_mask"] = batch["attention_mask"]
|
@@ -141,8 +142,11 @@ def main(
|
|
141 |
tgt_text = tokenizer.batch_decode(
|
142 |
generated_tokens, skip_special_tokens=True
|
143 |
)
|
144 |
-
|
145 |
-
|
|
|
|
|
|
|
146 |
|
147 |
pbar.update(len(tgt_text))
|
148 |
|
|
|
122 |
total=total_lines, desc="Dataset translation", leave=True, ascii=True
|
123 |
) as pbar, open(output_path, "w", encoding="utf-8") as output_file:
|
124 |
with torch.no_grad():
|
125 |
+
first_batch = True
|
126 |
for batch in data_loader:
|
127 |
batch["input_ids"] = batch["input_ids"]
|
128 |
batch["attention_mask"] = batch["attention_mask"]
|
|
|
142 |
tgt_text = tokenizer.batch_decode(
|
143 |
generated_tokens, skip_special_tokens=True
|
144 |
)
|
145 |
+
if not first_batch:
|
146 |
+
print(file=output_file)
|
147 |
+
else:
|
148 |
+
first_batch = False
|
149 |
+
print("\n".join(tgt_text), file=output_file, end="")
|
150 |
|
151 |
pbar.update(len(tgt_text))
|
152 |
|