Upload seamless_communication/cli/eval_utils/compute_metrics.py with huggingface_hub
Browse files
seamless_communication/cli/eval_utils/compute_metrics.py
ADDED
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# MIT_LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import json
|
8 |
+
import logging
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import Optional, Tuple, Union
|
11 |
+
|
12 |
+
import pandas as pd
|
13 |
+
import whisper
|
14 |
+
from fairseq2.typing import Device
|
15 |
+
from jiwer import cer, wer
|
16 |
+
from sacrebleu.metrics.base import Score, Signature
|
17 |
+
from sacrebleu.metrics.bleu import BLEU
|
18 |
+
from sacrebleu.metrics.chrf import CHRF
|
19 |
+
from seamless_communication.cli.eval_utils.lang_mapping import LANG3_LANG2
|
20 |
+
from tqdm import tqdm
|
21 |
+
from whisper import Whisper
|
22 |
+
from whisper.normalizers import BasicTextNormalizer, EnglishTextNormalizer
|
23 |
+
|
24 |
+
logging.basicConfig(
|
25 |
+
level=logging.INFO,
|
26 |
+
format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
|
27 |
+
)
|
28 |
+
|
29 |
+
logger = logging.getLogger(__name__)
|
30 |
+
|
31 |
+
|
32 |
+
def init_whisper_model(
|
33 |
+
device: Device,
|
34 |
+
whisper_model_name: str = "large",
|
35 |
+
) -> Whisper:
|
36 |
+
return whisper.load_model(name=whisper_model_name, device=device)
|
37 |
+
|
38 |
+
|
39 |
+
def transcribe_series(
|
40 |
+
audio_paths_series: pd.Series,
|
41 |
+
asr_model: Whisper,
|
42 |
+
audio_lang: str,
|
43 |
+
beam_size: int = 1,
|
44 |
+
temperature: float = 0.0,
|
45 |
+
) -> pd.Series:
|
46 |
+
"""Transcribes each audio filepath from series and returns series of transcriptions
|
47 |
+
Args:
|
48 |
+
audio_paths_series (pd.Series): each line contains path to audio file.
|
49 |
+
asr_model: ASR model to do the transcribing process e.g. Whisper
|
50 |
+
audio_lang (str): what language is used in the given audio, used by ASR model
|
51 |
+
beam_size (int): whisper beam size. Defaults to 1
|
52 |
+
temperature (float): whisper temperature. Defaults to 0.0 to avoid fallback decoding (see details below).
|
53 |
+
Returns:
|
54 |
+
pd.Series: Series where each line has a transcription of corresponding audio from audio_paths_series
|
55 |
+
Whisper model implements decoding with fallback: https://github.com/openai/whisper/blob/main/whisper/transcribe.py#L147
|
56 |
+
The core idea is that decoding at each time step might happen multiple times if at least one criterion to "fall back" i.e.
|
57 |
+
start over is fired. Number of fallback iterations is determined by the schedule of temperature values:
|
58 |
+
https://github.com/openai/whisper/blob/main/whisper/transcribe.py#L41
|
59 |
+
By default this schedule is active and temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0) i.e. even with beam_size 5 it might fell back and
|
60 |
+
turn on sampling by using temperature > 0, in this case the beam search is not used in the fall back iteration.
|
61 |
+
Explicit setting of temperature=0.0 overwrites the schedule and fall back decoding has only one for loop iteration i.e. no fall backs.
|
62 |
+
This allows us to do reproducible evaluation without sample variations. Beware that this might introduce the repetition loops in
|
63 |
+
the transcriptions and lead to worse ASR-BLEU score in the end.
|
64 |
+
"""
|
65 |
+
|
66 |
+
if len(audio_lang) == 3:
|
67 |
+
# to make it work with whisper
|
68 |
+
audio_lang = LANG3_LANG2[audio_lang]
|
69 |
+
|
70 |
+
transcriptions = {}
|
71 |
+
|
72 |
+
for idx, audio_path in tqdm(
|
73 |
+
audio_paths_series.items(),
|
74 |
+
desc=f"Transcribing {audio_paths_series.name} column",
|
75 |
+
total=len(audio_paths_series),
|
76 |
+
):
|
77 |
+
hypo = asr_model.transcribe(
|
78 |
+
audio_path,
|
79 |
+
temperature=temperature,
|
80 |
+
beam_size=beam_size,
|
81 |
+
language=audio_lang,
|
82 |
+
)["text"].strip()
|
83 |
+
transcriptions[idx] = hypo
|
84 |
+
|
85 |
+
transcriptions_series = pd.Series(transcriptions)
|
86 |
+
transcriptions_series.name = f"{audio_paths_series.name}_transcribed"
|
87 |
+
|
88 |
+
return transcriptions_series
|
89 |
+
|
90 |
+
|
91 |
+
def whisper_normalize_series(
|
92 |
+
transcription_series: pd.Series, text_lang: str
|
93 |
+
) -> pd.Series:
|
94 |
+
"""Normalizes the text series using whisper noramlizer. English has a specific one in whisper package.
|
95 |
+
Args:
|
96 |
+
transcription_series (pd.Series): Each line contains arbitrary text written in text_lang
|
97 |
+
text_lang (str): Language of the text in series
|
98 |
+
Returns:
|
99 |
+
pd.Series: Series with normalized text
|
100 |
+
"""
|
101 |
+
if text_lang == "eng":
|
102 |
+
normalizer = EnglishTextNormalizer()
|
103 |
+
else:
|
104 |
+
normalizer = BasicTextNormalizer()
|
105 |
+
|
106 |
+
norm_transcriptions = {}
|
107 |
+
|
108 |
+
for idx, text in transcription_series.items():
|
109 |
+
norm_transcriptions[idx] = normalizer(text)
|
110 |
+
|
111 |
+
norm_transcriptions_series = pd.Series(norm_transcriptions)
|
112 |
+
norm_transcriptions_series.name = transcription_series.name
|
113 |
+
|
114 |
+
return norm_transcriptions_series
|
115 |
+
|
116 |
+
|
117 |
+
def compute_asr_bleu(
|
118 |
+
audio_paths_series: pd.Series,
|
119 |
+
ref_text_series: pd.Series,
|
120 |
+
lang: str,
|
121 |
+
asr_model: Whisper,
|
122 |
+
whisper_normalize_text: bool = True,
|
123 |
+
beam_size: int = 1,
|
124 |
+
temperature: float = 0.0,
|
125 |
+
return_transcriptions: bool = True,
|
126 |
+
) -> Tuple[Score, Signature, pd.DataFrame]:
|
127 |
+
"""Wraps functions above to compute corpus-level ASR-BLEU
|
128 |
+
ASR decoding hyper-parameters are hard coded to ensure reproducibility across evaluations
|
129 |
+
Args:
|
130 |
+
audio_paths_series (pd.Series): each line contains path to audio
|
131 |
+
ref_text_series (pd.Series): each line contains the text reference to compare audio with
|
132 |
+
lang (str): the language of both audio and ref_text
|
133 |
+
asr_model: whisper ASR model
|
134 |
+
whisper_normalize_text (bool): normalize both text hypotheses and reference if True. Defaults to True.
|
135 |
+
beam_size (int): beam_size for whisper generation
|
136 |
+
temperature (float): Temperature sampling value for whisper generation
|
137 |
+
return_transcriptions (bool)
|
138 |
+
"""
|
139 |
+
|
140 |
+
audio_transcriptions = transcribe_series(
|
141 |
+
audio_paths_series,
|
142 |
+
asr_model,
|
143 |
+
audio_lang=lang,
|
144 |
+
beam_size=beam_size,
|
145 |
+
temperature=temperature,
|
146 |
+
)
|
147 |
+
asr_bleu, asr_bleu_signature = compute_corpus_metric_score(
|
148 |
+
audio_transcriptions, ref_text_series, lang, whisper_normalize_text
|
149 |
+
)
|
150 |
+
asr_bleu_signature.info["whisper_asr_beam_size"] = beam_size
|
151 |
+
asr_bleu_signature.info["whisper_asr_temperature"] = temperature
|
152 |
+
asr_bleu_signature.info["whisper_asr_language"] = lang
|
153 |
+
|
154 |
+
transcript_df = None
|
155 |
+
if return_transcriptions:
|
156 |
+
transcript_df = pd.concat(
|
157 |
+
[
|
158 |
+
audio_paths_series,
|
159 |
+
audio_transcriptions,
|
160 |
+
ref_text_series,
|
161 |
+
],
|
162 |
+
axis=1,
|
163 |
+
keys=["audio", "transcript", "reference"],
|
164 |
+
)
|
165 |
+
return asr_bleu, asr_bleu_signature, transcript_df
|
166 |
+
|
167 |
+
|
168 |
+
def get_tokenizer(lang: str, metric: str = "bleu") -> str:
|
169 |
+
"""Get tokenizer for language
|
170 |
+
Args:
|
171 |
+
lang (str): Three letter code of the language
|
172 |
+
metric (str): Metric being computed. Valid values are "bleu" and "asr"
|
173 |
+
"""
|
174 |
+
lang_tok_map = {
|
175 |
+
"cmn": "char",
|
176 |
+
"jpn": "char",
|
177 |
+
"tha": "char",
|
178 |
+
"lao": "char",
|
179 |
+
"mya": "char",
|
180 |
+
}
|
181 |
+
default = (
|
182 |
+
"13a" if metric == "bleu" else "word"
|
183 |
+
) # 13a is the default tokenizer for bleu and wer for asr
|
184 |
+
tok = lang_tok_map.get(lang, default)
|
185 |
+
return tok
|
186 |
+
|
187 |
+
|
188 |
+
def compute_asr_error_rate(
|
189 |
+
hyp_text_series: pd.Series,
|
190 |
+
ref_text_series: pd.Series,
|
191 |
+
lang: str,
|
192 |
+
whisper_normalize_text: bool = True,
|
193 |
+
) -> Tuple[float, str]:
|
194 |
+
"""Wraps normalization functions and computes ASR WER/CER score
|
195 |
+
Args:
|
196 |
+
hyp_text_series (pd.Series): each line contains s2t model prediction or first pass prediction
|
197 |
+
ref_text_series (pd.Series): _description_
|
198 |
+
lang (str): _description_
|
199 |
+
whisper_normalize_text (bool, optional): normalize both text hypotheses and reference if True. Defaults to True.
|
200 |
+
Returns:
|
201 |
+
(MetricScore, MetricScoreSignature)
|
202 |
+
"""
|
203 |
+
if whisper_normalize_text:
|
204 |
+
hyp_text_series = whisper_normalize_series(hyp_text_series, lang)
|
205 |
+
ref_text_series = whisper_normalize_series(ref_text_series, lang)
|
206 |
+
|
207 |
+
tokenizer_name = get_tokenizer(lang, metric="error_rate")
|
208 |
+
metric_name = wer if tokenizer_name == "word" else cer
|
209 |
+
metric_score = metric_name(hyp_text_series.to_list(), ref_text_series.to_list())
|
210 |
+
return metric_score, f"{metric_name.__name__} is {metric_score}"
|
211 |
+
|
212 |
+
|
213 |
+
def compute_corpus_metric_score(
|
214 |
+
hyp_text_series: pd.Series,
|
215 |
+
ref_text_series: pd.Series,
|
216 |
+
lang: str,
|
217 |
+
whisper_normalize_text: bool = True,
|
218 |
+
metric: str = "bleu",
|
219 |
+
) -> Tuple[Score, Signature]:
|
220 |
+
"""Wraps normalization functions and compute corpus-level BLEU/chrF++ score
|
221 |
+
Args:
|
222 |
+
hyp_text_series (pd.Series): each line contains s2t model prediction or first pass prediction
|
223 |
+
ref_text_series (pd.Series): _description_
|
224 |
+
lang (str): _description_
|
225 |
+
whisper_normalize_text (bool, optional): normalize both text hypotheses and reference if True. Defaults to True.
|
226 |
+
Returns:
|
227 |
+
(MetricScore, MetricScoreSignature)
|
228 |
+
"""
|
229 |
+
if whisper_normalize_text:
|
230 |
+
hyp_text_series = whisper_normalize_series(hyp_text_series, lang)
|
231 |
+
ref_text_series = whisper_normalize_series(ref_text_series, lang)
|
232 |
+
|
233 |
+
tokenizer_name = get_tokenizer(lang)
|
234 |
+
corpus_metric_score_metric: Union[BLEU, CHRF]
|
235 |
+
if metric == "bleu":
|
236 |
+
corpus_metric_score_metric = BLEU(
|
237 |
+
lowercase=whisper_normalize_text, tokenize=tokenizer_name
|
238 |
+
) # lowercase applied if we use whisper_normalize_text
|
239 |
+
elif metric == "chrF++":
|
240 |
+
corpus_metric_score_metric = CHRF(word_order=2)
|
241 |
+
|
242 |
+
corpus_metric_score = corpus_metric_score_metric.corpus_score(
|
243 |
+
hyp_text_series.to_list(), [ref_text_series.to_list()]
|
244 |
+
)
|
245 |
+
corpus_metric_score_signature = corpus_metric_score_metric.get_signature()
|
246 |
+
corpus_metric_score_signature.info["whisper_normalize"] = whisper_normalize_text
|
247 |
+
|
248 |
+
return corpus_metric_score, corpus_metric_score_signature
|
249 |
+
|
250 |
+
|
251 |
+
def compute_quality_metrics(
|
252 |
+
output_manifest_tsv_path: Path,
|
253 |
+
output_path: Path,
|
254 |
+
tgt_lang: str,
|
255 |
+
task: str,
|
256 |
+
device: Device,
|
257 |
+
whisper_model_name: str = "large",
|
258 |
+
whisper_normalize_text_output: bool = False,
|
259 |
+
ref_text_col_name: str = "ref_tgt_text",
|
260 |
+
pred_text_col_name: Optional[str] = "pred_tgt_text",
|
261 |
+
pred_audio_col_name: str = "pred_tgt_audio",
|
262 |
+
) -> str:
|
263 |
+
"""Wraps asr and s2t bleu functions to call it with TSV manifest composed on expressivity side
|
264 |
+
Args:
|
265 |
+
output_manifest_tsv_path (Path): output manifest which has "ref_text", "hypo_audio", "s2t_out" column names
|
266 |
+
output_path (Path): Directory to write files with metrics
|
267 |
+
tgt_lang (str): what language we evaluate on
|
268 |
+
task (str): Task we are currently evaluating for
|
269 |
+
device (Device): Device to use for inference
|
270 |
+
whisper_model_name (str): Whisper model name. Defaults to "large".
|
271 |
+
whisper_normalize_text_output (bool): Normalizes text output using whisper_normalizer if set to true
|
272 |
+
ref_text_col_name (str): Column name in the tsv corresponding to reference target text
|
273 |
+
pred_text_col_name (str): Column name in the tsv corresponding to predicted target text
|
274 |
+
pred_audio_col_name (str): Column name in the tsv corresponding to predicted target audio.
|
275 |
+
Setting this value to none will skip speech metrics
|
276 |
+
"""
|
277 |
+
df = pd.read_csv(
|
278 |
+
output_manifest_tsv_path, sep="\t", quoting=3, encoding="utf-8", escapechar="\\"
|
279 |
+
)
|
280 |
+
task = task.upper()
|
281 |
+
|
282 |
+
if not output_path.exists():
|
283 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
284 |
+
|
285 |
+
if task in ["S2TT", "S2ST", "T2TT"] and pred_text_col_name:
|
286 |
+
metric = "chrF++" if task == "T2TT" else "bleu"
|
287 |
+
text_metric, text_metric_signature = compute_corpus_metric_score(
|
288 |
+
hyp_text_series=df[pred_text_col_name],
|
289 |
+
ref_text_series=df[ref_text_col_name],
|
290 |
+
lang=tgt_lang,
|
291 |
+
whisper_normalize_text=whisper_normalize_text_output,
|
292 |
+
metric=metric,
|
293 |
+
)
|
294 |
+
text_metric_json = text_metric.format(
|
295 |
+
signature=text_metric_signature.format(), is_json=True
|
296 |
+
)
|
297 |
+
|
298 |
+
if task == "T2TT":
|
299 |
+
filename = "t2tt_chrf.json"
|
300 |
+
cur_task = "T2TT"
|
301 |
+
else:
|
302 |
+
filename = (
|
303 |
+
"s2tt_bleu_normalized.json"
|
304 |
+
if whisper_normalize_text_output
|
305 |
+
else "s2tt_bleu.json"
|
306 |
+
)
|
307 |
+
cur_task = "S2TT"
|
308 |
+
|
309 |
+
with open(output_path / filename, "w") as f:
|
310 |
+
f.write(text_metric_json)
|
311 |
+
|
312 |
+
logger.info(f"{cur_task} {metric}:\n{text_metric_json}")
|
313 |
+
|
314 |
+
if task in ["T2ST", "S2ST"]:
|
315 |
+
whisper_model = init_whisper_model(device, whisper_model_name)
|
316 |
+
(
|
317 |
+
asr_bleu_normalized,
|
318 |
+
asr_bleu_normalized_signature,
|
319 |
+
transcripts_df,
|
320 |
+
) = compute_asr_bleu(
|
321 |
+
audio_paths_series=df[pred_audio_col_name],
|
322 |
+
ref_text_series=df[ref_text_col_name],
|
323 |
+
lang=tgt_lang,
|
324 |
+
asr_model=whisper_model,
|
325 |
+
whisper_normalize_text=True,
|
326 |
+
)
|
327 |
+
transcripts_df.to_csv(
|
328 |
+
(output_path / "whisper_audio_transcriptions.tsv"),
|
329 |
+
sep="\t",
|
330 |
+
index=False,
|
331 |
+
encoding="utf-8",
|
332 |
+
escapechar="\\",
|
333 |
+
)
|
334 |
+
|
335 |
+
asr_bleu_normalized_signature.info["whisper_asr_model"] = whisper_model_name
|
336 |
+
|
337 |
+
asr_bleu_normalized_json = asr_bleu_normalized.format(
|
338 |
+
signature=asr_bleu_normalized_signature.format(), is_json=True
|
339 |
+
)
|
340 |
+
filename = f"{task.lower()}_asr_bleu_normalized.json"
|
341 |
+
|
342 |
+
with open(
|
343 |
+
output_path / filename,
|
344 |
+
"w",
|
345 |
+
) as f:
|
346 |
+
f.write(asr_bleu_normalized_json)
|
347 |
+
|
348 |
+
logger.info(f"{task} ASR Normalized BLEU:\n{asr_bleu_normalized_json}")
|
349 |
+
|
350 |
+
if task == "ASR":
|
351 |
+
asr_error_rate, asr_error_rate_signature = compute_asr_error_rate(
|
352 |
+
hyp_text_series=df[pred_text_col_name],
|
353 |
+
ref_text_series=df[ref_text_col_name],
|
354 |
+
lang=tgt_lang,
|
355 |
+
whisper_normalize_text=whisper_normalize_text_output,
|
356 |
+
)
|
357 |
+
d = {
|
358 |
+
"name": "WER",
|
359 |
+
"score": asr_error_rate,
|
360 |
+
"signature": asr_error_rate_signature,
|
361 |
+
}
|
362 |
+
asr_error_rate_json = json.dumps(d, indent=1, ensure_ascii=False)
|
363 |
+
|
364 |
+
filename = "asr_error_rate.json"
|
365 |
+
|
366 |
+
with open(output_path / filename, "w") as f:
|
367 |
+
f.write(asr_error_rate_json)
|
368 |
+
|
369 |
+
logger.info(f"ASR : {asr_error_rate_json}")
|
370 |
+
|
371 |
+
return filename
|