{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "57176d39", "metadata": {}, "outputs": [], "source": [ "from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, AutoModelForCTC, Wav2Vec2Processor, AutoProcessor, Wav2Vec2ProcessorWithLM\n", "from datasets import load_dataset, load_metric, Audio\n", "from pyctcdecode import build_ctcdecoder\n", "from pydub import AudioSegment\n", "from pydub.playback import play\n", "\n", "import numpy as np\n", "import torch\n", "import kenlm\n", "import pandas as pd\n", "import random\n", "import soundfile as sf\n", "from tqdm.auto import tqdm" ] }, { "cell_type": "code", "execution_count": 2, "id": "dbc1f98a", "metadata": {}, "outputs": [], "source": [ "# KENLM_MODEL_LOC = '/workspace/xls-r-300m-km/data/km_text_word_unigram.arpa'\n", "# KENLM_MODEL_LOC = '/workspace/xls-r-300m-km/data/km_wiki_ngram.arpa'\n", "KENLM_MODEL_LOC = '/workspace/xls-r-300m-km/data/kmwiki_5gram.binary'" ] }, { "cell_type": "code", "execution_count": 3, "id": "54d76e5f", "metadata": {}, "outputs": [], "source": [ "processor = AutoProcessor.from_pretrained(\"vitouphy/xls-r-300m-km\")" ] }, { "cell_type": "code", "execution_count": 4, "id": "c76a5c8e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'|': 0, 'ក': 1, 'ខ': 2, 'គ': 3, 'ឃ': 4, 'ង': 5, 'ច': 6, 'ឆ': 7, 'ជ': 8, 'ឈ': 9, 'ញ': 10, 'ដ': 11, 'ឋ': 12, 'ឌ': 13, 'ឍ': 14, 'ណ': 15, 'ត': 16, 'ថ': 17, 'ទ': 18, 'ធ': 19, 'ន': 20, 'ប': 21, 'ផ': 22, 'ព': 23, 'ភ': 24, 'ម': 25, 'យ': 26, 'រ': 27, 'ល': 28, 'វ': 29, 'ស': 30, 'ហ': 31, 'ឡ': 32, 'អ': 33, 'ឥ': 34, 'ឧ': 35, 'ឪ': 36, 'ឫ': 37, 'ឬ': 38, 'ឭ': 39, 'ឮ': 40, 'ឯ': 41, 'ឱ': 42, 'ា': 43, 'ិ': 44, 'ី': 45, 'ឹ': 46, 'ឺ': 47, 'ុ': 48, 'ូ': 49, 'ួ': 50, 'ើ': 51, 'ឿ': 52, 'ៀ': 53, 'េ': 54, 'ែ': 55, 'ៃ': 56, 'ោ': 57, 'ៅ': 58, 'ំ': 59, 'ះ': 60, 'ៈ': 61, '៉': 62, '៊': 63, '់': 64, '៌': 65, '៍': 66, '៎': 67, '៏': 68, '័': 69, '្': 70, '[unk]': 71, '[pad]': 72, '': 73, '': 74}\n" ] } ], "source": [ "vocab_dict = processor.tokenizer.get_vocab()\n", "sorted_vocab_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}\n", "print(sorted_vocab_dict)" ] }, { "cell_type": "code", "execution_count": 5, "id": "8b640127", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Unigrams not provided and cannot be automatically determined from LM file (only arpa format). Decoding accuracy might be reduced.\n", "Found entries of length > 1 in alphabet. This is unusual unless style is BPE, but the alphabet was not recognized as BPE type. Is this correct?\n", "No known unigrams provided, decoding results might be a lot worse.\n" ] } ], "source": [ "decoder = build_ctcdecoder(\n", " labels=list(sorted_vocab_dict.keys()),\n", " kenlm_model_path=KENLM_MODEL_LOC,\n", ")" ] }, { "cell_type": "code", "execution_count": 6, "id": "2560c32d", "metadata": {}, "outputs": [], "source": [ "processor_with_lm = Wav2Vec2ProcessorWithLM(\n", " feature_extractor=processor.feature_extractor,\n", " tokenizer=processor.tokenizer,\n", " decoder=decoder\n", ")" ] }, { "cell_type": "code", "execution_count": 7, "id": "badc19a1", "metadata": {}, "outputs": [], "source": [ "processor_with_lm.save_pretrained(\".\")" ] }, { "cell_type": "markdown", "id": "89e517c8", "metadata": {}, "source": [ "## Save Model" ] }, { "cell_type": "code", "execution_count": 9, "id": "ed9535c8", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "bc5bf68946064e97b869d44b02e7af19", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading: 0%| | 0.00/1.18G [00:00