{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "347417aa", "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset" ] }, { "cell_type": "code", "execution_count": 35, "id": "131dee3d", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Reusing dataset common_voice (/workspace/cache/hf/datasets/mozilla-foundation___common_voice/ga-IE/7.0.0/fe20cac47c166e25b1f096ab661832e3da7cf298ed4a91dcaa1343ad972d175b)\n" ] }, { "data": { "text/plain": [ "Dataset({\n", " features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],\n", " num_rows: 529\n", "})" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset = load_dataset(\"mozilla-foundation/common_voice_7_0\", 'ga-IE', split=\"train\", use_auth_token = True)\n", "dataset" ] }, { "cell_type": "code", "execution_count": 8, "id": "05c3ae92", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'An bhfuil do pheannsa sa bhaile, a Shíle'" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "audio_sample = dataset[2]\n", "audio_sample['sentence']" ] }, { "cell_type": "code", "execution_count": 9, "id": "c2edcf22", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'client_id': '7a622ad3217ecf8c2fc6656077a33059732504874d096bd0fc1d239f6e5f39a5861c5faad9e119d588ea80c048c3ab02bfae41fbde22232e89272ac3d5ecc7a4',\n", " 'path': 'cv-corpus-7.0-2021-07-21/ga-IE/clips/common_voice_ga-IE_17410230.mp3',\n", " 'audio': {'path': 'cv-corpus-7.0-2021-07-21/ga-IE/clips/common_voice_ga-IE_17410230.mp3',\n", " 'array': array([ 0. , 0. , 0. , ..., -0.00993466,\n", " -0.00979847, -0.00967056], dtype=float32),\n", " 'sampling_rate': 48000},\n", " 'sentence': 'An bhfuil do pheannsa sa bhaile, a Shíle',\n", " 'up_votes': 2,\n", " 'down_votes': 0,\n", " 'age': '',\n", " 'gender': '',\n", " 'accent': '',\n", " 'locale': 'ga-IE',\n", " 'segment': ''}" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "audio_sample" ] }, { "cell_type": "code", "execution_count": 12, "id": "f3d9bb64", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Git LFS: (2 of 2 files) 3.59 GB / 3.59 GB \n" ] } ], "source": [ "!git lfs pull" ] }, { "cell_type": "code", "execution_count": 13, "id": "936da5cb", "metadata": {}, "outputs": [], "source": [ "from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC\n", "\n", "processor = Wav2Vec2Processor.from_pretrained(\"./\")\n", "model = Wav2Vec2ForCTC.from_pretrained(\"./\")" ] }, { "cell_type": "code", "execution_count": 16, "id": "28659e97", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'client_id': '7a622ad3217ecf8c2fc6656077a33059732504874d096bd0fc1d239f6e5f39a5861c5faad9e119d588ea80c048c3ab02bfae41fbde22232e89272ac3d5ecc7a4',\n", " 'path': 'cv-corpus-7.0-2021-07-21/ga-IE/clips/common_voice_ga-IE_17410230.mp3',\n", " 'audio': {'path': 'cv-corpus-7.0-2021-07-21/ga-IE/clips/common_voice_ga-IE_17410230.mp3',\n", " 'array': array([ 0. , 0. , 0. , ..., -0.00993466,\n", " -0.00979847, -0.00967056], dtype=float32),\n", " 'sampling_rate': 48000},\n", " 'sentence': 'An bhfuil do pheannsa sa bhaile, a Shíle',\n", " 'up_votes': 2,\n", " 'down_votes': 0,\n", " 'age': '',\n", " 'gender': '',\n", " 'accent': '',\n", " 'locale': 'ga-IE',\n", " 'segment': ''}" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "audio_sample" ] }, { "cell_type": "code", "execution_count": 28, "id": "e24cacdc", "metadata": {}, "outputs": [], "source": [ "inputs = processor(audio_sample[\"audio\"][\"array\"], sampling_rate=16000, return_tensors=\"pt\")" ] }, { "cell_type": "code", "execution_count": 29, "id": "d9972307", "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "with torch.no_grad():\n", " logits = model(**inputs).logits" ] }, { "cell_type": "code", "execution_count": 30, "id": "b78e3ea6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 611, 34])" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "logits.shape" ] }, { "cell_type": "code", "execution_count": 31, "id": "b1692f0a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'anrúir car nria cn nn non mneanaénnni'" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predicted_ids = torch.argmax(logits, dim=-1)\n", "transcription = processor.batch_decode(predicted_ids)\n", "\n", "transcription[0].lower()" ] }, { "cell_type": "code", "execution_count": 38, "id": "ecf01625", "metadata": {}, "outputs": [], "source": [ "with open(\"text.txt\", \"w\") as file:\n", " file.write(\" \".join(dataset['sentence']))" ] }, { "cell_type": "code", "execution_count": 40, "id": "40067117", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "=== 1/5 Counting and sorting n-grams ===\n", "Reading /workspace/wav2vec2-xls-r-1b-ir/text.txt\n", "----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n", "****************************************************************************************************\n", "Unigram tokens 3960 types 1431\n", "=== 2/5 Calculating and sorting adjusted counts ===\n", "Chain sizes: 1:17172 2:14475761664 3:27142055936 4:43427287040 5:63331463168\n", "Statistics:\n", "1 1430 D1=0.788367 D2=1.34216 D3+=0.97277\n", "2 3029 D1=0.885256 D2=1.28784 D3+=1.81966\n", "3 3538 D1=0.95385 D2=1.62076 D3+=1.26573\n", "4 3709 D1=0.979641 D2=1.38128 D3+=2.02036\n", "5 3789 D1=0.942754 D2=1.59232 D3+=2.05725\n", "Memory estimate for binary LM:\n", "type kB\n", "probing 343 assuming -p 1.5\n", "probing 409 assuming -r models -p 1.5\n", "trie 160 without quantization\n", "trie 98 assuming -q 8 -b 8 quantization \n", "trie 154 assuming -a 22 array pointer compression\n", "trie 92 assuming -a 22 -q 8 -b 8 array pointer compression and quantization\n", "=== 3/5 Calculating and sorting initial probabilities ===\n", "Chain sizes: 1:17160 2:48464 3:70760 4:89016 5:106092\n", "----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n", "####################################################################################################\n", "=== 4/5 Calculating and writing order-interpolated probabilities ===\n", "Chain sizes: 1:17160 2:48464 3:70760 4:89016 5:106092\n", "----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n", "####################################################################################################\n", "=== 5/5 Writing ARPA model ===\n", "----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n", "****************************************************************************************************\n", "Name:lmplz\tVmPeak:145097716 kB\tVmRSS:6968 kB\tRSSMax:25576636 kB\tuser:2.61395\tsys:13.3051\tCPU:15.9192\treal:15.8981\n" ] } ], "source": [ "!../kenlm/build/bin/lmplz -o 5 <\"text.txt\" > \"5gram.arpa\"" ] }, { "cell_type": "code", "execution_count": 42, "id": "ab7fc7d0", "metadata": {}, "outputs": [], "source": [ "from transformers import AutoProcessor\n", "\n", "processor = AutoProcessor.from_pretrained('./')" ] }, { "cell_type": "code", "execution_count": 43, "id": "4d994ae5", "metadata": {}, "outputs": [], "source": [ "vocab_dict = processor.tokenizer.get_vocab()\n", "sorted_vocab_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}" ] }, { "cell_type": "code", "execution_count": 47, "id": "3dd24709", "metadata": {}, "outputs": [], "source": [ "with open(\"5gram.arpa\", \"r\") as read_file, open(\"5gram_correct.arpa\", \"w\") as write_file:\n", " has_added_eos = False\n", " for line in read_file:\n", " if not has_added_eos and \"ngram 1=\" in line:\n", " count=line.strip().split(\"=\")[-1]\n", " write_file.write(line.replace(f\"{count}\", f\"{int(count)+1}\"))\n", " elif not has_added_eos and \"\" in line:\n", " write_file.write(line)\n", " write_file.write(line.replace(\"\", \"\"))\n", " has_added_eos = True\n", " else:\n", " write_file.write(line)" ] }, { "cell_type": "code", "execution_count": null, "id": "65607d2d", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 48, "id": "9f2c0244", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Found entries of length > 1 in alphabet. This is unusual unless style is BPE, but the alphabet was not recognized as BPE type. Is this correct?\n", "Unigrams and labels don't seem to agree.\n" ] } ], "source": [ "from pyctcdecode import build_ctcdecoder\n", "\n", "decoder = build_ctcdecoder(\n", " labels=list(sorted_vocab_dict.keys()),\n", " kenlm_model_path=\"5gram_correct.arpa\",\n", ")" ] }, { "cell_type": "code", "execution_count": 49, "id": "017e8d2c", "metadata": {}, "outputs": [], "source": [ "from transformers import Wav2Vec2ProcessorWithLM\n", "\n", "processor_with_lm = Wav2Vec2ProcessorWithLM(\n", " feature_extractor=processor.feature_extractor,\n", " tokenizer=processor.tokenizer,\n", " decoder=decoder\n", ")" ] }, { "cell_type": "code", "execution_count": 50, "id": "3b4c2228", "metadata": {}, "outputs": [], "source": [ "processor_with_lm.save_pretrained(\"xls-r-1b-ir\")" ] }, { "cell_type": "code", "execution_count": 55, "id": "786587f0", "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'repo' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mrepo\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpush_to_hub\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcommit_message\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"Upload lm-boosted decoder\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;31mNameError\u001b[0m: name 'repo' is not defined" ] } ], "source": [ "repo.push_to_hub(commit_message=\"Upload lm-boosted decoder\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 5 }