suayptalha commited on
Commit
9fdcc4b
·
verified ·
1 Parent(s): d1adb77

Update modeling_minGRULM.py

Browse files
Files changed (1) hide show
  1. modeling_minGRULM.py +2 -27
modeling_minGRULM.py CHANGED
@@ -7,6 +7,7 @@ from typing import Optional
7
  from .configuration_minGRULM import MinGRULMConfig
8
  from minGRU_pytorch.minGRULM import minGRULM
9
 
 
10
  # Wrapper class for device compatibility
11
  class MinGRULMWrapped(nn.Module):
12
  def __init__(self, min_gru_model):
@@ -61,9 +62,6 @@ class MinGRULMForCausalLM(MinGRULMPreTrainedModel):
61
  # Language modeling head
62
  self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
63
 
64
- # Initialize weights (if required for missing layers)
65
- self.initialize_layers()
66
-
67
  self.post_init()
68
 
69
  def get_input_embeddings(self):
@@ -105,27 +103,4 @@ class MinGRULMForCausalLM(MinGRULMPreTrainedModel):
105
  return CausalLMOutputWithPast(
106
  loss=loss,
107
  logits=logits,
108
- )
109
-
110
- def initialize_layers(self):
111
- """
112
- Initialize missing layers in the model, such as custom layers or parts of the minGRULM.
113
- If layers are already initialized, we can skip them.
114
- """
115
- # Example: Initialize layers manually if needed
116
- for name, module in self.model.min_gru_model.named_children():
117
- if isinstance(module, nn.Module):
118
- if 'token_emb' in name:
119
- # Token embeddings, if needed, you can initialize with a custom scheme
120
- nn.init.xavier_uniform_(module.weight)
121
- elif isinstance(module, nn.Linear):
122
- # Initialize Linear layers if not initialized already
123
- if module.weight is not None:
124
- nn.init.xavier_uniform_(module.weight)
125
- if module.bias is not None:
126
- nn.init.zeros_(module.bias)
127
- # Initialize other layers similarly, depending on the type
128
- elif isinstance(module, nn.LayerNorm):
129
- # Initialize LayerNorm layers
130
- nn.init.constant_(module.weight, 1.0)
131
- nn.init.constant_(module.bias, 0)
 
7
  from .configuration_minGRULM import MinGRULMConfig
8
  from minGRU_pytorch.minGRULM import minGRULM
9
 
10
+
11
  # Wrapper class for device compatibility
12
  class MinGRULMWrapped(nn.Module):
13
  def __init__(self, min_gru_model):
 
62
  # Language modeling head
63
  self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
64
 
 
 
 
65
  self.post_init()
66
 
67
  def get_input_embeddings(self):
 
103
  return CausalLMOutputWithPast(
104
  loss=loss,
105
  logits=logits,
106
+ )