{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "dd5128ea", "metadata": {}, "outputs": [], "source": [ "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 2, "id": "b026bf65", "metadata": {}, "outputs": [], "source": [ "target_lang=\"ga-IE\" # change to your target lang" ] }, { "cell_type": "code", "execution_count": 101, "id": "dcd259e1", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using custom data configuration ga-pl-lang1=ga,lang2=pl\n", "Reusing dataset opus_dgt (/workspace/cache/hf/datasets/opus_dgt/ga-pl-lang1=ga,lang2=pl/0.0.0/a4db75cea3712eb5d4384f0539db82abf897c6b6da5e5e81693e8fd201efc346)\n" ] } ], "source": [ "from datasets import load_dataset\n", "\n", "# dataset = load_dataset(\"mozilla-foundation/common_voice_8_0\", \n", "# \"ga-IE\", \n", "# split=\"train\", \n", "# use_auth_token = True)\n", "\n", "# dataset = load_dataset(\"opus_dgt\", lang1=\"ga\", lang2=\"pl\", split = 'train')" ] }, { "cell_type": "code", "execution_count": 3, "id": "980f597f", "metadata": {}, "outputs": [], "source": [ "# ga_txt = [i['ga'] for i in dataset['translation']]\n", "# ga_txt = pd.Series(ga_txt)\n", "\n", "chars_to_ignore_regex = '[,?.!\\-\\;\\:\"“%‘”�—’…–]' # change to the ignored characters of your fine-tuned model\n", "\n", "import re\n", "\n", "def extract_text(batch):\n", " text = batch[\"translation\"]\n", " ga_text = text['ga']\n", " batch[\"text\"] = re.sub(chars_to_ignore_regex, \"\", ga_text.lower())\n", " return batch\n", "\n", "# dataset = dataset.map(extract_text, remove_columns=dataset.column_names)\n", "\n", "# dataset.push_to_hub(f\"{target_lang}_opus_dgt_train\", split=\"train\")" ] }, { "cell_type": "markdown", "id": "6bc6ad37", "metadata": {}, "source": [ "## N-gram KenLM" ] }, { "cell_type": "code", "execution_count": 4, "id": "8d206f65", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0c3dbd6368014788bff9249dd460d03e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading: 0%| | 0.00/1.60k [00:00 \"5gram.arpa\"" ] }, { "cell_type": "code", "execution_count": 8, "id": "5a1f7707", "metadata": {}, "outputs": [], "source": [ "with open(\"5gram.arpa\", \"r\") as read_file, open(\"5gram_correct.arpa\", \"w\") as write_file:\n", " has_added_eos = False\n", " for line in read_file:\n", " if not has_added_eos and \"ngram 1=\" in line:\n", " count=line.strip().split(\"=\")[-1]\n", " write_file.write(line.replace(f\"{count}\", f\"{int(count)+1}\"))\n", " elif not has_added_eos and \"\" in line:\n", " write_file.write(line)\n", " write_file.write(line.replace(\"\", \"\"))\n", " has_added_eos = True\n", " else:\n", " write_file.write(line)" ] }, { "cell_type": "code", "execution_count": 9, "id": "41d18e68", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\\data\\\n", "ngram 1=70781\n", "ngram 2=652306\n", "ngram 3=1669326\n", "ngram 4=2514789\n", "ngram 5=3053088\n", "\n", "\\1-grams:\n", "-5.8501472\t\t0\n", "0\t\t-0.11565505\n", "0\t\t-0.11565505\n", "-5.4088216\tmiontuairisc\t-0.20133564\n", "-4.6517477\tcheartaitheach\t-0.24842946\n", "-2.1893916\tmaidir\t-1.7147961\n", "-2.1071756\tle\t-0.7007309\n", "-4.156014\tcoinbhinsiún\t-0.31064242\n", "-1.8876181\tar\t-0.9045828\n", "-4.62287\tdhlínse\t-0.24268326\n", "-1.6051095\tagus\t-0.8729715\n", "-4.1465816\taithint\t-0.21693327\n" ] } ], "source": [ "!head -20 5gram_correct.arpa" ] }, { "cell_type": "code", "execution_count": 10, "id": "7f046bf8", "metadata": {}, "outputs": [], "source": [ "from transformers import AutoProcessor\n", "\n", "processor = AutoProcessor.from_pretrained(\"./\")" ] }, { "cell_type": "code", "execution_count": 11, "id": "040e764f", "metadata": {}, "outputs": [], "source": [ "vocab_dict = processor.tokenizer.get_vocab()\n", "sorted_vocab_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}" ] }, { "cell_type": "code", "execution_count": 12, "id": "4670cffe", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Found entries of length > 1 in alphabet. This is unusual unless style is BPE, but the alphabet was not recognized as BPE type. Is this correct?\n", "Unigrams and labels don't seem to agree.\n" ] } ], "source": [ "from pyctcdecode import build_ctcdecoder\n", "\n", "decoder = build_ctcdecoder(\n", " labels=list(sorted_vocab_dict.keys()),\n", " kenlm_model_path=\"5gram_correct.arpa\",\n", ")" ] }, { "cell_type": "code", "execution_count": 13, "id": "47a55861", "metadata": {}, "outputs": [], "source": [ "from transformers import Wav2Vec2ProcessorWithLM\n", "\n", "processor_with_lm = Wav2Vec2ProcessorWithLM(\n", " feature_extractor=processor.feature_extractor,\n", " tokenizer=processor.tokenizer,\n", " decoder=decoder\n", ")" ] }, { "cell_type": "code", "execution_count": 15, "id": "c1fcdaa6", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/workspace/wav2vec-cv7-1b-ir/./ is already a clone of https://huggingface.co/jcmc/wav2vec-cv7-1b-ir. Make sure you pull the latest changes with `repo.git_pull()`.\n" ] } ], "source": [ "from huggingface_hub import Repository\n", "\n", "repo = Repository(local_dir=\"./\", clone_from=\"jcmc/wav2vec-cv7-1b-ir\")" ] }, { "cell_type": "code", "execution_count": 16, "id": "a9d242c9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'/workspace/wav2vec-cv7-1b-ir'" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pwd" ] }, { "cell_type": "code", "execution_count": 17, "id": "719546e1", "metadata": {}, "outputs": [], "source": [ "processor_with_lm.save_pretrained(\"./\")" ] }, { "cell_type": "code", "execution_count": 19, "id": "fb1297ad", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Reading ./language_model/5gram_correct.arpa\n", "----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n", "****************************************************************************************************\n", "SUCCESS\n" ] } ], "source": [ "!../kenlm/build/bin/build_binary ./language_model/5gram_correct.arpa ./language_model/5gram.bin" ] }, { "cell_type": "code", "execution_count": 20, "id": "464b2582", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Adding files tracked by Git LFS: ['5gram.arpa', '5gram_correct.arpa', 'text.txt']. This may take a bit of time if the files are large.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "923b145932464690841cbd628875e90d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Upload file 5gram_correct.arpa: 0%| | 3.39k/359M [00:00 main\n", "\n" ] }, { "data": { "text/plain": [ "'https://huggingface.co/jcmc/wav2vec-cv7-1b-ir/commit/cee330588cadf6700b6e7cf42971cde5342da76e'" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "repo.push_to_hub(commit_message=\"Upload lm-boosted decoder\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 5 }