picocreator
commited on
Commit
•
df2c883
1
Parent(s):
27839c7
fixing dim size handling for 7B / 14B
Browse files- modeling_rwkv6.py +5 -1
modeling_rwkv6.py
CHANGED
@@ -123,12 +123,16 @@ class Rwkv6SelfAttention(nn.Module):
|
|
123 |
self.time_maa_g = nn.Parameter(torch.empty(1, 1, hidden_size))
|
124 |
|
125 |
TIME_MIX_EXTRA_DIM = 32 # generate TIME_MIX for w,k,v,r,g
|
|
|
|
|
126 |
self.time_maa_w1 = nn.Parameter(torch.empty(hidden_size, TIME_MIX_EXTRA_DIM*5))
|
127 |
self.time_maa_w2 = nn.Parameter(torch.empty(5, TIME_MIX_EXTRA_DIM, hidden_size))
|
128 |
|
129 |
self.time_decay = nn.Parameter(torch.empty(1, 1, attention_hidden_size))
|
130 |
|
131 |
TIME_DECAY_EXTRA_DIM = 64
|
|
|
|
|
132 |
self.time_decay_w1 = nn.Parameter(torch.empty(hidden_size, TIME_DECAY_EXTRA_DIM))
|
133 |
self.time_decay_w2 = nn.Parameter(torch.empty(TIME_DECAY_EXTRA_DIM, attention_hidden_size))
|
134 |
|
@@ -743,4 +747,4 @@ class Rwkv6ForCausalLM(Rwkv6PreTrainedModel):
|
|
743 |
state=outputs.state,
|
744 |
hidden_states=outputs.hidden_states,
|
745 |
attentions=outputs.attentions,
|
746 |
-
)
|
|
|
123 |
self.time_maa_g = nn.Parameter(torch.empty(1, 1, hidden_size))
|
124 |
|
125 |
TIME_MIX_EXTRA_DIM = 32 # generate TIME_MIX for w,k,v,r,g
|
126 |
+
if hidden_size == 4096: #7b
|
127 |
+
TIME_MIX_EXTRA_DIM = 64
|
128 |
self.time_maa_w1 = nn.Parameter(torch.empty(hidden_size, TIME_MIX_EXTRA_DIM*5))
|
129 |
self.time_maa_w2 = nn.Parameter(torch.empty(5, TIME_MIX_EXTRA_DIM, hidden_size))
|
130 |
|
131 |
self.time_decay = nn.Parameter(torch.empty(1, 1, attention_hidden_size))
|
132 |
|
133 |
TIME_DECAY_EXTRA_DIM = 64
|
134 |
+
if hidden_size == 4096: #7b
|
135 |
+
TIME_DECAY_EXTRA_DIM = 128
|
136 |
self.time_decay_w1 = nn.Parameter(torch.empty(hidden_size, TIME_DECAY_EXTRA_DIM))
|
137 |
self.time_decay_w2 = nn.Parameter(torch.empty(TIME_DECAY_EXTRA_DIM, attention_hidden_size))
|
138 |
|
|
|
747 |
state=outputs.state,
|
748 |
hidden_states=outputs.hidden_states,
|
749 |
attentions=outputs.attentions,
|
750 |
+
)
|