File size: 1,300 Bytes
8f96165 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import torch
import torch.nn as nn
from typing import Optional, Union, Tuple
from transformers.modeling_outputs import SequenceClassifierOutput
from .wav2vec2_wrapper import Wav2VecWrapper
from .multilevel_classifier import MultiLevelDownstreamModel
class CustomModelForAudioClassification(nn.Module):
def __init__(self, config):
super().__init__()
assert config.output_hidden_states == True, "The upstream model must return all hidden states"
self.config = config
self.encoder = Wav2VecWrapper(config)
self.classifier = MultiLevelDownstreamModel(config, use_conv_output=True)
def forward(
self,
input_features: Optional[torch.LongTensor],
length: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
if encoder_outputs is None:
encoder_output = self.encoder(
input_features,
length=length,
)
logits = self.classifier(**encoder_output)
loss = None
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=encoder_output['encoder_hidden_states']
) |