multimodal_emotion_recognition / src /model /multilevel_classifier.py
JuanJoseMV's picture
add model logic implementation
8f96165
from typing import Optional
import torch
import torch.nn as nn
from torch.nn import functional as F
class MultiLevelDownstreamModel(nn.Module):
def __init__(
self,
model_config,
use_conv_output: Optional[bool] = True,
):
super().__init__()
assert model_config.output_hidden_states == True, "The upstream model must return all hidden states"
self.model_config = model_config
self.use_conv_output = use_conv_output
self.model_seq = nn.Sequential(
nn.Conv1d(self.model_config.hidden_size, self.model_config.classifier_proj_size, 1, padding=0),
nn.ReLU(),
nn.Dropout(p=0.1),
nn.Conv1d(self.model_config.classifier_proj_size, self.model_config.classifier_proj_size, 1, padding=0),
nn.ReLU(),
nn.Dropout(p=0.1),
nn.Conv1d(self.model_config.classifier_proj_size, self.model_config.classifier_proj_size, 1, padding=0)
)
if self.use_conv_output:
num_layers = self.model_config.num_hidden_layers + 1 # transformer layers + input embeddings
self.weights = nn.Parameter(torch.ones(num_layers)/num_layers)
else:
num_layers = self.model_config.num_hidden_layers
self.weights = nn.Parameter(torch.zeros(num_layers))
self.out_layer = nn.Sequential(
nn.Linear(self.model_config.classifier_proj_size, self.model_config.classifier_proj_size),
nn.ReLU(),
nn.Linear(self.model_config.classifier_proj_size, self.model_config.num_labels),
)
def forward(self, encoder_hidden_states, length=None):
if self.use_conv_output:
stacked_feature = torch.stack(encoder_hidden_states, dim=0)
else:
stacked_feature = torch.stack(encoder_hidden_states, dim=0)[1:] # exclude the convolution output
_, *origin_shape = stacked_feature.shape
if self.use_conv_output:
stacked_feature = stacked_feature.view(self.model_config.num_hidden_layers + 1, -1)
else:
stacked_feature = stacked_feature.view(self.model_config.config.num_hidden_layers, -1)
norm_weights = F.softmax(self.weights, dim=-1)
weighted_feature = (norm_weights.unsqueeze(-1) * stacked_feature).sum(dim=0)
features = weighted_feature.view(*origin_shape)
features = features.transpose(1, 2)
features = self.model_seq(features)
features = features.transpose(1, 2)
if length is not None:
length = length.cuda()
masks = torch.arange(features.size(1)).expand(length.size(0), -1).cuda() < length.unsqueeze(1)
masks = masks.float()
features = (features * masks.unsqueeze(-1)).sum(1) / length.unsqueeze(1)
else:
features = torch.mean(features, dim=1)
predicted = self.out_layer(features)
return predicted