victan commited on
Commit
87809bd
1 Parent(s): 77d1870

Upload seamless_communication/cli/eval_utils/compute_metrics.py with huggingface_hub

Browse files
seamless_communication/cli/eval_utils/compute_metrics.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # MIT_LICENSE file in the root directory of this source tree.
6
+
7
+ import json
8
+ import logging
9
+ from pathlib import Path
10
+ from typing import Optional, Tuple, Union
11
+
12
+ import pandas as pd
13
+ import whisper
14
+ from fairseq2.typing import Device
15
+ from jiwer import cer, wer
16
+ from sacrebleu.metrics.base import Score, Signature
17
+ from sacrebleu.metrics.bleu import BLEU
18
+ from sacrebleu.metrics.chrf import CHRF
19
+ from seamless_communication.cli.eval_utils.lang_mapping import LANG3_LANG2
20
+ from tqdm import tqdm
21
+ from whisper import Whisper
22
+ from whisper.normalizers import BasicTextNormalizer, EnglishTextNormalizer
23
+
24
+ logging.basicConfig(
25
+ level=logging.INFO,
26
+ format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
27
+ )
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ def init_whisper_model(
33
+ device: Device,
34
+ whisper_model_name: str = "large",
35
+ ) -> Whisper:
36
+ return whisper.load_model(name=whisper_model_name, device=device)
37
+
38
+
39
+ def transcribe_series(
40
+ audio_paths_series: pd.Series,
41
+ asr_model: Whisper,
42
+ audio_lang: str,
43
+ beam_size: int = 1,
44
+ temperature: float = 0.0,
45
+ ) -> pd.Series:
46
+ """Transcribes each audio filepath from series and returns series of transcriptions
47
+ Args:
48
+ audio_paths_series (pd.Series): each line contains path to audio file.
49
+ asr_model: ASR model to do the transcribing process e.g. Whisper
50
+ audio_lang (str): what language is used in the given audio, used by ASR model
51
+ beam_size (int): whisper beam size. Defaults to 1
52
+ temperature (float): whisper temperature. Defaults to 0.0 to avoid fallback decoding (see details below).
53
+ Returns:
54
+ pd.Series: Series where each line has a transcription of corresponding audio from audio_paths_series
55
+ Whisper model implements decoding with fallback: https://github.com/openai/whisper/blob/main/whisper/transcribe.py#L147
56
+ The core idea is that decoding at each time step might happen multiple times if at least one criterion to "fall back" i.e.
57
+ start over is fired. Number of fallback iterations is determined by the schedule of temperature values:
58
+ https://github.com/openai/whisper/blob/main/whisper/transcribe.py#L41
59
+ By default this schedule is active and temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0) i.e. even with beam_size 5 it might fell back and
60
+ turn on sampling by using temperature > 0, in this case the beam search is not used in the fall back iteration.
61
+ Explicit setting of temperature=0.0 overwrites the schedule and fall back decoding has only one for loop iteration i.e. no fall backs.
62
+ This allows us to do reproducible evaluation without sample variations. Beware that this might introduce the repetition loops in
63
+ the transcriptions and lead to worse ASR-BLEU score in the end.
64
+ """
65
+
66
+ if len(audio_lang) == 3:
67
+ # to make it work with whisper
68
+ audio_lang = LANG3_LANG2[audio_lang]
69
+
70
+ transcriptions = {}
71
+
72
+ for idx, audio_path in tqdm(
73
+ audio_paths_series.items(),
74
+ desc=f"Transcribing {audio_paths_series.name} column",
75
+ total=len(audio_paths_series),
76
+ ):
77
+ hypo = asr_model.transcribe(
78
+ audio_path,
79
+ temperature=temperature,
80
+ beam_size=beam_size,
81
+ language=audio_lang,
82
+ )["text"].strip()
83
+ transcriptions[idx] = hypo
84
+
85
+ transcriptions_series = pd.Series(transcriptions)
86
+ transcriptions_series.name = f"{audio_paths_series.name}_transcribed"
87
+
88
+ return transcriptions_series
89
+
90
+
91
+ def whisper_normalize_series(
92
+ transcription_series: pd.Series, text_lang: str
93
+ ) -> pd.Series:
94
+ """Normalizes the text series using whisper noramlizer. English has a specific one in whisper package.
95
+ Args:
96
+ transcription_series (pd.Series): Each line contains arbitrary text written in text_lang
97
+ text_lang (str): Language of the text in series
98
+ Returns:
99
+ pd.Series: Series with normalized text
100
+ """
101
+ if text_lang == "eng":
102
+ normalizer = EnglishTextNormalizer()
103
+ else:
104
+ normalizer = BasicTextNormalizer()
105
+
106
+ norm_transcriptions = {}
107
+
108
+ for idx, text in transcription_series.items():
109
+ norm_transcriptions[idx] = normalizer(text)
110
+
111
+ norm_transcriptions_series = pd.Series(norm_transcriptions)
112
+ norm_transcriptions_series.name = transcription_series.name
113
+
114
+ return norm_transcriptions_series
115
+
116
+
117
+ def compute_asr_bleu(
118
+ audio_paths_series: pd.Series,
119
+ ref_text_series: pd.Series,
120
+ lang: str,
121
+ asr_model: Whisper,
122
+ whisper_normalize_text: bool = True,
123
+ beam_size: int = 1,
124
+ temperature: float = 0.0,
125
+ return_transcriptions: bool = True,
126
+ ) -> Tuple[Score, Signature, pd.DataFrame]:
127
+ """Wraps functions above to compute corpus-level ASR-BLEU
128
+ ASR decoding hyper-parameters are hard coded to ensure reproducibility across evaluations
129
+ Args:
130
+ audio_paths_series (pd.Series): each line contains path to audio
131
+ ref_text_series (pd.Series): each line contains the text reference to compare audio with
132
+ lang (str): the language of both audio and ref_text
133
+ asr_model: whisper ASR model
134
+ whisper_normalize_text (bool): normalize both text hypotheses and reference if True. Defaults to True.
135
+ beam_size (int): beam_size for whisper generation
136
+ temperature (float): Temperature sampling value for whisper generation
137
+ return_transcriptions (bool)
138
+ """
139
+
140
+ audio_transcriptions = transcribe_series(
141
+ audio_paths_series,
142
+ asr_model,
143
+ audio_lang=lang,
144
+ beam_size=beam_size,
145
+ temperature=temperature,
146
+ )
147
+ asr_bleu, asr_bleu_signature = compute_corpus_metric_score(
148
+ audio_transcriptions, ref_text_series, lang, whisper_normalize_text
149
+ )
150
+ asr_bleu_signature.info["whisper_asr_beam_size"] = beam_size
151
+ asr_bleu_signature.info["whisper_asr_temperature"] = temperature
152
+ asr_bleu_signature.info["whisper_asr_language"] = lang
153
+
154
+ transcript_df = None
155
+ if return_transcriptions:
156
+ transcript_df = pd.concat(
157
+ [
158
+ audio_paths_series,
159
+ audio_transcriptions,
160
+ ref_text_series,
161
+ ],
162
+ axis=1,
163
+ keys=["audio", "transcript", "reference"],
164
+ )
165
+ return asr_bleu, asr_bleu_signature, transcript_df
166
+
167
+
168
+ def get_tokenizer(lang: str, metric: str = "bleu") -> str:
169
+ """Get tokenizer for language
170
+ Args:
171
+ lang (str): Three letter code of the language
172
+ metric (str): Metric being computed. Valid values are "bleu" and "asr"
173
+ """
174
+ lang_tok_map = {
175
+ "cmn": "char",
176
+ "jpn": "char",
177
+ "tha": "char",
178
+ "lao": "char",
179
+ "mya": "char",
180
+ }
181
+ default = (
182
+ "13a" if metric == "bleu" else "word"
183
+ ) # 13a is the default tokenizer for bleu and wer for asr
184
+ tok = lang_tok_map.get(lang, default)
185
+ return tok
186
+
187
+
188
+ def compute_asr_error_rate(
189
+ hyp_text_series: pd.Series,
190
+ ref_text_series: pd.Series,
191
+ lang: str,
192
+ whisper_normalize_text: bool = True,
193
+ ) -> Tuple[float, str]:
194
+ """Wraps normalization functions and computes ASR WER/CER score
195
+ Args:
196
+ hyp_text_series (pd.Series): each line contains s2t model prediction or first pass prediction
197
+ ref_text_series (pd.Series): _description_
198
+ lang (str): _description_
199
+ whisper_normalize_text (bool, optional): normalize both text hypotheses and reference if True. Defaults to True.
200
+ Returns:
201
+ (MetricScore, MetricScoreSignature)
202
+ """
203
+ if whisper_normalize_text:
204
+ hyp_text_series = whisper_normalize_series(hyp_text_series, lang)
205
+ ref_text_series = whisper_normalize_series(ref_text_series, lang)
206
+
207
+ tokenizer_name = get_tokenizer(lang, metric="error_rate")
208
+ metric_name = wer if tokenizer_name == "word" else cer
209
+ metric_score = metric_name(hyp_text_series.to_list(), ref_text_series.to_list())
210
+ return metric_score, f"{metric_name.__name__} is {metric_score}"
211
+
212
+
213
+ def compute_corpus_metric_score(
214
+ hyp_text_series: pd.Series,
215
+ ref_text_series: pd.Series,
216
+ lang: str,
217
+ whisper_normalize_text: bool = True,
218
+ metric: str = "bleu",
219
+ ) -> Tuple[Score, Signature]:
220
+ """Wraps normalization functions and compute corpus-level BLEU/chrF++ score
221
+ Args:
222
+ hyp_text_series (pd.Series): each line contains s2t model prediction or first pass prediction
223
+ ref_text_series (pd.Series): _description_
224
+ lang (str): _description_
225
+ whisper_normalize_text (bool, optional): normalize both text hypotheses and reference if True. Defaults to True.
226
+ Returns:
227
+ (MetricScore, MetricScoreSignature)
228
+ """
229
+ if whisper_normalize_text:
230
+ hyp_text_series = whisper_normalize_series(hyp_text_series, lang)
231
+ ref_text_series = whisper_normalize_series(ref_text_series, lang)
232
+
233
+ tokenizer_name = get_tokenizer(lang)
234
+ corpus_metric_score_metric: Union[BLEU, CHRF]
235
+ if metric == "bleu":
236
+ corpus_metric_score_metric = BLEU(
237
+ lowercase=whisper_normalize_text, tokenize=tokenizer_name
238
+ ) # lowercase applied if we use whisper_normalize_text
239
+ elif metric == "chrF++":
240
+ corpus_metric_score_metric = CHRF(word_order=2)
241
+
242
+ corpus_metric_score = corpus_metric_score_metric.corpus_score(
243
+ hyp_text_series.to_list(), [ref_text_series.to_list()]
244
+ )
245
+ corpus_metric_score_signature = corpus_metric_score_metric.get_signature()
246
+ corpus_metric_score_signature.info["whisper_normalize"] = whisper_normalize_text
247
+
248
+ return corpus_metric_score, corpus_metric_score_signature
249
+
250
+
251
+ def compute_quality_metrics(
252
+ output_manifest_tsv_path: Path,
253
+ output_path: Path,
254
+ tgt_lang: str,
255
+ task: str,
256
+ device: Device,
257
+ whisper_model_name: str = "large",
258
+ whisper_normalize_text_output: bool = False,
259
+ ref_text_col_name: str = "ref_tgt_text",
260
+ pred_text_col_name: Optional[str] = "pred_tgt_text",
261
+ pred_audio_col_name: str = "pred_tgt_audio",
262
+ ) -> str:
263
+ """Wraps asr and s2t bleu functions to call it with TSV manifest composed on expressivity side
264
+ Args:
265
+ output_manifest_tsv_path (Path): output manifest which has "ref_text", "hypo_audio", "s2t_out" column names
266
+ output_path (Path): Directory to write files with metrics
267
+ tgt_lang (str): what language we evaluate on
268
+ task (str): Task we are currently evaluating for
269
+ device (Device): Device to use for inference
270
+ whisper_model_name (str): Whisper model name. Defaults to "large".
271
+ whisper_normalize_text_output (bool): Normalizes text output using whisper_normalizer if set to true
272
+ ref_text_col_name (str): Column name in the tsv corresponding to reference target text
273
+ pred_text_col_name (str): Column name in the tsv corresponding to predicted target text
274
+ pred_audio_col_name (str): Column name in the tsv corresponding to predicted target audio.
275
+ Setting this value to none will skip speech metrics
276
+ """
277
+ df = pd.read_csv(
278
+ output_manifest_tsv_path, sep="\t", quoting=3, encoding="utf-8", escapechar="\\"
279
+ )
280
+ task = task.upper()
281
+
282
+ if not output_path.exists():
283
+ output_path.mkdir(parents=True, exist_ok=True)
284
+
285
+ if task in ["S2TT", "S2ST", "T2TT"] and pred_text_col_name:
286
+ metric = "chrF++" if task == "T2TT" else "bleu"
287
+ text_metric, text_metric_signature = compute_corpus_metric_score(
288
+ hyp_text_series=df[pred_text_col_name],
289
+ ref_text_series=df[ref_text_col_name],
290
+ lang=tgt_lang,
291
+ whisper_normalize_text=whisper_normalize_text_output,
292
+ metric=metric,
293
+ )
294
+ text_metric_json = text_metric.format(
295
+ signature=text_metric_signature.format(), is_json=True
296
+ )
297
+
298
+ if task == "T2TT":
299
+ filename = "t2tt_chrf.json"
300
+ cur_task = "T2TT"
301
+ else:
302
+ filename = (
303
+ "s2tt_bleu_normalized.json"
304
+ if whisper_normalize_text_output
305
+ else "s2tt_bleu.json"
306
+ )
307
+ cur_task = "S2TT"
308
+
309
+ with open(output_path / filename, "w") as f:
310
+ f.write(text_metric_json)
311
+
312
+ logger.info(f"{cur_task} {metric}:\n{text_metric_json}")
313
+
314
+ if task in ["T2ST", "S2ST"]:
315
+ whisper_model = init_whisper_model(device, whisper_model_name)
316
+ (
317
+ asr_bleu_normalized,
318
+ asr_bleu_normalized_signature,
319
+ transcripts_df,
320
+ ) = compute_asr_bleu(
321
+ audio_paths_series=df[pred_audio_col_name],
322
+ ref_text_series=df[ref_text_col_name],
323
+ lang=tgt_lang,
324
+ asr_model=whisper_model,
325
+ whisper_normalize_text=True,
326
+ )
327
+ transcripts_df.to_csv(
328
+ (output_path / "whisper_audio_transcriptions.tsv"),
329
+ sep="\t",
330
+ index=False,
331
+ encoding="utf-8",
332
+ escapechar="\\",
333
+ )
334
+
335
+ asr_bleu_normalized_signature.info["whisper_asr_model"] = whisper_model_name
336
+
337
+ asr_bleu_normalized_json = asr_bleu_normalized.format(
338
+ signature=asr_bleu_normalized_signature.format(), is_json=True
339
+ )
340
+ filename = f"{task.lower()}_asr_bleu_normalized.json"
341
+
342
+ with open(
343
+ output_path / filename,
344
+ "w",
345
+ ) as f:
346
+ f.write(asr_bleu_normalized_json)
347
+
348
+ logger.info(f"{task} ASR Normalized BLEU:\n{asr_bleu_normalized_json}")
349
+
350
+ if task == "ASR":
351
+ asr_error_rate, asr_error_rate_signature = compute_asr_error_rate(
352
+ hyp_text_series=df[pred_text_col_name],
353
+ ref_text_series=df[ref_text_col_name],
354
+ lang=tgt_lang,
355
+ whisper_normalize_text=whisper_normalize_text_output,
356
+ )
357
+ d = {
358
+ "name": "WER",
359
+ "score": asr_error_rate,
360
+ "signature": asr_error_rate_signature,
361
+ }
362
+ asr_error_rate_json = json.dumps(d, indent=1, ensure_ascii=False)
363
+
364
+ filename = "asr_error_rate.json"
365
+
366
+ with open(output_path / filename, "w") as f:
367
+ f.write(asr_error_rate_json)
368
+
369
+ logger.info(f"ASR : {asr_error_rate_json}")
370
+
371
+ return filename