|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import tempfile |
|
import typing as tp |
|
import torchaudio |
|
from tqdm import tqdm |
|
from seamless_communication.cli.eval_utils.compute_metrics import init_whisper_model |
|
from seamless_communication.cli.eval_utils.lang_mapping import LANG3_LANG2 |
|
from seamless_communication.inference.translator import Modality |
|
import torch |
|
|
|
from pathlib import Path |
|
from seamless_communication.inference import Translator |
|
from fairseq2.data import Collater, DataPipeline, FileMapper |
|
from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter |
|
from fairseq2.data.text import StrSplitter, read_text |
|
from fairseq2.typing import DataType, Device |
|
|
|
from seamless_communication.toxicity import load_etox_bad_word_checker |
|
|
|
from whisper.model import Whisper |
|
|
|
import logging |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s %(levelname)s -- %(name)s: %(message)s", |
|
) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def main() -> None: |
|
parser = argparse.ArgumentParser( |
|
description="ASR ETOX will compute the toxicity level of speech inputs." |
|
) |
|
parser.add_argument( |
|
"data_file", |
|
type=Path, |
|
help="Path to the input TSV manifest that list the audio files.", |
|
) |
|
parser.add_argument( |
|
"output_file", |
|
type=Path, |
|
help="Path to a TSV file where to save the results.", |
|
) |
|
parser.add_argument( |
|
"--lang", |
|
type=str, |
|
help="Language, language of the speech to transcribe", |
|
required=True, |
|
) |
|
parser.add_argument( |
|
"--audio_root_dir", |
|
type=str, |
|
help="Root directory for the audio filenames in the data file.", |
|
) |
|
parser.add_argument( |
|
"--audio_column", |
|
type=str, |
|
help="Name of the column where the audiofile is listed in the input tsv.", |
|
default="audio", |
|
) |
|
parser.add_argument( |
|
"--model_name", |
|
type=str, |
|
help=( |
|
"Base model name (`seamlessM4T_medium`, " |
|
"`seamlessM4T_large`, `seamlessM4T_v2_large`), " |
|
" or whisper model, e.g. 'whisper_large'" |
|
), |
|
default="seamlessM4T_v2_large", |
|
) |
|
parser.add_argument( |
|
"--batch_size", |
|
type=int, |
|
help="Inference batch size.", |
|
default=4, |
|
) |
|
parser.add_argument( |
|
"--n_parallel", |
|
type=int, |
|
help="Number of data loading in parallel.", |
|
default=4, |
|
) |
|
args, _unknown = parser.parse_known_args() |
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda:0") |
|
dtype = torch.float16 |
|
else: |
|
device = torch.device("cpu") |
|
dtype = torch.float32 |
|
|
|
whisper_model = None |
|
translator = None |
|
is_whisper = False |
|
|
|
if args.model_name.startswith("whisper_"): |
|
logger.info("loading whisper model.") |
|
_, model_name = args.model_name.split("_", maxsplit=1) |
|
whisper_model = init_whisper_model(device, model_name) |
|
is_whisper = True |
|
else: |
|
logger.info(f"loading {args.model_name} model.") |
|
translator = Translator( |
|
args.model_name, |
|
None, |
|
device, |
|
text_tokenizer=None, |
|
dtype=dtype, |
|
input_modality=Modality.SPEECH, |
|
output_modality=Modality.TEXT, |
|
apply_mintox=False, |
|
) |
|
|
|
logger.info("loading etox.") |
|
bad_word_checker = load_etox_bad_word_checker("mintox") |
|
|
|
pipeline = build_data_pipeline( |
|
data_file=args.data_file, |
|
audio_root_dir=args.audio_root_dir, |
|
batch_size=args.batch_size, |
|
is_whisper=is_whisper, |
|
device=device, |
|
dtype=dtype, |
|
n_parallel=args.n_parallel, |
|
audio_column=args.audio_column, |
|
) |
|
|
|
logger.info("running ASR-ETOX.") |
|
with open(args.output_file, "w", encoding="utf-8") as outf: |
|
print("text", "toxicity", "bad_words", file=outf, sep="\t") |
|
for example in tqdm(pipeline, unit="line"): |
|
texts = get_text( |
|
lang=args.lang, |
|
example=example, |
|
whisper_model=whisper_model, |
|
translator=translator, |
|
audio_column=args.audio_column, |
|
) |
|
for t in texts: |
|
bad_words = bad_word_checker.get_bad_words( |
|
text=str(t), |
|
lang=args.lang, |
|
) |
|
print( |
|
t, |
|
len(bad_words), |
|
",".join(bad_words), |
|
file=outf, |
|
sep="\t", |
|
) |
|
|
|
|
|
def get_text( |
|
lang: str, |
|
example: tp.Dict[str, tp.Any], |
|
whisper_model: Whisper, |
|
translator: Translator, |
|
audio_column: str, |
|
): |
|
if whisper_model: |
|
with tempfile.NamedTemporaryFile(suffix=".wav") as temp: |
|
torchaudio.save( |
|
temp.name, |
|
example[audio_column]["data"]["waveform"]["seqs"][0] |
|
.transpose(0, 1) |
|
.cpu(), |
|
int(example[audio_column]["data"]["sample_rate"][0]), |
|
format="wav", |
|
) |
|
results = whisper_model.transcribe( |
|
temp.name, |
|
language=LANG3_LANG2[lang], |
|
) |
|
return [results["text"]] |
|
else: |
|
(text_output, _speech_output) = translator.predict( |
|
example[audio_column]["data"]["fbank"], |
|
"ASR", |
|
lang, |
|
src_lang=lang, |
|
) |
|
return text_output |
|
|
|
|
|
def build_data_pipeline( |
|
data_file: Path, |
|
audio_root_dir: str, |
|
batch_size: int, |
|
is_whisper: bool, |
|
device: Device, |
|
dtype: DataType, |
|
audio_column: str = "audio", |
|
n_parallel: int = 4, |
|
) -> DataPipeline: |
|
with data_file.open("r", encoding="utf-8") as f: |
|
header = f.readline().strip("\n").split("\t") |
|
|
|
split_tsv = StrSplitter(names=header) |
|
|
|
pipeline_builder = read_text(data_file, rtrim=True).skip(1).map(split_tsv) |
|
|
|
map_file = FileMapper(root_dir=audio_root_dir, cached_fd_count=10) |
|
|
|
pipeline_builder.map( |
|
map_file, |
|
selector=audio_column, |
|
num_parallel_calls=n_parallel, |
|
) |
|
|
|
decode_audio = AudioDecoder(dtype=torch.float32, device=device) |
|
|
|
convert_to_fbank = WaveformToFbankConverter( |
|
num_mel_bins=80, |
|
waveform_scale=2**15, |
|
channel_last=True, |
|
standardize=True, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
|
|
|
|
steps = [decode_audio] |
|
if not is_whisper: |
|
|
|
steps.append(convert_to_fbank) |
|
|
|
pipeline_builder.map( |
|
steps, |
|
selector=f"{audio_column}.data", |
|
num_parallel_calls=n_parallel, |
|
) |
|
|
|
if is_whisper: |
|
|
|
pipeline_builder.bucket(bucket_size=batch_size) |
|
|
|
collate = Collater(pad_value=0, pad_to_multiple=1) |
|
|
|
pipeline_builder.map(collate, num_parallel_calls=n_parallel) |
|
|
|
pipeline_builder.prefetch(4) |
|
|
|
return pipeline_builder.and_return() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|