import torch import torch.nn as nn from typing import Optional, Union, Tuple from transformers.modeling_outputs import SequenceClassifierOutput from .wav2vec2_wrapper import Wav2VecWrapper from .multilevel_classifier import MultiLevelDownstreamModel class CustomModelForAudioClassification(nn.Module): def __init__(self, config): super().__init__() assert config.output_hidden_states == True, "The upstream model must return all hidden states" self.config = config self.encoder = Wav2VecWrapper(config) self.classifier = MultiLevelDownstreamModel(config, use_conv_output=True) def forward( self, input_features: Optional[torch.LongTensor], length: Optional[torch.LongTensor] = None, encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: if encoder_outputs is None: encoder_output = self.encoder( input_features, length=length, ) logits = self.classifier(**encoder_output) loss = None return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=encoder_output['encoder_hidden_states'] )