from transformers import Wav2Vec2ForCTC
from transformers import Wav2Vec2Processor
from datasets import load_dataset, load_metric
import re
import torchaudio
import librosa
import numpy as np
from datasets import load_dataset, load_metric
import torch Please refer to https://github.com/pytorch/audio/issues/903 for the detail.\n", " warnings.warn(\n" ] } ], "source": [ "from transformers import Wav2Vec2ForCTC\n", "from transformers import Wav2Vec2Processor\n", "from datasets import load_dataset, load_metric\n", "import re\n", "import torchaudio\n", "import librosa\n", "import numpy as np\n", "from datasets import load_dataset, load_metric\n", "import torch" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2021-03-17T11:10:29.608803Z", "start_time": "2021-03-17T11:10:29.599700Z" } }, "outputs": [], "source": [ "chars_to_ignore_regex = '[\\,\\?\\.\\!\\-\\;\\:\\\"\\“\\%\\‘\\”\\�]'\n", "\n", "def remove_special_characters(batch):\n", " batch[\"text\"] = re.sub(chars_to_ignore_regex, '', batch[\"sentence\"]).lower() + \" \"\n", " return batch\n", "\n", "def speech_file_to_array_fn(batch):\n", " speech_array, sampling_rate = torchaudio.load(batch[\"path\"])\n", " batch[\"speech\"] = speech_array[0].numpy()\n", " batch[\"sampling_rate\"] = sampling_rate\n", " batch[\"target_text\"] = batch[\"text\"]\n", " return batch\n", "\n", "def resample(batch):\n", " batch[\"speech\"] = librosa.resample(np.asarray(batch[\"speech\"]), 48_000, 16_000)\n", " batch[\"sampling_rate\"] = 16_000\n", " return batch\n", "\n", "def prepare_dataset(batch):\n", " # check that all files have the correct sampling rate\n", " assert (\n", " len(set(batch[\"sampling_rate\"])) == 1\n", " ), f\"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}.\"\n", "\n", " batch[\"input_values\"] = processor(batch[\"speech\"], sampling_rate=batch[\"sampling_rate\"][0]).input_values\n", " \n", " with processor.as_target_processor():\n", " batch[\"labels\"] = processor(batch[\"target_text\"]).input_ids\n", " return batch" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2021-03-17T11:11:02.120225Z", "start_time": "2021-03-17T11:10:56.182488Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Special tokens have been added in the vocabulary, make sure the associated word embedding are fine-tuned or trained.\n" ] } ], "source": [ "model = Wav2Vec2ForCTC.from_pretrained(\".\").to(\"cuda\")\n", "processor = Wav2Vec2Processor.from_pretrained(\".\")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2021-03-17T11:12:18.847005Z", "start_time": "2021-03-17T11:12:14.919077Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using custom data configuration el-afd0a157f05ee080\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading and preparing dataset common_voice/el (download: 363.89 MiB, generated: 4.75 MiB, post-processed: Unknown size, total: 368.64 MiB) to /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Dataset common_voice downloaded and prepared to /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f. common_voice_test = load_dataset("common_voice", "el", data_dir="cv-corpus-6.1-2020-12-11", split="test") common_voice_test = common_voice_test.map(remove_special_characters, remove_columns=["sentence"]) common_voice_test = common_voice_test.map(speech_file_to_array_fn, remove_columns=common_voice_test.column_names) description='#5', max=190, style=ProgressStyle(description_width='initial'…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3b940160575143c7acfa142564e9f7d2", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#3', max=190, style=ProgressStyle(description_width='initial'…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "aa540f67ba894d7aa64e12fcdfab5ce0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#1', max=191, style=ProgressStyle(description_width='initial'…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4962bdefdbbc44a7a44591480d8d6406", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, common_voice_test = common_voice_test.map(resample, num_proc=8) "cell_type": "code", "execution_count": 11, "metadata": { "ExecuteTime": { "end_time": "2021-03-17T11:13:25.145155Z", "start_time": "2021-03-17T11:13:18.091929Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/earendil/anaconda3/envs/cuda110/lib/python3.8/site-packages/numpy/core/_asarray.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", " return array(a, dtype, copy=False, order=order)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ae326a173a044b1494793e2a70d76a87", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#0', max=24, style=ProgressStyle(description_width='initial')…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "21ab1ef2af5a4a4fb23c68b0c5cf32f8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#1', max=24, style=ProgressStyle(description_width='initial')…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d331c5f4f888477daceffe370f6cd89f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#3', max=24, style=ProgressStyle(description_width='initial')…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6fa790118aa340e4afb9f83e71403a13", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#2', max=24, style=ProgressStyle(description_width='initial')…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c8092e2f59a9404596dc2bab206edf2c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#5', max=24, style=ProgressStyle(description_width='initial')…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "20f913f0caf8401098743b9e5051fc52", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#4', max=24, style=ProgressStyle(description_width='initial')…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7c7e15e24384494cb49a72106ce41ccd", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#6', max=24, style=ProgressStyle(description_width='initial')…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "73245add55e24ee2a6dbe0713d5073d9", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#7', max=24, style=ProgressStyle(description_width='initial')…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] } ], "source": [ common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names, batch_size=8, num_proc=8, batched=True) max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Dataset common_voice downloaded and prepared to /home/earendil/.cache/huggingface/datasets/common_voice/el-ac779bf2c9f7c09b/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f. # Change this value to try inference on different CommonVoice extracts
example = 678

input_dict = processor(common_voice_test["input_values"][example], return_tensors="pt", sampling_rate=16_000, padding=True)

logits = model(input_dict.input_values.to("cuda")).logits

pred_ids = torch.argmax(logits, dim=-1)

print("Prediction:")
print(processor.decode(pred_ids[0]))
# πού θέλεις να πάμε ρώτησε φοβισμένα ο βασιλιάς

print("\nReference:")
print(common_voice_test_transcription["sentence"][example].lower())
# πού θέλεις να πάμε; ρώτησε φοβισμένα ο βασιλιάς.

def map_to_result(batch):
    model.to("cuda")
    input_values = processor(
        batch["input_values"], 
        sampling_rate=16_000, 
        return_tensors="pt"
    ).input_values.to("cuda")

    with torch.no_grad():
        logits = model(input_values).logits

    pred_ids = torch.argmax(logits, dim=-1)
    batch["pred_str"] = processor.batch_decode(pred_ids)[0]

    return batch

results = common_voice_test.map(map_to_result)

def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

wer_metric = load_metric("wer")

print("Test WER: {:.3f}".format(wer_metric.compute(predictions=results["pred_str"], references= [item.lower() for item in common_voice_test_transcription['sentence']])))