Iker commited on
Commit
62b1ca5
1 Parent(s): d0a815c

Implement evaluation

Browse files
README.md CHANGED
@@ -68,7 +68,7 @@ Run `python translate.py -h` for more info.
68
  ```bash
69
  accelerate launch translate.py \
70
  --sentences_path sample_text/en.txt \
71
- --output_path sample_text/en2es.translation.txt \
72
  --source_lang en \
73
  --target_lang es \
74
  --model_name facebook/m2m100_1.2B
@@ -83,7 +83,7 @@ You can use the Accelerate CLI to configure the Accelerate environment (Run
83
  ```bash
84
  accelerate launch --multi_gpu --num_processes 2 --num_machines 1 translate.py \
85
  --sentences_path sample_text/en.txt \
86
- --output_path sample_text/en2es.translation.txt \
87
  --source_lang en \
88
  --target_lang es \
89
  --model_name facebook/m2m100_1.2B
@@ -102,7 +102,7 @@ Use the `--precision` flag to choose the precision of the model. You can choose
102
  ```bash
103
  accelerate launch translate.py \
104
  --sentences_path sample_text/en.txt \
105
- --output_path sample_text/en2es.translation.txt \
106
  --source_lang en \
107
  --target_lang es \
108
  --model_name facebook/m2m100_1.2B \
@@ -111,6 +111,24 @@ accelerate launch translate.py \
111
 
112
  ## Evaluate translations
113
 
114
- Work in progress...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
 
 
68
  ```bash
69
  accelerate launch translate.py \
70
  --sentences_path sample_text/en.txt \
71
+ --output_path sample_text/en2es.translation.m2m100_1.2B.txt \
72
  --source_lang en \
73
  --target_lang es \
74
  --model_name facebook/m2m100_1.2B
 
83
  ```bash
84
  accelerate launch --multi_gpu --num_processes 2 --num_machines 1 translate.py \
85
  --sentences_path sample_text/en.txt \
86
+ --output_path sample_text/en2es.translation.m2m100_1.2B.txt \
87
  --source_lang en \
88
  --target_lang es \
89
  --model_name facebook/m2m100_1.2B
 
102
  ```bash
103
  accelerate launch translate.py \
104
  --sentences_path sample_text/en.txt \
105
+ --output_path sample_text/en2es.translation.m2m100_1.2B.txt \
106
  --source_lang en \
107
  --target_lang es \
108
  --model_name facebook/m2m100_1.2B \
 
111
 
112
  ## Evaluate translations
113
 
