Commit
•
51b83bb
1
Parent(s):
f97a1de
scale by num_senses and update weights
Browse files- backpack_model.py +7 -10
- model.safetensors +2 -2
backpack_model.py
CHANGED
@@ -159,6 +159,10 @@ class BackpackGPT2Model(BackpackGPT2PreTrainedModel):
|
|
159 |
|
160 |
# Compute resulting outputs
|
161 |
hidden_states = torch.sum(contextualization @ senses, dim=1) # (bs, nv, s, d) -> (bs, s, d)
|
|
|
|
|
|
|
|
|
162 |
return BackpackGPT2BaseModelOutput(
|
163 |
hidden_states=hidden_states,
|
164 |
contextualization=contextualization,
|
@@ -187,31 +191,24 @@ class BackpackGPT2LMHeadModel(BackpackGPT2PreTrainedModel):
|
|
187 |
def __init__(self, config):
|
188 |
super().__init__(config)
|
189 |
self.backpack = BackpackGPT2Model(config)
|
190 |
-
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
191 |
|
192 |
# Model parallel
|
193 |
self.model_parallel = False
|
194 |
self.device_map = None
|
195 |
|
196 |
-
self.tie_weights()
|
197 |
-
|
198 |
-
def tie_weights(self):
|
199 |
-
self.lm_head.weight = self.backpack.word_embeddings.weight # also tied with the underlying underlying transf
|
200 |
-
|
201 |
def get_lm_head(self):
|
202 |
return self.lm_head
|
203 |
|
204 |
def forward(self, input_ids, position_ids=None):
|
205 |
outputs = self.backpack(input_ids, position_ids=position_ids)
|
206 |
hidden_states, contextualization = outputs.hidden_states, outputs.contextualization
|
207 |
-
|
|
|
208 |
return BackpackGPT2LMHeadModelOutput(
|
209 |
logits=lm_logits,
|
210 |
contextualization=contextualization,
|
211 |
)
|
212 |
-
|
213 |
-
# return CausalLMOutput(logits=lm_logits)
|
214 |
-
|
215 |
def run_with_custom_contextualization(self, input_ids, contextualization):
|
216 |
outputs = self.backpack.run_with_custom_contextualization(input_ids, contextualization)
|
217 |
hidden_states, contextualization = outputs.hidden_states, outputs.contextualization
|
|
|
159 |
|
160 |
# Compute resulting outputs
|
161 |
hidden_states = torch.sum(contextualization @ senses, dim=1) # (bs, nv, s, d) -> (bs, s, d)
|
162 |
+
|
163 |
+
# divide hidden_states by 1 / num_senses
|
164 |
+
hidden_states = hidden_states / self.num_senses
|
165 |
+
|
166 |
return BackpackGPT2BaseModelOutput(
|
167 |
hidden_states=hidden_states,
|
168 |
contextualization=contextualization,
|
|
|
191 |
def __init__(self, config):
|
192 |
super().__init__(config)
|
193 |
self.backpack = BackpackGPT2Model(config)
|
|
|
194 |
|
195 |
# Model parallel
|
196 |
self.model_parallel = False
|
197 |
self.device_map = None
|
198 |
|
|
|
|
|
|
|
|
|
|
|
199 |
def get_lm_head(self):
|
200 |
return self.lm_head
|
201 |
|
202 |
def forward(self, input_ids, position_ids=None):
|
203 |
outputs = self.backpack(input_ids, position_ids=position_ids)
|
204 |
hidden_states, contextualization = outputs.hidden_states, outputs.contextualization
|
205 |
+
# unembed the hidden_states
|
206 |
+
lm_logits = torch.einsum('bsd,nd->bsn', hidden_states, self.backpack.word_embeddings.weight)
|
207 |
return BackpackGPT2LMHeadModelOutput(
|
208 |
logits=lm_logits,
|
209 |
contextualization=contextualization,
|
210 |
)
|
211 |
+
|
|
|
|
|
212 |
def run_with_custom_contextualization(self, input_ids, contextualization):
|
213 |
outputs = self.backpack.run_with_custom_contextualization(input_ids, contextualization)
|
214 |
hidden_states, contextualization = outputs.hidden_states, outputs.contextualization
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:912f6a4da875f80a90a238d8714e1a27f79f257693509a9dfd3dd7e2a39165e7
|
3 |
+
size 990745984
|