File size: 6,133 Bytes
aaf4d38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from dataclasses import dataclass
from typing import Optional

import torch
from nemo.collections.asr.models import EncDecRNNTBPEModel
from omegaconf import DictConfig
from transformers.utils import ModelOutput


@dataclass
class RNNTOutput(ModelOutput):
    """
    Base class for RNNT outputs.
    """

    loss: Optional[torch.FloatTensor] = None
    wer: Optional[float] = None
    wer_num: Optional[float] = None
    wer_denom: Optional[float] = None


# Adapted from https://github.com/NVIDIA/NeMo/blob/66c7677cd4a68d78965d4905dd1febbf5385dff3/nemo/collections/asr/models/rnnt_bpe_models.py#L33
class RNNTBPEModel(EncDecRNNTBPEModel):
    def __init__(self, cfg: DictConfig):
        super().__init__(cfg=cfg, trainer=None)

    def encoding(
            self, input_signal=None, input_signal_length=None, processed_signal=None, processed_signal_length=None
    ):
        """
        Forward pass of the acoustic model. Note that for RNNT Models, the forward pass of the model is a 3 step process,
        and this method only performs the first step - forward of the acoustic model.

        Please refer to the `forward` in order to see the full `forward` step for training - which
        performs the forward of the acoustic model, the prediction network and then the joint network.
        Finally, it computes the loss and possibly compute the detokenized text via the `decoding` step.

        Please refer to the `validation_step` in order to see the full `forward` step for inference - which
        performs the forward of the acoustic model, the prediction network and then the joint network.
        Finally, it computes the decoded tokens via the `decoding` step and possibly compute the batch metrics.

        Args:
            input_signal: Tensor that represents a batch of raw audio signals,
                of shape [B, T]. T here represents timesteps, with 1 second of audio represented as
                `self.sample_rate` number of floating point values.
            input_signal_length: Vector of length B, that contains the individual lengths of the audio
                sequences.
            processed_signal: Tensor that represents a batch of processed audio signals,
                of shape (B, D, T) that has undergone processing via some DALI preprocessor.
            processed_signal_length: Vector of length B, that contains the individual lengths of the
                processed audio sequences.

        Returns:
            A tuple of 2 elements -
            1) The log probabilities tensor of shape [B, T, D].
            2) The lengths of the acoustic sequence after propagation through the encoder, of shape [B].
        """
        has_input_signal = input_signal is not None and input_signal_length is not None
        has_processed_signal = processed_signal is not None and processed_signal_length is not None
        if (has_input_signal ^ has_processed_signal) is False:
            raise ValueError(
                f"{self} Arguments ``input_signal`` and ``input_signal_length`` are mutually exclusive "
                " with ``processed_signal`` and ``processed_signal_len`` arguments."
            )

        if not has_processed_signal:
            processed_signal, processed_signal_length = self.preprocessor(
                input_signal=input_signal, length=input_signal_length,
            )

        # Spec augment is not applied during evaluation/testing
        if self.spec_augmentation is not None and self.training:
            processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length)

        encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length)
        return encoded, encoded_len

    def forward(self, input_ids, input_lengths=None, labels=None, label_lengths=None):
        # encoding() only performs encoder forward
        encoded, encoded_len = self.encoding(input_signal=input_ids, input_signal_length=input_lengths)
        del input_ids

        # During training, loss must be computed, so decoder forward is necessary
        decoder, target_length, states = self.decoder(targets=labels, target_length=label_lengths)

        # If experimental fused Joint-Loss-WER is not used
        if not self.joint.fuse_loss_wer:
            # Compute full joint and loss
            joint = self.joint(encoder_outputs=encoded, decoder_outputs=decoder)
            loss_value = self.loss(
                log_probs=joint, targets=labels, input_lengths=encoded_len, target_lengths=target_length
            )
            # Add auxiliary losses, if registered
            loss_value = self.add_auxiliary_losses(loss_value)
            wer = wer_num = wer_denom = None
            if not self.training:
                self.wer.update(encoded, encoded_len, labels, target_length)
                wer, wer_num, wer_denom = self.wer.compute()
                self.wer.reset()

        else:
            # If experimental fused Joint-Loss-WER is used
            # Fused joint step
            loss_value, wer, wer_num, wer_denom = self.joint(
                encoder_outputs=encoded,
                decoder_outputs=decoder,
                encoder_lengths=encoded_len,
                transcripts=labels,
                transcript_lengths=label_lengths,
                compute_wer=not self.training,
            )
            # Add auxiliary losses, if registered
            loss_value = self.add_auxiliary_losses(loss_value)

        return RNNTOutput(loss=loss_value, wer=wer, wer_num=wer_num, wer_denom=wer_denom)

    def transcribe(self, input_ids, input_lengths=None, labels=None, label_lengths=None, return_hypotheses: bool = False, partial_hypothesis: Optional = None):
        encoded, encoded_len = self.encoding(input_signal=input_ids, input_signal_length=input_lengths)
        del input_ids
        best_hyp, all_hyp = self.decoding.rnnt_decoder_predictions_tensor(
            encoded,
            encoded_len,
            return_hypotheses=return_hypotheses,
            partial_hypotheses=partial_hypothesis,
        )
        return best_hyp, all_hyp