import weakref import torch import numpy as np from tqdm import tqdm from transformers import Wav2Vec2ProcessorWithLM, Wav2Vec2Processor from transformers import AutoModelForCTC, Wav2Vec2Processor, Wav2Vec2ForCTC from datasets import load_dataset, load_metric, Audio import fire from aspram.utils import clean_characters, prepare_dataset # import sentencepiece as spm # repo_name = "20220414-210228_lm" # repo_name = "./20220414-210228_lm_spm_bpe" def exec( *, repo_name: str, dataset: str = "yerevann/common_voice_9_0", cuda: bool = True, batch_size: int = 8, beam_width: int = 1, j: int = 1, sample_rate: int = 16_000, alpha: float = None, beta: float = None, unk_score_offset: float = None, lm_score_boundary: bool = None, beam_prune_logp: float = None, token_min_logp: float = None, output_file : str = None, ): # repo_name = "20220428-094209--72000_lm" print(f'loading model {repo_name}') model = Wav2Vec2ForCTC.from_pretrained(repo_name) print('done') if cuda: print('CUDA mode') model.cuda() if repo_name.endswith('_lm'): processor = Wav2Vec2ProcessorWithLM.from_pretrained(repo_name, sample_rate=sample_rate) with_lm = True else: processor = Wav2Vec2Processor.from_pretrained(repo_name, sample_rate=sample_rate) with_lm = False common_voice_test = load_dataset( dataset, "hy-AM", split="test", use_auth_token=True, ) common_voice_test = common_voice_test.map(clean_characters) common_voice_test = common_voice_test.cast_column( "audio", Audio(sampling_rate=sample_rate) ) common_voice_test = common_voice_test.map( prepare_dataset, remove_columns=common_voice_test.column_names, fn_kwargs=dict(processor=processor) ) # wer_metric = load()... # for batch in batched_dataset: # input_dict = processer(batch) # logits = model(input...) # wer_metric.update(true, pred) # wer_metric.compute # def exec_cer_wer(batch_size: int = 8, **kwargs): def predict(batch): # print(1) input_dict = processor( batch["input_values"], return_tensors="pt", padding=True, sampling_rate=sample_rate ) # print(2) with torch.no_grad(): x = input_dict.input_values if cuda: x = x.cuda() logits = model(x).logits # print(3) if with_lm: # print(beam_size) # sp = spm.SentencePieceProcessor() # sp.load('head_mes_lower_bpe.model') pred = processor.batch_decode( logits.cpu().numpy(), beam_width=beam_width, alpha=alpha, beta=beta, unk_score_offset=unk_score_offset, lm_score_boundary=lm_score_boundary, num_processes=j, beam_prune_logp=beam_prune_logp, #-1000, token_min_logp=token_min_logp, # sp=sp, ).text else: pred = processor.batch_decode( logits.cpu().numpy().argmax(-1), ) # print(pred) # print(pred) return { 'sentence': pred } with_predictions = common_voice_test.map(predict, batched=True, batch_size=batch_size) def detokenize(sample): if '▁' in sample['sentence']: print("------ ", sample) sample['sentence'] = sample['sentence'].replace(' ', '').replace('▁', ' ') print("------ ", sample) return sample with_predictions = with_predictions.map(detokenize) common_voice_test_transcription = load_dataset( dataset, "hy-AM", split="test", use_auth_token=True, ) with_predictions = with_predictions.map(clean_characters, fn_kwargs=dict(lower=True, only_mesropatar=True)) common_voice_test_transcription = common_voice_test_transcription.map(clean_characters, fn_kwargs=dict(lower=True, only_mesropatar=True)) predictions = with_predictions['sentence'] references = common_voice_test_transcription['sentence'] wer_metric = load_metric("wer") cer_metric = load_metric("cer") for ref, pred in zip(references, predictions): print(f' REF:\t{ref}') print(f'PRED:\t{pred}') print('\n') wer = wer_metric.compute(predictions=predictions, references=references) cer = cer_metric.compute(predictions=predictions, references=references) print("wer: ", wer) print("cer: ", cer) df = common_voice_test_transcription.to_pandas()['sentence'] df = df.to_frame() df["predictions"] = with_predictions.to_pandas()['sentence'] # df.insert(2, "predictions", with_predictions['sentence'], True) if output_file is not None: df.to_csv(output_file) # exec_cer_wer(beam_width=beam_width, batch_size=batch_size) # for pruning_score in {-10, -100, -2000}: # for alpha in {1, 0.5, 1.5}: # for beta in {1, 0.5, 1.5}: # for beam_size in {0, 2, 4, 6}: # print("Configuration:") # print("alpha {alpha} beta {beta}, beam_width {beam_size}, pruning_score {pruning_score}".format(alpha = alpha, beta = beta, beam_size = beam_size, pruning_score = pruning_score)) # exec_cer_wer(alpha, beta, 2**beam_size, pruning_score, batch_size=batch_size) # print('\n\n') if __name__ == "__main__": fire.Fire(exec)