Commit
·
9e40d21
1
Parent(s):
6512a17
Add HAT implementation files
Browse files- modelling_hat.py +3 -3
modelling_hat.py
CHANGED
@@ -1089,15 +1089,15 @@ class HATForMaskedLM(HATPreTrainedModel):
|
|
1089 |
|
1090 |
def get_output_embeddings(self):
|
1091 |
return self.lm_head.decoder
|
1092 |
-
|
1093 |
def set_output_embeddings(self, new_embeddings):
|
1094 |
self.lm_head.decoder = new_embeddings
|
1095 |
|
1096 |
def get_input_embeddings(self):
|
1097 |
-
return self.embeddings.word_embeddings
|
1098 |
|
1099 |
def set_input_embeddings(self, value):
|
1100 |
-
self.embeddings.word_embeddings = value
|
1101 |
|
1102 |
def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
|
1103 |
"""Tie or clone module weights depending of whether we are using TorchScript or not"""
|
|
|
1089 |
|
1090 |
def get_output_embeddings(self):
|
1091 |
return self.lm_head.decoder
|
1092 |
+
|
1093 |
def set_output_embeddings(self, new_embeddings):
|
1094 |
self.lm_head.decoder = new_embeddings
|
1095 |
|
1096 |
def get_input_embeddings(self):
|
1097 |
+
return self.hi_transformer.embeddings.word_embeddings
|
1098 |
|
1099 |
def set_input_embeddings(self, value):
|
1100 |
+
self.hi_transformer.embeddings.word_embeddings = value
|
1101 |
|
1102 |
def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
|
1103 |
"""Tie or clone module weights depending of whether we are using TorchScript or not"""
|