{ "cells": [ { "cell_type": "markdown", "id": "c8c824ea", "metadata": {}, "source": [ "

TRhis notebook is for testing models against common_voice (v7)

" ] }, { "cell_type": "markdown", "id": "a9180c0b", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "source": [ "\n", "\n", "\n", "##### TEST WITH RASMUS 1B model with language model added using our own common_voice v7 (processed before event) ###" ] }, { "cell_type": "code", "execution_count": 1, "id": "0ea4e3d0", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [], "source": [ "# Environment settings: \n", "import pandas as pd\n", "pd.set_option('display.max_column', None)\n", "pd.set_option('display.max_rows', None)\n", "pd.set_option('display.max_seq_items', None)\n", "pd.set_option('display.max_colwidth', 500)\n", "pd.set_option('expand_frame_repr', True)\n", "\n", "from datasets import concatenate_datasets, load_dataset" ] }, { "cell_type": "code", "execution_count": 2, "id": "0a810556", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c56f4efc99cc4320a0ddee55f8c6dfce", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(HTML(value='
\\n\", \"…\", \"–\", \"°\", \"´\", \"ʾ\", \"‹\", \"›\", \"©\", \"®\", \"—\", \"→\", \"。\",\n", " \"、\", \"﹂\", \"﹁\", \"‧\", \"~\", \"﹏\", \",\", \"{\", \"}\", \"(\", \")\", \"[\", \"]\", \"【\", \"】\", \"‥\", \"〽\",\n", " \"『\", \"』\", \"〝\", \"〟\", \"⟨\", \"⟩\", \"〜\", \":\", \"!\", \"?\", \"♪\", \"؛\", \"/\", \"\\\\\", \"º\", \"−\", \"^\", \"ʻ\", \"ˆ\"]\n", "\n", "\n", "chars_to_remove_regex = f\"[{re.escape(''.join(CHARS_TO_IGNORE))}]\"\n", "\n", "def remove_special_characters(batch):\n", " batch[\"sentence\"] = re.sub(chars_to_remove_regex, '', batch[\"sentence\"]).lower()\n", " return batch\n", "\n", "common_voice_test = common_voice_test.map(remove_special_characters)" ] }, { "cell_type": "code", "execution_count": 12, "id": "42423e76", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "Dataset({\n", " features: ['client_id', 'path', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment', 'split', 'audio', 'dataset_name', 'filename', '__index_level_0__'],\n", " num_rows: 1599\n", "})" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "common_voice_test" ] }, { "cell_type": "code", "execution_count": 13, "id": "21d0937f", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [], "source": [ "common_voice_test_audio = common_voice_test.cast_column(\"audio\", Audio(sampling_rate=16_000))\n" ] }, { "cell_type": "code", "execution_count": 14, "id": "02f295cf", "metadata": {}, "outputs": [], "source": [ "def prepare_dataset(batch):\n", " audio = batch[\"audio\"]\n", "\n", " # batched output is \"un-batched\"\n", " batch[\"input_values\"] = processor(audio[\"array\"], sampling_rate=audio[\"sampling_rate\"]).input_values[0]\n", " batch[\"input_length\"] = len(batch[\"input_values\"])\n", " batch[\"sentence\"] = batch[\"sentence\"]\n", " \n", " with processor.as_target_processor():\n", " batch[\"labels\"] = processor(batch[\"sentence\"]).input_ids\n", " return batch" ] }, { "cell_type": "code", "execution_count": 15, "id": "1341388b", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8798d1f90aa241129972f072d31f7686", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/1599 [00:00 Test from mozilla common_voice v_7_0 directly from hub \n", "Using currently the \"old\" preprocessing and not the \"audio\" method\", \"…\", \"–\", \"°\", \"´\", \"ʾ\", \"‹\", \"›\", \"©\", \"®\", \"—\", \"→\", \"。\",\n", " \"、\", \"﹂\", \"﹁\", \"‧\", \"~\", \"﹏\", \",\", \"{\", \"}\", \"(\", \")\", \"[\", \"]\", \"【\", \"】\", \"‥\", \"〽\",\n", " \"『\", \"』\", \"〝\", \"〟\", \"⟨\", \"⟩\", \"〜\", \":\", \"!\", \"?\", \"♪\", \"؛\", \"/\", \"\\\\\", \"º\", \"−\", \"^\", \"ʻ\", \"ˆ\"]\n", "\n", "\n", "chars_to_remove_regex = f\"[{re.escape(''.join(CHARS_TO_IGNORE))}]\"\n", "\n", "def remove_special_characters(batch):\n", " batch[\"sentence\"] = re.sub(chars_to_remove_regex, '', batch[\"sentence\"]).lower()\n", " return batch\n", "\n", "common_voice_dataset = common_voice_dataset.map(remove_special_characters)" ] }, { "cell_type": "code", "execution_count": 27, "id": "396369b6", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/mozilla-foundation___common_voice/fi/7.0.0/33e08856cfa0d0665e837bcad73ffd920a0bc713ce8c5fffb55dbdf1c084d5ba/cache-b66d07bf277a5504.arrow\n" ] } ], "source": [ "def resample_audios(batch):\n", " sr = batch['audio']['sampling_rate']\n", " batch['audio']['array'] = F.resample(torch.tensor(batch[\"audio\"][\"array\"]), sr, 16_000).numpy()\n", " return batch\n", "\n", "common_voice_dataset = common_voice_dataset.map(resample_audios)" ] }, { "cell_type": "code", "execution_count": 28, "id": "2852ca1c", "metadata": {}, "outputs": [], "source": [ "def prepare_dataset(batch):\n", " batch[\"input_values\"] = processor(batch[\"audio\"][\"array\"], sampling_rate=16000).input_values[0]\n", " batch[\"input_length\"] = len(batch[\"input_values\"])\n", " batch[\"sentence\"] = batch[\"sentence\"]\n", " \n", " with processor.as_target_processor():\n", " batch[\"labels\"] = processor(batch[\"sentence\"]).input_ids\n", " return batch" ] }, { "cell_type": "code", "execution_count": 29, "id": "7832bd68", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "58ff04eeda134bec9ffdc7bcf71fc66c", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/1599 [00:00 ASR PIPELINE PREDICTIONS (Same kind as in eval.py)\n", "\n" ] }, { "cell_type": "code", "execution_count": 49, "id": "5302f579", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using the latest cached version of the module from /workspace/.cache/huggingface/modules/datasets_modules/datasets/mozilla-foundation--common_voice_7_0/33e08856cfa0d0665e837bcad73ffd920a0bc713ce8c5fffb55dbdf1c084d5ba (last modified on Sun Jan 23 16:17:44 2022) since it couldn't be found locally at mozilla-foundation/common_voice_7_0., or remotely on the Hugging Face Hub.\n", "Reusing dataset common_voice (/workspace/.cache/huggingface/datasets/mozilla-foundation___common_voice/fi/7.0.0/33e08856cfa0d0665e837bcad73ffd920a0bc713ce8c5fffb55dbdf1c084d5ba)\n" ] } ], "source": [ "common_voice_dataset = load_dataset(\"mozilla-foundation/common_voice_7_0\", \"fi\", split=\"test\")\n", "\n", "common_voice_dataset = common_voice_dataset.cast_column(\"audio\", Audio(sampling_rate=16_000))" ] }, { "cell_type": "code", "execution_count": 54, "id": "6c048f8b", "metadata": {}, "outputs": [], "source": [ "def normalize_text(text: str) -> str:\n", " \"\"\"DO ADAPT FOR YOUR USE CASE. this function normalizes the target text.\"\"\"\n", "\n", " chars_to_ignore_regex = [\",\", \"?\", \"¿\", \".\", \"!\", \"¡\", \";\", \";\", \":\", '\"\"', \"%\", '\"', \"�\", \"ʿ\", \"·\", \"჻\", \"~\", \"՞\",\n", " \"؟\", \"،\", \"।\", \"॥\", \"«\", \"»\", \"„\", \"“\", \"”\", \"「\", \"」\", \"‘\", \"’\", \"《\", \"》\", \"(\", \")\", \"[\", \"]\",\n", " \"{\", \"}\", \"=\", \"`\", \"_\", \"+\", \"<\", \">\", \"…\", \"–\", \"°\", \"´\", \"ʾ\", \"‹\", \"›\", \"©\", \"®\", \"—\", \"→\", \"。\",\n", " \"、\", \"﹂\", \"﹁\", \"‧\", \"~\", \"﹏\", \",\", \"{\", \"}\", \"(\", \")\", \"[\", \"]\", \"【\", \"】\", \"‥\", \"〽\",\n", " \"『\", \"』\", \"〝\", \"〟\", \"⟨\", \"⟩\", \"〜\", \":\", \"!\", \"?\", \"♪\", \"؛\", \"/\", \"\\\\\", \"º\", \"−\", \"^\", \"ʻ\", \"ˆ\"] \n", "\n", "\n", " chars_to_remove_regex = f\"[{re.escape(''.join(chars_to_ignore_regex))}]\"\n", " \n", " \n", " \n", " # remove punctuation\n", " text = re.sub(chars_to_remove_regex, '', text)\n", " \n", " text = text.lower()\n", " \n", " # Let's also make sure we split on all kinds of newlines, spaces, etc...\n", " #text = \" \".join(text.split())\n", " \n", " return text" ] }, { "cell_type": "code", "execution_count": 55, "id": "9fa432f2", "metadata": {}, "outputs": [], "source": [ "# map function to decode audio\n", "def map_to_pred(batch):\n", " prediction = asr(\n", " batch[\"audio\"][\"array\"]\n", " )\n", "\n", " batch[\"prediction\"] = prediction[\"text\"]\n", " batch[\"target\"] = normalize_text(batch[\"sentence\"])\n", " return batch" ] }, { "cell_type": "code", "execution_count": 56, "id": "d80cf431", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4dc38b78128f49938ec47a41be469153", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/1599 [00:00