File size: 3,037 Bytes
8f96165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from typing import Optional

import torch
import torch.nn as nn
from torch.nn import functional as F


class MultiLevelDownstreamModel(nn.Module):
    def __init__(
            self,
            model_config,
            use_conv_output: Optional[bool] = True,
        ):
        super().__init__()
        assert model_config.output_hidden_states == True, "The upstream model must return all hidden states"

        self.model_config = model_config
        self.use_conv_output = use_conv_output


        self.model_seq = nn.Sequential(
            nn.Conv1d(self.model_config.hidden_size, self.model_config.classifier_proj_size, 1, padding=0),
            nn.ReLU(),
            nn.Dropout(p=0.1),
            nn.Conv1d(self.model_config.classifier_proj_size, self.model_config.classifier_proj_size, 1, padding=0),
            nn.ReLU(),
            nn.Dropout(p=0.1),
            nn.Conv1d(self.model_config.classifier_proj_size, self.model_config.classifier_proj_size, 1, padding=0)
        )

        if self.use_conv_output:
            num_layers = self.model_config.num_hidden_layers + 1  # transformer layers + input embeddings
            self.weights = nn.Parameter(torch.ones(num_layers)/num_layers)
        else:
            num_layers = self.model_config.num_hidden_layers
            self.weights = nn.Parameter(torch.zeros(num_layers))
        
        self.out_layer = nn.Sequential(
            nn.Linear(self.model_config.classifier_proj_size, self.model_config.classifier_proj_size),
            nn.ReLU(),
            nn.Linear(self.model_config.classifier_proj_size, self.model_config.num_labels),
        )

    def forward(self, encoder_hidden_states, length=None):
        if self.use_conv_output:
            stacked_feature = torch.stack(encoder_hidden_states, dim=0)
        else:
            stacked_feature = torch.stack(encoder_hidden_states, dim=0)[1:] # exclude the convolution output
        
        _, *origin_shape = stacked_feature.shape
        
        if self.use_conv_output:
            stacked_feature = stacked_feature.view(self.model_config.num_hidden_layers + 1, -1)
        else:
            stacked_feature = stacked_feature.view(self.model_config.config.num_hidden_layers, -1)
        
        norm_weights = F.softmax(self.weights, dim=-1)
        
        
        weighted_feature = (norm_weights.unsqueeze(-1) * stacked_feature).sum(dim=0)
        features = weighted_feature.view(*origin_shape)
        
        features = features.transpose(1, 2)
        features = self.model_seq(features)
        features = features.transpose(1, 2)
        
        if length is not None:
            length = length.cuda()
            masks = torch.arange(features.size(1)).expand(length.size(0), -1).cuda() < length.unsqueeze(1)
            masks = masks.float()
            features = (features * masks.unsqueeze(-1)).sum(1) / length.unsqueeze(1)
        else:
            features = torch.mean(features, dim=1)
        
        predicted = self.out_layer(features)
        return predicted