oweller2
commited on
Commit
•
cf03b9b
1
Parent(s):
4753b37
init
Browse files- modeling_flexbert.py +13 -4
modeling_flexbert.py
CHANGED
@@ -1537,17 +1537,26 @@ class FlexBertForCasualLM(FlexBertPreTrainedModel):
|
|
1537 |
self._init_weights(reset_params=False)
|
1538 |
|
1539 |
def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
|
|
|
1540 |
assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
|
1541 |
-
|
1542 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1543 |
else:
|
1544 |
assert isinstance(reset_params, bool)
|
1545 |
self.bert._init_weights(reset_params=reset_params)
|
1546 |
self.lm_head._init_weights(reset_params=reset_params)
|
1547 |
-
|
1548 |
if not self.config.tie_word_embeddings:
|
1549 |
init_weights(self.config, self.decoder, self.config.hidden_size, type_of_module=ModuleType.final_out)
|
1550 |
-
|
1551 |
@classmethod
|
1552 |
def from_composer(
|
1553 |
cls,
|
|
|
1537 |
self._init_weights(reset_params=False)
|
1538 |
|
1539 |
def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
|
1540 |
+
# Handle the XOR condition
|
1541 |
assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
|
1542 |
+
|
1543 |
+
if module is not None:
|
1544 |
+
# Add basic initialization for common module types
|
1545 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
1546 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
1547 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
1548 |
+
module.bias.data.zero_()
|
1549 |
+
elif isinstance(module, nn.LayerNorm):
|
1550 |
+
module.bias.data.zero_()
|
1551 |
+
module.weight.data.fill_(1.0)
|
1552 |
else:
|
1553 |
assert isinstance(reset_params, bool)
|
1554 |
self.bert._init_weights(reset_params=reset_params)
|
1555 |
self.lm_head._init_weights(reset_params=reset_params)
|
1556 |
+
|
1557 |
if not self.config.tie_word_embeddings:
|
1558 |
init_weights(self.config, self.decoder, self.config.hidden_size, type_of_module=ModuleType.final_out)
|
1559 |
+
|
1560 |
@classmethod
|
1561 |
def from_composer(
|
1562 |
cls,
|