{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "ac7631cc", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import re\n", "import librosa\n", "from datasets import load_dataset, load_metric\n", "from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor\n", "import warnings\n", "import os\n", "\n", "\n", "LANG_ID = \"zh-CN\"\n", "MODEL_ID = \"zh-CN-output-aishell\"\n", "\n", "test_dataset = load_dataset(\"common_voice\", LANG_ID, split=\"test\")\n", "\n", "wer = load_metric(\"wer\")\n", "cer = load_metric(\"cer\")\n", "\n", "\n", "\n", "processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)\n", "model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)\n", "model.to(DEVICE)\n", "\n", "# Preprocessing the datasets.\n", "# We need to read the audio files as arrays\n", "def speech_file_to_array_fn(batch):\n", " with warnings.catch_warnings():\n", " warnings.simplefilter(\"ignore\")\n", " speech_array, sampling_rate = librosa.load(batch[\"path\"], sr=16_000)\n", " batch[\"speech\"] = speech_array\n", " batch[\"sentence\"] = (\n", " re.sub(\"([^\\u4e00-\\u9fa5\\u0030-\\u0039])\", \"\", batch[\"sentence\"]).lower() + \" \"\n", " )\n", " return batch\n", "\n", "\n", "test_dataset = test_dataset.map(\n", " speech_file_to_array_fn,\n", " num_proc=15,\n", " remove_columns=['client_id', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],\n", ")\n", "\n", "# Preprocessing the datasets.\n", "# We need to read the audio files as arrays\n", "def evaluate(batch):\n", " inputs = processor(\n", " batch[\"speech\"], sampling_rate=16_000, return_tensors=\"pt\", padding=True\n", " )\n", "\n", " with torch.no_grad():\n", " logits = model(\n", " inputs.input_values.to(DEVICE),\n", " attention_mask=inputs.attention_mask.to(DEVICE),\n", " ).logits\n", "\n", " pred_ids = torch.argmax(logits, dim=-1)\n", " batch[\"pred_strings\"] = processor.batch_decode(pred_ids)\n", " return batch\n", "\n", "\n", "result = test_dataset.map(evaluate, batched=True, batch_size=8)\n", "\n", "predictions = [x.lower() for x in result[\"pred_strings\"]]\n", "references = [x.lower() for x in result[\"sentence\"]]\n", "\n", "print(\n", " f\"WER: {wer.compute(predictions=predictions, references=references, chunk_size=1000) * 100}\"\n", ")\n", "print(f\"CER: {cer.compute(predictions=predictions, references=references) * 100}\")" ] }, { "cell_type": "code", "execution_count": 15, "id": "7db04701", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "11/08/2022 09:41:20 - INFO - huggingsound.speech_recognition.model - Loading model...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "11/08/2022 09:41:23 - WARNING - root - bos_token not in provided tokens. It will be added to the list of tokens\n", "11/08/2022 09:41:23 - WARNING - root - eos_token not in provided tokens. It will be added to the list of tokens\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|███████████████████████████████████████████████| 1/1 [00:00<00:00, 2.11it/s]\n" ] } ], "source": [ "from huggingsound import SpeechRecognitionModel\n", "model = SpeechRecognitionModel(\"./wav2vec2-large-xlsr-chinese\")\n", "audio_paths = [\"1.wav\"]\n", "transcriptions = model.transcribe(audio_paths)" ] }, { "cell_type": "code", "execution_count": 19, "id": "23316152", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'你喜欢饭吗'" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# transcriptions[0]['transcription'].replace('[PAD]','')\n", "transcriptions[0]['transcription']" ] }, { "cell_type": "code", "execution_count": 24, "id": "730d4afa", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from transformers import Wav2Vec2Processor, HubertForCTC\n", "from datasets import load_dataset\n", "\n", "processor = Wav2Vec2Processor.from_pretrained(\"./english_fine_tune\")\n", "model = HubertForCTC.from_pretrained(\"./english_fine_tune\")" ] }, { "cell_type": "code", "execution_count": 25, "id": "f45768e8", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.\n" ] } ], "source": [ "import librosa\n", "input_audio, sr = librosa.load('english.wav', sr = 16000)\n", "input_values = processor(input_audio, return_tensors=\"pt\").input_values # Batch size 1\n", "logits = model(input_values).logits\n", "predicted_ids = torch.argmax(logits, dim=-1)\n", "transcription = processor.decode(predicted_ids[0])" ] }, { "cell_type": "code", "execution_count": 26, "id": "8bd98a38", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'WITHOUT THE DATA SET THE ARTICLE IS USELESS'" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "transcription" ] }, { "cell_type": "code", "execution_count": null, "id": "db6a5667", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" } }, "nbformat": 4, "nbformat_minor": 5 }