|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import logging |
|
from pathlib import Path |
|
from typing import Optional, Tuple, Union |
|
|
|
import pandas as pd |
|
import whisper |
|
from fairseq2.typing import Device |
|
from jiwer import cer, wer |
|
from sacrebleu.metrics.base import Score, Signature |
|
from sacrebleu.metrics.bleu import BLEU |
|
from sacrebleu.metrics.chrf import CHRF |
|
from seamless_communication.cli.eval_utils.lang_mapping import LANG3_LANG2 |
|
from tqdm import tqdm |
|
from whisper import Whisper |
|
from whisper.normalizers import BasicTextNormalizer, EnglishTextNormalizer |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s %(levelname)s -- %(name)s: %(message)s", |
|
) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def init_whisper_model( |
|
device: Device, |
|
whisper_model_name: str = "large", |
|
) -> Whisper: |
|
return whisper.load_model(name=whisper_model_name, device=device) |
|
|
|
|
|
def transcribe_series( |
|
audio_paths_series: pd.Series, |
|
asr_model: Whisper, |
|
audio_lang: str, |
|
beam_size: int = 1, |
|
temperature: float = 0.0, |
|
) -> pd.Series: |
|
"""Transcribes each audio filepath from series and returns series of transcriptions |
|
Args: |
|
audio_paths_series (pd.Series): each line contains path to audio file. |
|
asr_model: ASR model to do the transcribing process e.g. Whisper |
|
audio_lang (str): what language is used in the given audio, used by ASR model |
|
beam_size (int): whisper beam size. Defaults to 1 |
|
temperature (float): whisper temperature. Defaults to 0.0 to avoid fallback decoding (see details below). |
|
Returns: |
|
pd.Series: Series where each line has a transcription of corresponding audio from audio_paths_series |
|
Whisper model implements decoding with fallback: https://github.com/openai/whisper/blob/main/whisper/transcribe.py#L147 |
|
The core idea is that decoding at each time step might happen multiple times if at least one criterion to "fall back" i.e. |
|
start over is fired. Number of fallback iterations is determined by the schedule of temperature values: |
|
https://github.com/openai/whisper/blob/main/whisper/transcribe.py#L41 |
|
By default this schedule is active and temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0) i.e. even with beam_size 5 it might fell back and |
|
turn on sampling by using temperature > 0, in this case the beam search is not used in the fall back iteration. |
|
Explicit setting of temperature=0.0 overwrites the schedule and fall back decoding has only one for loop iteration i.e. no fall backs. |
|
This allows us to do reproducible evaluation without sample variations. Beware that this might introduce the repetition loops in |
|
the transcriptions and lead to worse ASR-BLEU score in the end. |
|
""" |
|
|
|
if len(audio_lang) == 3: |
|
|
|
audio_lang = LANG3_LANG2[audio_lang] |
|
|
|
transcriptions = {} |
|
|
|
for idx, audio_path in tqdm( |
|
audio_paths_series.items(), |
|
desc=f"Transcribing {audio_paths_series.name} column", |
|
total=len(audio_paths_series), |
|
): |
|
hypo = asr_model.transcribe( |
|
audio_path, |
|
temperature=temperature, |
|
beam_size=beam_size, |
|
language=audio_lang, |
|
)["text"].strip() |
|
transcriptions[idx] = hypo |
|
|
|
transcriptions_series = pd.Series(transcriptions) |
|
transcriptions_series.name = f"{audio_paths_series.name}_transcribed" |
|
|
|
return transcriptions_series |
|
|
|
|
|
def whisper_normalize_series( |
|
transcription_series: pd.Series, text_lang: str |
|
) -> pd.Series: |
|
"""Normalizes the text series using whisper noramlizer. English has a specific one in whisper package. |
|
Args: |
|
transcription_series (pd.Series): Each line contains arbitrary text written in text_lang |
|
text_lang (str): Language of the text in series |
|
Returns: |
|
pd.Series: Series with normalized text |
|
""" |
|
if text_lang == "eng": |
|
normalizer = EnglishTextNormalizer() |
|
else: |
|
normalizer = BasicTextNormalizer() |
|
|
|
norm_transcriptions = {} |
|
|
|
for idx, text in transcription_series.items(): |
|
norm_transcriptions[idx] = normalizer(text) |
|
|
|
norm_transcriptions_series = pd.Series(norm_transcriptions) |
|
norm_transcriptions_series.name = transcription_series.name |
|
|
|
return norm_transcriptions_series |
|
|
|
|
|
def compute_asr_bleu( |
|
audio_paths_series: pd.Series, |
|
ref_text_series: pd.Series, |
|
lang: str, |
|
asr_model: Whisper, |
|
whisper_normalize_text: bool = True, |
|
beam_size: int = 1, |
|
temperature: float = 0.0, |
|
return_transcriptions: bool = True, |
|
) -> Tuple[Score, Signature, pd.DataFrame]: |
|
"""Wraps functions above to compute corpus-level ASR-BLEU |
|
ASR decoding hyper-parameters are hard coded to ensure reproducibility across evaluations |
|
Args: |
|
audio_paths_series (pd.Series): each line contains path to audio |
|
ref_text_series (pd.Series): each line contains the text reference to compare audio with |
|
lang (str): the language of both audio and ref_text |
|
asr_model: whisper ASR model |
|
whisper_normalize_text (bool): normalize both text hypotheses and reference if True. Defaults to True. |
|
beam_size (int): beam_size for whisper generation |
|
temperature (float): Temperature sampling value for whisper generation |
|
return_transcriptions (bool) |
|
""" |
|
|
|
audio_transcriptions = transcribe_series( |
|
audio_paths_series, |
|
asr_model, |
|
audio_lang=lang, |
|
beam_size=beam_size, |
|
temperature=temperature, |
|
) |
|
asr_bleu, asr_bleu_signature = compute_corpus_metric_score( |
|
audio_transcriptions, ref_text_series, lang, whisper_normalize_text |
|
) |
|
asr_bleu_signature.info["whisper_asr_beam_size"] = beam_size |
|
asr_bleu_signature.info["whisper_asr_temperature"] = temperature |
|
asr_bleu_signature.info["whisper_asr_language"] = lang |
|
|
|
transcript_df = None |
|
if return_transcriptions: |
|
transcript_df = pd.concat( |
|
[ |
|
audio_paths_series, |
|
audio_transcriptions, |
|
ref_text_series, |
|
], |
|
axis=1, |
|
keys=["audio", "transcript", "reference"], |
|
) |
|
return asr_bleu, asr_bleu_signature, transcript_df |
|
|
|
|
|
def get_tokenizer(lang: str, metric: str = "bleu") -> str: |
|
"""Get tokenizer for language |
|
Args: |
|
lang (str): Three letter code of the language |
|
metric (str): Metric being computed. Valid values are "bleu" and "asr" |
|
""" |
|
lang_tok_map = { |
|
"cmn": "char", |
|
"jpn": "char", |
|
"tha": "char", |
|
"lao": "char", |
|
"mya": "char", |
|
} |
|
default = ( |
|
"13a" if metric == "bleu" else "word" |
|
) |
|
tok = lang_tok_map.get(lang, default) |
|
return tok |
|
|
|
|
|
def compute_asr_error_rate( |
|
hyp_text_series: pd.Series, |
|
ref_text_series: pd.Series, |
|
lang: str, |
|
whisper_normalize_text: bool = True, |
|
) -> Tuple[float, str]: |
|
"""Wraps normalization functions and computes ASR WER/CER score |
|
Args: |
|
hyp_text_series (pd.Series): each line contains s2t model prediction or first pass prediction |
|
ref_text_series (pd.Series): _description_ |
|
lang (str): _description_ |
|
whisper_normalize_text (bool, optional): normalize both text hypotheses and reference if True. Defaults to True. |
|
Returns: |
|
(MetricScore, MetricScoreSignature) |
|
""" |
|
if whisper_normalize_text: |
|
hyp_text_series = whisper_normalize_series(hyp_text_series, lang) |
|
ref_text_series = whisper_normalize_series(ref_text_series, lang) |
|
|
|
tokenizer_name = get_tokenizer(lang, metric="error_rate") |
|
metric_name = wer if tokenizer_name == "word" else cer |
|
metric_score = metric_name(hyp_text_series.to_list(), ref_text_series.to_list()) |
|
return metric_score, f"{metric_name.__name__} is {metric_score}" |
|
|
|
|
|
def compute_corpus_metric_score( |
|
hyp_text_series: pd.Series, |
|
ref_text_series: pd.Series, |
|
lang: str, |
|
whisper_normalize_text: bool = True, |
|
metric: str = "bleu", |
|
) -> Tuple[Score, Signature]: |
|
"""Wraps normalization functions and compute corpus-level BLEU/chrF++ score |
|
Args: |
|
hyp_text_series (pd.Series): each line contains s2t model prediction or first pass prediction |
|
ref_text_series (pd.Series): _description_ |
|
lang (str): _description_ |
|
whisper_normalize_text (bool, optional): normalize both text hypotheses and reference if True. Defaults to True. |
|
Returns: |
|
(MetricScore, MetricScoreSignature) |
|
""" |
|
if whisper_normalize_text: |
|
hyp_text_series = whisper_normalize_series(hyp_text_series, lang) |
|
ref_text_series = whisper_normalize_series(ref_text_series, lang) |
|
|
|
tokenizer_name = get_tokenizer(lang) |
|
corpus_metric_score_metric: Union[BLEU, CHRF] |
|
if metric == "bleu": |
|
corpus_metric_score_metric = BLEU( |
|
lowercase=whisper_normalize_text, tokenize=tokenizer_name |
|
) |
|
elif metric == "chrF++": |
|
corpus_metric_score_metric = CHRF(word_order=2) |
|
|
|
corpus_metric_score = corpus_metric_score_metric.corpus_score( |
|
hyp_text_series.to_list(), [ref_text_series.to_list()] |
|
) |
|
corpus_metric_score_signature = corpus_metric_score_metric.get_signature() |
|
corpus_metric_score_signature.info["whisper_normalize"] = whisper_normalize_text |
|
|
|
return corpus_metric_score, corpus_metric_score_signature |
|
|
|
|
|
def compute_quality_metrics( |
|
output_manifest_tsv_path: Path, |
|
output_path: Path, |
|
tgt_lang: str, |
|
task: str, |
|
device: Device, |
|
whisper_model_name: str = "large", |
|
whisper_normalize_text_output: bool = False, |
|
ref_text_col_name: str = "ref_tgt_text", |
|
pred_text_col_name: Optional[str] = "pred_tgt_text", |
|
pred_audio_col_name: str = "pred_tgt_audio", |
|
) -> str: |
|
"""Wraps asr and s2t bleu functions to call it with TSV manifest composed on expressivity side |
|
Args: |
|
output_manifest_tsv_path (Path): output manifest which has "ref_text", "hypo_audio", "s2t_out" column names |
|
output_path (Path): Directory to write files with metrics |
|
tgt_lang (str): what language we evaluate on |
|
task (str): Task we are currently evaluating for |
|
device (Device): Device to use for inference |
|
whisper_model_name (str): Whisper model name. Defaults to "large". |
|
whisper_normalize_text_output (bool): Normalizes text output using whisper_normalizer if set to true |
|
ref_text_col_name (str): Column name in the tsv corresponding to reference target text |
|
pred_text_col_name (str): Column name in the tsv corresponding to predicted target text |
|
pred_audio_col_name (str): Column name in the tsv corresponding to predicted target audio. |
|
Setting this value to none will skip speech metrics |
|
""" |
|
df = pd.read_csv( |
|
output_manifest_tsv_path, sep="\t", quoting=3, encoding="utf-8", escapechar="\\" |
|
) |
|
task = task.upper() |
|
|
|
if not output_path.exists(): |
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
if task in ["S2TT", "S2ST", "T2TT"] and pred_text_col_name: |
|
metric = "chrF++" if task == "T2TT" else "bleu" |
|
text_metric, text_metric_signature = compute_corpus_metric_score( |
|
hyp_text_series=df[pred_text_col_name], |
|
ref_text_series=df[ref_text_col_name], |
|
lang=tgt_lang, |
|
whisper_normalize_text=whisper_normalize_text_output, |
|
metric=metric, |
|
) |
|
text_metric_json = text_metric.format( |
|
signature=text_metric_signature.format(), is_json=True |
|
) |
|
|
|
if task == "T2TT": |
|
filename = "t2tt_chrf.json" |
|
cur_task = "T2TT" |
|
else: |
|
filename = ( |
|
"s2tt_bleu_normalized.json" |
|
if whisper_normalize_text_output |
|
else "s2tt_bleu.json" |
|
) |
|
cur_task = "S2TT" |
|
|
|
with open(output_path / filename, "w") as f: |
|
f.write(text_metric_json) |
|
|
|
logger.info(f"{cur_task} {metric}:\n{text_metric_json}") |
|
|
|
if task in ["T2ST", "S2ST"]: |
|
whisper_model = init_whisper_model(device, whisper_model_name) |
|
( |
|
asr_bleu_normalized, |
|
asr_bleu_normalized_signature, |
|
transcripts_df, |
|
) = compute_asr_bleu( |
|
audio_paths_series=df[pred_audio_col_name], |
|
ref_text_series=df[ref_text_col_name], |
|
lang=tgt_lang, |
|
asr_model=whisper_model, |
|
whisper_normalize_text=True, |
|
) |
|
transcripts_df.to_csv( |
|
(output_path / "whisper_audio_transcriptions.tsv"), |
|
sep="\t", |
|
index=False, |
|
encoding="utf-8", |
|
escapechar="\\", |
|
) |
|
|
|
asr_bleu_normalized_signature.info["whisper_asr_model"] = whisper_model_name |
|
|
|
asr_bleu_normalized_json = asr_bleu_normalized.format( |
|
signature=asr_bleu_normalized_signature.format(), is_json=True |
|
) |
|
filename = f"{task.lower()}_asr_bleu_normalized.json" |
|
|
|
with open( |
|
output_path / filename, |
|
"w", |
|
) as f: |
|
f.write(asr_bleu_normalized_json) |
|
|
|
logger.info(f"{task} ASR Normalized BLEU:\n{asr_bleu_normalized_json}") |
|
|
|
if task == "ASR": |
|
asr_error_rate, asr_error_rate_signature = compute_asr_error_rate( |
|
hyp_text_series=df[pred_text_col_name], |
|
ref_text_series=df[ref_text_col_name], |
|
lang=tgt_lang, |
|
whisper_normalize_text=whisper_normalize_text_output, |
|
) |
|
d = { |
|
"name": "WER", |
|
"score": asr_error_rate, |
|
"signature": asr_error_rate_signature, |
|
} |
|
asr_error_rate_json = json.dumps(d, indent=1, ensure_ascii=False) |
|
|
|
filename = "asr_error_rate.json" |
|
|
|
with open(output_path / filename, "w") as f: |
|
f.write(asr_error_rate_json) |
|
|
|
logger.info(f"ASR : {asr_error_rate_json}") |
|
|
|
return filename |
|
|