kiddothe2b commited on
Commit
6512a17
1 Parent(s): e3a99d1

Add HAT implementation files

Browse files
Files changed (1) hide show
  1. modelling_hat.py +7 -1
modelling_hat.py CHANGED
@@ -1089,10 +1089,16 @@ class HATForMaskedLM(HATPreTrainedModel):
1089
 
1090
  def get_output_embeddings(self):
1091
  return self.lm_head.decoder
1092
-
1093
  def set_output_embeddings(self, new_embeddings):
1094
  self.lm_head.decoder = new_embeddings
1095
 
 
 
 
 
 
 
1096
  def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
1097
  """Tie or clone module weights depending of whether we are using TorchScript or not"""
1098
  if self.config.torchscript:
 
1089
 
1090
  def get_output_embeddings(self):
1091
  return self.lm_head.decoder
1092
+
1093
  def set_output_embeddings(self, new_embeddings):
1094
  self.lm_head.decoder = new_embeddings
1095
 
1096
+ def get_input_embeddings(self):
1097
+ return self.embeddings.word_embeddings
1098
+
1099
+ def set_input_embeddings(self, value):
1100
+ self.embeddings.word_embeddings = value
1101
+
1102
  def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
1103
  """Tie or clone module weights depending of whether we are using TorchScript or not"""
1104
  if self.config.torchscript: