File size: 14,197 Bytes
87809bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 |
# Copyright (c) Meta Platforms, Inc. and affiliates
# All rights reserved.
#
# This source code is licensed under the license found in the
# MIT_LICENSE file in the root directory of this source tree.
import json
import logging
from pathlib import Path
from typing import Optional, Tuple, Union
import pandas as pd
import whisper
from fairseq2.typing import Device
from jiwer import cer, wer
from sacrebleu.metrics.base import Score, Signature
from sacrebleu.metrics.bleu import BLEU
from sacrebleu.metrics.chrf import CHRF
from seamless_communication.cli.eval_utils.lang_mapping import LANG3_LANG2
from tqdm import tqdm
from whisper import Whisper
from whisper.normalizers import BasicTextNormalizer, EnglishTextNormalizer
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
)
logger = logging.getLogger(__name__)
def init_whisper_model(
device: Device,
whisper_model_name: str = "large",
) -> Whisper:
return whisper.load_model(name=whisper_model_name, device=device)
def transcribe_series(
audio_paths_series: pd.Series,
asr_model: Whisper,
audio_lang: str,
beam_size: int = 1,
temperature: float = 0.0,
) -> pd.Series:
"""Transcribes each audio filepath from series and returns series of transcriptions
Args:
audio_paths_series (pd.Series): each line contains path to audio file.
asr_model: ASR model to do the transcribing process e.g. Whisper
audio_lang (str): what language is used in the given audio, used by ASR model
beam_size (int): whisper beam size. Defaults to 1
temperature (float): whisper temperature. Defaults to 0.0 to avoid fallback decoding (see details below).
Returns:
pd.Series: Series where each line has a transcription of corresponding audio from audio_paths_series
Whisper model implements decoding with fallback: https://github.com/openai/whisper/blob/main/whisper/transcribe.py#L147
The core idea is that decoding at each time step might happen multiple times if at least one criterion to "fall back" i.e.
start over is fired. Number of fallback iterations is determined by the schedule of temperature values:
https://github.com/openai/whisper/blob/main/whisper/transcribe.py#L41
By default this schedule is active and temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0) i.e. even with beam_size 5 it might fell back and
turn on sampling by using temperature > 0, in this case the beam search is not used in the fall back iteration.
Explicit setting of temperature=0.0 overwrites the schedule and fall back decoding has only one for loop iteration i.e. no fall backs.
This allows us to do reproducible evaluation without sample variations. Beware that this might introduce the repetition loops in
the transcriptions and lead to worse ASR-BLEU score in the end.
"""
if len(audio_lang) == 3:
# to make it work with whisper
audio_lang = LANG3_LANG2[audio_lang]
transcriptions = {}
for idx, audio_path in tqdm(
audio_paths_series.items(),
desc=f"Transcribing {audio_paths_series.name} column",
total=len(audio_paths_series),
):
hypo = asr_model.transcribe(
audio_path,
temperature=temperature,
beam_size=beam_size,
language=audio_lang,
)["text"].strip()
transcriptions[idx] = hypo
transcriptions_series = pd.Series(transcriptions)
transcriptions_series.name = f"{audio_paths_series.name}_transcribed"
return transcriptions_series
def whisper_normalize_series(
transcription_series: pd.Series, text_lang: str
) -> pd.Series:
"""Normalizes the text series using whisper noramlizer. English has a specific one in whisper package.
Args:
transcription_series (pd.Series): Each line contains arbitrary text written in text_lang
text_lang (str): Language of the text in series
Returns:
pd.Series: Series with normalized text
"""
if text_lang == "eng":
normalizer = EnglishTextNormalizer()
else:
normalizer = BasicTextNormalizer()
norm_transcriptions = {}
for idx, text in transcription_series.items():
norm_transcriptions[idx] = normalizer(text)
norm_transcriptions_series = pd.Series(norm_transcriptions)
norm_transcriptions_series.name = transcription_series.name
return norm_transcriptions_series
def compute_asr_bleu(
audio_paths_series: pd.Series,
ref_text_series: pd.Series,
lang: str,
asr_model: Whisper,
whisper_normalize_text: bool = True,
beam_size: int = 1,
temperature: float = 0.0,
return_transcriptions: bool = True,
) -> Tuple[Score, Signature, pd.DataFrame]:
"""Wraps functions above to compute corpus-level ASR-BLEU
ASR decoding hyper-parameters are hard coded to ensure reproducibility across evaluations
Args:
audio_paths_series (pd.Series): each line contains path to audio
ref_text_series (pd.Series): each line contains the text reference to compare audio with
lang (str): the language of both audio and ref_text
asr_model: whisper ASR model
whisper_normalize_text (bool): normalize both text hypotheses and reference if True. Defaults to True.
beam_size (int): beam_size for whisper generation
temperature (float): Temperature sampling value for whisper generation
return_transcriptions (bool)
"""
audio_transcriptions = transcribe_series(
audio_paths_series,
asr_model,
audio_lang=lang,
beam_size=beam_size,
temperature=temperature,
)
asr_bleu, asr_bleu_signature = compute_corpus_metric_score(
audio_transcriptions, ref_text_series, lang, whisper_normalize_text
)
asr_bleu_signature.info["whisper_asr_beam_size"] = beam_size
asr_bleu_signature.info["whisper_asr_temperature"] = temperature
asr_bleu_signature.info["whisper_asr_language"] = lang
transcript_df = None
if return_transcriptions:
transcript_df = pd.concat(
[
audio_paths_series,
audio_transcriptions,
ref_text_series,
],
axis=1,
keys=["audio", "transcript", "reference"],
)
return asr_bleu, asr_bleu_signature, transcript_df
def get_tokenizer(lang: str, metric: str = "bleu") -> str:
"""Get tokenizer for language
Args:
lang (str): Three letter code of the language
metric (str): Metric being computed. Valid values are "bleu" and "asr"
"""
lang_tok_map = {
"cmn": "char",
"jpn": "char",
"tha": "char",
"lao": "char",
"mya": "char",
}
default = (
"13a" if metric == "bleu" else "word"
) # 13a is the default tokenizer for bleu and wer for asr
tok = lang_tok_map.get(lang, default)
return tok
def compute_asr_error_rate(
hyp_text_series: pd.Series,
ref_text_series: pd.Series,
lang: str,
whisper_normalize_text: bool = True,
) -> Tuple[float, str]:
"""Wraps normalization functions and computes ASR WER/CER score
Args:
hyp_text_series (pd.Series): each line contains s2t model prediction or first pass prediction
ref_text_series (pd.Series): _description_
lang (str): _description_
whisper_normalize_text (bool, optional): normalize both text hypotheses and reference if True. Defaults to True.
Returns:
(MetricScore, MetricScoreSignature)
"""
if whisper_normalize_text:
hyp_text_series = whisper_normalize_series(hyp_text_series, lang)
ref_text_series = whisper_normalize_series(ref_text_series, lang)
tokenizer_name = get_tokenizer(lang, metric="error_rate")
metric_name = wer if tokenizer_name == "word" else cer
metric_score = metric_name(hyp_text_series.to_list(), ref_text_series.to_list())
return metric_score, f"{metric_name.__name__} is {metric_score}"
def compute_corpus_metric_score(
hyp_text_series: pd.Series,
ref_text_series: pd.Series,
lang: str,
whisper_normalize_text: bool = True,
metric: str = "bleu",
) -> Tuple[Score, Signature]:
"""Wraps normalization functions and compute corpus-level BLEU/chrF++ score
Args:
hyp_text_series (pd.Series): each line contains s2t model prediction or first pass prediction
ref_text_series (pd.Series): _description_
lang (str): _description_
whisper_normalize_text (bool, optional): normalize both text hypotheses and reference if True. Defaults to True.
Returns:
(MetricScore, MetricScoreSignature)
"""
if whisper_normalize_text:
hyp_text_series = whisper_normalize_series(hyp_text_series, lang)
ref_text_series = whisper_normalize_series(ref_text_series, lang)
tokenizer_name = get_tokenizer(lang)
corpus_metric_score_metric: Union[BLEU, CHRF]
if metric == "bleu":
corpus_metric_score_metric = BLEU(
lowercase=whisper_normalize_text, tokenize=tokenizer_name
) # lowercase applied if we use whisper_normalize_text
elif metric == "chrF++":
corpus_metric_score_metric = CHRF(word_order=2)
corpus_metric_score = corpus_metric_score_metric.corpus_score(
hyp_text_series.to_list(), [ref_text_series.to_list()]
)
corpus_metric_score_signature = corpus_metric_score_metric.get_signature()
corpus_metric_score_signature.info["whisper_normalize"] = whisper_normalize_text
return corpus_metric_score, corpus_metric_score_signature
def compute_quality_metrics(
output_manifest_tsv_path: Path,
output_path: Path,
tgt_lang: str,
task: str,
device: Device,
whisper_model_name: str = "large",
whisper_normalize_text_output: bool = False,
ref_text_col_name: str = "ref_tgt_text",
pred_text_col_name: Optional[str] = "pred_tgt_text",
pred_audio_col_name: str = "pred_tgt_audio",
) -> str:
"""Wraps asr and s2t bleu functions to call it with TSV manifest composed on expressivity side
Args:
output_manifest_tsv_path (Path): output manifest which has "ref_text", "hypo_audio", "s2t_out" column names
output_path (Path): Directory to write files with metrics
tgt_lang (str): what language we evaluate on
task (str): Task we are currently evaluating for
device (Device): Device to use for inference
whisper_model_name (str): Whisper model name. Defaults to "large".
whisper_normalize_text_output (bool): Normalizes text output using whisper_normalizer if set to true
ref_text_col_name (str): Column name in the tsv corresponding to reference target text
pred_text_col_name (str): Column name in the tsv corresponding to predicted target text
pred_audio_col_name (str): Column name in the tsv corresponding to predicted target audio.
Setting this value to none will skip speech metrics
"""
df = pd.read_csv(
output_manifest_tsv_path, sep="\t", quoting=3, encoding="utf-8", escapechar="\\"
)
task = task.upper()
if not output_path.exists():
output_path.mkdir(parents=True, exist_ok=True)
if task in ["S2TT", "S2ST", "T2TT"] and pred_text_col_name:
metric = "chrF++" if task == "T2TT" else "bleu"
text_metric, text_metric_signature = compute_corpus_metric_score(
hyp_text_series=df[pred_text_col_name],
ref_text_series=df[ref_text_col_name],
lang=tgt_lang,
whisper_normalize_text=whisper_normalize_text_output,
metric=metric,
)
text_metric_json = text_metric.format(
signature=text_metric_signature.format(), is_json=True
)
if task == "T2TT":
filename = "t2tt_chrf.json"
cur_task = "T2TT"
else:
filename = (
"s2tt_bleu_normalized.json"
if whisper_normalize_text_output
else "s2tt_bleu.json"
)
cur_task = "S2TT"
with open(output_path / filename, "w") as f:
f.write(text_metric_json)
logger.info(f"{cur_task} {metric}:\n{text_metric_json}")
if task in ["T2ST", "S2ST"]:
whisper_model = init_whisper_model(device, whisper_model_name)
(
asr_bleu_normalized,
asr_bleu_normalized_signature,
transcripts_df,
) = compute_asr_bleu(
audio_paths_series=df[pred_audio_col_name],
ref_text_series=df[ref_text_col_name],
lang=tgt_lang,
asr_model=whisper_model,
whisper_normalize_text=True,
)
transcripts_df.to_csv(
(output_path / "whisper_audio_transcriptions.tsv"),
sep="\t",
index=False,
encoding="utf-8",
escapechar="\\",
)
asr_bleu_normalized_signature.info["whisper_asr_model"] = whisper_model_name
asr_bleu_normalized_json = asr_bleu_normalized.format(
signature=asr_bleu_normalized_signature.format(), is_json=True
)
filename = f"{task.lower()}_asr_bleu_normalized.json"
with open(
output_path / filename,
"w",
) as f:
f.write(asr_bleu_normalized_json)
logger.info(f"{task} ASR Normalized BLEU:\n{asr_bleu_normalized_json}")
if task == "ASR":
asr_error_rate, asr_error_rate_signature = compute_asr_error_rate(
hyp_text_series=df[pred_text_col_name],
ref_text_series=df[ref_text_col_name],
lang=tgt_lang,
whisper_normalize_text=whisper_normalize_text_output,
)
d = {
"name": "WER",
"score": asr_error_rate,
"signature": asr_error_rate_signature,
}
asr_error_rate_json = json.dumps(d, indent=1, ensure_ascii=False)
filename = "asr_error_rate.json"
with open(output_path / filename, "w") as f:
f.write(asr_error_rate_json)
logger.info(f"ASR : {asr_error_rate_json}")
return filename
|