ivanzhouyq commited on
Commit
51b83bb
1 Parent(s): f97a1de

scale by num_senses and update weights

Browse files
Files changed (2) hide show
  1. backpack_model.py +7 -10
  2. model.safetensors +2 -2
backpack_model.py CHANGED
@@ -159,6 +159,10 @@ class BackpackGPT2Model(BackpackGPT2PreTrainedModel):
159
 
160
  # Compute resulting outputs
161
  hidden_states = torch.sum(contextualization @ senses, dim=1) # (bs, nv, s, d) -> (bs, s, d)
 
 
 
 
162
  return BackpackGPT2BaseModelOutput(
163
  hidden_states=hidden_states,
164
  contextualization=contextualization,
@@ -187,31 +191,24 @@ class BackpackGPT2LMHeadModel(BackpackGPT2PreTrainedModel):
187
  def __init__(self, config):
188
  super().__init__(config)
189
  self.backpack = BackpackGPT2Model(config)
190
- self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
191
 
192
  # Model parallel
193
  self.model_parallel = False
194
  self.device_map = None
195
 
196
- self.tie_weights()
197
-
198
- def tie_weights(self):
199
- self.lm_head.weight = self.backpack.word_embeddings.weight # also tied with the underlying underlying transf
200
-
201
  def get_lm_head(self):
202
  return self.lm_head
203
 
204
  def forward(self, input_ids, position_ids=None):
205
  outputs = self.backpack(input_ids, position_ids=position_ids)
206
  hidden_states, contextualization = outputs.hidden_states, outputs.contextualization
207
- lm_logits = self.lm_head(hidden_states) # (bs, s, V)
 
208
  return BackpackGPT2LMHeadModelOutput(
209
  logits=lm_logits,
210
  contextualization=contextualization,
211
  )
212
- # CausalLMOutput = namedtuple('CausalLMOutput', ['logits'])
213
- # return CausalLMOutput(logits=lm_logits)
214
-
215
  def run_with_custom_contextualization(self, input_ids, contextualization):
216
  outputs = self.backpack.run_with_custom_contextualization(input_ids, contextualization)
217
  hidden_states, contextualization = outputs.hidden_states, outputs.contextualization
 
159
 
160
  # Compute resulting outputs
161
  hidden_states = torch.sum(contextualization @ senses, dim=1) # (bs, nv, s, d) -> (bs, s, d)
162
+
163
+ # divide hidden_states by 1 / num_senses
164
+ hidden_states = hidden_states / self.num_senses
165
+
166
  return BackpackGPT2BaseModelOutput(
167
  hidden_states=hidden_states,
168
  contextualization=contextualization,
 
191
  def __init__(self, config):
192
  super().__init__(config)
193
  self.backpack = BackpackGPT2Model(config)
 
194
 
195
  # Model parallel
196
  self.model_parallel = False
197
  self.device_map = None
198
 
 
 
 
 
 
199
  def get_lm_head(self):
200
  return self.lm_head
201
 
202
  def forward(self, input_ids, position_ids=None):
203
  outputs = self.backpack(input_ids, position_ids=position_ids)
204
  hidden_states, contextualization = outputs.hidden_states, outputs.contextualization
205
+ # unembed the hidden_states
206
+ lm_logits = torch.einsum('bsd,nd->bsn', hidden_states, self.backpack.word_embeddings.weight)
207
  return BackpackGPT2LMHeadModelOutput(
208
  logits=lm_logits,
209
  contextualization=contextualization,
210
  )
211
+
 
 
212
  def run_with_custom_contextualization(self, input_ids, contextualization):
213
  outputs = self.backpack.run_with_custom_contextualization(input_ids, contextualization)
214
  hidden_states, contextualization = outputs.hidden_states, outputs.contextualization
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:844eb078f8af73181515736354aedcd84d99b1dd21e1218da5e7d4454df46463
3
- size 836334888
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:912f6a4da875f80a90a238d8714e1a27f79f257693509a9dfd3dd7e2a39165e7
3
+ size 990745984