pranjalchitale
commited on
Commit
·
1ec8585
1
Parent(s):
20ef3db
Fix TieWeights
Browse files- modeling_indictrans.py +3 -3
modeling_indictrans.py
CHANGED
@@ -1655,9 +1655,9 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel, GenerationMi
|
|
1655 |
|
1656 |
self.post_init()
|
1657 |
|
1658 |
-
def
|
1659 |
-
if
|
1660 |
-
|
1661 |
|
1662 |
def get_encoder(self):
|
1663 |
return self.model.encoder
|
|
|
1655 |
|
1656 |
self.post_init()
|
1657 |
|
1658 |
+
def tie_weights(self):
|
1659 |
+
if self.config.share_decoder_input_output_embed:
|
1660 |
+
self._tie_or_clone_weights(self.model.decoder.embed_tokens, self.lm_head)
|
1661 |
|
1662 |
def get_encoder(self):
|
1663 |
return self.model.encoder
|