victan's picture
Upload seamless_communication/cli/eval_utils/compute_metrics.py with huggingface_hub
87809bd
raw
history blame contribute delete
No virus
14.2 kB
# Copyright (c) Meta Platforms, Inc. and affiliates
# All rights reserved.
#
# This source code is licensed under the license found in the
# MIT_LICENSE file in the root directory of this source tree.
import json
import logging
from pathlib import Path
from typing import Optional, Tuple, Union
import pandas as pd
import whisper
from fairseq2.typing import Device
from jiwer import cer, wer
from sacrebleu.metrics.base import Score, Signature
from sacrebleu.metrics.bleu import BLEU
from sacrebleu.metrics.chrf import CHRF
from seamless_communication.cli.eval_utils.lang_mapping import LANG3_LANG2
from tqdm import tqdm
from whisper import Whisper
from whisper.normalizers import BasicTextNormalizer, EnglishTextNormalizer
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
)
logger = logging.getLogger(__name__)
def init_whisper_model(
device: Device,
whisper_model_name: str = "large",
) -> Whisper:
return whisper.load_model(name=whisper_model_name, device=device)
def transcribe_series(
audio_paths_series: pd.Series,
asr_model: Whisper,
audio_lang: str,
beam_size: int = 1,
temperature: float = 0.0,
) -> pd.Series:
"""Transcribes each audio filepath from series and returns series of transcriptions
Args:
audio_paths_series (pd.Series): each line contains path to audio file.
asr_model: ASR model to do the transcribing process e.g. Whisper
audio_lang (str): what language is used in the given audio, used by ASR model
beam_size (int): whisper beam size. Defaults to 1
temperature (float): whisper temperature. Defaults to 0.0 to avoid fallback decoding (see details below).
Returns:
pd.Series: Series where each line has a transcription of corresponding audio from audio_paths_series
Whisper model implements decoding with fallback: https://github.com/openai/whisper/blob/main/whisper/transcribe.py#L147
The core idea is that decoding at each time step might happen multiple times if at least one criterion to "fall back" i.e.
start over is fired. Number of fallback iterations is determined by the schedule of temperature values:
https://github.com/openai/whisper/blob/main/whisper/transcribe.py#L41
By default this schedule is active and temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0) i.e. even with beam_size 5 it might fell back and
turn on sampling by using temperature > 0, in this case the beam search is not used in the fall back iteration.
Explicit setting of temperature=0.0 overwrites the schedule and fall back decoding has only one for loop iteration i.e. no fall backs.
This allows us to do reproducible evaluation without sample variations. Beware that this might introduce the repetition loops in
the transcriptions and lead to worse ASR-BLEU score in the end.
"""
if len(audio_lang) == 3:
# to make it work with whisper
audio_lang = LANG3_LANG2[audio_lang]
transcriptions = {}
for idx, audio_path in tqdm(
audio_paths_series.items(),
desc=f"Transcribing {audio_paths_series.name} column",
total=len(audio_paths_series),
):
hypo = asr_model.transcribe(
audio_path,
temperature=temperature,
beam_size=beam_size,
language=audio_lang,
)["text"].strip()
transcriptions[idx] = hypo
transcriptions_series = pd.Series(transcriptions)
transcriptions_series.name = f"{audio_paths_series.name}_transcribed"
return transcriptions_series
def whisper_normalize_series(
transcription_series: pd.Series, text_lang: str
) -> pd.Series:
"""Normalizes the text series using whisper noramlizer. English has a specific one in whisper package.
Args:
transcription_series (pd.Series): Each line contains arbitrary text written in text_lang
text_lang (str): Language of the text in series
Returns:
pd.Series: Series with normalized text
"""
if text_lang == "eng":
normalizer = EnglishTextNormalizer()
else:
normalizer = BasicTextNormalizer()
norm_transcriptions = {}
for idx, text in transcription_series.items():
norm_transcriptions[idx] = normalizer(text)
norm_transcriptions_series = pd.Series(norm_transcriptions)
norm_transcriptions_series.name = transcription_series.name
return norm_transcriptions_series
def compute_asr_bleu(
audio_paths_series: pd.Series,
ref_text_series: pd.Series,
lang: str,
asr_model: Whisper,
whisper_normalize_text: bool = True,
beam_size: int = 1,
temperature: float = 0.0,
return_transcriptions: bool = True,
) -> Tuple[Score, Signature, pd.DataFrame]:
"""Wraps functions above to compute corpus-level ASR-BLEU
ASR decoding hyper-parameters are hard coded to ensure reproducibility across evaluations
Args:
audio_paths_series (pd.Series): each line contains path to audio
ref_text_series (pd.Series): each line contains the text reference to compare audio with
lang (str): the language of both audio and ref_text
asr_model: whisper ASR model
whisper_normalize_text (bool): normalize both text hypotheses and reference if True. Defaults to True.
beam_size (int): beam_size for whisper generation
temperature (float): Temperature sampling value for whisper generation
return_transcriptions (bool)
"""
audio_transcriptions = transcribe_series(
audio_paths_series,
asr_model,
audio_lang=lang,
beam_size=beam_size,
temperature=temperature,
)
asr_bleu, asr_bleu_signature = compute_corpus_metric_score(
audio_transcriptions, ref_text_series, lang, whisper_normalize_text
)
asr_bleu_signature.info["whisper_asr_beam_size"] = beam_size
asr_bleu_signature.info["whisper_asr_temperature"] = temperature
asr_bleu_signature.info["whisper_asr_language"] = lang
transcript_df = None
if return_transcriptions:
transcript_df = pd.concat(
[
audio_paths_series,
audio_transcriptions,
ref_text_series,
],
axis=1,
keys=["audio", "transcript", "reference"],
)
return asr_bleu, asr_bleu_signature, transcript_df
def get_tokenizer(lang: str, metric: str = "bleu") -> str:
"""Get tokenizer for language
Args:
lang (str): Three letter code of the language
metric (str): Metric being computed. Valid values are "bleu" and "asr"
"""
lang_tok_map = {
"cmn": "char",
"jpn": "char",
"tha": "char",
"lao": "char",
"mya": "char",
}
default = (
"13a" if metric == "bleu" else "word"
) # 13a is the default tokenizer for bleu and wer for asr
tok = lang_tok_map.get(lang, default)
return tok
def compute_asr_error_rate(
hyp_text_series: pd.Series,
ref_text_series: pd.Series,
lang: str,
whisper_normalize_text: bool = True,
) -> Tuple[float, str]:
"""Wraps normalization functions and computes ASR WER/CER score
Args:
hyp_text_series (pd.Series): each line contains s2t model prediction or first pass prediction
ref_text_series (pd.Series): _description_
lang (str): _description_
whisper_normalize_text (bool, optional): normalize both text hypotheses and reference if True. Defaults to True.
Returns:
(MetricScore, MetricScoreSignature)
"""
if whisper_normalize_text:
hyp_text_series = whisper_normalize_series(hyp_text_series, lang)
ref_text_series = whisper_normalize_series(ref_text_series, lang)
tokenizer_name = get_tokenizer(lang, metric="error_rate")
metric_name = wer if tokenizer_name == "word" else cer
metric_score = metric_name(hyp_text_series.to_list(), ref_text_series.to_list())
return metric_score, f"{metric_name.__name__} is {metric_score}"
def compute_corpus_metric_score(
hyp_text_series: pd.Series,
ref_text_series: pd.Series,
lang: str,
whisper_normalize_text: bool = True,
metric: str = "bleu",
) -> Tuple[Score, Signature]:
"""Wraps normalization functions and compute corpus-level BLEU/chrF++ score
Args:
hyp_text_series (pd.Series): each line contains s2t model prediction or first pass prediction
ref_text_series (pd.Series): _description_
lang (str): _description_
whisper_normalize_text (bool, optional): normalize both text hypotheses and reference if True. Defaults to True.
Returns:
(MetricScore, MetricScoreSignature)
"""
if whisper_normalize_text:
hyp_text_series = whisper_normalize_series(hyp_text_series, lang)
ref_text_series = whisper_normalize_series(ref_text_series, lang)
tokenizer_name = get_tokenizer(lang)
corpus_metric_score_metric: Union[BLEU, CHRF]
if metric == "bleu":
corpus_metric_score_metric = BLEU(
lowercase=whisper_normalize_text, tokenize=tokenizer_name
) # lowercase applied if we use whisper_normalize_text
elif metric == "chrF++":
corpus_metric_score_metric = CHRF(word_order=2)
corpus_metric_score = corpus_metric_score_metric.corpus_score(
hyp_text_series.to_list(), [ref_text_series.to_list()]
)
corpus_metric_score_signature = corpus_metric_score_metric.get_signature()
corpus_metric_score_signature.info["whisper_normalize"] = whisper_normalize_text
return corpus_metric_score, corpus_metric_score_signature
def compute_quality_metrics(
output_manifest_tsv_path: Path,
output_path: Path,
tgt_lang: str,
task: str,
device: Device,
whisper_model_name: str = "large",
whisper_normalize_text_output: bool = False,
ref_text_col_name: str = "ref_tgt_text",
pred_text_col_name: Optional[str] = "pred_tgt_text",
pred_audio_col_name: str = "pred_tgt_audio",
) -> str:
"""Wraps asr and s2t bleu functions to call it with TSV manifest composed on expressivity side
Args:
output_manifest_tsv_path (Path): output manifest which has "ref_text", "hypo_audio", "s2t_out" column names
output_path (Path): Directory to write files with metrics
tgt_lang (str): what language we evaluate on
task (str): Task we are currently evaluating for
device (Device): Device to use for inference
whisper_model_name (str): Whisper model name. Defaults to "large".
whisper_normalize_text_output (bool): Normalizes text output using whisper_normalizer if set to true
ref_text_col_name (str): Column name in the tsv corresponding to reference target text
pred_text_col_name (str): Column name in the tsv corresponding to predicted target text
pred_audio_col_name (str): Column name in the tsv corresponding to predicted target audio.
Setting this value to none will skip speech metrics
"""
df = pd.read_csv(
output_manifest_tsv_path, sep="\t", quoting=3, encoding="utf-8", escapechar="\\"
)
task = task.upper()
if not output_path.exists():
output_path.mkdir(parents=True, exist_ok=True)
if task in ["S2TT", "S2ST", "T2TT"] and pred_text_col_name:
metric = "chrF++" if task == "T2TT" else "bleu"
text_metric, text_metric_signature = compute_corpus_metric_score(
hyp_text_series=df[pred_text_col_name],
ref_text_series=df[ref_text_col_name],
lang=tgt_lang,
whisper_normalize_text=whisper_normalize_text_output,
metric=metric,
)
text_metric_json = text_metric.format(
signature=text_metric_signature.format(), is_json=True
)
if task == "T2TT":
filename = "t2tt_chrf.json"
cur_task = "T2TT"
else:
filename = (
"s2tt_bleu_normalized.json"
if whisper_normalize_text_output
else "s2tt_bleu.json"
)
cur_task = "S2TT"
with open(output_path / filename, "w") as f:
f.write(text_metric_json)
logger.info(f"{cur_task} {metric}:\n{text_metric_json}")
if task in ["T2ST", "S2ST"]:
whisper_model = init_whisper_model(device, whisper_model_name)
(
asr_bleu_normalized,
asr_bleu_normalized_signature,
transcripts_df,
) = compute_asr_bleu(
audio_paths_series=df[pred_audio_col_name],
ref_text_series=df[ref_text_col_name],
lang=tgt_lang,
asr_model=whisper_model,
whisper_normalize_text=True,
)
transcripts_df.to_csv(
(output_path / "whisper_audio_transcriptions.tsv"),
sep="\t",
index=False,
encoding="utf-8",
escapechar="\\",
)
asr_bleu_normalized_signature.info["whisper_asr_model"] = whisper_model_name
asr_bleu_normalized_json = asr_bleu_normalized.format(
signature=asr_bleu_normalized_signature.format(), is_json=True
)
filename = f"{task.lower()}_asr_bleu_normalized.json"
with open(
output_path / filename,
"w",
) as f:
f.write(asr_bleu_normalized_json)
logger.info(f"{task} ASR Normalized BLEU:\n{asr_bleu_normalized_json}")
if task == "ASR":
asr_error_rate, asr_error_rate_signature = compute_asr_error_rate(
hyp_text_series=df[pred_text_col_name],
ref_text_series=df[ref_text_col_name],
lang=tgt_lang,
whisper_normalize_text=whisper_normalize_text_output,
)
d = {
"name": "WER",
"score": asr_error_rate,
"signature": asr_error_rate_signature,
}
asr_error_rate_json = json.dumps(d, indent=1, ensure_ascii=False)
filename = "asr_error_rate.json"
with open(output_path / filename, "w") as f:
f.write(asr_error_rate_json)
logger.info(f"ASR : {asr_error_rate_json}")
return filename