qilowoq commited on
Commit
0344cb1
·
1 Parent(s): e23e2b4

Upload AbLang

Browse files
Files changed (2) hide show
  1. encoderblocks.py +2 -2
  2. model.py +3 -3
encoderblocks.py CHANGED
@@ -13,7 +13,7 @@ class AbRepOutput():
13
  """
14
  Dataclass used to store AbRep output.
15
  """
16
- last_hidden_states: torch.FloatTensor
17
  all_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
18
  attentions: Optional[Tuple[torch.FloatTensor]] = None
19
 
@@ -36,7 +36,7 @@ class EncoderBlocks(PreTrainedModel):
36
  all_hidden_states = all_hidden_states + (hidden_states,) # Takes out each hidden states after each EncoderBlock
37
  if output_attentions:
38
  all_self_attentions = all_self_attentions + (attentions,) # Takes out attention layers for analysis
39
- return AbRepOutput(last_hidden_states=hidden_states, all_hidden_states=all_hidden_states, attentions=all_self_attentions)
40
 
41
 
42
  class EncoderBlock(PreTrainedModel):
 
13
  """
14
  Dataclass used to store AbRep output.
15
  """
16
+ last_hidden_state: torch.FloatTensor
17
  all_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
18
  attentions: Optional[Tuple[torch.FloatTensor]] = None
19
 
 
36
  all_hidden_states = all_hidden_states + (hidden_states,) # Takes out each hidden states after each EncoderBlock
37
  if output_attentions:
38
  all_self_attentions = all_self_attentions + (attentions,) # Takes out attention layers for analysis
39
+ return AbRepOutput(last_hidden_state=hidden_states, all_hidden_states=all_hidden_states, attentions=all_self_attentions)
40
 
41
 
42
  class EncoderBlock(PreTrainedModel):
model.py CHANGED
@@ -47,8 +47,8 @@ def apply_cls_embeddings(inputs, outputs):
47
  for i in d:
48
  mask[i, d[i]] = 0
49
  mask[:, 0] = 0.0 # make cls token invisible
50
- mask = mask.unsqueeze(-1).expand(outputs.last_hidden_states.size())
51
- sum_embeddings = torch.sum(outputs.last_hidden_states * mask, 1)
52
  sum_mask = torch.clamp(mask.sum(1), min=1e-9)
53
- outputs.last_hidden_states[:, 0, :] = sum_embeddings / sum_mask
54
  return outputs
 
47
  for i in d:
48
  mask[i, d[i]] = 0
49
  mask[:, 0] = 0.0 # make cls token invisible
50
+ mask = mask.unsqueeze(-1).expand(outputs.last_hidden_state.size())
51
+ sum_embeddings = torch.sum(outputs.last_hidden_state * mask, 1)
52
  sum_mask = torch.clamp(mask.sum(1), min=1e-9)
53
+ outputs.last_hidden_state[:, 0, :] = sum_embeddings / sum_mask
54
  return outputs