114
+ To run the evaluation script you need to install [bert_score](https://github.com/Tiiiger/bert_score): `pip install bert_score`
115
+
116
+ The evaluation script will calculate the following metrics:
117
+ * [SacreBLEU](https://github.com/huggingface/datasets/tree/master/metrics/sacrebleu)
118
+ * [BLEU](https://github.com/huggingface/datasets/tree/master/metrics/bleu)
119
+ * [ROUGE](https://github.com/huggingface/datasets/tree/master/metrics/rouge)
120
+ * [METEOR](https://github.com/huggingface/datasets/tree/master/metrics/meteor)
121
+ * [TER](https://github.com/huggingface/datasets/tree/master/metrics/ter)
122
+ * [BertScore](https://github.com/huggingface/datasets/tree/master/metrics/bertscore)
123
+
124
+ Run the following command to evaluate the translations:
125
+
126
+ ```bash
127
+ accelerate launch eval.py \
128
+ --pred_path sample_text/es.txt \
129
+ --gold_path sample_text/en2es.translation.m2m100_1.2B.txt
130
+ ```
131
+
132
+ If you want to save the results to a file use the `--output_path` flag.
133
 
134
 
dataset.py CHANGED
@@ -38,3 +38,36 @@ class DatasetReader(IterableDataset):
38
  file_itr = open(self.filename, "r")
39
  mapped_itr = map(self.preprocess, file_itr)
40
  return mapped_itr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  file_itr = open(self.filename, "r")
39
  mapped_itr = map(self.preprocess, file_itr)
40
  return mapped_itr
41
+
42
+
43
+ class ParallelTextReader(IterableDataset):
44
+ def __init__(self, pred_path: str, gold_path: str):
45
+ self.pred_path = pred_path
46
+ self.gold_path = gold_path
47
+ pref_filename_lines = count_lines(pred_path)
48
+ gold_path_lines = count_lines(gold_path)
49
+ assert pref_filename_lines == gold_path_lines, (
50
+ f"Lines in {pred_path} and {gold_path} do not match "
51
+ f"{pref_filename_lines} vs {gold_path_lines}"
52
+ )
53
+ self.num_sentences = gold_path_lines
54
+ self.current_line = 0
55
+
56
+ def preprocess(self, pred: str, gold: str):
57
+ self.current_line += 1
58
+ pred = pred.rstrip().strip()
59
+ gold = gold.rstrip().strip()
60
+ if len(pred) == 0:
61
+ print(f"Warning: Pred empty sentence at line {self.current_line}")
62
+ if len(gold) == 0:
63
+ print(f"Warning: Gold empty sentence at line {self.current_line}")
64
+ return pred, [gold]
65
+
66
+ def __iter__(self):
67
+ pred_itr = open(self.pred_path, "r")
68
+ gold_itr = open(self.gold_path, "r")
69
+ mapped_itr = map(self.preprocess, pred_itr, gold_itr)
70
+ return mapped_itr
71
+
72
+ def __len__(self):
73
+ return self.num_sentences
eval.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataset import ParallelTextReader
2
+ from torch.utils.data import DataLoader
3
+ from accelerate.memory_utils import find_executable_batch_size
4
+ from datasets import load_metric
5
+ from tqdm import tqdm
6
+ import torch
7
+ import json
8
+ import argparse
9
+ import numpy as np
10
+
11
+
12
+ def get_dataloader(pred_path: str, gold_path: str, batch_size: int):
13
+ """
14
+ Returns a dataloader for the given files.
15
+ """
16
+
17
+ def collate_fn(batch):
18
+ return list(map(list, zip(*batch)))
19
+
20
+ reader = ParallelTextReader(pred_path=pred_path, gold_path=gold_path)
21
+ dataloader = DataLoader(reader, batch_size=batch_size, collate_fn=collate_fn)
22
+ return dataloader
23
+
24
+
25
+ def eval_files(
26
+ pred_path: str,
27
+ gold_path: str,
28
+ bert_score_model: str,
29
+ starting_batch_size: int = 128,
30
+ output_path: str = None,
31
+ ):
32
+ """
33
+ Evaluates the given files.
34
+ """
35
+ if torch.cuda.is_available():
36
+ device = "cuda:0"
37
+ print("We will use a GPU to calculate BertScore.")
38
+ else:
39
+ device = "cpu"
40
+ print(
41
+ f"We will use the CPU to calculate BertScore, this can be slow for large datasets."
42
+ )
43
+
44
+ dataloader = get_dataloader(pred_path, gold_path, starting_batch_size)
45
+ print("Loading sacrebleu...")
46
+ sacrebleu = load_metric("sacrebleu")
47
+ print("Loading rouge...")
48
+ rouge = load_metric("rouge")
49
+ print("Loading bleu...")
50
+ bleu = load_metric("bleu")
51
+ print("Loading meteor...")
52
+ meteor = load_metric("meteor")
53
+ print("Loading ter...")
54
+ ter = load_metric("ter")
55
+ print("Loading BertScore...")
56
+ bert_score = load_metric("bertscore")
57
+
58
+ with tqdm(total=len(dataloader.dataset), desc="Loading data...") as pbar:
59
+ for predictions, references in dataloader:
60
+ sacrebleu.add_batch(predictions=predictions, references=references)
61
+ rouge.add_batch(predictions=predictions, references=references)
62
+ bleu.add_batch(
63
+ predictions=[p.split() for p in predictions],
64
+ references=[[r[0].split()] for r in references],
65
+ )
66
+ meteor.add_batch(predictions=predictions, references=references)
67
+ ter.add_batch(predictions=predictions, references=references)
68
+ bert_score.add_batch(predictions=predictions, references=references)
69
+ pbar.update(len(predictions))
70
+
71
+ result_dictionary = {}
72
+ print(f"Computing sacrebleu")
73
+ result_dictionary["sacrebleu"] = sacrebleu.compute()
74
+ print(f"Computing rouge score")
75
+ result_dictionary["rouge"] = rouge.compute()
76
+ print(f"Computing bleu score")
77
+ result_dictionary["bleu"] = bleu.compute()
78
+ print(f"Computing meteor score")
79
+ result_dictionary["meteor"] = meteor.compute()
80
+ print(f"Computing ter score")
81
+ result_dictionary["ter"] = ter.compute()
82
+
83
+ @find_executable_batch_size(starting_batch_size=starting_batch_size)
84
+ def inference(batch_size):
85
+ nonlocal bert_score, bert_score_model
86
+ print(f"Computing bert score with batch size {batch_size} on {device}")
87
+ results = bert_score.compute(
88
+ model_type=bert_score_model,
89
+ batch_size=batch_size,
90
+ device=device,
91
+ use_fast_tokenizer=True,
92
+ )
93
+
94
+ results["precision"] = np.average(results["precision"])
95
+ results["recall"] = np.average(results["recall"])
96
+ results["f1"] = np.average(results["f1"])
97
+
98
+ return results
99
+
100
+ result_dictionary["bert_score"] = inference()
101
+
102
+ if output_path is not None:
103
+ with open(output_path, "w") as f:
104
+ json.dump(result_dictionary, f, indent=4)
105
+
106
+ print(f"Results: {json.dumps(result_dictionary,indent=4)}")
107
+
108
+ return result_dictionary
109
+
110
+
111
+ if __name__ == "__main__":
112
+ parser = argparse.ArgumentParser(
113
+ description="Run the translation evaluation experiments"
114
+ )
115
+ parser.add_argument(
116
+ "--pred_path",
117
+ type=str,
118
+ required=True,
119
+ help="Path to a txt file containing the predicted sentences.",
120
+ )
121
+
122
+ parser.add_argument(
123
+ "--gold_path",
124
+ type=str,
125
+ required=True,
126
+ help="Path to a txt file containing the gold sentences.",
127
+ )
128
+
129
+ parser.add_argument(
130
+ "--starting_batch_size",
131
+ type=int,
132
+ default=64,
133
+ help="Starting batch size for BertScore, we will automatically reduce it if we find an OOM error.",
134
+ )
135
+
136
+ parser.add_argument(
137
+ "--output_path",
138
+ type=str,
139
+ default=None,
140
+ help="Path to a json file to save the results. If not given, the results will be printed to the console.",
141
+ )
142
+
143
+ parser.add_argument(
144
+ "--bert_score_model",
145
+ type=str,
146
+ default="microsoft/deberta-xlarge-mnli",
147
+ help="Model to use for BertScore. See: https://github.com/huggingface/datasets/tree/master/metrics/bertscore"
148
+ "and https://github.com/Tiiiger/bert_score for more details.",
149
+ )
150
+
151
+ args = parser.parse_args()
152
+
153
+ eval_files(
154
+ pred_path=args.pred_path,
155
+ gold_path=args.gold_path,
156
+ starting_batch_size=args.starting_batch_size,
157
+ output_path=args.output_path,
158
+ bert_score_model=args.bert_score_model,
159
+ )
sample_text/en2es.m2m100_1.2B.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"sacrebleu": {"score": 32.101150640281695, "counts": [19160, 11392, 7558, 5186], "totals": [31477, 30479, 29481, 28485], "precisions": [60.86984147155066, 37.37655434889596, 25.636850853091822, 18.20607337195015], "bp": 1.0, "sys_len": 31477, "ref_len": 30102}, "rouge": {"rouge1": [[0.5852396804366098, 0.6089057437338691, 0.5919486437026797], [0.5964621218261164, 0.6200342221830797, 0.6029705008756368], [0.6068321807422377, 0.6311106822798185, 0.61324805661008]], "rouge2": [[0.3710985389559613, 0.38708055355385995, 0.3761201217327784], [0.3844850790869714, 0.40017782122170353, 0.38920434271970195], [0.3968990790506025, 0.41382310483690327, 0.4022299418726329]], "rougeL": [[0.5351505034410595, 0.5564838960633809, 0.5410602618870524], [0.5457898501195475, 0.5677049056091881, 0.5519189480892548], [0.5575497491149766, 0.5787856637940312, 0.5630101422167583]], "rougeLsum": [[0.5352116089085267, 0.5570236521823667, 0.5415939934790461], [0.5463246235983789, 0.5676427704754348, 0.5522237812823654], [0.5581141358005033, 0.5796683147249665, 0.5630221371759908]]}, "bleu": {"bleu": 0.2842153038526809, "precisions": [0.5535070989616444, 0.33646946844340314, 0.22383069265549602, 0.15653135365661033], "brevity_penalty": 1.0, "length_ratio": 1.0469217970049918, "translation_length": 28314, "reference_length": 27045}, "meteor": {"meteor": 0.4880039569987408}, "ter": {"score": 59.500831946755405, "num_edits": 16092, "ref_length": 27045.0}, "bert_score": {"precision": 0.8192511852383614, "recall": 0.8262866012752056, "f1": 0.8223477345705033, "hashcode": "microsoft/deberta-xlarge-mnli_L40_no-idf_version=0.3.11(hug_trans=4.18.0)_fast-tokenizer"}}
sample_text/en2es.m2m100_418M.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"sacrebleu": {"score": 29.035496917461597, "counts": [18582, 10514, 6681, 4387], "totals": [31477, 30479, 29481, 28485], "precisions": [59.033580074339994, 34.49588241084025, 22.662053525999795, 15.401088292083553], "bp": 1.0, "sys_len": 31477, "ref_len": 30388}, "rouge": {"rouge1": [[0.5661701202298134, 0.5806961045770566, 0.5693885562082325], [0.5768745925790656, 0.5926959547911554, 0.5803693779677083], [0.5871085218904836, 0.6035331460243276, 0.5900979805085623]], "rouge2": [[0.34243414046469267, 0.35226400857606666, 0.34469210847048837], [0.3545484183384055, 0.36470783370743065, 0.3569058648048812], [0.36612813327517263, 0.37717476449671, 0.3689653665404565]], "rougeL": [[0.5129704896656746, 0.526995889564155, 0.5162056185006965], [0.523632841460358, 0.5375452284094455, 0.5267080806612512], [0.5350158816319085, 0.5480980981777757, 0.5372302857012781]], "rougeLsum": [[0.5126805856827783, 0.5265189554049317, 0.5155154093959223], [0.5239559133309495, 0.5380410013947112, 0.5271022617246641], [0.5351934954578494, 0.5491115103854219, 0.5381174565735956]]}, "bleu": {"bleu": 0.2546886610724999, "precisions": [0.5339761248852158, 0.30784155806120955, 0.19560013678331242, 0.1308640025272469], "brevity_penalty": 1.0, "length_ratio": 1.0353982300884956, "translation_length": 28314, "reference_length": 27346}, "meteor": {"meteor": 0.4630996837124251}, "ter": {"score": 61.848167922182405, "num_edits": 16913, "ref_length": 27346.0}, "bert_score": {"precision": 0.8128398380875588, "recall": 0.8185442119538784, "f1": 0.8153291321396827, "hashcode": "microsoft/deberta-xlarge-mnli_L40_no-idf_version=0.3.11(hug_trans=4.18.0)_fast-tokenizer"}}
sample_text/{en2es.translation.txt → en2es.translation.m2m100_1.2B.txt} RENAMED
@@ -997,4 +997,4 @@ Quiero felicitarle, lamentablemente en su ausencia, por la forma exhaustiva y ri
997
  Él mencionó anteriormente que el informe se llevó a cabo con una mayoría significativa, pero no con mi apoyo.
998
  Por lo tanto, aunque no comparto sus conclusiones, creo que él ha ilustrado en su informe muchas de las cuestiones que la Comisión debe abordar.
999
  La primera es la posibilidad de renacentización de la política de competencia.
1000
- Sé que la Comisión se opone a esto, pero el potencial existe.
 
997
  Él mencionó anteriormente que el informe se llevó a cabo con una mayoría significativa, pero no con mi apoyo.
998
  Por lo tanto, aunque no comparto sus conclusiones, creo que él ha ilustrado en su informe muchas de las cuestiones que la Comisión debe abordar.
999
  La primera es la posibilidad de renacentización de la política de competencia.
1000
+ Sé que la Comisión se opone a esto, pero el potencial existe.
sample_text/en2es.translation.m2m100_418M.txt ADDED
The diff for this file is too large to render. See raw diff
 
translate.py CHANGED
@@ -122,6 +122,7 @@ def main(
122
  total=total_lines, desc="Dataset translation", leave=True, ascii=True
123
  ) as pbar, open(output_path, "w", encoding="utf-8") as output_file:
124
  with torch.no_grad():
 
125
  for batch in data_loader:
126
  batch["input_ids"] = batch["input_ids"]
127
  batch["attention_mask"] = batch["attention_mask"]
@@ -141,8 +142,11 @@ def main(
141
  tgt_text = tokenizer.batch_decode(
142
  generated_tokens, skip_special_tokens=True
143
  )
144
-
145
- print("\n".join(tgt_text), file=output_file)
 
 
 
146
 
147
  pbar.update(len(tgt_text))
148
 
 
122
  total=total_lines, desc="Dataset translation", leave=True, ascii=True
123
  ) as pbar, open(output_path, "w", encoding="utf-8") as output_file:
124
  with torch.no_grad():
125
+ first_batch = True
126
  for batch in data_loader:
127
  batch["input_ids"] = batch["input_ids"]
128
  batch["attention_mask"] = batch["attention_mask"]
 
142
  tgt_text = tokenizer.batch_decode(
143
  generated_tokens, skip_special_tokens=True
144
  )
145
+ if not first_batch:
146
+ print(file=output_file)
147
+ else:
148
+ first_batch = False
149
+ print("\n".join(tgt_text), file=output_file, end="")
150
 
151
  pbar.update(len(tgt_text))
152