{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "Intro-to-Transducers.ipynb", "provenance": [], "collapsed_sections": [], "toc_visible": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "code", "metadata": { "id": "qI25LKhSNg02" }, "source": [ "\"\"\"\n", "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n", "\n", "Instructions for setting up Colab are as follows:\n", "1. Open a new Python 3 notebook.\n", "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n", "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n", "4. Run this cell to set up dependencies.\n", "5. Restart the runtime (Runtime -> Restart Runtime) for any upgraded packages to take effect\n", "\"\"\"\n", "# If you're using Google Colab and not running locally, run this cell.\n", "import os\n", "\n", "# Install dependencies\n", "!pip install wget\n", "!apt-get install sox libsndfile1 ffmpeg\n", "!pip install text-unidecode\n", "!pip install matplotlib>=3.3.2\n", "\n", "## Install NeMo\n", "BRANCH = 'r1.17.0'\n", "!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "Ac55MjAM5cls" }, "source": [ "# In a conda environment, you would use the following command\n", "# Update Numba to > 0.53\n", "# conda install -c conda-forge numba\n", "# or\n", "# conda update -c conda-forge numba\n", "\n", "# For pip based environments,\n", "# Update Numba to > 0.54\n", "!pip install --upgrade numba>=0.54.1" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "dqbpRwpnQ-1D" }, "source": [ "# Intro to Transducers\n", "\n", "By following the earlier tutorials for Automatic Speech Recognition in NeMo, one would have probably noticed that we always end up using [Connectionist Temporal Classification (CTC) loss](https://distill.pub/2017/ctc/) in order to train the model. Speech Recognition can be formulated in many different ways, and CTC is a more popular approach because it is a monotonic loss - an acoustic feature at timestep $t_1$ and $t_2$ will correspond to a target token at timestep $u_1$ and only then $u_2$. This monotonic property significantly simplifies the training of ASR models and speeds up convergence. However, it has certain drawbacks that we will discuss below.\n", "\n", "In general, ASR can be described as a sequence-to-sequence prediction task - the original sequence is an audio sequence (often transformed into mel spectrograms). The target sequence is a sequence of characters (or subword tokens). Attention models are capable of the same sequence-to-sequence prediction tasks. They can even perform better than CTC due to their autoregressive decoding. However, they lack certain inductive biases that can be leveraged to stabilize and speed up training (such as the monotonicity exhibited by the CTC loss). Furthermore, by design, attention models require the entire sequence to be available to align the sequence to the output, thereby preventing their use for streaming inference.\n", "\n", "Then comes the [Transducer Loss](https://arxiv.org/abs/1211.3711). Proposed by Alex Graves, it aimed to resolve the issues in CTC loss while resolving the transcription accuracy issues by performing autoregressive decoding. \n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "OaPS4_xSRGNv" }, "source": [ "## Drawbacks of Connectionist Temporal Classification (CTC)\n", "\n", "CTC is an excellent loss to train ASR models in a stable manner but comes with certain limitations on model design. If we presume speech recognition to be a sequence-to-sequence problem, let $T$ be the sequence length of the acoustic model's output, and let $U$ be the sequence length of the target text transcript (post tokenization, either as characters or subwords). \n", "\n", "-------\n", "\n", "1) CTC imposes the limitation : $T \\ge U$. Normally, this assumption is naturally valid because $T$ is generally a lot longer than the final text transcription. However, there are many cases where this assumption fails.\n", "\n", "- Acoustic model performs downsampling to such a degree that $T \\ge U$. Why would we want to perform so much downsampling? For convolutions, longer sequences take more stride steps and more memory. For Attention-based models (say Conformer), there's a quadratic memory cost of computing the attention step in proportion to $T$. So more downsampling significantly helps relieve the memory requirements. There are ways to bypass this limitation, as discussed in the `ASR_with_Subword_Tokenization` notebook, but even that has limits.\n", "\n", "- The target sequence is generally very long. Think of languages such as German, which have very long translations for short English words. In the task of ASR, if there is more than 2x downsampling and character tokenization is used, the model will often fail to learn due to this CTC limitation.\n", "\n", "2) Tokens predicted by models which are trained with just CTC loss are assumed to be *conditionally independent*. This means that, unlike language models where *h*-*e*-*l*-*l* as input would probably predict *o* to complete *hello*, for CTC trained models - any character from the English alphabet has equal likelihood for prediction. So CTC trained models often have misspellings or missing tokens when transcribing the audio segment to text. \n", "\n", "- Since we often use the Word Error Rate (WER) metric when evaluating models, even a single misspelling contributes significantly to the \"word\" being incorrect. \n", "\n", "- To alleviate this issue, we have to resort to Beam Search via an external language model. While this often works and significantly improves transcription accuracy, it is a slow process and involves large N-gram or Neural language models. " ] }, { "cell_type": "markdown", "metadata": { "id": "5EVBcBDNf658" }, "source": [ "--------\n", "\n", "Let's see CTC loss's limitation (1) in action:" ] }, { "cell_type": "code", "metadata": { "id": "4whMzIjYf4w8" }, "source": [ "import torch\n", "import torch.nn as nn" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "aGdKAFe7gGY4" }, "source": [ "T = 10 # acoustic sequence length\n", "U = 16 # target sequence length\n", "V = 28 # vocabulary size\n", "\n", "def get_sample(T, U, V, require_grad=True):\n", " torch.manual_seed(0)\n", "\n", " acoustic_seq = torch.randn(1, T, V + 1, requires_grad=require_grad)\n", " acoustic_seq_len = torch.tensor([T], dtype=torch.int32) # actual seq length in padded tensor (here no padding is done)\n", "\n", " target_seq = torch.randint(low=0, high=V, size=(1, U))\n", " target_seq_len = torch.tensor([U], dtype=torch.int32)\n", "\n", " return acoustic_seq, acoustic_seq_len, target_seq, target_seq_len" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "DTYIb-7ngo_L" }, "source": [ "# First, we use CTC loss in the general sense.\n", "loss = torch.nn.CTCLoss(blank=V, zero_infinity=False)\n", "\n", "acoustic_seq, acoustic_seq_len, target_seq, target_seq_len = get_sample(T, U, V)\n", "\n", "# CTC loss expects acoustic sequence to be in shape (T, B, V)\n", "val = loss(acoustic_seq.transpose(1, 0), target_seq, acoustic_seq_len, target_seq_len)\n", "print(\"CTC Loss :\", val)\n", "\n", "val.backward()\n", "print(\"Grad of Acoustic model (over V):\", acoustic_seq.grad[0, 0, :])" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "lBDvC2RykFC4" }, "source": [ "# Next, we use CTC loss with `zero_infinity` flag set.\n", "loss = torch.nn.CTCLoss(blank=V, zero_infinity=True)\n", "\n", "acoustic_seq, acoustic_seq_len, target_seq, target_seq_len = get_sample(T, U, V)\n", "\n", "# CTC loss expects acoustic sequence to be in shape (T, B, V)\n", "val = loss(acoustic_seq.transpose(1, 0), target_seq, acoustic_seq_len, target_seq_len)\n", "print(\"CTC Loss :\", val)\n", "\n", "val.backward()\n", "print(\"Grad of Acoustic model (over V):\", acoustic_seq.grad[0, 0, :])" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "SQe6WYnWkSAZ" }, "source": [ "-------\n", "\n", "As we saw, CTC loss in general case will not be able to compute the loss or the gradient when $T \\ge U$. In the PyTorch specific implementation of CTC Loss, we can specify a flag `zero_infinity`, which explicitly checks for such cases, zeroes out the loss and the gradient if such a case occurs. The flag allows us to train a batch of samples where some samples may accidentally violate this limitation, but training will not halt, and gradients will not become NAN." ] }, { "cell_type": "markdown", "metadata": { "id": "EGnc5HqnZ-GZ" }, "source": [ "## What is the Transducer Loss ?" ] }, { "cell_type": "markdown", "metadata": { "id": "0W12xF_CqcVF" }, "source": [ "![](https://github.com/NVIDIA/NeMo/blob/main/tutorials/asr/images/transducer.png?raw=true)" ] }, { "cell_type": "markdown", "metadata": { "id": "RoOQJtIkqxbA" }, "source": [ "A model that seeks to use the Transducer loss is composed of three models that interact with each other. They are:\n", "\n", "-------\n", "\n", "1) **Acoustic model** : This is nearly the same acoustic model used for CTC models. The output shape of these models is generally $(Batch, \\, T, \\, AM-Hidden)$. You will note that unlike for CTC, the output of the acoustic model is no longer passed through a decoder layer which would have the shape $(Batch, \\, T, \\, Vocabulary + 1)$.\n", "\n", "2) **Prediction / Decoder model** : The prediction model accepts a sequence of target tokens (in the case of ASR, text tokens) and is usually a causal auto-regressive model that is tasked with predicting some hidden feature dimension of shape $(Batch, \\, U, \\, Pred-Hidden)$.\n", "\n", "3) **Joint model** : This model accepts the outputs of the Acoustic model and the Prediction model and joins them to compute a joint probability distribution over the vocabulary space to compute the alignments from Acoustic sequence to Target sequence. The output of this model is of the shape $(Batch, \\, T, \\, U, \\, Vocabulary + 1)$.\n", "\n", "--------\n", "\n", "During training, the transducer loss is computed on the output of the joint model, which computes the joint probability distribution of a target vocabulary token $v_{t, u}$ (for all $v \\in V$) being predicted given the acoustic feature at timestep $t \\le T$ and the prediction network features at timestep $u \\le U$.\n", "\n", "--------\n", "\n", "During inference, we perform a single forward pass over the Acoustic Network to obtain the features of shape $(Batch, \\, T, \\, AM-Hidden)$, and autoregressively perform the forward passes of the Prediction Network and the Joint Network to decode several $u \\le U$ target tokens per acoustic timestep $t \\le T$. We will discuss decoding in the following sections.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "yBxtxt2Ztuoo" }, "source": [ "---------\n", "\n", "**Note**: For an excellent in-depth explanation of how Transducer loss works, how it computes the alignment, and how the gradient of this alignment is calculated, we highly encourage you to read this post about [Sequence-to-sequence learning with Transducers by Loren Lugosch](https://lorenlugosch.github.io/posts/2020/11/transducer/).\n", "\n", "---------" ] }, { "cell_type": "markdown", "metadata": { "id": "VgdYFkeyRGP-" }, "source": [ "## Benefits of Transducer Loss\n", "\n", "Now that we understand what a Transducer model is comprised of and how it is trained, the next question that comes to mind is - What is the benefit of the Transducer loss?\n", "\n", "------\n", "\n", "1) It is a monotonic loss (similar to CTC). Monotonicity speeds up convergence and does not require auxiliary losses to stabilize training (which is required when using only attention-based loss for sequence-to-sequence training).\n", "\n", "2) Autoregressive decoding enables the model to implicitly have a dependency between predicted tokens (the conditional independence assumption of CTC trained models is corrected). As such, missing characters or incorrect spellings are less frequent (but still exist since no model is perfect).\n", "\n", "3) It no longer has the $T \\ge U$ limitation that CTC imposed. This is because the total joint probability distribution is calculated now - mapping every acoustic timestep $t \\le T$ to one or more target timestep $u \\le U$. This means that for each timestep $t$, the model has at most $U$ tokens that it can predict, and therefore in the extreme case, it can predict a total of $T \\times U$ tokens!" ] }, { "cell_type": "markdown", "metadata": { "id": "wisUfV8aRGSY" }, "source": [ "## Drawbacks of Transducer Loss\n", "\n", "All of these benefits come with certain costs. As is (almost) always the case in machine learning, there is no free lunch. \n", "\n", "-------\n", "\n", "1) During training, the Joint model is required to compute a joint matrix of shape $(Batch, \\, T, \\, U, \\, Vocabulary + 1)$. If you consider the value of these constants for a general dataset like Librispeech, $T \\sim 1600$, $U \\sim 450$ (with character encoding) and vocabulary $V \\sim 28+1$. Considering a batch size of 32, that total memory cost comes out to roughly **2.7 GB** at float precision. The model would also need another **2.7 GB** for the gradients. Of course, the model needs more memory still for the actual Acoustic model + Prediction model + their gradients. Note, however - this issue can be *partially* resolved with some simple tricks, which are discussed in the next tutorial. Also, this memory cost is no longer an issue during inference!\n", "\n", "2) Autoregressive decoding is slow. Much slower than CTC models, which require just a simple argmax of the output tensor. So while we do get superior transcription quality, we sacrifice decoding speed." ] }, { "cell_type": "markdown", "metadata": { "id": "RuASpPlD2con" }, "source": [ "--------\n", "\n", "Let's check that RNNT loss no longer shows the limitations of CTC loss - " ] }, { "cell_type": "code", "metadata": { "id": "XXodnve02c8h" }, "source": [ "T = 10 # acoustic sequence length\n", "U = 16 # target sequence length\n", "V = 28 # vocabulary size\n", "\n", "def get_rnnt_sample(T, U, V, require_grad=True):\n", " torch.manual_seed(0)\n", "\n", " joint_tensor = torch.randn(1, T, U + 1, V + 1, requires_grad=require_grad)\n", " acoustic_seq_len = torch.tensor([T], dtype=torch.int32) # actual seq length in padded tensor (here no padding is done)\n", "\n", " target_seq = torch.randint(low=0, high=V, size=(1, U))\n", " target_seq_len = torch.tensor([U], dtype=torch.int32)\n", "\n", " return joint_tensor, acoustic_seq_len, target_seq, target_seq_len" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "w-9Qx01G21oK" }, "source": [ "import nemo.collections.asr as nemo_asr" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "6hb7q81f21qj" }, "source": [ "joint_tensor, acoustic_seq_len, target_seq, target_seq_len = get_rnnt_sample(T, U, V)\n", "\n", "# RNNT loss expects joint tensor to be in shape (B, T, U, V)\n", "loss = nemo_asr.losses.rnnt.RNNTLoss(num_classes=V)\n", "\n", "# Uncomment to check out the keyword arguments required to call the RNNT loss\n", "print(\"Transducer loss input types :\", loss.input_types)\n", "print()\n", "\n", "val = loss(log_probs=joint_tensor, targets=target_seq, input_lengths=acoustic_seq_len, target_lengths=target_seq_len)\n", "print(\"Transducer Loss :\", val)\n", "\n", "val.backward()\n", "print(\"Grad of Acoustic model (over V):\", joint_tensor.grad[0, 0, 0, :])" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "5pfrpy7wRGUc" }, "source": [ "# Configure a Transducer Model\n", "\n", "We now understand a bit more about the transducer loss. Next, we will take a deep dive into how to set up the config for a transducer model.\n", "\n", "Transducer configs contain a fair bit more detail compared to CTC configs. However, the vast majority of the defaults can be copied and pasted into your configs to have a perfectly functioning transducer model!\n", "\n", "------\n", "\n", "Let us download one of the transducer configs already available in NeMo to analyze the components." ] }, { "cell_type": "code", "metadata": { "id": "cgJQXfwy7LO_" }, "source": [ "import os\n", "\n", "if not os.path.exists(\"contextnet_rnnt.yaml\"):\n", " !wget https://raw.githubusercontent.com/NVIDIA/NeMo/$BRANCH/examples/asr/conf/contextnet_rnnt/contextnet_rnnt.yaml" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "CJ2-ORS17XbF" }, "source": [ "from omegaconf import OmegaConf, open_dict\n", "\n", "cfg = OmegaConf.load('contextnet_rnnt.yaml')" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "h5TsAJQk6o4N" }, "source": [ "## Model Defaults\n", "\n", "Since the transducer model is comprised of three separate models working in unison, it is practical to have some shared section of the config. That shared section is called `model.model_defaults`." ] }, { "cell_type": "code", "metadata": { "id": "N8tWZ9eb75Gx" }, "source": [ "print(OmegaConf.to_yaml(cfg.model.model_defaults))" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "m8IxgJFj7_gc" }, "source": [ "-------\n", "\n", "Of the many components shared here, the last three values are the primary components that a transducer model **must** possess. They are :\n", "\n", "1) `enc_hidden`: The hidden dimension of the final layer of the Encoder network.\n", "\n", "2) `pred_hidden`: The hidden dimension of the final layer of the Prediction network.\n", "\n", "3) `joint_hidden`: The hidden dimension of the intermediate layer of the Joint network.\n", "\n", "--------\n", "\n", "One can access these values inside the config by using OmegaConf interpolation as follows :\n", "\n", "```yaml\n", "model:\n", " ...\n", " decoder:\n", " ...\n", " prednet:\n", " pred_hidden: ${model.model_defaults.pred_hidden}\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "uRckIz_eRGWr" }, "source": [ "## Acoustic Model\n", "\n", "As we discussed before, the transducer model is comprised of three models combined. One of these models is the Acoustic (encoder) model. We should be able to drop in any CTC Acoustic model config into this section of the transducer config.\n", "\n", "The only condition that needs to be met is that **the final layer of the acoustic model must have the dimension defined in `model_defaults.enc_hidden`**." ] }, { "cell_type": "markdown", "metadata": { "id": "o505IGX4RGYy" }, "source": [ "## Decoder / Prediction Model\n", "\n", "The Prediction model is generally an autoregressive, causal model that consumes text tokens and returns embeddings that will be used by the Joint model. \n", "\n", "**This config can be dropped into any custom transducer model with no modification.**" ] }, { "cell_type": "code", "metadata": { "id": "A2a9Y5CCArLs" }, "source": [ "print(OmegaConf.to_yaml(cfg.model.decoder))" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "ry5a_Z-zAvll" }, "source": [ "------\n", "\n", "This config will build an LSTM based Transducer Decoder model. Let us discuss some of the important arguments:\n", "\n", "1) `blank_as_pad`: In ordinary transducer models, the embedding matrix does not acknowledge the `Transducer Blank` token (similar to CTC Blank). However, this causes the autoregressive loop to be more complicated and less efficient. Instead, this flag which is set by default, will add the `Transducer Blank` token to the embedding matrix - and use it as a pad value (zeros tensor). This enables more efficient inference without harming training.\n", "\n", "2) `prednet.pred_hidden`: The hidden dimension of the LSTM and the output dimension of the Prediction network.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "FtdYE25cW1j_" }, "source": [ "## Joint Model\n", "\n", "The Joint model is a simple feed-forward Multi-Layer Perceptron network. This MLP accepts the output of the Acoustic and Prediction models and computes a joint probability distribution over the entire vocabulary space.\n", "\n", "**This config can be dropped into any custom transducer model with no modification.**" ] }, { "cell_type": "code", "metadata": { "id": "pP8fL1bED3Dv" }, "source": [ "print(OmegaConf.to_yaml(cfg.model.joint))" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "rA11eez_FYUl" }, "source": [ "------\n", "\n", "The Joint model config has several essential components which we discuss below :\n", "\n", "1) `log_softmax`: Due to the cost of computing softmax on such large tensors, the Numba CUDA implementation of RNNT loss will implicitly compute the log softmax when called (so its inputs should be logits). The CPU version of the loss doesn't face such memory issues so it requires log-probabilities instead. Since the behaviour is different for CPU-GPU, the `None` value will automatically switch behaviour dependent on whether the input tensor is on a CPU or GPU device.\n", "\n", "2) `preserve_memory`: This flag will call `torch.cuda.empty_cache()` at certain critical sections when computing the Joint tensor. While this operation might allow us to preserve some memory, the empty_cache() operation is tremendously slow and will slow down training by an order of magnitude or more. It is available to use but not recommended.\n", "\n", "3) `fuse_loss_wer`: This flag performs \"batch splitting\" and then \"fused loss + metric\" calculation. It will be discussed in detail in the next tutorial that will train a Transducer model.\n", "\n", "4) `fused_batch_size`: When the above flag is set to True, the model will have two distinct \"batch sizes\". The batch size provided in the three data loader configs (`model.*_ds.batch_size`) will now be the `Acoustic model` batch size, whereas the `fused_batch_size` will be the batch size of the `Prediction model`, the `Joint model`, the `transducer loss` module and the `decoding` module.\n", "\n", "5) `jointnet.joint_hidden`: The hidden intermediate dimension of the joint network." ] }, { "cell_type": "markdown", "metadata": { "id": "cmIwDscCW1mP" }, "source": [ "## Transducer Decoding\n", "\n", "Models which have been trained with CTC can transcribe text simply by performing a regular argmax over the output of their decoder.\n", "\n", "For transducer-based models, the three networks must operate in a synchronized manner in order to transcribe the acoustic features.\n", "\n", "The following section of the config describes how to change the decoding logic of the transducer model.\n", "\n", "\n", "**This config can be dropped into any custom transducer model with no modification.**" ] }, { "cell_type": "code", "metadata": { "id": "LQjfXJsrIqFJ" }, "source": [ "print(OmegaConf.to_yaml(cfg.model.decoding))" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "6FXtn41wIvu8" }, "source": [ "-------\n", "\n", "The most important component at the top level is the `strategy`. It can take one of many values:\n", "\n", "1) `greedy`: This is sample-level greedy decoding. It is generally exceptionally slow as each sample in the batch will be decoded independently. For publications, this should be used alongside batch size of 1 for exact results.\n", "\n", "2) `greedy_batch`: This is the general default and should nearly match the `greedy` decoding scores (if the acoustic features are not affected by feature mixing in batch mode). Even for small batch sizes, this strategy is significantly faster than `greedy`.\n", "\n", "3) `beam`: Runs beam search with the implicit language model of the Prediction model. It will generally be quite slow, and might need some tuning of the beam size to get better transcriptions.\n", "\n", "4) `tsd`: Time synchronous decoding. Please refer to the paper: [Alignment-Length Synchronous Decoding for RNN Transducer](https://ieeexplore.ieee.org/document/9053040) for details on the algorithm implemented. Time synchronous decoding (TSD) execution time grows by the factor T * max_symmetric_expansions. For longer sequences, T is greater and can therefore take a long time for beams to obtain good results. TSD also requires more memory to execute.\n", "\n", "5) `alsd`: Alignment-length synchronous decoding. Please refer to the paper: [Alignment-Length Synchronous Decoding for RNN Transducer](https://ieeexplore.ieee.org/document/9053040) for details on the algorithm implemented. Alignment-length synchronous decoding (ALSD) execution time is faster than TSD, with a growth factor of T + U_max, where U_max is the maximum target length expected during execution. Generally, T + U_max < T * max_symmetric_expansions. However, ALSD beams are non-unique. Therefore it is required to use larger beam sizes to achieve the same (or close to the same) decoding accuracy as TSD. For a given decoding accuracy, it is possible to attain faster decoding via ALSD than TSD.\n", "\n", "-------\n", "\n", "Below, we discuss the various decoding strategies." ] }, { "cell_type": "markdown", "metadata": { "id": "PXzY7laMW1oo" }, "source": [ "### Greedy Decoding\n", "\n", "When `strategy` is one of `greedy` or `greedy_batch`, an additional subconfig of `decoding.greedy` can be used to set an important decoding value." ] }, { "cell_type": "code", "metadata": { "id": "778R5oy6Ipha" }, "source": [ "print(OmegaConf.to_yaml(cfg.model.decoding.greedy))" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "vItXzTbZKwyB" }, "source": [ "-------\n", "\n", "This argument `max_symbols` is the maximum number of `target token` decoding steps $u \\le U$ per acoustic timestep $t \\le T$. Note that during training, this was implicitly constrained by the shape of the joint matrix (max_symbols = $U$). However, there is no such $U$ upper bound during inference (we don't have the ground truth $U$).\n", "\n", "So we explicitly set a heuristic upper bound on how many decoding steps can be performed per acoustic timestep. Generally a value of 5 and above is sufficient." ] }, { "cell_type": "markdown", "metadata": { "id": "ebFogfLvW1q9" }, "source": [ "### Beam Decoding\n", "\n", "Next, we discuss the subconfig when `strategy` is one of `beam`, `tsd` or `alsd`." ] }, { "cell_type": "code", "metadata": { "id": "w073zT8ILtki" }, "source": [ "print(OmegaConf.to_yaml(cfg.model.decoding.beam))" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "AvOeRhsULtrx" }, "source": [ "------\n", "\n", "There are several important arguments in this section :\n", "\n", "1) `beam_size`: This determines the beam size for all types of beam decoding strategy. Since this is implemented in PyTorch, large beam sizes will take exorbitant amounts of time.\n", "\n", "2) `score_norm`: Whether to normalize scores prior to pruning the beam.\n", "\n", "3) `return_best_hypothesis`: If beam search is being performed, we can choose to return just the best hypothesis or all the hypotheses.\n", "\n", "4) `tsd_max_sym_exp`: The maximum symmetric expansions allowed per timestep during beam search. Larger values should be used to attempt decoding of longer sequences, but this in turn increases execution time and memory usage.\n", "\n", "5) `alsd_max_target_len`: The maximum expected target sequence length during beam search. Larger values allow decoding of longer sequences at the expense of execution time and memory.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "MkoHp0dQW1tP" }, "source": [ "## Transducer Loss\n", "\n", "Finally, we reach the Transducer loss config itself. This section configures the type of Transducer loss itself, along with possible sub-sections.\n", "\n", "**This config can be dropped into any custom transducer model with no modification.**" ] }, { "cell_type": "code", "metadata": { "id": "l3Uk11uHOa4O" }, "source": [ "print(OmegaConf.to_yaml(cfg.model.loss))" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "1z_mYV9UOk_7" }, "source": [ "---------\n", "\n", "The loss config is based on a resolver pattern and can be used as follows:\n", "\n", "1) `loss_name`: `default` is generally a good option. Will select one of the available resolved losses and match the kwargs from a sub-configs passed via explicit `{loss_name}_kwargs` sub-config. \n", "\n", "2) `{loss_name}_kwargs`: This sub-config is passed to the resolved loss above and can be used to configure the resolved loss." ] }, { "cell_type": "markdown", "metadata": { "id": "7w3Z3-IaRGaz" }, "source": [ "### WarpRNNT Numba Loss\n", "\n", "The default transducer loss implemented in NeMo is a Numba port of the excellent CUDA implementation of Transducer Loss found in https://github.com/HawkAaron/warp-transducer.\n", "\n", "It should suffice for most use cases (CPU / GPU) transducer training." ] }, { "cell_type": "markdown", "metadata": { "id": "u9bPDu2-XYPB" }, "source": [ "### FastEmit Regularization\n", "\n", "Recently proposed regularization approach - [FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization](https://arxiv.org/abs/2010.11148) allows us near-direct control over the latency of transducer models.\n", "\n", "Refer to the above paper for results and recommendations of `fastemit_lambda`." ] }, { "cell_type": "markdown", "metadata": { "id": "iGG9UKTZXlme" }, "source": [ "# Next Steps\n", "\n", "After that deep dive into how to configure Transducer models, the next tutorial will use one such config to build a transducer model and train it on a small dataset. We will then move on to exploring various decoding strategies and how to evaluate the model." ] } ] }