{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "33e4a305", "metadata": {}, "outputs": [], "source": [ "from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, AutoModelForCTC, Wav2Vec2Processor\n", "from datasets import load_dataset, load_metric, Audio\n", "from pyctcdecode import build_ctcdecoder\n", "from pydub import AudioSegment\n", "from pydub.playback import play\n", "\n", "import numpy as np\n", "import torch\n", "import kenlm\n", "import pandas as pd\n", "import random\n", "import soundfile as sf" ] }, { "cell_type": "code", "execution_count": 2, "id": "328d0662", "metadata": {}, "outputs": [], "source": [ "model = AutoModelForCTC.from_pretrained(\".\")\n", "processor = Wav2Vec2Processor.from_pretrained(\".\")" ] }, { "cell_type": "code", "execution_count": 28, "id": "0fea2518", "metadata": {}, "outputs": [], "source": [ "# model = AutoModelForCTC.from_pretrained(\"vitouphy/xls-r-300m-km\").to('cuda')\n", "# processor = Wav2Vec2Processor.from_pretrained(\"vitouphy/xls-r-300m-km\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "9cfef23c", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using custom data configuration default-36119ec2a15afb82\n", "Reusing dataset csv (/workspace/.cache/huggingface/datasets/csv/default-36119ec2a15afb82/0.0.0/6b9057d9e23d9d8a2f05b985917a0da84d70c5dae3d22ddd8a3f22fb01c69d9e)\n" ] } ], "source": [ "common_voice_test = (load_dataset('csv', data_files='km_kh_male/line_index_test.csv', split = 'train')\n", " .remove_columns([\"Unnamed: 0\", \"drop\"])\n", " .rename_column('text', 'sentence')\n", " .cast_column(\"path\", Audio(sampling_rate=16_000)).rename_column('path', 'audio'))" ] }, { "cell_type": "code", "execution_count": 4, "id": "29e6bb1a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'audio': {'path': '/workspace/xls-r-300m-km/km_kh_male/wavs/khm_3154_2555595821.wav',\n", " 'array': array([ 0.00014737, 0.00016698, 0.00013704, ..., -0.00011244,\n", " -0.0001059 , -0.00011476], dtype=float32),\n", " 'sampling_rate': 16000},\n", " 'sentence': 'ការ ធ្វើ អាជីវកម្ម រ៉ែ ដំបូង នៅ កម្ពុជា'}" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "common_voice_test[0]" ] }, { "cell_type": "code", "execution_count": 5, "id": "0554b8d8", "metadata": {}, "outputs": [], "source": [ "def prepare_dataset(batch):\n", " audio = batch[\"audio\"]\n", " \n", " # batched output is \"un-batched\"\n", " batch[\"input_values\"] = processor(np.array(audio[\"array\"]), sampling_rate=audio[\"sampling_rate\"]).input_values[0]\n", " batch[\"input_length\"] = len(batch[\"input_values\"])\n", " \n", " with processor.as_target_processor():\n", " batch[\"labels\"] = processor(batch[\"sentence\"]).input_ids\n", " return batch" ] }, { "cell_type": "code", "execution_count": 6, "id": "d26a6659", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/csv/default-36119ec2a15afb82/0.0.0/6b9057d9e23d9d8a2f05b985917a0da84d70c5dae3d22ddd8a3f22fb01c69d9e/cache-081703c0621182da.arrow\n" ] } ], "source": [ "common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names)" ] }, { "cell_type": "code", "execution_count": 9, "id": "04a94f74", "metadata": {}, "outputs": [], "source": [ "i = 25" ] }, { "cell_type": "code", "execution_count": 10, "id": "3993d2c4", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.\n" ] } ], "source": [ "input_dict = processor(common_voice_test[i][\"input_values\"], return_tensors=\"pt\", padding=True)" ] }, { "cell_type": "code", "execution_count": 11, "id": "7e3026dc", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'input_values': tensor([[ 2.8537e-04, 2.5043e-04, 2.7738e-04, ..., -4.8949e-05,\n", " -1.1382e-04, 2.7166e-04]]), 'attention_mask': tensor([[1, 1, 1, ..., 1, 1, 1]], dtype=torch.int32)}" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "input_dict" ] }, { "cell_type": "code", "execution_count": 12, "id": "adf215c0", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.\n" ] } ], "source": [ "input_dict = processor(common_voice_test[i][\"input_values\"], return_tensors=\"pt\", padding=True)\n", "logits = model(input_dict.input_values.to(\"cuda\")).logits\n", "pred_ids = torch.argmax(logits, dim=-1)[0]" ] }, { "cell_type": "code", "execution_count": 14, "id": "e8310629", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([ 1, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72,\n", " 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 10, 70, 70, 70, 10, 72,\n", " 43, 72, 72, 72, 72, 72, 72, 0, 0, 72, 72, 18, 72, 54, 72, 72, 72, 72,\n", " 72, 0, 72, 21, 72, 49, 72, 72, 72, 72, 72, 72, 23, 70, 70, 27, 72, 46,\n", " 72, 72, 72, 1, 72, 0, 0, 30, 72, 72, 72, 72, 25, 70, 70, 72, 72, 11,\n", " 55, 72, 72, 72, 72, 5, 72, 0, 20, 58, 72, 72, 72, 0, 0, 16, 72, 72,\n", " 72, 20, 70, 70, 72, 72, 16, 70, 27, 72, 72, 72, 72, 72, 45, 0, 0, 30,\n", " 30, 70, 70, 27, 72, 43, 72, 72, 72, 72, 72, 72, 21, 72, 53, 72, 72, 72,\n", " 27, 72, 0, 1, 72, 72, 72, 72, 25, 70, 23, 23, 48, 72, 72, 72, 72, 72,\n", " 72, 8, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72,\n", " 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72,\n", " 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72,\n", " 72, 72, 72, 72, 72, 72, 72, 72, 43], device='cuda:0')" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pred_ids" ] }, { "cell_type": "code", "execution_count": 15, "id": "5dd986a0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prediction:\n", "កញ្ញា ទេ បូព្រឹក សម្ដែង នៅ តន្ត្រី ស្រាបៀរ កម្ពុជា\n", "\n", "Reference:\n", "កញ្ញា ទេព បូព្រឹក្ស សម្ដែង នៅ តន្ត្រី ស្រាបៀរ កម្ពុជា\n" ] } ], "source": [ "print(\"Prediction:\")\n", "pred_ids = pred_ids[pred_ids != processor.tokenizer.pad_token_id]\n", "print(processor.decode(pred_ids))\n", "\n", "print(\"\\nReference:\")\n", "print(processor.decode(common_voice_test['labels'][i]))\n", "# print(common_voice_test_transcription[0][\"sentence\"].lower())" ] }, { "cell_type": "code", "execution_count": null, "id": "8e39b112", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "562af933", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 5 }