|
import os |
|
import numpy as np |
|
import re |
|
import argparse |
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["CUDA_VISIBLE_DEVICES"].replace("CUDA", "") |
|
|
|
from transformers import pipeline |
|
from transformers.models.whisper.english_normalizer import BasicTextNormalizer |
|
from datasets import load_dataset, Audio |
|
|
|
whisper_norm = BasicTextNormalizer() |
|
|
|
def simple_norm(utt): |
|
norm_utt = re.sub(r'[^\w\s]', '', utt) |
|
norm_utt = " ".join(norm_utt.split()) |
|
norm_utt = norm_utt.lower() |
|
return norm_utt |
|
|
|
def data(dataset): |
|
for i, item in enumerate(dataset): |
|
yield {**item["audio"], "reference": item["text"], "utt_id": item["id"]} |
|
|
|
def get_ckpt(path, ckpt_id): |
|
if ckpt_id != 0: |
|
model = os.path.join(path, "checkpoint-%i" % ckpt) |
|
else: |
|
dirs = [d for d in os.listdir(path) if d.startswith("checkpoint-")] |
|
ckpts = [int(d.split('-')[-1]) for d in dirs] |
|
last_ckpt = sorted(ckpts)[-1] |
|
model = os.path.join(path, "checkpoint-%s" % last_ckpt) |
|
return model |
|
|
|
def main(args): |
|
batch_size = args.batch_size |
|
|
|
if args.device == "cpu": |
|
device_id = -1 |
|
elif args.device == "gpu": |
|
device_id = 0 |
|
else: |
|
raise NotImplementedError("unknown device %s, should be cpu/gpu" % args.device) |
|
|
|
model_dir = os.path.join(args.expdir, args.model_size) |
|
|
|
|
|
|
|
model = model_dir |
|
|
|
|
|
whisper_asr = pipeline( |
|
"automatic-speech-recognition", model=model, device=device_id |
|
) |
|
|
|
whisper_asr.model.config.forced_decoder_ids = ( |
|
whisper_asr.tokenizer.get_decoder_prompt_ids( |
|
language=args.language, task="transcribe" |
|
) |
|
) |
|
|
|
if args.dataset == 'cgn-dev': |
|
dataset_path = "./cgn-dev/cgn-dev.py" |
|
elif args.dataset == 'subs-annot': |
|
dataset_path = "./subs-annot/subs-annot.py" |
|
else: |
|
raise NotImplementedError('unknown dataset %s' % args.dataset) |
|
|
|
cache_dir = "/esat/audioslave/jponcele/hf_cache" |
|
dataset = load_dataset(dataset_path, name="raw", split="test", cache_dir=cache_dir, streaming=True) |
|
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000)) |
|
|
|
utterances = [] |
|
predictions = [] |
|
references = [] |
|
|
|
|
|
for out in whisper_asr(data(dataset), batch_size=batch_size): |
|
predictions.append(out["text"]) |
|
utterances.append(out["utt_id"][0]) |
|
references.append(out["reference"][0]) |
|
|
|
|
|
result_dir = os.path.join(args.expdir, "results", args.dataset) |
|
os.makedirs(result_dir, exist_ok=True) |
|
|
|
with open(os.path.join(result_dir, "whisper_%s.txt" % args.model_size), "w") as pd: |
|
for i, utt in enumerate(utterances): |
|
pred = predictions[i] |
|
pd.write(utt + ' ' + pred + '\n') |
|
|
|
with open(os.path.join(result_dir, "whisper_%s_normW.txt" % args.model_size), "w") as pd: |
|
for i, utt in enumerate(utterances): |
|
pred = whisper_norm(predictions[i]) |
|
pd.write(utt + ' ' + pred + '\n') |
|
|
|
with open(os.path.join(result_dir, "whisper_%s_normS.txt" % args.model_size), "w") as pd: |
|
for i, utt in enumerate(utterances): |
|
pred = simple_norm(predictions[i]) |
|
pd.write(utt + ' ' + pred + '\n') |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument( |
|
"--expdir", |
|
type=str, |
|
default="/esat/audioslave/jponcele/whisper/finetuning_event/CGN", |
|
help="Directory with finetuned models", |
|
) |
|
parser.add_argument( |
|
"--model_size", |
|
type=str, |
|
default="tiny", |
|
help="Model size", |
|
) |
|
parser.add_argument( |
|
"--checkpoint", |
|
type=int, |
|
default=0, |
|
help="Load specific checkpoint. 0 means latest", |
|
) |
|
parser.add_argument( |
|
"--dataset", |
|
type=str, |
|
default="cgn-dev", |
|
help="Dataset name to evaluate the `model_id`. Should be loadable with 🤗 Datasets", |
|
) |
|
parser.add_argument( |
|
"--device", |
|
type=str, |
|
default="cpu", |
|
help="cpu/gpu", |
|
) |
|
parser.add_argument( |
|
"--batch_size", |
|
type=int, |
|
default=16, |
|
help="Number of samples to go through each streamed batch.", |
|
) |
|
parser.add_argument( |
|
"--language", |
|
type=str, |
|
default="dutch", |
|
help="Two letter language code for the transcription language, e.g. use 'en' for English.", |
|
) |
|
|
|
args = parser.parse_args() |
|
main(args) |
|
|
|
|