from torchaudio.models import Conformer from torchaudio.models.rnnt import _TimeReduction from transformers import PretrainedConfig, PreTrainedModel import torch from torch import nn from typing import List, Tuple, Optional class ConformerConfig(PretrainedConfig): model_type = 'conformer' class ConformerEncoder(PreTrainedModel): config_class = ConformerConfig def __init__( self, config, ) -> None: super().__init__(config) self.time_reduction = _TimeReduction(config.time_reduction_stride) self.input_linear = torch.nn.Linear( config.input_dim * config.time_reduction_stride, config.conformer_input_dim) self.conformer = Conformer( num_layers=config.conformer_num_layers, input_dim=config.conformer_input_dim, ffn_dim=config.conformer_ffn_dim, num_heads=config.conformer_num_heads, depthwise_conv_kernel_size=config.conformer_depthwise_conv_kernel_size, dropout=config.conformer_dropout, use_group_norm=True, convolution_first=True, ) self.output_linear = torch.nn.Linear(config.conformer_input_dim, config.output_dim) def forward(self, inputs, lengths, labels=None): time_reduction_out, time_reduction_lengths = self.time_reduction(inputs, lengths) input_linear_out = self.input_linear(time_reduction_out) x, input_lengths = self.conformer(input_linear_out, time_reduction_lengths) logits = self.output_linear(x) loss = None if labels is not None: labels_mask = labels >= 0 target_lengths = labels_mask.sum(-1) flattened_targets = labels.masked_select(labels_mask) log_probs = nn.functional.log_softmax( logits, dim=-1, dtype=torch.float32 ).transpose(0, 1) with torch.backends.cudnn.flags(enabled=False): loss = nn.functional.ctc_loss( log_probs, flattened_targets, input_lengths, target_lengths, blank=self.config.pad_token_id, reduction=self.config.ctc_loss_reduction, zero_infinity=self.config.ctc_zero_infinity, ) output = (logits, input_lengths) return ((loss,) + output) if loss is not None else output