Upload AbLang
Browse files- encoderblocks.py +2 -2
- model.py +3 -3
encoderblocks.py
CHANGED
@@ -13,7 +13,7 @@ class AbRepOutput():
|
|
13 |
"""
|
14 |
Dataclass used to store AbRep output.
|
15 |
"""
|
16 |
-
|
17 |
all_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
18 |
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
19 |
|
@@ -36,7 +36,7 @@ class EncoderBlocks(PreTrainedModel):
|
|
36 |
all_hidden_states = all_hidden_states + (hidden_states,) # Takes out each hidden states after each EncoderBlock
|
37 |
if output_attentions:
|
38 |
all_self_attentions = all_self_attentions + (attentions,) # Takes out attention layers for analysis
|
39 |
-
return AbRepOutput(
|
40 |
|
41 |
|
42 |
class EncoderBlock(PreTrainedModel):
|
|
|
13 |
"""
|
14 |
Dataclass used to store AbRep output.
|
15 |
"""
|
16 |
+
last_hidden_state: torch.FloatTensor
|
17 |
all_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
18 |
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
19 |
|
|
|
36 |
all_hidden_states = all_hidden_states + (hidden_states,) # Takes out each hidden states after each EncoderBlock
|
37 |
if output_attentions:
|
38 |
all_self_attentions = all_self_attentions + (attentions,) # Takes out attention layers for analysis
|
39 |
+
return AbRepOutput(last_hidden_state=hidden_states, all_hidden_states=all_hidden_states, attentions=all_self_attentions)
|
40 |
|
41 |
|
42 |
class EncoderBlock(PreTrainedModel):
|
model.py
CHANGED
@@ -47,8 +47,8 @@ def apply_cls_embeddings(inputs, outputs):
|
|
47 |
for i in d:
|
48 |
mask[i, d[i]] = 0
|
49 |
mask[:, 0] = 0.0 # make cls token invisible
|
50 |
-
mask = mask.unsqueeze(-1).expand(outputs.
|
51 |
-
sum_embeddings = torch.sum(outputs.
|
52 |
sum_mask = torch.clamp(mask.sum(1), min=1e-9)
|
53 |
-
outputs.
|
54 |
return outputs
|
|
|
47 |
for i in d:
|
48 |
mask[i, d[i]] = 0
|
49 |
mask[:, 0] = 0.0 # make cls token invisible
|
50 |
+
mask = mask.unsqueeze(-1).expand(outputs.last_hidden_state.size())
|
51 |
+
sum_embeddings = torch.sum(outputs.last_hidden_state * mask, 1)
|
52 |
sum_mask = torch.clamp(mask.sum(1), min=1e-9)
|
53 |
+
outputs.last_hidden_state[:, 0, :] = sum_embeddings / sum_mask
|
54 |
return outputs
|