File size: 8,494 Bytes
5b1efcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9697cc
 
5b1efcc
 
 
 
 
a9697cc
5b1efcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import math
from dataclasses import dataclass
from typing import Optional, Tuple

import torch
import torch.utils.checkpoint
from torch import nn

from transformers.activations import ACT2FN
from transformers.pytorch_utils import Conv1D
from transformers.utils import (
    ModelOutput,
    logging,
)
from transformers.models.gpt2.modeling_gpt2 import GPT2Model, GPT2PreTrainedModel
from .configuration_backpack_gpt2 import BackpackGPT2Config

logger = logging.get_logger(__name__)


### Backpack-Specific
class BackpackGPT2PreTrainedModel(GPT2PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """
    _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias"]

    config_class = BackpackGPT2Config
    base_model_prefix = "backpack"
    is_parallelizable = True
    supports_gradient_checkpointing = False
    _no_split_modules = ["GPT2Block", "BackpackNoMixBlock"]

    def __init__(self, *inputs, **kwargs):
        super().__init__(*inputs, **kwargs)

class BackpackMLP(nn.Module):

  def __init__(self, embed_dim, intermediate_dim, out_dim, config):
        super().__init__()
        self.c_fc = Conv1D(intermediate_dim, embed_dim)
        self.c_proj = Conv1D(out_dim, intermediate_dim)
        self.act = ACT2FN[config.activation_function]
        self.dropout = nn.Dropout(config.resid_pdrop)

  def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
      hidden_states = self.c_fc(hidden_states)
      hidden_states = self.act(hidden_states)
      hidden_states = self.c_proj(hidden_states)
      hidden_states = self.dropout(hidden_states)
      return hidden_states

class BackpackNoMixBlock(nn.Module):

  def __init__(self, config):
    super().__init__()
    self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
    self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
    self.mlp = BackpackMLP(config.n_embd, config.n_embd*4, config.n_embd, config)
    self.resid_dropout1 = nn.Dropout(config.resid_pdrop)
    self.resid_dropout2 = nn.Dropout(config.resid_pdrop)

  def forward(self, hidden_states, residual):
    residual = self.resid_dropout1(hidden_states) + residual
    hidden_states = self.ln_1(residual)
    mlp_out = self.mlp(hidden_states)
    residual = self.resid_dropout2(mlp_out) + residual
    hidden_states = self.ln_2(residual)
    return hidden_states


class BackpackSenseNetwork(nn.Module):
    def __init__(self, config, num_senses, device=None, dtype=None):
        super().__init__()
        self.num_senses = num_senses
        #self.embeddings = embeddings
        self.n_embd = config.n_embd

        self.dropout = nn.Dropout(config.embd_pdrop)
        self.block = BackpackNoMixBlock(config)
        self.ln = nn.LayerNorm(self.n_embd, eps=config.layer_norm_epsilon)
        self.final_mlp = BackpackMLP(
            embed_dim=config.n_embd,
            intermediate_dim=config.sense_intermediate_scale*config.n_embd,
            out_dim=config.n_embd*config.num_senses,
            config=config,
            )

    def forward(self, input_embeds):
      residual = self.dropout(input_embeds)
      hidden_states = self.ln(residual)
      hidden_states = self.block(hidden_states, residual)
      senses = self.final_mlp(hidden_states)
      bs, s, nvd = senses.shape
      return senses.reshape(bs, s, self.num_senses, self.n_embd).transpose(1,2) # (bs, nv, s, d)

class BackpackWeightNetwork(nn.Module):

  def __init__(self, num_senses, embed_dim):
    super().__init__()
    self.n_embd = embed_dim
    self.num_senses = num_senses
    self.embed_per_sense = embed_dim // num_senses
    self.c_attn = nn.Linear(embed_dim, 2 * num_senses * self.embed_per_sense)
    self.softmax_scale = None

  def forward(self, encoded):
    b, s, d = encoded.shape
    encoded = self.c_attn(encoded) # (b, s, 2*d)
    encoded = encoded.reshape(b, s, 2, self.num_senses, self.embed_per_sense) #(b, s, 2, nv, d//nv)
    batch_size, seqlen = encoded.shape[0], encoded.shape[1]

    # compute scores & mask
    q, k = encoded.unbind(dim=2)
    softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
    scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
    causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
    scores = scores + causal_mask.to(dtype=scores.dtype)

    return torch.softmax(scores, dim=-1, dtype=q.dtype)
  

@dataclass
class BackpackGPT2BaseModelOutput(ModelOutput):
    hidden_states: torch.FloatTensor = None
    contextualization: torch.FloatTensor = None

class BackpackGPT2Model(BackpackGPT2PreTrainedModel):
    _keys_to_ignore_on_load_missing = [r".*attn.masked_bias", r".*attn.bias"]

    def __init__(self, config):
        super().__init__(config)

        self.embed_dim = config.n_embd

        self.num_senses = config.num_senses
        self.gpt2_model = GPT2Model(config)
        self.sense_network = BackpackSenseNetwork(config, self.num_senses, self.gpt2_model.wte)
        self.word_embeddings = self.gpt2_model.wte
        self.position_embeddings = self.gpt2_model.wpe
        self.sense_weight_net = BackpackWeightNetwork(self.num_senses, self.embed_dim)
        # Model parallel
        self.model_parallel = False
        self.device_map = None
        self.gradient_checkpointing = False

    def get_num_senses(self):
        return self.num_senses

    def get_word_embeddings(self):
        return self.word_embeddings

    def get_sense_network(self):
        return self.sense_network

    def forward(self, input_ids, position_ids):
        # Compute senses
        sense_input_embeds = self.word_embeddings(input_ids)
        senses = self.sense_network(sense_input_embeds) # (bs, nv, s, d)

        # Compute contextualization weights
        contextl_hidden_states = self.gpt2_model(input_ids, position_ids=position_ids).last_hidden_state # (bs, s, d)
        contextualization = self.sense_weight_net(contextl_hidden_states) # (bs, nv, s, s)

        # Compute resulting outputs
        hidden_states = torch.sum(contextualization @ senses, dim=1) # (bs, nv, s, d) -> (bs, s, d)
        return BackpackGPT2BaseModelOutput(
            hidden_states=hidden_states,
            contextualization=contextualization,
        )
    
    def run_with_custom_contextualization(self, input_ids, contextualization):
        # Compute senses
        sense_input_embeds = self.word_embeddings(input_ids)
        senses = self.sense_network(sense_input_embeds) # (bs, nv, s, d)

        # Compute resulting outputs
        hidden_states = torch.sum(contextualization @ senses, dim=1) # (bs, nv, s, d) -> (bs, s, d)
        return BackpackGPT2BaseModelOutput(
            hidden_states=hidden_states,
            contextualization=contextualization,
        )

@dataclass
class BackpackGPT2LMHeadModelOutput(ModelOutput):
    logits: torch.FloatTensor = None
    contextualization: torch.FloatTensor = None

class BackpackGPT2LMHeadModel(BackpackGPT2PreTrainedModel):
  _keys_to_ignore_on_load_missing = [r".*attn.masked_bias", r".*attn.bias"]

  def __init__(self, config):
    super().__init__(config)
    self.backpack = BackpackGPT2Model(config)
    self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

    # Model parallel
    self.model_parallel = False
    self.device_map = None

    self.tie_weights()

  def tie_weights(self):
      self.lm_head.weight = self.backpack.word_embeddings.weight # also tied with the underlying underlying transf

  def get_lm_head(self):
      return self.lm_head

  def forward(self, input_ids, position_ids=None):
      outputs = self.backpack(input_ids, position_ids=position_ids)
      hidden_states, contextualization = outputs.hidden_states, outputs.contextualization
      lm_logits = self.lm_head(hidden_states) # (bs, s, V)
      return BackpackGPT2LMHeadModelOutput(
            logits=lm_logits,
            contextualization=contextualization,
        )

  def run_with_custom_contextualization(self, input_ids, contextualization):
      outputs = self.backpack.run_with_custom_contextualization(input_ids, contextualization)
      hidden_states, contextualization = outputs.hidden_states, outputs.contextualization
      lm_logits = self.lm_head(hidden_states)
      return BackpackGPT2LMHeadModelOutput(
        logits=lm_logits,
        contextualization=contextualization,
    )