suayptalha
commited on
Update modeling_minGRULM.py
Browse files- modeling_minGRULM.py +27 -2
modeling_minGRULM.py
CHANGED
@@ -7,7 +7,6 @@ from typing import Optional
|
|
7 |
from .configuration_minGRULM import MinGRULMConfig
|
8 |
from minGRU_pytorch.minGRULM import minGRULM
|
9 |
|
10 |
-
|
11 |
# Wrapper class for device compatibility
|
12 |
class MinGRULMWrapped(nn.Module):
|
13 |
def __init__(self, min_gru_model):
|
@@ -62,6 +61,9 @@ class MinGRULMForCausalLM(MinGRULMPreTrainedModel):
|
|
62 |
# Language modeling head
|
63 |
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
64 |
|
|
|
|
|
|
|
65 |
self.post_init()
|
66 |
|
67 |
def get_input_embeddings(self):
|
@@ -103,4 +105,27 @@ class MinGRULMForCausalLM(MinGRULMPreTrainedModel):
|
|
103 |
return CausalLMOutputWithPast(
|
104 |
loss=loss,
|
105 |
logits=logits,
|
106 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
from .configuration_minGRULM import MinGRULMConfig
|
8 |
from minGRU_pytorch.minGRULM import minGRULM
|
9 |
|
|
|
10 |
# Wrapper class for device compatibility
|
11 |
class MinGRULMWrapped(nn.Module):
|
12 |
def __init__(self, min_gru_model):
|
|
|
61 |
# Language modeling head
|
62 |
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
63 |
|
64 |
+
# Initialize weights (if required for missing layers)
|
65 |
+
self.initialize_layers()
|
66 |
+
|
67 |
self.post_init()
|
68 |
|
69 |
def get_input_embeddings(self):
|
|
|
105 |
return CausalLMOutputWithPast(
|
106 |
loss=loss,
|
107 |
logits=logits,
|
108 |
+
)
|
109 |
+
|
110 |
+
def initialize_layers(self):
|
111 |
+
"""
|
112 |
+
Initialize missing layers in the model, such as custom layers or parts of the minGRULM.
|
113 |
+
If layers are already initialized, we can skip them.
|
114 |
+
"""
|
115 |
+
# Example: Initialize layers manually if needed
|
116 |
+
for name, module in self.model.min_gru_model.named_children():
|
117 |
+
if isinstance(module, nn.Module):
|
118 |
+
if 'token_emb' in name:
|
119 |
+
# Token embeddings, if needed, you can initialize with a custom scheme
|
120 |
+
nn.init.xavier_uniform_(module.weight)
|
121 |
+
elif isinstance(module, nn.Linear):
|
122 |
+
# Initialize Linear layers if not initialized already
|
123 |
+
if module.weight is not None:
|
124 |
+
nn.init.xavier_uniform_(module.weight)
|
125 |
+
if module.bias is not None:
|
126 |
+
nn.init.zeros_(module.bias)
|
127 |
+
# Initialize other layers similarly, depending on the type
|
128 |
+
elif isinstance(module, nn.LayerNorm):
|
129 |
+
# Initialize LayerNorm layers
|
130 |
+
nn.init.constant_(module.weight, 1.0)
|
131 |
+
nn.init.constant_(module.bias, 0)
|