kiddothe2b commited on
Commit
9e40d21
1 Parent(s): 6512a17

Add HAT implementation files

Browse files
Files changed (1) hide show
  1. modelling_hat.py +3 -3
modelling_hat.py CHANGED
@@ -1089,15 +1089,15 @@ class HATForMaskedLM(HATPreTrainedModel):
1089
 
1090
  def get_output_embeddings(self):
1091
  return self.lm_head.decoder
1092
-
1093
  def set_output_embeddings(self, new_embeddings):
1094
  self.lm_head.decoder = new_embeddings
1095
 
1096
  def get_input_embeddings(self):
1097
- return self.embeddings.word_embeddings
1098
 
1099
  def set_input_embeddings(self, value):
1100
- self.embeddings.word_embeddings = value
1101
 
1102
  def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
1103
  """Tie or clone module weights depending of whether we are using TorchScript or not"""
 
1089
 
1090
  def get_output_embeddings(self):
1091
  return self.lm_head.decoder
1092
+
1093
  def set_output_embeddings(self, new_embeddings):
1094
  self.lm_head.decoder = new_embeddings
1095
 
1096
  def get_input_embeddings(self):
1097
+ return self.hi_transformer.embeddings.word_embeddings
1098
 
1099
  def set_input_embeddings(self, value):
1100
+ self.hi_transformer.embeddings.word_embeddings = value
1101
 
1102
  def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
1103
  """Tie or clone module weights depending of whether we are using TorchScript or not"""