diff --git "a/train_en.ipynb" "b/train_en.ipynb" new file mode 100644--- /dev/null +++ "b/train_en.ipynb" @@ -0,0 +1,1591 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "id": "888b8263", + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset, load_metric, Audio, Dataset\n", + "import os\n", + "import torchaudio\n", + "from tqdm.auto import tqdm" + ] + }, + { + "cell_type": "markdown", + "id": "1e7ea2f9", + "metadata": {}, + "source": [ + "# Load English Data" + ] + }, + { + "cell_type": "code", + "execution_count": 94, + "id": "33c4cc04", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset librispeech_asr (/workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25)\n", + "Reusing dataset librispeech_asr (/workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25)\n" + ] + } + ], + "source": [ + "common_voice_train = load_dataset('librispeech_asr', 'clean', split='train.100', use_auth_token=True)\n", + "common_voice_valid = load_dataset('librispeech_asr', 'clean', split='validation', use_auth_token=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 95, + "id": "9bb7700a", + "metadata": {}, + "outputs": [], + "source": [ + "common_voice_train = common_voice_train.remove_columns([\"speaker_id\", \"chapter_id\", \"id\"])\n", + "common_voice_valid = common_voice_valid.remove_columns([\"speaker_id\", \"chapter_id\", \"id\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 96, + "id": "24841b75", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'file': '374-180298-0020.flac',\n", + " 'audio': {'path': '374-180298-0020.flac',\n", + " 'array': array([-0.00045776, -0.00015259, 0.00045776, ..., -0.00079346,\n", + " -0.00082397, -0.00073242]),\n", + " 'sampling_rate': 16000},\n", + " 'text': 'I FLUNG MYSELF INTO THIS RAPID NOISY AND VOLCANIC LIFE WHICH HAD FORMERLY TERRIFIED ME WHEN I THOUGHT OF IT AND WHICH HAD BECOME FOR ME THE NECESSARY COMPLEMENT OF MY LOVE FOR MARGUERITE WHAT ELSE COULD I HAVE DONE'}" + ] + }, + "execution_count": 96, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "common_voice_train[20]" + ] + }, + { + "cell_type": "code", + "execution_count": 97, + "id": "10216d93", + "metadata": {}, + "outputs": [], + "source": [ + "def normalize_text(batch):\n", + " batch['sentence'] = batch['text'].lower()\n", + " return batch" + ] + }, + { + "cell_type": "code", + "execution_count": 98, + "id": "e273df9e", + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-39821e14da4ccfc6.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-e4f620dfcf745abf.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-3d1e276c48d7700d.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-be0164e2ebe18904.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-7ae9fcac3a3f91a6.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-6730d9dfdb618c5d.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-1d1f12207c9e8fb0.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-4ce6cce328a9926d.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-c501669f7a8c025f.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-c2dae441b29ce89e.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-946bf579b56edd11.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-ac18c4a28fdb4bdd.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-68ab7133b03d8e77.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-66e212cd3ec8d831.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-d488b939ca1bb2fe.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-9d69c01d32100ac3.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-157418d55680a70b.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-ab2be882c2068767.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-02fea36ebcc8dacc.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-4d2d40cb1751761c.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-ae4e5194a6b192b5.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-d8d04b05cf6a2f82.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-d5115676a5b3530b.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-e6a127133fe4b3bf.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-0fe341c7990c7d0a.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-0378733bb5599baf.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-5101b06be8e38155.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-7b7f564a00c65d6c.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-239f07ad965fdac0.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-f8f1e208eb5ed9aa.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-1bdfe386c0fd5293.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-79fe40e71e41b753.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-2ff358bfdad016d3.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-1ba1853b00d71b5c.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-a948acf1174f6b88.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-07cde6276597a1ea.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-7a681976b966ded2.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-3324b72ee4399962.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-391503b49f67ea50.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-f8b4377cf88d2a6b.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-9056a3e05d21caa4.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-761c3f4d780c2c64.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-5b3caf31cc4a300f.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-8cca2516996d11fa.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-718723bfa205616f.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-195e15858820d5ac.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-2a1ee4478dea00e9.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-87c0e1e7dd27a6fb.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-bf146de2d4b323b7.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-fb390a112062d6b8.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-cb5bbb30626c0365.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-c962e9e10741026d.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-3adbe2317c04995f.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-d9a69f4fc5db2c17.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-5c7ad25eba032698.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-e45d17faeb7a48bb.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-ae562a5ca95e49d9.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-3e7ab0836a6648c0.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-d1f01db6745a51a3.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-97d147d31af2e730.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-f1bbbe800a5f07e4.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-63434e5cecc9beba.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-06b694fb9c0c3664.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-3b830e7dd7005f0e.arrow\n" + ] + } + ], + "source": [ + "common_voice_train = common_voice_train.map(normalize_text, num_proc=32)\n", + "common_voice_valid = common_voice_valid.map(normalize_text, num_proc=32)" + ] + }, + { + "cell_type": "code", + "execution_count": 99, + "id": "c7855db2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'file': '374-180298-0011.flac',\n", + " 'audio': {'path': '374-180298-0011.flac',\n", + " 'array': array([-1.52587891e-04, -3.05175781e-05, 3.05175781e-05, ...,\n", + " -3.96728516e-04, 1.28173828e-03, 1.09863281e-03]),\n", + " 'sampling_rate': 16000},\n", + " 'text': 'FORGIVE ME IF I GIVE YOU ALL THESE DETAILS BUT YOU WILL SEE THAT THEY WERE THE CAUSE OF WHAT WAS TO FOLLOW WHAT I TELL YOU IS A TRUE AND SIMPLE STORY AND I LEAVE TO IT ALL THE NAIVETE OF ITS DETAILS',\n", + " 'sentence': 'forgive me if i give you all these details but you will see that they were the cause of what was to follow what i tell you is a true and simple story and i leave to it all the naivete of its details'}" + ] + }, + "execution_count": 99, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "common_voice_train[11]" + ] + }, + { + "cell_type": "markdown", + "id": "91e691cf", + "metadata": {}, + "source": [ + "### Clean Up the Text" + ] + }, + { + "cell_type": "code", + "execution_count": 100, + "id": "1e5928ca", + "metadata": {}, + "outputs": [], + "source": [ + "# Remove character\n", + "import re\n", + "chars_to_remove_regex = '[\\,\\?\\.\\!\\-\\;\\:\\\"\\“\\%\\‘\\”\\�\\'\\。]'\n", + "def remove_special_characters(batch):\n", + " sentence = re.sub(chars_to_remove_regex, '', batch[\"sentence\"])\n", + " batch['sentence'] = sentence\n", + " return batch" + ] + }, + { + "cell_type": "code", + "execution_count": 101, + "id": "b174637a", + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-e4a181970b0ccf3a.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-366a9a8ad0d4f020.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-9c81fa6805a57e36.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-98662e2e0d6de08c.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-2ed4f597158e7f97.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-a7eb91fc97196d6a.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-dfa73be9a208ca63.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-b47a5233137ac8c2.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-8f953dac9f62a5b2.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-498064aa789e131a.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-fa701fadd4e6a183.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-00e9fb896fc7930c.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-aa07024d4bb6d697.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-ee50eefef45501a4.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-ca19eaee616b5703.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-8b5e0695957ea96f.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-906417ae8832c6c4.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-a95ff672eb2f214e.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-1ba14805577c5d51.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-54219d547b1ec85f.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-2e68540358007e7e.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-fc03c2cbbcebc863.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-e096bd4232ad94aa.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-b02c234c214c6213.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-681e04bd16072174.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-ab1c3a9e9292f5bd.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-e380978163720d58.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-88b0df8a76f97a13.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-4d0dd38deb43b5cf.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-3888690f581d999b.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-08959b59931f0862.arrow\n", + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-087d98a6c6dcee8f.arrow\n" + ] + } + ], + "source": [ + "common_voice_train = common_voice_train.map(remove_special_characters, num_proc=16)\n", + "common_voice_valid = common_voice_valid.map(remove_special_characters, num_proc=16)" + ] + }, + { + "cell_type": "code", + "execution_count": 102, + "id": "60464061", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'file': '374-180298-0001.flac',\n", + " 'audio': {'path': '374-180298-0001.flac',\n", + " 'array': array([-9.15527344e-05, -1.52587891e-04, -1.52587891e-04, ...,\n", + " -2.13623047e-04, -1.83105469e-04, -2.74658203e-04]),\n", + " 'sampling_rate': 16000},\n", + " 'text': \"MARGUERITE TO BE UNABLE TO LIVE APART FROM ME IT WAS THE DAY AFTER THE EVENING WHEN SHE CAME TO SEE ME THAT I SENT HER MANON LESCAUT FROM THAT TIME SEEING THAT I COULD NOT CHANGE MY MISTRESS'S LIFE I CHANGED MY OWN\",\n", + " 'sentence': 'marguerite to be unable to live apart from me it was the day after the evening when she came to see me that i sent her manon lescaut from that time seeing that i could not change my mistresss life i changed my own'}" + ] + }, + "execution_count": 102, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "common_voice_train[1]" + ] + }, + { + "cell_type": "markdown", + "id": "e81047af", + "metadata": {}, + "source": [ + "### Build Character" + ] + }, + { + "cell_type": "code", + "execution_count": 103, + "id": "abc1d92e", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "278c083d4f1744128d6a1de797034bdd", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/28539 [00:00\n", + " \n", + " Your browser does not support the audio element.\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 113, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import IPython.display as ipd\n", + "import numpy as np\n", + "import random\n", + "\n", + "rand_int = random.randint(0, len(common_voice_train)-1)\n", + "\n", + "print(\"Target text:\", common_voice_train[rand_int][\"sentence\"])\n", + "print(\"Input array shape:\", common_voice_train[rand_int][\"audio\"][\"array\"].shape)\n", + "print(\"Sampling rate:\", common_voice_train[rand_int][\"audio\"][\"sampling_rate\"])\n", + "ipd.Audio(data=common_voice_train[rand_int][\"audio\"][\"array\"], autoplay=False, rate=16000)" + ] + }, + { + "cell_type": "code", + "execution_count": 114, + "id": "927dbf96", + "metadata": {}, + "outputs": [], + "source": [ + "# This does not prepare the input for the Transformer model.\n", + "# This will resample the data and convert the sentence into indices\n", + "# Batch here is just for one entry (row)\n", + "def prepare_dataset(batch):\n", + " audio = batch[\"audio\"]\n", + " \n", + " # batched output is \"un-batched\"\n", + " batch[\"input_values\"] = processor(audio[\"array\"], sampling_rate=audio[\"sampling_rate\"]).input_values[0]\n", + " batch[\"input_length\"] = len(batch[\"input_values\"])\n", + " \n", + " with processor.as_target_processor():\n", + " batch[\"labels\"] = processor(batch[\"sentence\"]).input_ids\n", + " return batch" + ] + }, + { + "cell_type": "code", + "execution_count": 115, + "id": "0b73a58a", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "70dca39efd2148eaa755c2f6a14de114", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "0ex [00:00, ?ex/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/8c6e15bda76db687d2a7c7198808151adecbb4d890ff463033a2e6f788c0ba25/cache-440d93538cd91d0a.arrow\n" + ] + } + ], + "source": [ + "common_voice_train = common_voice_train.map(prepare_dataset, remove_columns=common_voice_train.column_names)\n", + "common_voice_valid = common_voice_valid.map(prepare_dataset, remove_columns=common_voice_valid.column_names)" + ] + }, + { + "cell_type": "code", + "execution_count": 117, + "id": "dd807bc7", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8a32c0e5e4c14038b81e8c1eae653ff8", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/29 [00:00 Dict[str, torch.Tensor]:\n", + " # split inputs and labels since they have to be of different lenghts and need\n", + " # different padding methods\n", + " input_features = [{\"input_values\": feature[\"input_values\"]} for feature in features]\n", + " label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n", + "\n", + " batch = self.processor.pad(\n", + " input_features,\n", + " padding=self.padding,\n", + " return_tensors=\"pt\",\n", + " )\n", + "\n", + " with self.processor.as_target_processor():\n", + " labels_batch = self.processor.pad(\n", + " label_features,\n", + " padding=self.padding,\n", + " return_tensors=\"pt\",\n", + " )\n", + "\n", + " # replace padding with -100 to ignore loss correctly\n", + " labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n", + "\n", + " batch[\"labels\"] = labels\n", + "\n", + " return batch" + ] + }, + { + "cell_type": "code", + "execution_count": 119, + "id": "5e435f4d", + "metadata": {}, + "outputs": [], + "source": [ + "data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 120, + "id": "94202896", + "metadata": {}, + "outputs": [], + "source": [ + "wer_metric = load_metric(\"wer\")\n", + "# cer_metric = load_metric(\"cer\")" + ] + }, + { + "cell_type": "code", + "execution_count": 121, + "id": "126e6222", + "metadata": {}, + "outputs": [], + "source": [ + "def compute_metrics(pred):\n", + " pred_logits = pred.predictions\n", + " pred_ids = np.argmax(pred_logits, axis=-1)\n", + "\n", + " pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id\n", + "\n", + " pred_str = tokenizer.batch_decode(pred_ids)\n", + " label_str = tokenizer.batch_decode(pred.label_ids, group_tokens=False)\n", + " \n", + " wer = wer_metric.compute(predictions=pred_str, references=label_str)\n", + "# cer = cer_metric.compute(predictions=pred_str, references=label_str)\n", + "\n", + " return {\"wer\": wer}\n", + "# return {\"cer\": cer}" + ] + }, + { + "cell_type": "code", + "execution_count": 142, + "id": "5797fd64", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "loading configuration file https://huggingface.co/facebook/wav2vec2-xls-r-300m/resolve/main/config.json from cache at /workspace/.cache/huggingface/transformers/dabc27df63e37bd2a7a221c7774e35f36a280fbdf917cf54cadfc7df8c786f6f.a3e4c3c967d9985881e0ae550a5f6f668f897db5ab2e0802f9b97973b15970e6\n", + "Model config Wav2Vec2Config {\n", + " \"activation_dropout\": 0.0,\n", + " \"adapter_kernel_size\": 3,\n", + " \"adapter_stride\": 2,\n", + " \"add_adapter\": false,\n", + " \"apply_spec_augment\": true,\n", + " \"architectures\": [\n", + " \"Wav2Vec2ForPreTraining\"\n", + " ],\n", + " \"attention_dropout\": 0.1,\n", + " \"bos_token_id\": 1,\n", + " \"classifier_proj_size\": 256,\n", + " \"codevector_dim\": 768,\n", + " \"contrastive_logits_temperature\": 0.1,\n", + " \"conv_bias\": true,\n", + " \"conv_dim\": [\n", + " 512,\n", + " 512,\n", + " 512,\n", + " 512,\n", + " 512,\n", + " 512,\n", + " 512\n", + " ],\n", + " \"conv_kernel\": [\n", + " 10,\n", + " 3,\n", + " 3,\n", + " 3,\n", + " 3,\n", + " 2,\n", + " 2\n", + " ],\n", + " \"conv_stride\": [\n", + " 5,\n", + " 2,\n", + " 2,\n", + " 2,\n", + " 2,\n", + " 2,\n", + " 2\n", + " ],\n", + " \"ctc_loss_reduction\": \"mean\",\n", + " \"ctc_zero_infinity\": false,\n", + " \"diversity_loss_weight\": 0.1,\n", + " \"do_stable_layer_norm\": true,\n", + " \"eos_token_id\": 2,\n", + " \"feat_extract_activation\": \"gelu\",\n", + " \"feat_extract_dropout\": 0.0,\n", + " \"feat_extract_norm\": \"layer\",\n", + " \"feat_proj_dropout\": 0.0,\n", + " \"feat_quantizer_dropout\": 0.0,\n", + " \"final_dropout\": 0.0,\n", + " \"gradient_checkpointing\": false,\n", + " \"hidden_act\": \"gelu\",\n", + " \"hidden_dropout\": 0.1,\n", + " \"hidden_size\": 1024,\n", + " \"initializer_range\": 0.02,\n", + " \"intermediate_size\": 4096,\n", + " \"layer_norm_eps\": 1e-05,\n", + " \"layerdrop\": 0.0,\n", + " \"mask_feature_length\": 64,\n", + " \"mask_feature_min_masks\": 0,\n", + " \"mask_feature_prob\": 0.25,\n", + " \"mask_time_length\": 10,\n", + " \"mask_time_min_masks\": 2,\n", + " \"mask_time_prob\": 0.75,\n", + " \"model_type\": \"wav2vec2\",\n", + " \"num_adapter_layers\": 3,\n", + " \"num_attention_heads\": 16,\n", + " \"num_codevector_groups\": 2,\n", + " \"num_codevectors_per_group\": 320,\n", + " \"num_conv_pos_embedding_groups\": 16,\n", + " \"num_conv_pos_embeddings\": 128,\n", + " \"num_feat_extract_layers\": 7,\n", + " \"num_hidden_layers\": 24,\n", + " \"num_negatives\": 100,\n", + " \"output_hidden_size\": 1024,\n", + " \"pad_token_id\": 28,\n", + " \"proj_codevector_dim\": 768,\n", + " \"tdnn_dilation\": [\n", + " 1,\n", + " 2,\n", + " 3,\n", + " 1,\n", + " 1\n", + " ],\n", + " \"tdnn_dim\": [\n", + " 512,\n", + " 512,\n", + " 512,\n", + " 512,\n", + " 1500\n", + " ],\n", + " \"tdnn_kernel\": [\n", + " 5,\n", + " 3,\n", + " 3,\n", + " 1,\n", + " 1\n", + " ],\n", + " \"torch_dtype\": \"float32\",\n", + " \"transformers_version\": \"4.17.0.dev0\",\n", + " \"use_weighted_layer_sum\": false,\n", + " \"vocab_size\": 31,\n", + " \"xvector_output_dim\": 512\n", + "}\n", + "\n", + "loading weights file https://huggingface.co/facebook/wav2vec2-xls-r-300m/resolve/main/pytorch_model.bin from cache at /workspace/.cache/huggingface/transformers/1e6a6507f3b689035cd4b247e2a37c154e27f39143f31357a49b4e38baeccc36.1edb32803799e27ed554eb7dd935f6745b1a0b17b0ea256442fe24db6eb546cd\n", + "Some weights of the model checkpoint at facebook/wav2vec2-xls-r-300m were not used when initializing Wav2Vec2ForCTC: ['quantizer.weight_proj.bias', 'quantizer.codevectors', 'project_hid.weight', 'project_hid.bias', 'project_q.bias', 'quantizer.weight_proj.weight', 'project_q.weight']\n", + "- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-xls-r-300m and are newly initialized: ['lm_head.bias', 'lm_head.weight']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + } + ], + "source": [ + "from transformers import Wav2Vec2ForCTC\n", + "\n", + "model = Wav2Vec2ForCTC.from_pretrained(\n", + " \"facebook/wav2vec2-xls-r-300m\", \n", + " attention_dropout=0.1,\n", + " layerdrop=0.0,\n", + " feat_proj_dropout=0.0,\n", + " mask_time_prob=0.75, \n", + " mask_time_length=10,\n", + " mask_feature_prob=0.25,\n", + " mask_feature_length=64,\n", + " ctc_loss_reduction=\"mean\",\n", + " pad_token_id=processor.tokenizer.pad_token_id,\n", + " vocab_size=len(processor.tokenizer)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 143, + "id": "e66e718d", + "metadata": {}, + "outputs": [], + "source": [ + "model.freeze_feature_encoder()" + ] + }, + { + "cell_type": "code", + "execution_count": 147, + "id": "6cdb6148", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "PyTorch: setting up devices\n", + "The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).\n" + ] + } + ], + "source": [ + "from transformers import TrainingArguments\n", + "\n", + "training_args = TrainingArguments(\n", + " output_dir='.',\n", + " group_by_length=True,\n", + " per_device_train_batch_size=8,\n", + " gradient_accumulation_steps=4,\n", + " evaluation_strategy=\"steps\",\n", + " gradient_checkpointing=True,\n", + " fp16=True,\n", + " num_train_epochs=50,\n", + " save_steps=500,\n", + " eval_steps=500,\n", + " logging_steps=100,\n", + " learning_rate=5e-5,\n", + " warmup_steps=1000,\n", + " save_total_limit=3,\n", + " load_best_model_at_end=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 148, + "id": "f396bd8f", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using amp half precision backend\n" + ] + } + ], + "source": [ + "from transformers import Trainer\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " data_collator=data_collator,\n", + " args=training_args,\n", + " compute_metrics=compute_metrics,\n", + " train_dataset=common_voice_train,\n", + " eval_dataset=common_voice_valid,\n", + " tokenizer=processor.feature_extractor,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 149, + "id": "50550e52", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The following columns in the training set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n", + "/opt/conda/lib/python3.8/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use thePyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", + " warnings.warn(\n", + "***** Running training *****\n", + " Num examples = 3857\n", + " Num Epochs = 50\n", + " Instantaneous batch size per device = 8\n", + " Total train batch size (w. parallel, distributed & accumulation) = 32\n", + " Gradient Accumulation steps = 4\n", + " Total optimization steps = 6000\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [6000/6000 3:56:13, Epoch 49/50]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining LossValidation LossWer
5002.9365002.9397950.999872
10001.5444000.5947150.428913
15001.1367000.2750930.236642
20000.9972000.2032340.179661
25000.9118000.1785940.147944
30000.8664000.1640960.140763
35000.8251000.1536810.126742
40000.7930000.1524650.124434
45000.7850000.1469750.118449
50000.7612000.1446020.117722
55000.7478000.1449030.117594
60000.7443000.1444080.116697

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n", + "***** Running Evaluation *****\n", + " Num examples = 1812\n", + " Batch size = 8\n", + "Saving model checkpoint to ./checkpoint-500\n", + "Configuration saved in ./checkpoint-500/config.json\n", + "Model weights saved in ./checkpoint-500/pytorch_model.bin\n", + "Configuration saved in ./checkpoint-500/preprocessor_config.json\n", + "The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n", + "***** Running Evaluation *****\n", + " Num examples = 1812\n", + " Batch size = 8\n", + "Saving model checkpoint to ./checkpoint-1000\n", + "Configuration saved in ./checkpoint-1000/config.json\n", + "Model weights saved in ./checkpoint-1000/pytorch_model.bin\n", + "Configuration saved in ./checkpoint-1000/preprocessor_config.json\n", + "The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n", + "***** Running Evaluation *****\n", + " Num examples = 1812\n", + " Batch size = 8\n", + "Saving model checkpoint to ./checkpoint-1500\n", + "Configuration saved in ./checkpoint-1500/config.json\n", + "Model weights saved in ./checkpoint-1500/pytorch_model.bin\n", + "Configuration saved in ./checkpoint-1500/preprocessor_config.json\n", + "The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n", + "***** Running Evaluation *****\n", + " Num examples = 1812\n", + " Batch size = 8\n", + "Saving model checkpoint to ./checkpoint-2000\n", + "Configuration saved in ./checkpoint-2000/config.json\n", + "Model weights saved in ./checkpoint-2000/pytorch_model.bin\n", + "Configuration saved in ./checkpoint-2000/preprocessor_config.json\n", + "Deleting older checkpoint [checkpoint-500] due to args.save_total_limit\n", + "The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n", + "***** Running Evaluation *****\n", + " Num examples = 1812\n", + " Batch size = 8\n", + "Saving model checkpoint to ./checkpoint-2500\n", + "Configuration saved in ./checkpoint-2500/config.json\n", + "Model weights saved in ./checkpoint-2500/pytorch_model.bin\n", + "Configuration saved in ./checkpoint-2500/preprocessor_config.json\n", + "Deleting older checkpoint [checkpoint-1000] due to args.save_total_limit\n", + "The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n", + "***** Running Evaluation *****\n", + " Num examples = 1812\n", + " Batch size = 8\n", + "Saving model checkpoint to ./checkpoint-3000\n", + "Configuration saved in ./checkpoint-3000/config.json\n", + "Model weights saved in ./checkpoint-3000/pytorch_model.bin\n", + "Configuration saved in ./checkpoint-3000/preprocessor_config.json\n", + "Deleting older checkpoint [checkpoint-1500] due to args.save_total_limit\n", + "The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n", + "***** Running Evaluation *****\n", + " Num examples = 1812\n", + " Batch size = 8\n", + "Saving model checkpoint to ./checkpoint-3500\n", + "Configuration saved in ./checkpoint-3500/config.json\n", + "Model weights saved in ./checkpoint-3500/pytorch_model.bin\n", + "Configuration saved in ./checkpoint-3500/preprocessor_config.json\n", + "Deleting older checkpoint [checkpoint-2000] due to args.save_total_limit\n", + "The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n", + "***** Running Evaluation *****\n", + " Num examples = 1812\n", + " Batch size = 8\n", + "Saving model checkpoint to ./checkpoint-4000\n", + "Configuration saved in ./checkpoint-4000/config.json\n", + "Model weights saved in ./checkpoint-4000/pytorch_model.bin\n", + "Configuration saved in ./checkpoint-4000/preprocessor_config.json\n", + "Deleting older checkpoint [checkpoint-2500] due to args.save_total_limit\n", + "The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n", + "***** Running Evaluation *****\n", + " Num examples = 1812\n", + " Batch size = 8\n", + "Saving model checkpoint to ./checkpoint-4500\n", + "Configuration saved in ./checkpoint-4500/config.json\n", + "Model weights saved in ./checkpoint-4500/pytorch_model.bin\n", + "Configuration saved in ./checkpoint-4500/preprocessor_config.json\n", + "Deleting older checkpoint [checkpoint-3000] due to args.save_total_limit\n", + "The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n", + "***** Running Evaluation *****\n", + " Num examples = 1812\n", + " Batch size = 8\n", + "Saving model checkpoint to ./checkpoint-5000\n", + "Configuration saved in ./checkpoint-5000/config.json\n", + "Model weights saved in ./checkpoint-5000/pytorch_model.bin\n", + "Configuration saved in ./checkpoint-5000/preprocessor_config.json\n", + "Deleting older checkpoint [checkpoint-3500] due to args.save_total_limit\n", + "The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n", + "***** Running Evaluation *****\n", + " Num examples = 1812\n", + " Batch size = 8\n", + "Saving model checkpoint to ./checkpoint-5500\n", + "Configuration saved in ./checkpoint-5500/config.json\n", + "Model weights saved in ./checkpoint-5500/pytorch_model.bin\n", + "Configuration saved in ./checkpoint-5500/preprocessor_config.json\n", + "Deleting older checkpoint [checkpoint-4000] due to args.save_total_limit\n", + "The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n", + "***** Running Evaluation *****\n", + " Num examples = 1812\n", + " Batch size = 8\n", + "Saving model checkpoint to ./checkpoint-6000\n", + "Configuration saved in ./checkpoint-6000/config.json\n", + "Model weights saved in ./checkpoint-6000/pytorch_model.bin\n", + "Configuration saved in ./checkpoint-6000/preprocessor_config.json\n", + "Deleting older checkpoint [checkpoint-4500] due to args.save_total_limit\n", + "\n", + "\n", + "Training completed. Do not forget to share your model on huggingface.co/models =)\n", + "\n", + "\n", + "Loading best model from ./checkpoint-6000 (score: 0.14440837502479553).\n" + ] + }, + { + "data": { + "text/plain": [ + "TrainOutput(global_step=6000, training_loss=1.1765391832987468, metrics={'train_runtime': 14177.2496, 'train_samples_per_second': 13.603, 'train_steps_per_second': 0.423, 'total_flos': 2.9510893171822916e+19, 'train_loss': 1.1765391832987468, 'epoch': 49.99})" + ] + }, + "execution_count": 149, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer.train()" + ] + }, + { + "cell_type": "code", + "execution_count": 150, + "id": "57f2a4e2", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "tokenizer config file saved in ./tokenizer_config.json\n", + "Special tokens file saved in ./special_tokens_map.json\n", + "added tokens file saved in ./added_tokens.json\n" + ] + }, + { + "data": { + "text/plain": [ + "('./tokenizer_config.json',\n", + " './special_tokens_map.json',\n", + " './vocab.json',\n", + " './added_tokens.json')" + ] + }, + "execution_count": 150, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer.save_pretrained('.')" + ] + }, + { + "cell_type": "code", + "execution_count": 151, + "id": "5d14e7f1", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Configuration saved in ./preprocessor_config.json\n", + "tokenizer config file saved in ./tokenizer_config.json\n", + "Special tokens file saved in ./special_tokens_map.json\n", + "added tokens file saved in ./added_tokens.json\n" + ] + } + ], + "source": [ + "processor.save_pretrained('.')" + ] + }, + { + "cell_type": "code", + "execution_count": 152, + "id": "97ab4059", + "metadata": {}, + "outputs": [], + "source": [ + "kwargs = {\n", + " \"finetuned_from\": \"facebook/wav2vec2-xls-r-300m\",\n", + " \"tasks\": \"speech-recognition\",\n", + " \"tags\": [\"automatic-speech-recognition\", \"librispeech_asr\", \"robust-speech-event\", \"en\"],\n", + " \"dataset_args\": f\"Config: clean, Training split: train.100, Eval split: validation\",\n", + " \"dataset\": \"librispeech_asr\",\n", + " \"language\": \"en\"\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 153, + "id": "62fc6680", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Dropping the following result as it does not have all the necessary fields:\n", + "{}\n" + ] + } + ], + "source": [ + "trainer.create_model_card(**kwargs)" + ] + }, + { + "cell_type": "code", + "execution_count": 154, + "id": "ba5d5f5d", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Saving model checkpoint to .\n", + "Configuration saved in ./config.json\n", + "Model weights saved in ./pytorch_model.bin\n", + "Configuration saved in ./preprocessor_config.json\n" + ] + } + ], + "source": [ + "trainer.save_model('.')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7618702f", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c8b7927f", + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer.push_to_hub('.')" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "id": "341a70d4", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Configuration saved in vitouphy/xls-r-300m-ja/config.json\n", + "Model weights saved in vitouphy/xls-r-300m-ja/pytorch_model.bin\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6e6bb4dfb7ea43818e83f52252cf939b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Upload file pytorch_model.bin: 0%| | 3.39k/1.18G [00:00 main\n", + "\n" + ] + }, + { + "data": { + "text/plain": [ + "'https://huggingface.co/vitouphy/xls-r-300m-ja/commit/1e678ca0c4b03aa3bca71af6fd2c0aa738b7aa7b'" + ] + }, + "execution_count": 61, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.push_to_hub('vitouphy/xls-r-300m-ja')" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "id": "f4b4919d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1" + ] + }, + "execution_count": 62, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "1" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "id": "f0d11e5d", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Saving model checkpoint to .\n", + "Configuration saved in ./config.json\n", + "Model weights saved in ./pytorch_model.bin\n", + "Configuration saved in ./preprocessor_config.json\n" + ] + }, + { + "ename": "AttributeError", + "evalue": "'Trainer' object has no attribute 'repo'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "Input \u001b[0;32mIn [57]\u001b[0m, in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpush_to_hub\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mvitouphy/xls-r-300m-ja\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/transformers/trainer.py:2792\u001b[0m, in \u001b[0;36mTrainer.push_to_hub\u001b[0;34m(self, commit_message, blocking, **kwargs)\u001b[0m\n\u001b[1;32m 2789\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mis_world_process_zero():\n\u001b[1;32m 2790\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[0;32m-> 2792\u001b[0m git_head_commit_url \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrepo\u001b[49m\u001b[38;5;241m.\u001b[39mpush_to_hub(\n\u001b[1;32m 2793\u001b[0m commit_message\u001b[38;5;241m=\u001b[39mcommit_message, blocking\u001b[38;5;241m=\u001b[39mblocking, auto_lfs_prune\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 2794\u001b[0m )\n\u001b[1;32m 2795\u001b[0m \u001b[38;5;66;03m# push separately the model card to be independant from the rest of the model\u001b[39;00m\n\u001b[1;32m 2796\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mshould_save:\n", + "\u001b[0;31mAttributeError\u001b[0m: 'Trainer' object has no attribute 'repo'" + ] + } + ], + "source": [ + "trainer.push_to_hub('vitouphy/xls-r-300m-ja')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9256963c", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}