File size: 5,889 Bytes
004e907
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
from transformers import SpeechEncoderDecoderModel
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.models.encoder_decoder.modeling_encoder_decoder import shift_tokens_right
from transformers.modeling_outputs import Seq2SeqLMOutput

class Wav2VecGPT2Model(SpeechEncoderDecoderModel):
    """
    Basically the same as `SpeechEncoderDecoderModel` but position embeddings (initialized with GPT2's position
    embeddings) are added to encoder output
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.encoder_outputs_pos_emb = nn.Embedding(1024, self.decoder.config.hidden_size)
        with torch.no_grad():
            self.encoder_outputs_pos_emb.weight.copy_(self.decoder.transformer.wpe.weight)
        self.enc_to_dec_proj_ln = nn.LayerNorm(self.decoder.config.hidden_size,
                                               eps=self.decoder.config.layer_norm_epsilon)

    def __getattribute__(self, name):
        # Fake class so it is recognized as seq2seq model.
        if name == '__class__':
            return SpeechEncoderDecoderModel
        return SpeechEncoderDecoderModel.__getattribute__(self, name)

    def forward(
        self,
        inputs=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        input_values=None,
        input_features=None,
        return_dict=None,
        **kwargs,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}

        kwargs_decoder = {
            argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
        }

        if encoder_outputs is None and inputs is None:
            if input_values is not None and input_features is not None:
                raise ValueError("You cannot specify both input_values and input_features at the same time")
            elif input_values is not None:
                inputs = input_values
            elif input_features is not None:
                inputs = input_features
            else:
                raise ValueError("You have to specify either input_values or input_features")

            encoder_outputs = self.encoder(
                inputs,
                attention_mask=attention_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                **kwargs_encoder,
            )

        encoder_hidden_states = encoder_outputs[0]

        # optionally project encoder_hidden_states
        if (
            self.encoder_output_dim != self.decoder.config.hidden_size
            and self.decoder.config.cross_attention_hidden_size is None
        ):
            # TODO: Truncate and warn if the sequence length is greater than 1024!
            encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
            encoder_hidden_states += self.encoder_outputs_pos_emb(
                torch.arange(0, encoder_hidden_states.shape[1], device=encoder_hidden_states.device)
            )
            encoder_hidden_states = self.enc_to_dec_proj_ln(encoder_hidden_states)

        # compute correct encoder attention mask
        if attention_mask is not None:
            encoder_attention_mask = self.encoder._get_feature_vector_attention_mask(
                encoder_hidden_states.shape[1], attention_mask
            )
        else:
            encoder_attention_mask = None

        if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
            decoder_input_ids = shift_tokens_right(
                labels, self.config.pad_token_id, self.config.decoder_start_token_id
            )

        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            use_cache=use_cache,
            past_key_values=past_key_values,
            return_dict=return_dict,
            **kwargs_decoder,
        )

        # Compute loss independent from decoder (as some shift the logits inside them)
        loss = None
        if labels is not None:
            logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))

        if not return_dict:
            if loss is not None:
                return (loss,) + decoder_outputs + encoder_outputs
            else:
                return decoder_outputs + encoder_outputs

        return Seq2SeqLMOutput(
            loss=loss,
            logits=decoder_outputs.logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs[0],
            encoder_hidden_states=getattr(encoder_outputs, 'hidden_states', None),  # TODO: only temporary (inconsistant)
            encoder_attentions=getattr(encoder_outputs, 'attentions', None),
        )