suayptalha commited on
Commit
d1adb77
·
verified ·
1 Parent(s): bc73b41

Update modeling_minGRULM.py

Browse files
Files changed (1) hide show
  1. modeling_minGRULM.py +27 -2
modeling_minGRULM.py CHANGED
@@ -7,7 +7,6 @@ from typing import Optional
7
  from .configuration_minGRULM import MinGRULMConfig
8
  from minGRU_pytorch.minGRULM import minGRULM
9
 
10
-
11
  # Wrapper class for device compatibility
12
  class MinGRULMWrapped(nn.Module):
13
  def __init__(self, min_gru_model):
@@ -62,6 +61,9 @@ class MinGRULMForCausalLM(MinGRULMPreTrainedModel):
62
  # Language modeling head
63
  self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
64
 
 
 
 
65
  self.post_init()
66
 
67
  def get_input_embeddings(self):
@@ -103,4 +105,27 @@ class MinGRULMForCausalLM(MinGRULMPreTrainedModel):
103
  return CausalLMOutputWithPast(
104
  loss=loss,
105
  logits=logits,
106
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from .configuration_minGRULM import MinGRULMConfig
8
  from minGRU_pytorch.minGRULM import minGRULM
9
 
 
10
  # Wrapper class for device compatibility
11
  class MinGRULMWrapped(nn.Module):
12
  def __init__(self, min_gru_model):
 
61
  # Language modeling head
62
  self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
63
 
64
+ # Initialize weights (if required for missing layers)
65
+ self.initialize_layers()
66
+
67
  self.post_init()
68
 
69
  def get_input_embeddings(self):
 
105
  return CausalLMOutputWithPast(
106
  loss=loss,
107
  logits=logits,
108
+ )
109
+
110
+ def initialize_layers(self):
111
+ """
112
+ Initialize missing layers in the model, such as custom layers or parts of the minGRULM.
113
+ If layers are already initialized, we can skip them.
114
+ """
115
+ # Example: Initialize layers manually if needed
116
+ for name, module in self.model.min_gru_model.named_children():
117
+ if isinstance(module, nn.Module):
118
+ if 'token_emb' in name:
119
+ # Token embeddings, if needed, you can initialize with a custom scheme
120
+ nn.init.xavier_uniform_(module.weight)
121
+ elif isinstance(module, nn.Linear):
122
+ # Initialize Linear layers if not initialized already
123
+ if module.weight is not None:
124
+ nn.init.xavier_uniform_(module.weight)
125
+ if module.bias is not None:
126
+ nn.init.zeros_(module.bias)
127
+ # Initialize other layers similarly, depending on the type
128
+ elif isinstance(module, nn.LayerNorm):
129
+ # Initialize LayerNorm layers
130
+ nn.init.constant_(module.weight, 1.0)
131
+ nn.init.constant_(module.bias, 0)