|
from typing import Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
|
|
|
|
class MultiLevelDownstreamModel(nn.Module): |
|
def __init__( |
|
self, |
|
model_config, |
|
use_conv_output: Optional[bool] = True, |
|
): |
|
super().__init__() |
|
assert model_config.output_hidden_states == True, "The upstream model must return all hidden states" |
|
|
|
self.model_config = model_config |
|
self.use_conv_output = use_conv_output |
|
|
|
|
|
self.model_seq = nn.Sequential( |
|
nn.Conv1d(self.model_config.hidden_size, self.model_config.classifier_proj_size, 1, padding=0), |
|
nn.ReLU(), |
|
nn.Dropout(p=0.1), |
|
nn.Conv1d(self.model_config.classifier_proj_size, self.model_config.classifier_proj_size, 1, padding=0), |
|
nn.ReLU(), |
|
nn.Dropout(p=0.1), |
|
nn.Conv1d(self.model_config.classifier_proj_size, self.model_config.classifier_proj_size, 1, padding=0) |
|
) |
|
|
|
if self.use_conv_output: |
|
num_layers = self.model_config.num_hidden_layers + 1 |
|
self.weights = nn.Parameter(torch.ones(num_layers)/num_layers) |
|
else: |
|
num_layers = self.model_config.num_hidden_layers |
|
self.weights = nn.Parameter(torch.zeros(num_layers)) |
|
|
|
self.out_layer = nn.Sequential( |
|
nn.Linear(self.model_config.classifier_proj_size, self.model_config.classifier_proj_size), |
|
nn.ReLU(), |
|
nn.Linear(self.model_config.classifier_proj_size, self.model_config.num_labels), |
|
) |
|
|
|
def forward(self, encoder_hidden_states, length=None): |
|
if self.use_conv_output: |
|
stacked_feature = torch.stack(encoder_hidden_states, dim=0) |
|
else: |
|
stacked_feature = torch.stack(encoder_hidden_states, dim=0)[1:] |
|
|
|
_, *origin_shape = stacked_feature.shape |
|
|
|
if self.use_conv_output: |
|
stacked_feature = stacked_feature.view(self.model_config.num_hidden_layers + 1, -1) |
|
else: |
|
stacked_feature = stacked_feature.view(self.model_config.config.num_hidden_layers, -1) |
|
|
|
norm_weights = F.softmax(self.weights, dim=-1) |
|
|
|
|
|
weighted_feature = (norm_weights.unsqueeze(-1) * stacked_feature).sum(dim=0) |
|
features = weighted_feature.view(*origin_shape) |
|
|
|
features = features.transpose(1, 2) |
|
features = self.model_seq(features) |
|
features = features.transpose(1, 2) |
|
|
|
if length is not None: |
|
length = length.cuda() |
|
masks = torch.arange(features.size(1)).expand(length.size(0), -1).cuda() < length.unsqueeze(1) |
|
masks = masks.float() |
|
features = (features * masks.unsqueeze(-1)).sum(1) / length.unsqueeze(1) |
|
else: |
|
features = torch.mean(features, dim=1) |
|
|
|
predicted = self.out_layer(features) |
|
return predicted |