Update modeling_jais.py
Browse files- modeling_jais.py +6 -0
modeling_jais.py
CHANGED
@@ -535,6 +535,11 @@ class JAISPreTrainedModel(PreTrainedModel):
|
|
535 |
stddev = self.config.initializer_range * mup_init_scale / math.sqrt(2 * self.config.n_layer)
|
536 |
p.data.normal_(mean=0.0, std=stddev)
|
537 |
|
|
|
|
|
|
|
|
|
|
|
538 |
JAIS_START_DOCSTRING = r"""
|
539 |
|
540 |
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
@@ -861,6 +866,7 @@ class JAISModel(JAISPreTrainedModel):
|
|
861 |
hidden_states = inputs_embeds + position_embeds
|
862 |
else:
|
863 |
hidden_states = inputs_embeds
|
|
|
864 |
scale_factor_hidden = torch.tensor(float(self.embeddings_scale), dtype=hidden_states.dtype, device=hidden_states.device)
|
865 |
hidden_states = hidden_states * scale_factor_hidden
|
866 |
|
|
|
535 |
stddev = self.config.initializer_range * mup_init_scale / math.sqrt(2 * self.config.n_layer)
|
536 |
p.data.normal_(mean=0.0, std=stddev)
|
537 |
|
538 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
539 |
+
if isinstance(module, JAISModel):
|
540 |
+
module.gradient_checkpointing = value
|
541 |
+
|
542 |
+
|
543 |
JAIS_START_DOCSTRING = r"""
|
544 |
|
545 |
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
|
|
866 |
hidden_states = inputs_embeds + position_embeds
|
867 |
else:
|
868 |
hidden_states = inputs_embeds
|
869 |
+
'''hidden_states *= torch.tensor(float(self.embeddings_scale), dtype=hidden_states.dtype, device=hidden_states.device) '''
|
870 |
scale_factor_hidden = torch.tensor(float(self.embeddings_scale), dtype=hidden_states.dtype, device=hidden_states.device)
|
871 |
hidden_states = hidden_states * scale_factor_hidden
|
872 |
|