aspram / aspram /compute_wer.py
lilitket's picture
Move to package
cab7f7b
import weakref
import torch
import numpy as np
from tqdm import tqdm
from transformers import Wav2Vec2ProcessorWithLM, Wav2Vec2Processor
from transformers import AutoModelForCTC, Wav2Vec2Processor, Wav2Vec2ForCTC
from datasets import load_dataset, load_metric, Audio
import fire
from aspram.utils import clean_characters, prepare_dataset
# import sentencepiece as spm
# repo_name = "20220414-210228_lm"
# repo_name = "./20220414-210228_lm_spm_bpe"
def exec(
*,
repo_name: str,
dataset: str = "yerevann/common_voice_9_0",
cuda: bool = True,
batch_size: int = 8,
beam_width: int = 1,
j: int = 1,
sample_rate: int = 16_000,
alpha: float = None,
beta: float = None,
unk_score_offset: float = None,
lm_score_boundary: bool = None,
beam_prune_logp: float = None,
token_min_logp: float = None,
output_file : str = None,
):
# repo_name = "20220428-094209--72000_lm"
print(f'loading model {repo_name}')
model = Wav2Vec2ForCTC.from_pretrained(repo_name)
print('done')
if cuda:
print('CUDA mode')
model.cuda()
if repo_name.endswith('_lm'):
processor = Wav2Vec2ProcessorWithLM.from_pretrained(repo_name, sample_rate=sample_rate)
with_lm = True
else:
processor = Wav2Vec2Processor.from_pretrained(repo_name, sample_rate=sample_rate)
with_lm = False
common_voice_test = load_dataset(
dataset,
"hy-AM",
split="test",
use_auth_token=True,
)
common_voice_test = common_voice_test.map(clean_characters)
common_voice_test = common_voice_test.cast_column(
"audio", Audio(sampling_rate=sample_rate)
)
common_voice_test = common_voice_test.map(
prepare_dataset,
remove_columns=common_voice_test.column_names,
fn_kwargs=dict(processor=processor)
)
# wer_metric = load()...
# for batch in batched_dataset:
# input_dict = processer(batch)
# logits = model(input...)
# wer_metric.update(true, pred)
# wer_metric.compute
# def exec_cer_wer(batch_size: int = 8, **kwargs):
def predict(batch):
# print(1)
input_dict = processor(
batch["input_values"],
return_tensors="pt",
padding=True,
sampling_rate=sample_rate
)
# print(2)
with torch.no_grad():
x = input_dict.input_values
if cuda:
x = x.cuda()
logits = model(x).logits
# print(3)
if with_lm:
# print(beam_size)
# sp = spm.SentencePieceProcessor()
# sp.load('head_mes_lower_bpe.model')
pred = processor.batch_decode(
logits.cpu().numpy(),
beam_width=beam_width,
alpha=alpha,
beta=beta,
unk_score_offset=unk_score_offset,
lm_score_boundary=lm_score_boundary,
num_processes=j,
beam_prune_logp=beam_prune_logp, #-1000,
token_min_logp=token_min_logp,
# sp=sp,
).text
else:
pred = processor.batch_decode(
logits.cpu().numpy().argmax(-1),
)
# print(pred)
# print(pred)
return {
'sentence': pred
}
with_predictions = common_voice_test.map(predict, batched=True, batch_size=batch_size)
def detokenize(sample):
if '▁' in sample['sentence']:
print("------ ", sample)
sample['sentence'] = sample['sentence'].replace(' ', '').replace('▁', ' ')
print("------ ", sample)
return sample
with_predictions = with_predictions.map(detokenize)
common_voice_test_transcription = load_dataset(
dataset,
"hy-AM",
split="test",
use_auth_token=True,
)
with_predictions = with_predictions.map(clean_characters, fn_kwargs=dict(lower=True, only_mesropatar=True))
common_voice_test_transcription = common_voice_test_transcription.map(clean_characters, fn_kwargs=dict(lower=True, only_mesropatar=True))
predictions = with_predictions['sentence']
references = common_voice_test_transcription['sentence']
wer_metric = load_metric("wer")
cer_metric = load_metric("cer")
for ref, pred in zip(references, predictions):
print(f' REF:\t{ref}')
print(f'PRED:\t{pred}')
print('\n')
wer = wer_metric.compute(predictions=predictions, references=references)
cer = cer_metric.compute(predictions=predictions, references=references)
print("wer: ", wer)
print("cer: ", cer)
df = common_voice_test_transcription.to_pandas()['sentence']
df = df.to_frame()
df["predictions"] = with_predictions.to_pandas()['sentence']
# df.insert(2, "predictions", with_predictions['sentence'], True)
if output_file is not None:
df.to_csv(output_file)
# exec_cer_wer(beam_width=beam_width, batch_size=batch_size)
# for pruning_score in {-10, -100, -2000}:
# for alpha in {1, 0.5, 1.5}:
# for beta in {1, 0.5, 1.5}:
# for beam_size in {0, 2, 4, 6}:
# print("Configuration:")
# print("alpha {alpha} beta {beta}, beam_width {beam_size}, pruning_score {pruning_score}".format(alpha = alpha, beta = beta, beam_size = beam_size, pruning_score = pruning_score))
# exec_cer_wer(alpha, beta, 2**beam_size, pruning_score, batch_size=batch_size)
# print('\n\n')
if __name__ == "__main__":
fire.Fire(exec)