{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T19:32:40.765119Z", "start_time": "2021-03-14T19:32:39.314790Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/earendil/anaconda3/envs/cuda110/lib/python3.8/site-packages/torchaudio/backend/utils.py:53: UserWarning: \"sox\" backend is being deprecated. The default backend will be changed to \"sox_io\" backend in 0.8.0 and \"sox\" backend will be removed in 0.9.0. Please migrate to \"sox_io\" backend. Please refer to https://github.com/pytorch/audio/issues/903 for the detail.\n", " warnings.warn(\n" ] } ], "source": [ "from transformers import Wav2Vec2ForCTC\n", "from transformers import Wav2Vec2Processor\n", "from datasets import load_dataset, load_metric\n", "import re\n", "import torchaudio\n", "import librosa\n", "import numpy as np\n", "from datasets import load_dataset, load_metric\n", "import torch" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T19:32:40.774860Z", "start_time": "2021-03-14T19:32:40.771235Z" } }, "outputs": [], "source": [ "chars_to_ignore_regex = '[\\,\\?\\.\\!\\-\\;\\:\\\"\\“\\%\\‘\\”\\�]'\n", "\n", "def remove_special_characters(batch):\n", " batch[\"text\"] = re.sub(chars_to_ignore_regex, '', batch[\"sentence\"]).lower() + \" \"\n", " return batch\n", "\n", "def speech_file_to_array_fn(batch):\n", " speech_array, sampling_rate = torchaudio.load(batch[\"path\"])\n", " batch[\"speech\"] = speech_array[0].numpy()\n", " batch[\"sampling_rate\"] = sampling_rate\n", " batch[\"target_text\"] = batch[\"text\"]\n", " return batch\n", "\n", "def resample(batch):\n", " batch[\"speech\"] = librosa.resample(np.asarray(batch[\"speech\"]), 48_000, 16_000)\n", " batch[\"sampling_rate\"] = 16_000\n", " return batch\n", "\n", "def prepare_dataset(batch):\n", " # check that all files have the correct sampling rate\n", " assert (\n", " len(set(batch[\"sampling_rate\"])) == 1\n", " ), f\"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}.\"\n", "\n", " batch[\"input_values\"] = processor(batch[\"speech\"], sampling_rate=batch[\"sampling_rate\"][0]).input_values\n", " \n", " with processor.as_target_processor():\n", " batch[\"labels\"] = processor(batch[\"target_text\"]).input_ids\n", " return batch" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T19:32:49.565850Z", "start_time": "2021-03-14T19:32:41.891601Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Special tokens have been added in the vocabulary, make sure the associated word embedding are fine-tuned or trained.\n" ] } ], "source": [ "model = Wav2Vec2ForCTC.from_pretrained(\".\").to(\"cuda\")\n", "processor = Wav2Vec2Processor.from_pretrained(\".\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T19:33:03.514113Z", "start_time": "2021-03-14T19:33:00.953049Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using custom data configuration el-afd0a157f05ee080\n", "Reusing dataset common_voice (/home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564)\n" ] } ], "source": [ "common_voice_test = load_dataset(\"common_voice\", \"el\", data_dir=\"cv-corpus-6.1-2020-12-11\", split=\"test\")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T19:33:03.528699Z", "start_time": "2021-03-14T19:33:03.525034Z" } }, "outputs": [], "source": [ "common_voice_test = common_voice_test.remove_columns([\"accent\", \"age\", \"client_id\", \"down_votes\", \"gender\", \"locale\", \"segment\", \"up_votes\"])" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T19:33:03.542260Z", "start_time": "2021-03-14T19:33:03.538498Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-0ce2ebca66096fff.arrow\n" ] } ], "source": [ "common_voice_test = common_voice_test.map(remove_special_characters, remove_columns=[\"sentence\"])" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T19:33:03.561798Z", "start_time": "2021-03-14T19:33:03.554256Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-38a09981767eff59.arrow\n" ] } ], "source": [ "common_voice_test = common_voice_test.map(speech_file_to_array_fn, remove_columns=common_voice_test.column_names)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T19:33:04.357229Z", "start_time": "2021-03-14T19:33:03.570805Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-ba8c6dd59eb8ccf2.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-2e240883a5f827fd.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-485c00dc9048ed50.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-44bf1791baae8e2e.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-ecc0dfac5615a58e.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-923d905502a8661d.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-bb54bb00dae79669.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-062aeafc3b8816c1.arrow\n" ] } ], "source": [ "common_voice_test = common_voice_test.map(resample, num_proc=8)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T19:33:11.205598Z", "start_time": "2021-03-14T19:33:04.368615Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/earendil/anaconda3/envs/cuda110/lib/python3.8/site-packages/numpy/core/_asarray.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", " return array(a, dtype, copy=False, order=order)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b179696d93284b739cc550511ca28b78", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#1', max=24, style=ProgressStyle(description_width='initial')…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2aa538a2fdcb4d56a55cd612c879044f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#0', max=24, style=ProgressStyle(description_width='initial')…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "073b156d7b7541e08ad20ec377efb05b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#2', max=24, style=ProgressStyle(description_width='initial')…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "dcaa10621a674699a7bf17357cd75a4a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#4', max=24, style=ProgressStyle(description_width='initial')…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9ee2184681db4e16802387bbc1acfef3", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#3', max=24, style=ProgressStyle(description_width='initial')…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "36c3bcde96114167a1dc6942d9d551f1", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#6', max=24, style=ProgressStyle(description_width='initial')…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b0ad190f479d40a599e74cfa9deddec8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#5', max=24, style=ProgressStyle(description_width='initial')…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "138f7261158c4a9ab84414a20954de01", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='#7', max=24, style=ProgressStyle(description_width='initial')…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] } ], "source": [ "common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names, batch_size=8, num_proc=8, batched=True)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T19:33:14.391497Z", "start_time": "2021-03-14T19:33:11.216118Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using custom data configuration el-ac779bf2c9f7c09b\n", "Reusing dataset common_voice (/home/earendil/.cache/huggingface/datasets/common_voice/el-ac779bf2c9f7c09b/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564)\n" ] } ], "source": [ "common_voice_test_transcription = load_dataset(\"common_voice\", \"el\", data_dir=\"./cv-corpus-6.1-2020-12-11\", split=\"test\")" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T19:33:39.856174Z", "start_time": "2021-03-14T19:33:14.402825Z" } }, "outputs": [], "source": [ "# Change this value to try inference on different CommonVoice extracts\n", "example = 678\n", "\n", "input_dict = processor(common_voice_test[\"input_values\"][example], return_tensors=\"pt\", sampling_rate=16_000, padding=True)\n", "\n", "logits = model(input_dict.input_values.to(\"cuda\")).logits\n", "\n", "pred_ids = torch.argmax(logits, dim=-1)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T19:33:39.887236Z", "start_time": "2021-03-14T19:33:39.881958Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prediction:\n", "πού θέλεις να πάμε ρώτησε φοβισμένα ο βασιλιάς\n", "\n", "Reference:\n", "πού θέλεις να πάμε; ρώτησε φοβισμένα ο βασιλιάς.\n" ] } ], "source": [ "print(\"Prediction:\")\n", "print(processor.decode(pred_ids[0]))\n", "# πού θέλεις να πάμε ρώτησε φοβισμένα ο βασιλιάς\n", "\n", "print(\"\\nReference:\")\n", "print(common_voice_test_transcription[\"sentence\"][example].lower())\n", "# πού θέλεις να πάμε; ρώτησε φοβισμένα ο βασιλιάς." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "cuda110", "language": "python", "name": "cuda110" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 4 }