{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2021-03-17T11:10:25.794375Z", "start_time": "2021-03-17T11:10:24.301013Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/earendil/anaconda3/envs/cuda110/lib/python3.8/site-packages/torchaudio/backend/utils.py:53: UserWarning: \"sox\" backend is being deprecated. The default backend will be changed to \"sox_io\" backend in 0.8.0 and \"sox\" backend will be removed in 0.9.0. Please migrate to \"sox_io\" backend. Please refer to https://github.com/pytorch/audio/issues/903 for the detail.\n", " warnings.warn(\n" ] } ], "source": [ "from transformers import Wav2Vec2ForCTC\n", "from transformers import Wav2Vec2Processor\n", "from datasets import load_dataset, load_metric\n", "import re\n", "import torchaudio\n", "import librosa\n", "import numpy as np\n", "from datasets import load_dataset, load_metric\n", "import torch" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2021-03-17T11:10:29.608803Z", "start_time": "2021-03-17T11:10:29.599700Z" } }, "outputs": [], "source": [ "chars_to_ignore_regex = '[\\,\\?\\.\\!\\-\\;\\:\\\"\\“\\%\\‘\\”\\�]'\n", "\n", "def remove_special_characters(batch):\n", " batch[\"text\"] = re.sub(chars_to_ignore_regex, '', batch[\"sentence\"]).lower() + \" \"\n", " return batch\n", "\n", "def speech_file_to_array_fn(batch):\n", " speech_array, sampling_rate = torchaudio.load(batch[\"path\"])\n", " batch[\"speech\"] = speech_array[0].numpy()\n", " batch[\"sampling_rate\"] = sampling_rate\n", " batch[\"target_text\"] = batch[\"text\"]\n", " return batch\n", "\n", "def resample(batch):\n", " batch[\"speech\"] = librosa.resample(np.asarray(batch[\"speech\"]), 48_000, 16_000)\n", " batch[\"sampling_rate\"] = 16_000\n", " return batch\n", "\n", "def prepare_dataset(batch):\n", " # check that all files have the correct sampling rate\n", " assert (\n", " len(set(batch[\"sampling_rate\"])) == 1\n", " ), f\"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}.\"\n", "\n", " batch[\"input_values\"] = processor(batch[\"speech\"], sampling_rate=batch[\"sampling_rate\"][0]).input_values\n", " \n", " with processor.as_target_processor():\n", " batch[\"labels\"] = processor(batch[\"target_text\"]).input_ids\n", " return batch" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2021-03-17T11:11:02.120225Z", "start_time": "2021-03-17T11:10:56.182488Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Special tokens have been added in the vocabulary, make sure the associated word embedding are fine-tuned or trained.\n" ] } ], "source": [ "model = Wav2Vec2ForCTC.from_pretrained(\".\").to(\"cuda\")\n", "processor = Wav2Vec2Processor.from_pretrained(\".\")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2021-03-17T11:12:18.847005Z", "start_time": "2021-03-17T11:12:14.919077Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using custom data configuration el-afd0a157f05ee080\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading and preparing dataset common_voice/el (download: 363.89 MiB, generated: 4.75 MiB, post-processed: Unknown size, total: 368.64 MiB) to /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Dataset common_voice downloaded and prepared to /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f. Subsequent calls will reuse this data.\n" ] } ], "source": [ "common_voice_test = load_dataset(\"common_voice\", \"el\", data_dir=\"cv-corpus-6.1-2020-12-11\", split=\"test\")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2021-03-17T11:12:18.860240Z", "start_time": "2021-03-17T11:12:18.857252Z" } }, "outputs": [], "source": [ "common_voice_test = common_voice_test.remove_columns([\"accent\", \"age\", \"client_id\", \"down_votes\", \"gender\", \"locale\", \"segment\", \"up_votes\"])" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2021-03-17T11:12:18.928497Z", "start_time": "2021-03-17T11:12:18.869198Z" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9869698af86e44bca75c4252996ff1a3", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=1522), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "common_voice_test = common_voice_test.map(remove_special_characters, remove_columns=[\"sentence\"])" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "ExecuteTime": { "end_time": "2021-03-17T11:12:40.824595Z", "start_time": "2021-03-17T11:12:18.937930Z" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d232b2bb009543e0bb2542bce273c554", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=1522), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "common_voice_test = common_voice_test.map(speech_file_to_array_fn, remove_columns=common_voice_test.column_names)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "ExecuteTime": { "end_time": "2021-03-17T11:13:18.078738Z", "start_time": "2021-03-17T11:12:40.834398Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ffd787bc4ed048ae8f4977f2c539bedb", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#0', max=191, style=ProgressStyle(description_width='initial'…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "79c51995d4f84ad8812230480d14b8cd", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#2', max=190, style=ProgressStyle(description_width='initial'…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "52963d9cfd814346af070b2cc4e105cf", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#5', max=190, style=ProgressStyle(description_width='initial'…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3b940160575143c7acfa142564e9f7d2", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#3', max=190, style=ProgressStyle(description_width='initial'…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "aa540f67ba894d7aa64e12fcdfab5ce0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#1', max=191, style=ProgressStyle(description_width='initial'…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4962bdefdbbc44a7a44591480d8d6406", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#4', max=190, style=ProgressStyle(description_width='initial'…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e77f088bfe5644548fe2c4277d0c86da", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#7', max=190, style=ProgressStyle(description_width='initial'…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5827f93e99994fe9919aac53f0fb9444", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#6', max=190, style=ProgressStyle(description_width='initial'…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] } ], "source": [ "common_voice_test = common_voice_test.map(resample, num_proc=8)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "ExecuteTime": { "end_time": "2021-03-17T11:13:25.145155Z", "start_time": "2021-03-17T11:13:18.091929Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/earendil/anaconda3/envs/cuda110/lib/python3.8/site-packages/numpy/core/_asarray.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", " return array(a, dtype, copy=False, order=order)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ae326a173a044b1494793e2a70d76a87", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#0', max=24, style=ProgressStyle(description_width='initial')…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "21ab1ef2af5a4a4fb23c68b0c5cf32f8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#1', max=24, style=ProgressStyle(description_width='initial')…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d331c5f4f888477daceffe370f6cd89f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#3', max=24, style=ProgressStyle(description_width='initial')…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6fa790118aa340e4afb9f83e71403a13", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#2', max=24, style=ProgressStyle(description_width='initial')…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c8092e2f59a9404596dc2bab206edf2c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#5', max=24, style=ProgressStyle(description_width='initial')…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "20f913f0caf8401098743b9e5051fc52", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#4', max=24, style=ProgressStyle(description_width='initial')…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7c7e15e24384494cb49a72106ce41ccd", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#6', max=24, style=ProgressStyle(description_width='initial')…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "73245add55e24ee2a6dbe0713d5073d9", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#7', max=24, style=ProgressStyle(description_width='initial')…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] } ], "source": [ "common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names, batch_size=8, num_proc=8, batched=True)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "ExecuteTime": { "end_time": "2021-03-17T11:14:12.721500Z", "start_time": "2021-03-17T11:14:08.198478Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using custom data configuration el-ac779bf2c9f7c09b\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading and preparing dataset common_voice/el (download: 363.89 MiB, generated: 4.75 MiB, post-processed: Unknown size, total: 368.64 MiB) to /home/earendil/.cache/huggingface/datasets/common_voice/el-ac779bf2c9f7c09b/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Dataset common_voice downloaded and prepared to /home/earendil/.cache/huggingface/datasets/common_voice/el-ac779bf2c9f7c09b/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f. Subsequent calls will reuse this data.\n" ] } ], "source": [ "common_voice_test_transcription = load_dataset(\"common_voice\", \"el\", data_dir=\"./cv-corpus-6.1-2020-12-11\", split=\"test\")" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T19:33:39.856174Z", "start_time": "2021-03-14T19:33:14.402825Z" } }, "outputs": [], "source": [ "# Change this value to try inference on different CommonVoice extracts\n", "example = 678\n", "\n", "input_dict = processor(common_voice_test[\"input_values\"][example], return_tensors=\"pt\", sampling_rate=16_000, padding=True)\n", "\n", "logits = model(input_dict.input_values.to(\"cuda\")).logits\n", "\n", "pred_ids = torch.argmax(logits, dim=-1)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T19:33:39.887236Z", "start_time": "2021-03-14T19:33:39.881958Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prediction:\n", "πού θέλεις να πάμε ρώτησε φοβισμένα ο βασιλιάς\n", "\n", "Reference:\n", "πού θέλεις να πάμε; ρώτησε φοβισμένα ο βασιλιάς.\n" ] } ], "source": [ "print(\"Prediction:\")\n", "print(processor.decode(pred_ids[0]))\n", "# πού θέλεις να πάμε ρώτησε φοβισμένα ο βασιλιάς\n", "\n", "print(\"\\nReference:\")\n", "print(common_voice_test_transcription[\"sentence\"][example].lower())\n", "# πού θέλεις να πάμε; ρώτησε φοβισμένα ο βασιλιάς." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "ExecuteTime": { "end_time": "2021-03-17T11:15:35.637739Z", "start_time": "2021-03-17T11:14:14.689842Z" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1f7ba9e12187401f870555d20a6a9458", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=1522), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "def map_to_result(batch):\n", " model.to(\"cuda\")\n", " input_values = processor(\n", " batch[\"input_values\"], \n", " sampling_rate=16_000, \n", " return_tensors=\"pt\"\n", " ).input_values.to(\"cuda\")\n", "\n", " with torch.no_grad():\n", " logits = model(input_values).logits\n", "\n", " pred_ids = torch.argmax(logits, dim=-1)\n", " batch[\"pred_str\"] = processor.batch_decode(pred_ids)[0]\n", "\n", " return batch\n", "\n", "results = common_voice_test.map(map_to_result)\n" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "ExecuteTime": { "end_time": "2021-03-17T11:17:11.951524Z", "start_time": "2021-03-17T11:17:08.856552Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test WER: 0.396\n" ] } ], "source": [ "def compute_metrics(pred):\n", " pred_logits = pred.predictions\n", " pred_ids = np.argmax(pred_logits, axis=-1)\n", "\n", " pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id\n", "\n", " pred_str = processor.batch_decode(pred_ids)\n", " # we do not want to group tokens when computing the metrics\n", " label_str = processor.batch_decode(pred.label_ids, group_tokens=False)\n", "\n", " wer = wer_metric.compute(predictions=pred_str, references=label_str)\n", "\n", " return {\"wer\": wer}\n", "\n", "wer_metric = load_metric(\"wer\")\n", "\n", "print(\"Test WER: {:.3f}\".format(wer_metric.compute(predictions=results[\"pred_str\"], references= [item.lower() for item in common_voice_test_transcription['sentence']])))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "cuda110", "language": "python", "name": "cuda110" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 4 }