kiddothe2b commited on
Commit
e3a99d1
1 Parent(s): 895ac06

Add HAT implementation files

Browse files
Files changed (1) hide show
  1. modelling_hat.py +20 -0
modelling_hat.py CHANGED
@@ -1093,6 +1093,26 @@ class HATForMaskedLM(HATPreTrainedModel):
1093
  def set_output_embeddings(self, new_embeddings):
1094
  self.lm_head.decoder = new_embeddings
1095
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1096
  @add_start_docstrings_to_model_forward(HAT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1097
  @add_code_sample_docstrings(
1098
  processor_class=_TOKENIZER_FOR_DOC,
 
1093
  def set_output_embeddings(self, new_embeddings):
1094
  self.lm_head.decoder = new_embeddings
1095
 
1096
+ def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
1097
+ """Tie or clone module weights depending of whether we are using TorchScript or not"""
1098
+ if self.config.torchscript:
1099
+ output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())
1100
+ else:
1101
+ output_embeddings.weight = input_embeddings.weight
1102
+
1103
+ if getattr(output_embeddings, "bias", None) is not None:
1104
+ output_embeddings.bias.data = nn.functional.pad(
1105
+ output_embeddings.bias.data,
1106
+ (
1107
+ 0,
1108
+ output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0],
1109
+ ),
1110
+ "constant",
1111
+ 0,
1112
+ )
1113
+ if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
1114
+ output_embeddings.out_features = input_embeddings.num_embeddings
1115
+
1116
  @add_start_docstrings_to_model_forward(HAT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1117
  @add_code_sample_docstrings(
1118
  processor_class=_TOKENIZER_FOR_DOC,