File size: 1,300 Bytes
8f96165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import torch.nn as nn

from typing import Optional, Union, Tuple
from transformers.modeling_outputs import SequenceClassifierOutput
from .wav2vec2_wrapper import Wav2VecWrapper
from .multilevel_classifier import MultiLevelDownstreamModel

class CustomModelForAudioClassification(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.output_hidden_states == True, "The upstream model must return all hidden states"
        self.config = config
        self.encoder = Wav2VecWrapper(config)

        self.classifier = MultiLevelDownstreamModel(config, use_conv_output=True)
        
    def forward(
        self,
        input_features: Optional[torch.LongTensor],
        length: Optional[torch.LongTensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
        if encoder_outputs is None:
            encoder_output = self.encoder(
                input_features,
                length=length,
            )

        logits = self.classifier(**encoder_output)

        loss = None

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=encoder_output['encoder_hidden_states']
        )