kiddothe2b commited on
Commit
162074d
1 Parent(s): c1c87bf

Add HAT implementation files

Browse files
Files changed (1) hide show
  1. modelling_hat.py +4 -1
modelling_hat.py CHANGED
@@ -1186,6 +1186,7 @@ class HATModelForDocumentRepresentation(HATPreTrainedModel):
1186
  super().__init__(config)
1187
  self.num_labels = config.num_labels
1188
  self.config = config
 
1189
 
1190
  self.hi_transformer = HATModel(config)
1191
  self.pooler = HATPooler(config, pooling=pooling)
@@ -1233,7 +1234,7 @@ class HATModelForDocumentRepresentation(HATPreTrainedModel):
1233
  return_dict=return_dict,
1234
  )
1235
  sequence_output = outputs[0]
1236
- pooled_outputs = self.pooler(sequence_output)
1237
 
1238
  drp_loss = None
1239
  if labels is not None:
@@ -1832,6 +1833,7 @@ class HATForSequenceClassification(HATPreTrainedModel):
1832
  super().__init__(config)
1833
  self.num_labels = config.num_labels
1834
  self.config = config
 
1835
  self.pooling = pooling
1836
 
1837
  self.hi_transformer = HATModel(config)
@@ -2043,6 +2045,7 @@ class HATForMultipleChoice(HATPreTrainedModel):
2043
  super().__init__(config)
2044
 
2045
  self.pooling = pooling
 
2046
  self.hi_transformer = HATModel(config)
2047
  classifier_dropout = (
2048
  config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1186
  super().__init__(config)
1187
  self.num_labels = config.num_labels
1188
  self.config = config
1189
+ self.max_sentence_length = config.max_sentence_length
1190
 
1191
  self.hi_transformer = HATModel(config)
1192
  self.pooler = HATPooler(config, pooling=pooling)
1234
  return_dict=return_dict,
1235
  )
1236
  sequence_output = outputs[0]
1237
+ pooled_outputs = self.pooler(sequence_output[:, ::self.max_sentence_length])
1238
 
1239
  drp_loss = None
1240
  if labels is not None:
1833
  super().__init__(config)
1834
  self.num_labels = config.num_labels
1835
  self.config = config
1836
+ self.max_sentence_length = config.max_sentence_length
1837
  self.pooling = pooling
1838
 
1839
  self.hi_transformer = HATModel(config)
2045
  super().__init__(config)
2046
 
2047
  self.pooling = pooling
2048
+ self.max_sentence_length = config.max_sentence_length
2049
  self.hi_transformer = HATModel(config)
2050
  classifier_dropout = (
2051
  config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob