JuanJoseMV's picture
add model logic implementation
8f96165
import torch
import torch.nn as nn
from typing import Optional, Union, Tuple
from transformers.modeling_outputs import SequenceClassifierOutput
from .wav2vec2_wrapper import Wav2VecWrapper
from .multilevel_classifier import MultiLevelDownstreamModel
class CustomModelForAudioClassification(nn.Module):
def __init__(self, config):
super().__init__()
assert config.output_hidden_states == True, "The upstream model must return all hidden states"
self.config = config
self.encoder = Wav2VecWrapper(config)
self.classifier = MultiLevelDownstreamModel(config, use_conv_output=True)
def forward(
self,
input_features: Optional[torch.LongTensor],
length: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
if encoder_outputs is None:
encoder_output = self.encoder(
input_features,
length=length,
)
logits = self.classifier(**encoder_output)
loss = None
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=encoder_output['encoder_hidden_states']
)