kiddothe2b commited on
Commit
c1c87bf
1 Parent(s): af99e83

Add HAT implementation files

Browse files
Files changed (1) hide show
  1. modelling_hat.py +4 -9
modelling_hat.py CHANGED
@@ -1839,8 +1839,6 @@ class HATForSequenceClassification(HATPreTrainedModel):
1839
  config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1840
  )
1841
  self.dropout = nn.Dropout(classifier_dropout)
1842
- if self.pooling != 'cls':
1843
- self.sentencizer = HATSentencizer(config)
1844
  self.pooler = HATPooler(config, pooling=pooling)
1845
  self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1846
 
@@ -1885,13 +1883,12 @@ class HATForSequenceClassification(HATPreTrainedModel):
1885
  return_dict=return_dict,
1886
  )
1887
  sequence_output = outputs[0]
1888
- if self.pooling not in ['first', 'last']:
1889
- sentence_outputs = self.sentencizer(sequence_output)
1890
- pooled_output = self.pooler(sentence_outputs)
1891
- elif self.pooling == 'first':
1892
  pooled_output = self.pooler(torch.unsqueeze(sequence_output[:, 0, :], 1))
1893
  elif self.pooling == 'last':
1894
  pooled_output = self.pooler(torch.unsqueeze(sequence_output[:, -128, :], 1))
 
 
1895
 
1896
  pooled_output = self.dropout(pooled_output)
1897
  logits = self.classifier(pooled_output)
@@ -2051,8 +2048,6 @@ class HATForMultipleChoice(HATPreTrainedModel):
2051
  config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
2052
  )
2053
  self.dropout = nn.Dropout(classifier_dropout)
2054
- if self.pooling not in ['first', 'last']:
2055
- self.sentencizer = HATSentencizer(config)
2056
  self.pooler = HATPooler(config, pooling=pooling)
2057
  self.classifier = nn.Linear(config.hidden_size, 1)
2058
 
@@ -2113,7 +2108,7 @@ class HATForMultipleChoice(HATPreTrainedModel):
2113
  elif self.pooling == 'last':
2114
  pooled_output = self.pooler(torch.unsqueeze(sequence_output[:, -128, :], 1))
2115
  else:
2116
- pooled_output = self.pooler(self.sentencizer(sequence_output))
2117
 
2118
  pooled_output = self.dropout(pooled_output)
2119
  logits = self.classifier(pooled_output)
 
1839
  config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1840
  )
1841
  self.dropout = nn.Dropout(classifier_dropout)
 
 
1842
  self.pooler = HATPooler(config, pooling=pooling)
1843
  self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1844
 
 
1883
  return_dict=return_dict,
1884
  )
1885
  sequence_output = outputs[0]
1886
+ if self.pooling == 'first':
 
 
 
1887
  pooled_output = self.pooler(torch.unsqueeze(sequence_output[:, 0, :], 1))
1888
  elif self.pooling == 'last':
1889
  pooled_output = self.pooler(torch.unsqueeze(sequence_output[:, -128, :], 1))
1890
+ else:
1891
+ pooled_output = self.pooler(sequence_output[:, ::self.max_sentence_length])
1892
 
1893
  pooled_output = self.dropout(pooled_output)
1894
  logits = self.classifier(pooled_output)
 
2048
  config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
2049
  )
2050
  self.dropout = nn.Dropout(classifier_dropout)
 
 
2051
  self.pooler = HATPooler(config, pooling=pooling)
2052
  self.classifier = nn.Linear(config.hidden_size, 1)
2053
 
 
2108
  elif self.pooling == 'last':
2109
  pooled_output = self.pooler(torch.unsqueeze(sequence_output[:, -128, :], 1))
2110
  else:
2111
+ pooled_output = self.pooler(sequence_output[:, ::self.max_sentence_length])
2112
 
2113
  pooled_output = self.dropout(pooled_output)
2114
  logits = self.classifier(pooled_output)