{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "Boosting_Wav2Vec2_with_n_grams_in_🤗_Transformers.ipynb", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "code", "source": [ "!pip install datasets transformers" ], "metadata": { "id": "OWGc_zfyq5_T" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "!pip install https://github.com/kpu/kenlm/archive/master.zip pyctcdecode" ], "metadata": { "id": "TvDJ7CYpzSJQ" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from huggingface_hub import notebook_login\n", "\n", "notebook_login()" ], "metadata": { "id": "JHTeonOGXiGq" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "!sudo apt install build-essential cmake libboost-system-dev libboost-thread-dev libboost-program-options-dev libboost-test-dev libeigen3-dev zlib1g-dev libbz2-dev liblzma-dev" ], "metadata": { "id": "FKMMWfVQp_gP" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "!wget -O - https://kheafield.com/code/kenlm.tar.gz | tar xz" ], "metadata": { "id": "J8mm4ExzqIaZ" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "!mkdir kenlm/build && cd kenlm/build && cmake .. && make -j2\n", "!ls kenlm/build/bin" ], "metadata": { "id": "MS4mqMyZqVAI" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from datasets import load_dataset\n", "\n", "username = \"hf-test\" # change to your username\n", "target_lang = \"sv\"\n", "\n", "dataset = load_dataset(f\"{username}/{target_lang}_corpora_parliament_processed\", split=\"train\")\n", "\n", "with open(\"text.txt\", \"w\") as file:\n", " file.write(\" \".join(dataset[\"text\"]))" ], "metadata": { "id": "VIgErMqApENm" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "\n", "!kenlm/build/bin/lmplz -o 5 <\"text.txt\" > \"5gram.arpa\"" ], "metadata": { "id": "_MdDNBlZrPOm" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "!head -20 5gram.arpa" ], "metadata": { "id": "TRnV8Miusl--" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "with open(\"5gram.arpa\", \"r\") as read_file, open(\"5gram_correct.arpa\", \"w\") as write_file:\n", " has_added_eos = False\n", " for line in read_file:\n", " if not has_added_eos and \"ngram 1=\" in line:\n", " count=line.strip().split(\"=\")[-1]\n", " write_file.write(line.replace(f\"{count}\", f\"{int(count)+1}\"))\n", " elif not has_added_eos and \"\" in line:\n", " write_file.write(line)\n", " write_file.write(line.replace(\"\", \"\"))\n", " has_added_eos = True\n", " else:\n", " write_file.write(line)" ], "metadata": { "id": "_7u7dVPkvyRZ" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "!head -20 5gram_correct.arpa" ], "metadata": { "id": "YF1RSm-Pxst5" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from transformers import AutoProcessor\n", "\n", "processor = AutoProcessor.from_pretrained(\"marinone94/xls-r-300m-sv-robust\")" ], "metadata": { "id": "paV71gdAtkDC" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "vocab_dict = processor.tokenizer.get_vocab()\n", "sorted_vocab_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}" ], "metadata": { "id": "ZKwKxMoitoGS" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from pyctcdecode import build_ctcdecoder\n", "\n", "decoder = build_ctcdecoder(\n", " labels=list(sorted_vocab_dict.keys()),\n", " kenlm_model_path=\"5gram_correct.arpa\",\n", ")" ], "metadata": { "id": "zTLzCLB2tQP7" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from transformers import Wav2Vec2ProcessorWithLM\n", "\n", "processor_with_lm = Wav2Vec2ProcessorWithLM(\n", " feature_extractor=processor.feature_extractor,\n", " tokenizer=processor.tokenizer,\n", " decoder=decoder\n", ")" ], "metadata": { "id": "VBVf50EzZgAQ" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "!sudo apt-get install git-lfs tree" ], "metadata": { "id": "BZZm3ECc5TMP" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from huggingface_hub import Repository\n", "\n", "repo = Repository(local_dir=\"xls-r-300m-sv-robust\", clone_from=\"marinone94/xls-r-300m-sv-robust\")" ], "metadata": { "id": "fIfcunhF4YM6" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "processor_with_lm.save_pretrained(\"xls-r-300m-sv-robust\")" ], "metadata": { "id": "UZ1sWfPH2oce" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "!tree -h xls-r-300m-sv/" ], "metadata": { "id": "ClyENOYFcC_C" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "!kenlm/build/bin/build_binary xls-r-300m-sv-robust/language_model/5gram_correct.arpa xls-r-300m-sv-robust/language_model/5gram.bin" ], "metadata": { "id": "X9qg4FPt2zi8" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "!rm xls-r-300m-sv-robust/language_model/5gram_correct.arpa && tree -h xls-r-300m-sv-robust/" ], "metadata": { "id": "Zn4J-4OZdMPc" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "repo.push_to_hub(commit_message=\"Upload 5-gram lm-boosted decoder\")" ], "metadata": { "id": "WEV1sx6ee3aT" }, "execution_count": null, "outputs": [] } ] }