|
import torch |
|
import torch.nn as nn |
|
|
|
from typing import Optional, Union, Tuple |
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
from .wav2vec2_wrapper import Wav2VecWrapper |
|
from .multilevel_classifier import MultiLevelDownstreamModel |
|
|
|
class CustomModelForAudioClassification(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
assert config.output_hidden_states == True, "The upstream model must return all hidden states" |
|
self.config = config |
|
self.encoder = Wav2VecWrapper(config) |
|
|
|
self.classifier = MultiLevelDownstreamModel(config, use_conv_output=True) |
|
|
|
def forward( |
|
self, |
|
input_features: Optional[torch.LongTensor], |
|
length: Optional[torch.LongTensor] = None, |
|
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: |
|
if encoder_outputs is None: |
|
encoder_output = self.encoder( |
|
input_features, |
|
length=length, |
|
) |
|
|
|
logits = self.classifier(**encoder_output) |
|
|
|
loss = None |
|
|
|
return SequenceClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=encoder_output['encoder_hidden_states'] |
|
) |