kiddothe2b
commited on
Commit
•
6512a17
1
Parent(s):
e3a99d1
Add HAT implementation files
Browse files- modelling_hat.py +7 -1
modelling_hat.py
CHANGED
@@ -1089,10 +1089,16 @@ class HATForMaskedLM(HATPreTrainedModel):
|
|
1089 |
|
1090 |
def get_output_embeddings(self):
|
1091 |
return self.lm_head.decoder
|
1092 |
-
|
1093 |
def set_output_embeddings(self, new_embeddings):
|
1094 |
self.lm_head.decoder = new_embeddings
|
1095 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1096 |
def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
|
1097 |
"""Tie or clone module weights depending of whether we are using TorchScript or not"""
|
1098 |
if self.config.torchscript:
|
|
|
1089 |
|
1090 |
def get_output_embeddings(self):
|
1091 |
return self.lm_head.decoder
|
1092 |
+
|
1093 |
def set_output_embeddings(self, new_embeddings):
|
1094 |
self.lm_head.decoder = new_embeddings
|
1095 |
|
1096 |
+
def get_input_embeddings(self):
|
1097 |
+
return self.embeddings.word_embeddings
|
1098 |
+
|
1099 |
+
def set_input_embeddings(self, value):
|
1100 |
+
self.embeddings.word_embeddings = value
|
1101 |
+
|
1102 |
def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
|
1103 |
"""Tie or clone module weights depending of whether we are using TorchScript or not"""
|
1104 |
if self.config.torchscript:
|