{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "02b2d284", "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "os.environ['CUDA_VISIBLE_DEVICES'] = '0'" ] }, { "cell_type": "code", "execution_count": 2, "id": "4966a667", "metadata": {}, "outputs": [], "source": [ "# !wget https://huggingface.co/huseinzol05/language-model-bahasa-manglish-combined/resolve/main/model.klm\n", "# !pip3 install pyctcdecode==0.1.0 pypi-kenlm==0.1.20210121" ] }, { "cell_type": "code", "execution_count": 3, "id": "42d8d861", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ubuntu/.local/lib/python3.8/site-packages/apex/pyprof/__init__.py:5: FutureWarning: pyprof will be removed by the end of June, 2022\n", " warnings.warn(\"pyprof will be removed by the end of June, 2022\", FutureWarning)\n" ] } ], "source": [ "import transformers\n", "from transformers import (\n", " HfArgumentParser,\n", " Trainer,\n", " TrainingArguments,\n", " Wav2Vec2CTCTokenizer,\n", " Wav2Vec2FeatureExtractor,\n", " Wav2Vec2ForCTC,\n", " Wav2Vec2Processor,\n", " is_apex_available,\n", " set_seed,\n", " AutoModelForCTC,\n", " TFWav2Vec2ForCTC,\n", " TFWav2Vec2PreTrainedModel,\n", " Wav2Vec2PreTrainedModel,\n", ")\n", "from scipy.special import log_softmax" ] }, { "cell_type": "code", "execution_count": 4, "id": "0d6b421c", "metadata": {}, "outputs": [], "source": [ "import torch" ] }, { "cell_type": "code", "execution_count": 5, "id": "060fb120", "metadata": {}, "outputs": [], "source": [ "import string\n", "import json\n", "\n", "CTC_VOCAB = [''] + list(string.ascii_lowercase + string.digits) + [' ']\n", "vocab_dict = {v: k for k, v in enumerate(CTC_VOCAB)}\n", "vocab_dict[\"|\"] = vocab_dict[\" \"]\n", "del vocab_dict[\" \"]\n", "vocab_dict[\"[UNK]\"] = len(vocab_dict)\n", "vocab_dict[\"[PAD]\"] = len(vocab_dict)\n", "\n", "with open(\"ctc-vocab.json\", \"w\") as vocab_file:\n", " json.dump(vocab_dict, vocab_file)\n", "\n", "tokenizer = Wav2Vec2CTCTokenizer(\n", " \"ctc-vocab.json\",\n", " unk_token=\"[UNK]\",\n", " pad_token=\"[PAD]\",\n", " word_delimiter_token=\"|\",\n", ")" ] }, { "cell_type": "code", "execution_count": 6, "id": "c16b890f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(765, 3579, 614)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from glob import glob\n", "malay = sorted(glob('malay-test/*.wav'), key = lambda x: int(x.split('/')[1].replace('.wav', '')))\n", "singlish = sorted(glob('singlish-test/*.wav'), key = lambda x: int(x.split('/')[1].replace('.wav', '')))\n", "mandarin = sorted(glob('mandarin-test/*.wav'), key = lambda x: int(x.split('/')[1].replace('.wav', '')))\n", "len(malay), len(singlish), len(mandarin)" ] }, { "cell_type": "code", "execution_count": 7, "id": "29568a5f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(765, 3579, 614)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "with open('malay-test.json') as fopen:\n", " malay_label = json.load(fopen)\n", "with open('singlish-test.json') as fopen:\n", " singlish_label = json.load(fopen)\n", "with open('mandarin-test.json') as fopen:\n", " mandarin_label = json.load(fopen)\n", " \n", "len(malay_label), len(singlish_label), len(mandarin_label)" ] }, { "cell_type": "code", "execution_count": 8, "id": "bdac1296", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('singlish-test/3057.wav', 'the teenagers paddled hard on their boat'),\n", " ('malay-test/705.wav', 'kenapa justin trudeau seperti kemaluan wanita'),\n", " ('singlish-test/2631.wav',\n", " 'a letter by a mans daughter pleading for leniency was submitted'),\n", " ('singlish-test/659.wav', 'and theres thousands of people to meet'),\n", " ('singlish-test/809.wav', 'how much lower are the prices'),\n", " ('singlish-test/2040.wav',\n", " 'suddenly a gun shot was fired from a distance which sent the dogs fleeing in an instant'),\n", " ('singlish-test/1616.wav',\n", " 'a stronger dollar pressures gold making it more expensive for holders of other currencies'),\n", " ('singlish-test/1816.wav',\n", " 'family as a priority has become real for me and not just a cliche'),\n", " ('malay-test/147.wav',\n", " 'adakah anda percaya bahawa donald trump adalah kedatangan kedua jesus christ'),\n", " ('singlish-test/3468.wav',\n", " 'but much of the technology required for such a fantastic instrument didnt yet exist')]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.utils import shuffle\n", "\n", "audio = malay + singlish + mandarin\n", "labels = malay_label + singlish_label + mandarin_label\n", "audio, labels = shuffle(audio, labels)\n", "test_set = list(zip(audio, labels))\n", "test_set[:10]" ] }, { "cell_type": "code", "execution_count": 9, "id": "69cb17cc", "metadata": {}, "outputs": [], "source": [ "import soundfile as sf\n", "import numpy as np\n", "\n", "def norm_audio(x):\n", " return (x - x.mean()) / np.sqrt(x.var() + 1e-7)\n", "\n", "def sequence_1d(\n", " seq, maxlen=None, padding: str = 'post', pad_int=0, return_len=False\n", "):\n", " if padding not in ['post', 'pre']:\n", " raise ValueError('padding only supported [`post`, `pre`]')\n", "\n", " if not maxlen:\n", " maxlen = max([len(s) for s in seq])\n", "\n", " padded_seqs, length = [], []\n", " for s in seq:\n", " if isinstance(s, np.ndarray):\n", " s = s.tolist()\n", " if padding == 'post':\n", " padded_seqs.append(s + [pad_int] * (maxlen - len(s)))\n", " if padding == 'pre':\n", " padded_seqs.append([pad_int] * (maxlen - len(s)) + s)\n", " length.append(len(s))\n", " if return_len:\n", " return np.array(padded_seqs), length\n", " return np.array(padded_seqs)\n", "\n", "def batching(audios):\n", " audios = [sf.read(a)[0] for a in audios]\n", " batch, lens = sequence_1d(audios,return_len=True)\n", " attentions = [[1] * l for l in lens]\n", " attentions = sequence_1d(attentions)\n", " normed_input_values = []\n", "\n", " for vector, length in zip(batch, attentions.sum(-1)):\n", " normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)\n", " if length < normed_slice.shape[0]:\n", " normed_slice[length:] = 0.0\n", "\n", " normed_input_values.append(normed_slice)\n", "\n", " normed_input_values = np.array(normed_input_values)\n", " return normed_input_values.astype(np.float32), attentions" ] }, { "cell_type": "code", "execution_count": 10, "id": "f97f22e4", "metadata": {}, "outputs": [], "source": [ "model = AutoModelForCTC.from_pretrained(\n", " './checkpoint-115000',\n", " ctc_loss_reduction=\"mean\",\n", " pad_token_id=tokenizer.pad_token_id,\n", " vocab_size=len(tokenizer),\n", ").cuda()" ] }, { "cell_type": "code", "execution_count": 11, "id": "20fee479", "metadata": {}, "outputs": [], "source": [ "_ = model.eval()" ] }, { "cell_type": "code", "execution_count": 12, "id": "51703510", "metadata": {}, "outputs": [], "source": [ "batch_size = 4\n", "batch_x = audio[:batch_size]\n", "normed_input_values, attentions = batching(batch_x)" ] }, { "cell_type": "code", "execution_count": 13, "id": "065fce75", "metadata": {}, "outputs": [], "source": [ "o_pt = model(torch.from_numpy(normed_input_values.astype(np.float32)).cuda(), \n", " attention_mask = torch.from_numpy(attentions).cuda())\n", "o_pt = o_pt.logits.detach().cpu().numpy()\n", "o_pt = log_softmax(o_pt, axis = -1)" ] }, { "cell_type": "code", "execution_count": 14, "id": "b7851fc9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['the teenagers paddled hard on their boat',\n", " 'kenapa justin tradio seperti kemaluan wanita',\n", " 'a letter bya mans daughter pleading for lenien te was submitted',\n", " 'and theres thousands of people to meet']" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pred_ids = np.argmax(o_pt, axis = -1)\n", "tokenizer.batch_decode(pred_ids)" ] }, { "cell_type": "code", "execution_count": 15, "id": "3efd715e", "metadata": {}, "outputs": [], "source": [ "unique_vocab = list(vocab_dict.keys())\n", "unique_vocab[-3] = ' ' \n", "unique_vocab[-2] = '?'\n", "unique_vocab[-1] = '_'" ] }, { "cell_type": "code", "execution_count": 16, "id": "3024298f", "metadata": {}, "outputs": [], "source": [ "from pyctcdecode import build_ctcdecoder\n", "import kenlm\n", "\n", "kenlm_model = kenlm.Model('model.klm')\n", "decoder = build_ctcdecoder(\n", " unique_vocab,\n", " kenlm_model,\n", " alpha=0.2,\n", " beta=1.0,\n", " ctc_token_idx=tokenizer.pad_token_id\n", ")" ] }, { "cell_type": "code", "execution_count": 17, "id": "6100ea60", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 to know more about this years budget click here\n", "1 you can bake shortbread cookies just with sugar butter and flour\n", "2 all good citizens should learn how to change a light bulb\n", "3 as a child madam surley was constantly teased by other children over her appearance\n" ] } ], "source": [ "for k in range(len(o_pt)):\n", " out = decoder.decode_beams(o_pt[k], prune_history=True)\n", " d_lm2, lm_state, timesteps, logit_score, lm_score = out[0]\n", " print(k, d_lm2)" ] }, { "cell_type": "code", "execution_count": 18, "id": "4672ac73", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['to know more about this years budget click here',\n", " 'you can bake shortbread cookies just with sugar butter and flour',\n", " 'all good citizens should learn how to change a light bulb',\n", " 'as a child madam shirley was constantly teased by other children over her appearance']" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "labels[:batch_size]" ] }, { "cell_type": "code", "execution_count": 19, "id": "5d47692d", "metadata": {}, "outputs": [], "source": [ "def calculate_cer(actual, hyp):\n", " \"\"\"\n", " Calculate CER using `python-Levenshtein`.\n", " \"\"\"\n", " import Levenshtein as Lev\n", "\n", " actual = actual.replace(' ', '')\n", " hyp = hyp.replace(' ', '')\n", " return Lev.distance(actual, hyp) / len(actual)\n", "\n", "\n", "def calculate_wer(actual, hyp):\n", " \"\"\"\n", " Calculate WER using `python-Levenshtein`.\n", " \"\"\"\n", " import Levenshtein as Lev\n", "\n", " b = set(actual.split() + hyp.split())\n", " word2char = dict(zip(b, range(len(b))))\n", "\n", " w1 = [chr(word2char[w]) for w in actual.split()]\n", " w2 = [chr(word2char[w]) for w in hyp.split()]\n", "\n", " return Lev.distance(''.join(w1), ''.join(w2)) / len(actual.split())" ] }, { "cell_type": "code", "execution_count": 20, "id": "c01ea2e4", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 1240/1240 [04:23<00:00, 4.71it/s]\n" ] } ], "source": [ "from tqdm import tqdm\n", "\n", "wer, cer = [], []\n", "wer_lm, cer_lm = [], []\n", "\n", "for i in tqdm(range(0, len(audio), batch_size)):\n", " torch.cuda.empty_cache()\n", " \n", " batch_x = audio[i: i + batch_size]\n", " batch_y = labels[i: i + batch_size]\n", " normed_input_values, attentions = batching(batch_x)\n", " inputs = torch.from_numpy(normed_input_values.astype(np.float32)).cuda()\n", " attention_mask = torch.from_numpy(attentions).cuda()\n", " o_pt = model(inputs, attention_mask = attention_mask)\n", " o_pt = o_pt.logits.detach().cpu().numpy()\n", " o_pt = log_softmax(o_pt, axis = -1)\n", " pred_ids = np.argmax(o_pt, axis = -1)\n", " pred = tokenizer.batch_decode(pred_ids)\n", " for k in range(len(o_pt)):\n", " out = decoder.decode_beams(o_pt[k], prune_history=True)\n", " d_lm2, lm_state, timesteps, logit_score, lm_score = out[0]\n", " \n", " wer.append(calculate_wer(batch_y[k], pred[k]))\n", " cer.append(calculate_cer(batch_y[k], pred[k]))\n", " \n", " wer_lm.append(calculate_wer(batch_y[k], d_lm2))\n", " cer_lm.append(calculate_cer(batch_y[k], d_lm2))" ] }, { "cell_type": "code", "execution_count": 21, "id": "6c6ce8ef", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(0.1322198446007387,\n", " 0.0481054244857041,\n", " 0.09880169127621556,\n", " 0.041196586938584696)" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.mean(wer), np.mean(cer), np.mean(wer_lm), np.mean(cer_lm)" ] }, { "cell_type": "code", "execution_count": 22, "id": "cf53914e", "metadata": {}, "outputs": [], "source": [ "index_malay = [no for no, i in enumerate(audio) if 'malay-test/' in i]\n", "index_singlish = [no for no, i in enumerate(audio) if 'singlish-test/' in i]\n", "index_mandarin = [no for no, i in enumerate(audio) if 'mandarin-test/' in i]" ] }, { "cell_type": "code", "execution_count": 23, "id": "b1558987", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(0.19561999547293663,\n", " 0.051636391937588406,\n", " 0.12710746406824835,\n", " 0.03917689630621449)" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.mean(np.array(wer)[index_malay]), np.mean(np.array(cer)[index_malay]), np.mean(np.array(wer_lm)[index_malay]), np.mean(np.array(cer_lm)[index_malay])" ] }, { "cell_type": "code", "execution_count": 24, "id": "f340cde7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(0.12763802881676573,\n", " 0.0494915200071987,\n", " 0.09677160640413336,\n", " 0.04271234986432335)" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.mean(np.array(wer)[index_singlish]), np.mean(np.array(cer)[index_singlish]), np.mean(np.array(wer_lm)[index_singlish]), np.mean(np.array(cer_lm)[index_singlish])" ] }, { "cell_type": "code", "execution_count": 26, "id": "cbc2539f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(0.07993515937860181,\n", " 0.035626554824269824,\n", " 0.07536807168546154,\n", " 0.03487760945087219)" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.mean(np.array(wer)[index_mandarin]), np.mean(np.array(cer)[index_mandarin]), np.mean(np.array(wer_lm)[index_mandarin]), np.mean(np.array(cer_lm)[index_mandarin])" ] }, { "cell_type": "code", "execution_count": null, "id": "4c543d0c", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7270a78ff7874222b18f538069750bc1", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Upload file pytorch_model.bin: 0%| | 4.00k/1.18G [00:00 main\n", "\n" ] }, { "data": { "text/plain": [ "'https://huggingface.co/mesolitica/wav2vec2-xls-r-300m-mixed/commit/adf65347379e5902f7488753aef24d4e9d16daff'" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "processor.push_to_hub('wav2vec2-xls-r-300m-mixed', organization='mesolitica')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 5 }