|
import weakref |
|
import torch |
|
import numpy as np |
|
from tqdm import tqdm |
|
|
|
from transformers import Wav2Vec2ProcessorWithLM, Wav2Vec2Processor |
|
from transformers import AutoModelForCTC, Wav2Vec2Processor, Wav2Vec2ForCTC |
|
|
|
from datasets import load_dataset, load_metric, Audio |
|
|
|
import fire |
|
|
|
from aspram.utils import clean_characters, prepare_dataset |
|
|
|
|
|
|
|
|
|
|
|
def exec( |
|
*, |
|
repo_name: str, |
|
dataset: str = "yerevann/common_voice_9_0", |
|
cuda: bool = True, |
|
batch_size: int = 8, |
|
beam_width: int = 1, |
|
j: int = 1, |
|
sample_rate: int = 16_000, |
|
alpha: float = None, |
|
beta: float = None, |
|
unk_score_offset: float = None, |
|
lm_score_boundary: bool = None, |
|
beam_prune_logp: float = None, |
|
token_min_logp: float = None, |
|
output_file : str = None, |
|
): |
|
|
|
|
|
|
|
print(f'loading model {repo_name}') |
|
model = Wav2Vec2ForCTC.from_pretrained(repo_name) |
|
print('done') |
|
if cuda: |
|
print('CUDA mode') |
|
model.cuda() |
|
|
|
if repo_name.endswith('_lm'): |
|
processor = Wav2Vec2ProcessorWithLM.from_pretrained(repo_name, sample_rate=sample_rate) |
|
with_lm = True |
|
else: |
|
processor = Wav2Vec2Processor.from_pretrained(repo_name, sample_rate=sample_rate) |
|
with_lm = False |
|
|
|
common_voice_test = load_dataset( |
|
dataset, |
|
"hy-AM", |
|
split="test", |
|
use_auth_token=True, |
|
) |
|
common_voice_test = common_voice_test.map(clean_characters) |
|
common_voice_test = common_voice_test.cast_column( |
|
"audio", Audio(sampling_rate=sample_rate) |
|
) |
|
common_voice_test = common_voice_test.map( |
|
prepare_dataset, |
|
remove_columns=common_voice_test.column_names, |
|
fn_kwargs=dict(processor=processor) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict(batch): |
|
|
|
input_dict = processor( |
|
batch["input_values"], |
|
return_tensors="pt", |
|
padding=True, |
|
sampling_rate=sample_rate |
|
) |
|
|
|
with torch.no_grad(): |
|
x = input_dict.input_values |
|
if cuda: |
|
x = x.cuda() |
|
logits = model(x).logits |
|
|
|
if with_lm: |
|
|
|
|
|
|
|
|
|
pred = processor.batch_decode( |
|
logits.cpu().numpy(), |
|
beam_width=beam_width, |
|
alpha=alpha, |
|
beta=beta, |
|
unk_score_offset=unk_score_offset, |
|
lm_score_boundary=lm_score_boundary, |
|
num_processes=j, |
|
beam_prune_logp=beam_prune_logp, |
|
token_min_logp=token_min_logp, |
|
|
|
).text |
|
else: |
|
pred = processor.batch_decode( |
|
logits.cpu().numpy().argmax(-1), |
|
) |
|
|
|
|
|
|
|
return { |
|
'sentence': pred |
|
} |
|
|
|
with_predictions = common_voice_test.map(predict, batched=True, batch_size=batch_size) |
|
|
|
def detokenize(sample): |
|
if '▁' in sample['sentence']: |
|
print("------ ", sample) |
|
sample['sentence'] = sample['sentence'].replace(' ', '').replace('▁', ' ') |
|
print("------ ", sample) |
|
return sample |
|
|
|
with_predictions = with_predictions.map(detokenize) |
|
|
|
common_voice_test_transcription = load_dataset( |
|
dataset, |
|
"hy-AM", |
|
split="test", |
|
use_auth_token=True, |
|
) |
|
|
|
with_predictions = with_predictions.map(clean_characters, fn_kwargs=dict(lower=True, only_mesropatar=True)) |
|
common_voice_test_transcription = common_voice_test_transcription.map(clean_characters, fn_kwargs=dict(lower=True, only_mesropatar=True)) |
|
|
|
predictions = with_predictions['sentence'] |
|
references = common_voice_test_transcription['sentence'] |
|
|
|
wer_metric = load_metric("wer") |
|
cer_metric = load_metric("cer") |
|
|
|
for ref, pred in zip(references, predictions): |
|
print(f' REF:\t{ref}') |
|
print(f'PRED:\t{pred}') |
|
print('\n') |
|
|
|
wer = wer_metric.compute(predictions=predictions, references=references) |
|
cer = cer_metric.compute(predictions=predictions, references=references) |
|
print("wer: ", wer) |
|
print("cer: ", cer) |
|
|
|
df = common_voice_test_transcription.to_pandas()['sentence'] |
|
df = df.to_frame() |
|
df["predictions"] = with_predictions.to_pandas()['sentence'] |
|
|
|
|
|
|
|
if output_file is not None: |
|
df.to_csv(output_file) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
fire.Fire(exec) |