Ivan Zhou commited on
Commit
35942d5
·
1 Parent(s): e8b5f8d

Update model, config, and weights

Browse files
Files changed (4) hide show
  1. backpack_config.py +17 -0
  2. backpack_model.py +222 -0
  3. config.json +1 -1
  4. model.safetensors +2 -2
backpack_config.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2Config
2
+
3
+
4
+ class BackpackGPT2Config(GPT2Config):
5
+ model_type = 'backpack-gpt2'
6
+
7
+ def __init__(self,
8
+ vocab_size=50264,
9
+ num_senses=16,
10
+ sense_intermediate_scale=4,
11
+ n_positions=512,
12
+ scale_attn_by_inverse_layer_idx=True,
13
+ **kwargs,
14
+ ):
15
+ self.num_senses = num_senses
16
+ self.sense_intermediate_scale = sense_intermediate_scale
17
+ super().__init__(vocab_size=vocab_size, n_positions=n_positions, scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx, **kwargs)
backpack_model.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import Optional, Tuple
4
+
5
+ import torch
6
+ import torch.utils.checkpoint
7
+ from torch import nn
8
+ from transformers.activations import ACT2FN
9
+ from transformers.pytorch_utils import Conv1D
10
+ from transformers.utils import ModelOutput
11
+ from transformers import GPT2PreTrainedModel, GPT2Model
12
+ from .backpack_config import BackpackGPT2Config
13
+
14
+
15
+ ### Backpack-Specific
16
+ class BackpackGPT2PreTrainedModel(GPT2PreTrainedModel):
17
+ """
18
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
19
+ models.
20
+ """
21
+ _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias"]
22
+
23
+ config_class = BackpackGPT2Config
24
+ base_model_prefix = "backpack"
25
+ is_parallelizable = True
26
+ supports_gradient_checkpointing = False
27
+ _no_split_modules = ["GPT2Block", "BackpackNoMixBlock"]
28
+
29
+ def __init__(self, *inputs, **kwargs):
30
+ super().__init__(*inputs, **kwargs)
31
+
32
+ class BackpackMLP(nn.Module):
33
+
34
+ def __init__(self, embed_dim, intermediate_dim, out_dim, config):
35
+ super().__init__()
36
+ self.c_fc = Conv1D(intermediate_dim, embed_dim)
37
+ self.c_proj = Conv1D(out_dim, intermediate_dim)
38
+ self.act = ACT2FN[config.activation_function]
39
+ self.dropout = nn.Dropout(config.resid_pdrop)
40
+
41
+ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
42
+ hidden_states = self.c_fc(hidden_states)
43
+ hidden_states = self.act(hidden_states)
44
+ hidden_states = self.c_proj(hidden_states)
45
+ hidden_states = self.dropout(hidden_states)
46
+ return hidden_states
47
+
48
+ class BackpackNoMixBlock(nn.Module):
49
+
50
+ def __init__(self, config):
51
+ super().__init__()
52
+ self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
53
+ self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
54
+ self.mlp = BackpackMLP(config.n_embd, config.n_embd*4, config.n_embd, config)
55
+ self.resid_dropout1 = nn.Dropout(config.resid_pdrop)
56
+ self.resid_dropout2 = nn.Dropout(config.resid_pdrop)
57
+
58
+ def forward(self, hidden_states, residual):
59
+ residual = self.resid_dropout1(hidden_states) + residual
60
+ hidden_states = self.ln_1(residual)
61
+ mlp_out = self.mlp(hidden_states)
62
+ residual = self.resid_dropout2(mlp_out) + residual
63
+ hidden_states = self.ln_2(residual)
64
+ return hidden_states
65
+
66
+
67
+ class BackpackSenseNetwork(nn.Module):
68
+ def __init__(self, config, num_senses, device=None, dtype=None):
69
+ super().__init__()
70
+ self.num_senses = num_senses
71
+ #self.embeddings = embeddings
72
+ self.n_embd = config.n_embd
73
+
74
+ self.dropout = nn.Dropout(config.embd_pdrop)
75
+ self.block = BackpackNoMixBlock(config)
76
+ self.ln = nn.LayerNorm(self.n_embd, eps=config.layer_norm_epsilon)
77
+ self.final_mlp = BackpackMLP(
78
+ embed_dim=config.n_embd,
79
+ intermediate_dim=config.sense_intermediate_scale*config.n_embd,
80
+ out_dim=config.n_embd*config.num_senses,
81
+ config=config,
82
+ )
83
+
84
+ def forward(self, input_embeds):
85
+ residual = self.dropout(input_embeds)
86
+ hidden_states = self.ln(residual)
87
+ hidden_states = self.block(hidden_states, residual)
88
+ senses = self.final_mlp(hidden_states)
89
+ bs, s, nvd = senses.shape
90
+ return senses.reshape(bs, s, self.num_senses, self.n_embd).transpose(1,2) # (bs, nv, s, d)
91
+
92
+
93
+ class BackpackWeightNetwork(nn.Module):
94
+
95
+ def __init__(self, num_senses, embed_dim):
96
+ super().__init__()
97
+ self.n_embd = embed_dim
98
+ self.num_senses = num_senses
99
+ self.c_attn = nn.Linear(embed_dim, 2*embed_dim)
100
+ self.softmax_scale = None
101
+
102
+ def forward(self, encoded):
103
+ b, s, d = encoded.shape
104
+ encoded = self.c_attn(encoded) # (b, s, 2*d)
105
+ encoded = encoded.reshape(b, s, 2, self.num_senses, d // self.num_senses) #(b, s, 2, nv, d//nv)
106
+ batch_size, seqlen = encoded.shape[0], encoded.shape[1]
107
+
108
+ # compute scores & mask
109
+ q, k = encoded.unbind(dim=2)
110
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
111
+ scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
112
+ causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
113
+ scores = scores + causal_mask.to(dtype=scores.dtype)
114
+
115
+ return torch.softmax(scores, dim=-1, dtype=q.dtype)
116
+
117
+
118
+ @dataclass
119
+ class BackpackGPT2BaseModelOutput(ModelOutput):
120
+ hidden_states: torch.FloatTensor = None
121
+ contextualization: torch.FloatTensor = None
122
+
123
+ class BackpackGPT2Model(BackpackGPT2PreTrainedModel):
124
+ _keys_to_ignore_on_load_missing = [r".*attn.masked_bias", r".*attn.bias"]
125
+
126
+ def __init__(self, config):
127
+ super().__init__(config)
128
+
129
+ self.embed_dim = config.n_embd
130
+
131
+ self.num_senses = config.num_senses
132
+ self.gpt2_model = GPT2Model(config)
133
+ self.sense_network = BackpackSenseNetwork(config, self.num_senses, self.gpt2_model.wte)
134
+ self.word_embeddings = self.gpt2_model.wte
135
+ self.position_embeddings = self.gpt2_model.wpe
136
+ self.sense_weight_net = BackpackWeightNetwork(self.num_senses, self.embed_dim)
137
+ # Model parallel
138
+ self.model_parallel = False
139
+ self.device_map = None
140
+ self.gradient_checkpointing = False
141
+
142
+ def get_num_senses(self):
143
+ return self.num_senses
144
+
145
+ def get_word_embeddings(self):
146
+ return self.word_embeddings
147
+
148
+ def get_sense_network(self):
149
+ return self.sense_network
150
+
151
+ def forward(self, input_ids, position_ids):
152
+ # Compute senses
153
+ sense_input_embeds = self.word_embeddings(input_ids)
154
+ senses = self.sense_network(sense_input_embeds) # (bs, nv, s, d)
155
+
156
+ # Compute contextualization weights
157
+ contextl_hidden_states = self.gpt2_model(input_ids, position_ids=position_ids).last_hidden_state # (bs, s, d)
158
+ contextualization = self.sense_weight_net(contextl_hidden_states) # (bs, nv, s, s)
159
+
160
+ # Compute resulting outputs
161
+ hidden_states = torch.sum(contextualization @ senses, dim=1) # (bs, nv, s, d) -> (bs, s, d)
162
+ return BackpackGPT2BaseModelOutput(
163
+ hidden_states=hidden_states,
164
+ contextualization=contextualization,
165
+ )
166
+
167
+ def run_with_custom_contextualization(self, input_ids, contextualization):
168
+ # Compute senses
169
+ sense_input_embeds = self.word_embeddings(input_ids)
170
+ senses = self.sense_network(sense_input_embeds) # (bs, nv, s, d)
171
+
172
+ # Compute resulting outputs
173
+ hidden_states = torch.sum(contextualization @ senses, dim=1) # (bs, nv, s, d) -> (bs, s, d)
174
+ return BackpackGPT2BaseModelOutput(
175
+ hidden_states=hidden_states,
176
+ contextualization=contextualization,
177
+ )
178
+
179
+ @dataclass
180
+ class BackpackGPT2LMHeadModelOutput(ModelOutput):
181
+ logits: torch.FloatTensor = None
182
+ contextualization: torch.FloatTensor = None
183
+
184
+ class BackpackGPT2LMHeadModel(BackpackGPT2PreTrainedModel):
185
+ _keys_to_ignore_on_load_missing = [r".*attn.masked_bias", r".*attn.bias"]
186
+
187
+ def __init__(self, config):
188
+ super().__init__(config)
189
+ self.backpack = BackpackGPT2Model(config)
190
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
191
+
192
+ # Model parallel
193
+ self.model_parallel = False
194
+ self.device_map = None
195
+
196
+ self.tie_weights()
197
+
198
+ def tie_weights(self):
199
+ self.lm_head.weight = self.backpack.word_embeddings.weight # also tied with the underlying underlying transf
200
+
201
+ def get_lm_head(self):
202
+ return self.lm_head
203
+
204
+ def forward(self, input_ids, position_ids=None):
205
+ outputs = self.backpack(input_ids, position_ids=position_ids)
206
+ hidden_states, contextualization = outputs.hidden_states, outputs.contextualization
207
+ lm_logits = self.lm_head(hidden_states) # (bs, s, V)
208
+ return BackpackGPT2LMHeadModelOutput(
209
+ logits=lm_logits,
210
+ contextualization=contextualization,
211
+ )
212
+ # CausalLMOutput = namedtuple('CausalLMOutput', ['logits'])
213
+ # return CausalLMOutput(logits=lm_logits)
214
+
215
+ def run_with_custom_contextualization(self, input_ids, contextualization):
216
+ outputs = self.backpack.run_with_custom_contextualization(input_ids, contextualization)
217
+ hidden_states, contextualization = outputs.hidden_states, outputs.contextualization
218
+ lm_logits = self.lm_head(hidden_states)
219
+ return BackpackGPT2LMHeadModelOutput(
220
+ logits=lm_logits,
221
+ contextualization=contextualization,
222
+ )
config.json CHANGED
@@ -1 +1 @@
1
- {"return_dict": true, "output_hidden_states": false, "output_attentions": false, "torchscript": false, "torch_dtype": null, "use_bfloat16": false, "tf_legacy_loss": false, "pruned_heads": {}, "tie_word_embeddings": true, "is_encoder_decoder": false, "is_decoder": false, "cross_attention_hidden_size": null, "add_cross_attention": false, "tie_encoder_decoder": false, "max_length": 20, "min_length": 0, "do_sample": false, "early_stopping": false, "num_beams": 1, "num_beam_groups": 1, "diversity_penalty": 0.0, "temperature": 1.0, "top_k": 50, "top_p": 1.0, "typical_p": 1.0, "repetition_penalty": 1.0, "length_penalty": 1.0, "no_repeat_ngram_size": 0, "encoder_no_repeat_ngram_size": 0, "bad_words_ids": null, "num_return_sequences": 1, "chunk_size_feed_forward": 0, "output_scores": false, "return_dict_in_generate": false, "forced_bos_token_id": null, "forced_eos_token_id": null, "remove_invalid_values": false, "exponential_decay_length_penalty": null, "suppress_tokens": null, "begin_suppress_tokens": null, "architectures": null, "finetuning_task": null, "id2label": {"0": "LABEL_0", "1": "LABEL_1"}, "label2id": {"LABEL_0": 0, "LABEL_1": 1}, "tokenizer_class": null, "prefix": null, "bos_token_id": null, "pad_token_id": null, "eos_token_id": null, "sep_token_id": null, "decoder_start_token_id": null, "task_specific_params": null, "problem_type": null, "_name_or_path": "", "transformers_version": "4.29.2", "vocab_size": 50264, "seq_len": 512, "hidden_dim": 768, "num_layers": 12, "num_heads": 12, "mlp_scale": 4, "initializer_range": 0.02, "embed_pdrop": 0.0, "resid_pdrop": 0.0, "attn_pdrop": 0.0, "layer_norm_epsilon": 1e-05, "activation_function": "gelu_new", "scale_attn_by_inverse_layer_idx": true, "upcast_attn": false, "num_senses": 16, "sense_intermediate_scale": 4, "model_type": ""}
 
1
+ {"return_dict": true, "output_hidden_states": false, "output_attentions": false, "torchscript": false, "torch_dtype": null, "use_bfloat16": false, "tf_legacy_loss": false, "pruned_heads": {}, "tie_word_embeddings": true, "is_encoder_decoder": false, "is_decoder": false, "cross_attention_hidden_size": null, "add_cross_attention": false, "tie_encoder_decoder": false, "max_length": 20, "min_length": 0, "do_sample": false, "early_stopping": false, "num_beams": 1, "num_beam_groups": 1, "diversity_penalty": 0.0, "temperature": 1.0, "top_k": 50, "top_p": 1.0, "typical_p": 1.0, "repetition_penalty": 1.0, "length_penalty": 1.0, "no_repeat_ngram_size": 0, "encoder_no_repeat_ngram_size": 0, "bad_words_ids": null, "num_return_sequences": 1, "chunk_size_feed_forward": 0, "output_scores": false, "return_dict_in_generate": false, "forced_bos_token_id": null, "forced_eos_token_id": null, "remove_invalid_values": false, "exponential_decay_length_penalty": null, "suppress_tokens": null, "begin_suppress_tokens": null, "architectures": null, "finetuning_task": null, "id2label": {"0": "LABEL_0", "1": "LABEL_1"}, "label2id": {"LABEL_0": 0, "LABEL_1": 1}, "tokenizer_class": null, "prefix": null, "bos_token_id": null, "pad_token_id": null, "eos_token_id": null, "sep_token_id": null, "decoder_start_token_id": null, "task_specific_params": null, "problem_type": null, "_name_or_path": "", "transformers_version": "4.29.2", "vocab_size": 50264, "n_positions": 512, "n_layer": 36, "n_head": 20, "n_embd": 1280, "initializer_range": 0.02, "attn_pdrop": 0.0, "embd_pdrop": 0.0, "layer_norm_epsilon": 1e-05, "activation_function": "gelu_new", "scale_attn_by_inverse_layer_idx": true, "reorder_and_upcast_attn": false, "auto_map": {"AutoConfig": "backpack_config.BackpackGPT2Config", "AutoModelForCausalLM": "backpack_model.BackpackGPT2Model"}, "model_type": "backpack-gpt2"}
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:18e949b43b912d0763ce1999a605b42485e659bd29c17e0544b4a75b56565f15
3
- size 680349136
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:844eb078f8af73181515736354aedcd84d99b1dd21e1218da5e7d4454df46463
3
+ size 836334888