File size: 8,865 Bytes
fed1e6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import torch
from torch import nn
from typing import Optional, Union, Tuple

from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.models.t5.modeling_t5 import T5Config, T5ClassificationHead, T5EncoderModel
from transformers.models.mt5.modeling_mt5 import MT5Config, MT5ClassificationHead, MT5EncoderModel
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig


def clean_t5_config(config: Union[T5Config, MT5Config], model_type: str):
    assert model_type in ['t5', 'mt5']
    setattr(config, 'pooling_mode', 'mean')
    setattr(config, 'model_type', model_type)
    setattr(config, 'use_cache', False)
    setattr(config, 'is_encoder_decoder', False)
    setattr(config, 'num_decoder_layers', None)
    delattr(config, 'task_specific_params') if hasattr(config, 'task_specific_params') else None


class T5EncoderForSequenceClassification(T5EncoderModel):
    """
    T5 encoder for sequence classification tasks.

    :param config: The T5 configuration object.
    """
    def __init__(self, config: T5Config):
        super().__init__(config)
        self.pool_layer = PoolLayer(config.pooling_mode)
        self.classification_head = T5ClassificationHead(config)

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
        """
        Forward pass of the T5 encoder for sequence classification tasks.

        :param input_ids: The input token IDs.
        :param attention_mask: The attention mask.
        :param head_mask: The head mask.
        :param inputs_embeds: The input embeddings.
        :param labels: The target labels.
        :param output_attentions: Whether to output attentions.
        :param output_hidden_states: Whether to output hidden states.
        :param return_dict: Whether to return a dictionary.
        :returns: The logits for the classification task or a dictionary containing the outputs.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        loss = None

        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        sequence_output = self.pool_layer(outputs.last_hidden_state, attention_mask)
        logits = self.classification_head(sequence_output)

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output
        
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


class MT5EncoderForSequenceClassification(MT5EncoderModel):
    """
    mT5 encoder for sequence classification tasks.

    :param config: The mT5 configuration object.
    """
    def __init__(self, config: MT5Config):
        super().__init__(config)
        self.pool_layer = PoolLayer(config.pooling_mode)
        self.classification_head = MT5ClassificationHead(config)

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
        """
        Forward pass of the mT5 encoder for sequence classification tasks.

        :param input_ids: The input token IDs.
        :param attention_mask: The attention mask.
        :param head_mask: The head mask.
        :param inputs_embeds: The input embeddings.
        :param labels: The target labels.
        :param output_attentions: Whether to output attentions.
        :param output_hidden_states: Whether to output hidden states.
        :param return_dict: Whether to return a dictionary.
        :returns: The logits for the classification task or a dictionary containing the outputs.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        loss = None

        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        sequence_output = self.pool_layer(outputs.last_hidden_state, attention_mask)
        logits = self.classification_head(sequence_output)

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output
        
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


class PoolLayer(nn.Module):
    """
    Pooling layer on top of the commputed token embeddings. 

    :param pooling_mode: The pooling strategy to use.
    """
    def __init__(self, pooling_mode: str):
        super().__init__()
        assert pooling_mode in ['first', 'mean', 'max'], f"ERROR: Unknown pooling strategy '{pooling_mode}'"
        self.pooling_mode = pooling_mode

    def forward(self, token_embeddings: torch.Tensor, attention_masks: torch.Tensor) -> torch.Tensor:
        """
        Compute the passage vector by pooling the token embeddings.

        :param token_embeddings: A 3D tensor of size [batch_size, seq_len, d_model].
        :param attention_masks: A 2D tensor of size [batch_size, seq_len].
        :returns: A 2D tensor of size [batch_size, d_model].
        """
        if self.pooling_mode == 'first':
            text_vectors = token_embeddings[:, 0, :]
        elif self.pooling_mode == 'max':
            # Set all values of the [PAD] embeddings to large negative values (so that they are never considered as maximum for a channel).
            attention_masks_expanded = attention_masks.unsqueeze(-1).expand(token_embeddings.size())
            token_embeddings[attention_masks_expanded == 0] = -1e+9 if token_embeddings.dtype == torch.float32 else -1e+4
            # Compute the maxima along the 'seq_length' dimension (-> Tensor[batch_size, d_model]).
            text_vectors = torch.max(token_embeddings, dim=1).values
        else:
            # Set all values of the [PAD] embeddings to zeros (so that they are not taken into account in the sum for a channel).
            attention_masks_expanded = attention_masks.unsqueeze(-1).expand(token_embeddings.size())
            token_embeddings[attention_masks_expanded == 0] = 0.0
            # Compute the means by first summing along the 'seq_length' dimension (-> Tensor[batch_size, d_model]).
            sum_embeddings = torch.sum(token_embeddings, dim=1)
            # Then, divide all values of a passage vector by the original passage length.
            sum_mask = attention_masks_expanded.sum(dim=1) # -> Tensor[batch_size, d_model] where each value is the length of the corresponding passage.
            sum_mask = torch.clamp(sum_mask, min=1e-7) # Make sure not to have zeros by lower bounding all elements to 1e-7.
            text_vectors = sum_embeddings / sum_mask # Divide each dimension by the sequence length.
        return text_vectors


if __name__ == "__main__":
    config = AutoConfig.from_pretrained(model_name)
    if isinstance(config, T5Config):
        clean_t5_config(self.config, model_type='t5')
        model = T5EncoderForSequenceClassification.from_pretrained(model_name, config=config)
    elif isinstance(config, MT5Config):
        clean_t5_config(self.config, model_type='t5')
        model = MT5EncoderForSequenceClassification.from_pretrained(model_name, config=config)
    else:
        model = AutoModelForSequenceClassification.from_pretrained(model_name, config=config)
    tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_args)