ivanzhouyq commited on
Commit
de682e6
1 Parent(s): 64192d1

Upload model

Browse files
backpack_config.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2Config
2
+
3
+
4
+ class BackpackGPT2Config(GPT2Config):
5
+ model_type = "backpack-gpt2"
6
+
7
+ def __init__(
8
+ self,
9
+ num_senses: int = 16,
10
+ sense_intermediate_scale: int = 4,
11
+ vocab_size: int = 50264,
12
+ n_positions: int = 512,
13
+ scale_attn_by_inverse_layer_idx: bool = True,
14
+ **kwargs,
15
+ ):
16
+ self.num_senses = num_senses
17
+ self.sense_intermediate_scale = sense_intermediate_scale
18
+ super().__init__(
19
+ vocab_size=vocab_size,
20
+ n_positions=n_positions,
21
+ scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx,
22
+ **kwargs,
23
+ )
backpack_model.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import Optional, Tuple
4
+
5
+ import torch
6
+ import torch.utils.checkpoint
7
+ from torch import nn
8
+ from transformers.activations import ACT2FN
9
+ from transformers.pytorch_utils import Conv1D
10
+ from transformers.utils import ModelOutput
11
+ from transformers import GPT2PreTrainedModel, GPT2Model
12
+ from .backpack_config import BackpackGPT2Config
13
+
14
+
15
+ ### Backpack-Specific
16
+ class BackpackGPT2PreTrainedModel(GPT2PreTrainedModel):
17
+ """
18
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
19
+ models.
20
+ """
21
+
22
+ _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias"]
23
+
24
+ config_class = BackpackGPT2Config
25
+ base_model_prefix = "backpack"
26
+ is_parallelizable = True
27
+ supports_gradient_checkpointing = False
28
+ _no_split_modules = ["GPT2Block", "BackpackNoMixBlock"]
29
+
30
+ def __init__(self, *inputs, **kwargs):
31
+ super().__init__(*inputs, **kwargs)
32
+
33
+
34
+ class BackpackMLP(nn.Module):
35
+ def __init__(self, embed_dim, intermediate_dim, out_dim, config):
36
+ super().__init__()
37
+ self.c_fc = Conv1D(intermediate_dim, embed_dim)
38
+ self.c_proj = Conv1D(out_dim, intermediate_dim)
39
+ self.act = ACT2FN[config.activation_function]
40
+ self.dropout = nn.Dropout(config.resid_pdrop)
41
+
42
+ def forward(
43
+ self, hidden_states: Optional[Tuple[torch.FloatTensor]]
44
+ ) -> torch.FloatTensor:
45
+ hidden_states = self.c_fc(hidden_states)
46
+ hidden_states = self.act(hidden_states)
47
+ hidden_states = self.c_proj(hidden_states)
48
+ hidden_states = self.dropout(hidden_states)
49
+ return hidden_states
50
+
51
+
52
+ class BackpackNoMixBlock(nn.Module):
53
+ def __init__(self, config):
54
+ super().__init__()
55
+ self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
56
+ self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
57
+ self.mlp = BackpackMLP(config.n_embd, config.n_embd * 4, config.n_embd, config)
58
+ self.resid_dropout1 = nn.Dropout(config.resid_pdrop)
59
+ self.resid_dropout2 = nn.Dropout(config.resid_pdrop)
60
+
61
+ def forward(self, hidden_states, residual):
62
+ residual = self.resid_dropout1(hidden_states) + residual
63
+ hidden_states = self.ln_1(residual)
64
+ mlp_out = self.mlp(hidden_states)
65
+ residual = self.resid_dropout2(mlp_out) + residual
66
+ hidden_states = self.ln_2(residual)
67
+ return hidden_states
68
+
69
+
70
+ class BackpackSenseNetwork(nn.Module):
71
+ def __init__(self, config, num_senses, device=None, dtype=None):
72
+ super().__init__()
73
+ self.num_senses = num_senses
74
+ # self.embeddings = embeddings
75
+ self.n_embd = config.n_embd
76
+
77
+ self.dropout = nn.Dropout(config.embd_pdrop)
78
+ self.block = BackpackNoMixBlock(config)
79
+ self.ln = nn.LayerNorm(self.n_embd, eps=config.layer_norm_epsilon)
80
+ self.final_mlp = BackpackMLP(
81
+ embed_dim=config.n_embd,
82
+ intermediate_dim=config.sense_intermediate_scale * config.n_embd,
83
+ out_dim=config.n_embd * config.num_senses,
84
+ config=config,
85
+ )
86
+
87
+ def forward(self, input_embeds):
88
+ residual = self.dropout(input_embeds)
89
+ hidden_states = self.ln(residual)
90
+ hidden_states = self.block(hidden_states, residual)
91
+ senses = self.final_mlp(hidden_states)
92
+ bs, s, nvd = senses.shape
93
+ return senses.reshape(bs, s, self.num_senses, self.n_embd).transpose(
94
+ 1, 2
95
+ ) # (bs, nv, s, d)
96
+
97
+
98
+ class BackpackWeightNetwork(nn.Module):
99
+ def __init__(self, num_senses, embed_dim):
100
+ super().__init__()
101
+ self.n_embd = embed_dim
102
+ self.num_senses = num_senses
103
+ self.embed_per_sense = embed_dim // num_senses
104
+ self.c_attn = nn.Linear(embed_dim, 2 * num_senses * self.embed_per_sense)
105
+ self.softmax_scale = None
106
+
107
+ def forward(self, encoded):
108
+ b, s, d = encoded.shape
109
+ encoded = self.c_attn(encoded) # (b, s, 2*d)
110
+ encoded = encoded.reshape(
111
+ b, s, 2, self.num_senses, self.embed_per_sense
112
+ ) # (b, s, 2, nv, d//nv)
113
+ batch_size, seqlen = encoded.shape[0], encoded.shape[1]
114
+
115
+ # compute scores & mask
116
+ q, k = encoded.unbind(dim=2)
117
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
118
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
119
+ causal_mask = torch.triu(
120
+ torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
121
+ )
122
+ scores = scores + causal_mask.to(dtype=scores.dtype)
123
+
124
+ return torch.softmax(scores, dim=-1, dtype=q.dtype)
125
+
126
+
127
+ @dataclass
128
+ class BackpackGPT2BaseModelOutput(ModelOutput):
129
+ hidden_states: torch.FloatTensor = None
130
+ contextualization: torch.FloatTensor = None
131
+
132
+
133
+ class BackpackGPT2Model(BackpackGPT2PreTrainedModel):
134
+ _keys_to_ignore_on_load_missing = [r".*attn.masked_bias", r".*attn.bias"]
135
+
136
+ def __init__(self, config):
137
+ super().__init__(config)
138
+
139
+ self.embed_dim = config.n_embd
140
+
141
+ self.num_senses = config.num_senses
142
+ self.gpt2_model = GPT2Model(config)
143
+ self.sense_network = BackpackSenseNetwork(
144
+ config, self.num_senses, self.gpt2_model.wte
145
+ )
146
+ self.word_embeddings = self.gpt2_model.wte
147
+ self.position_embeddings = self.gpt2_model.wpe
148
+ self.sense_weight_net = BackpackWeightNetwork(self.num_senses, self.embed_dim)
149
+ # Model parallel
150
+ self.model_parallel = False
151
+ self.device_map = None
152
+ self.gradient_checkpointing = False
153
+
154
+ def get_num_senses(self):
155
+ return self.num_senses
156
+
157
+ def get_word_embeddings(self):
158
+ return self.word_embeddings
159
+
160
+ def get_sense_network(self):
161
+ return self.sense_network
162
+
163
+ def forward(self, input_ids, position_ids: Optional[torch.LongTensor] = None):
164
+ # Compute senses
165
+ sense_input_embeds = self.word_embeddings(input_ids)
166
+ senses = self.sense_network(sense_input_embeds) # (bs, nv, s, d)
167
+
168
+ # Compute contextualization weights
169
+ contextl_hidden_states = self.gpt2_model(
170
+ input_ids, position_ids=position_ids
171
+ ).last_hidden_state # (bs, s, d)
172
+ contextualization = self.sense_weight_net(
173
+ contextl_hidden_states
174
+ ) # (bs, nv, s, s)
175
+
176
+ # Compute resulting outputs
177
+ hidden_states = torch.sum(
178
+ contextualization @ senses, dim=1
179
+ ) # (bs, nv, s, d) -> (bs, s, d)
180
+
181
+ # divide hidden_states by 1 / num_senses
182
+ hidden_states = hidden_states / self.num_senses
183
+
184
+ return BackpackGPT2BaseModelOutput(
185
+ hidden_states=hidden_states,
186
+ contextualization=contextualization,
187
+ )
188
+
189
+ def run_with_custom_contextualization(self, input_ids, contextualization):
190
+ # Compute senses
191
+ sense_input_embeds = self.word_embeddings(input_ids)
192
+ senses = self.sense_network(sense_input_embeds) # (bs, nv, s, d)
193
+
194
+ # Compute resulting outputs
195
+ hidden_states = torch.sum(
196
+ contextualization @ senses, dim=1
197
+ ) # (bs, nv, s, d) -> (bs, s, d)
198
+ return BackpackGPT2BaseModelOutput(
199
+ hidden_states=hidden_states,
200
+ contextualization=contextualization,
201
+ )
202
+
203
+
204
+ @dataclass
205
+ class BackpackGPT2LMHeadModelOutput(ModelOutput):
206
+ logits: torch.FloatTensor = None
207
+ contextualization: torch.FloatTensor = None
208
+
209
+
210
+ class BackpackGPT2LMHeadModel(BackpackGPT2PreTrainedModel):
211
+ _keys_to_ignore_on_load_missing = [r".*attn.masked_bias", r".*attn.bias"]
212
+
213
+ def __init__(self, config):
214
+ super().__init__(config)
215
+ self.backpack = BackpackGPT2Model(config)
216
+
217
+ # Model parallel
218
+ self.model_parallel = False
219
+ self.device_map = None
220
+
221
+ def get_lm_head(self):
222
+ return self.lm_head
223
+
224
+ def forward(self, input_ids, position_ids=None):
225
+ outputs = self.backpack(input_ids, position_ids=position_ids)
226
+ hidden_states, contextualization = (
227
+ outputs.hidden_states,
228
+ outputs.contextualization,
229
+ )
230
+ # unembed the hidden_states
231
+ lm_logits = torch.einsum(
232
+ "bsd,nd->bsn", hidden_states, self.backpack.word_embeddings.weight
233
+ )
234
+ return BackpackGPT2LMHeadModelOutput(
235
+ logits=lm_logits,
236
+ contextualization=contextualization,
237
+ )
238
+
239
+ def run_with_custom_contextualization(self, input_ids, contextualization):
240
+ outputs = self.backpack.run_with_custom_contextualization(
241
+ input_ids, contextualization
242
+ )
243
+ hidden_states, contextualization = (
244
+ outputs.hidden_states,
245
+ outputs.contextualization,
246
+ )
247
+ lm_logits = self.lm_head(hidden_states)
248
+ return BackpackGPT2LMHeadModelOutput(
249
+ logits=lm_logits,
250
+ contextualization=contextualization,
251
+ )
config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"return_dict": true, "output_hidden_states": false, "output_attentions": false, "torchscript": false, "torch_dtype": null, "use_bfloat16": false, "tf_legacy_loss": false, "pruned_heads": {}, "tie_word_embeddings": true, "is_encoder_decoder": false, "is_decoder": false, "cross_attention_hidden_size": null, "add_cross_attention": false, "tie_encoder_decoder": false, "max_length": 20, "min_length": 0, "do_sample": false, "early_stopping": false, "num_beams": 1, "num_beam_groups": 1, "diversity_penalty": 0.0, "temperature": 1.0, "top_k": 50, "top_p": 1.0, "typical_p": 1.0, "repetition_penalty": 1.0, "length_penalty": 1.0, "no_repeat_ngram_size": 0, "encoder_no_repeat_ngram_size": 0, "bad_words_ids": null, "num_return_sequences": 1, "chunk_size_feed_forward": 0, "output_scores": false, "return_dict_in_generate": false, "forced_bos_token_id": null, "forced_eos_token_id": null, "remove_invalid_values": false, "exponential_decay_length_penalty": null, "suppress_tokens": null, "begin_suppress_tokens": null, "architectures": null, "finetuning_task": null, "id2label": {"0": "LABEL_0", "1": "LABEL_1"}, "label2id": {"LABEL_0": 0, "LABEL_1": 1}, "tokenizer_class": null, "prefix": null, "bos_token_id": null, "pad_token_id": null, "eos_token_id": null, "sep_token_id": null, "decoder_start_token_id": null, "task_specific_params": null, "problem_type": null, "_name_or_path": "", "transformers_version": "4.29.2", "vocab_size": 50264, "n_positions": 512, "n_layer": 36, "n_head": 20, "n_embd": 1280, "initializer_range": 0.02, "attn_pdrop": 0.0, "embd_pdrop": 0.0, "layer_norm_epsilon": 1e-05, "activation_function": "gelu_new", "scale_attn_by_inverse_layer_idx": true, "reorder_and_upcast_attn": false, "num_senses": 48, "sense_intermediate_scale": 7, "auto_map": {"AutoConfig": "backpack_config.BackpackGPT2Config", "AutoModelForCausalLM": "backpack_model.BackpackGPT2LMHeadModel"}, "model_type": "backpack-gpt2"}
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:95a0528c52856a01a9de3bb640e77beb40a478324423f71570de9f1557316247
3
+ size 5666995936
special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|endoftext|>",
3
+ "eos_token": "<|endoftext|>",
4
+ "unk_token": "<|endoftext|>"
5
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": "<|endoftext|>",
4
+ "clean_up_tokenization_spaces": true,
5
+ "eos_token": "<|endoftext|>",
6
+ "model_max_length": 1024,
7
+ "tokenizer_class": "GPT2Tokenizer",
8
+ "unk_token": "<|endoftext|>"
9
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff