Update modeling_rwkv5.py
Browse files- modeling_rwkv5.py +0 -20
modeling_rwkv5.py
CHANGED
@@ -92,7 +92,6 @@ def rwkv_linear_attention_v5_2(B, H, S, T, n_head, hidden, time_decay, time_firs
|
|
92 |
time_first = time_first.float().reshape(-1,1,1).reshape(n_head, -1, 1)
|
93 |
lxw = lxw.float()
|
94 |
lxb = lxb.float()
|
95 |
-
# if seq_mode:
|
96 |
out = torch.empty((B, T, H, S), dtype=receptance.dtype, device=receptance.device)
|
97 |
for t in range(T):
|
98 |
rt = receptance[:,:,t:t+1,:]
|
@@ -106,25 +105,6 @@ def rwkv_linear_attention_v5_2(B, H, S, T, n_head, hidden, time_decay, time_firs
|
|
106 |
out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, T, H*S)
|
107 |
out = out.to(dtype=hidden.dtype) * gate
|
108 |
out = out @ ow
|
109 |
-
# else:
|
110 |
-
# a = key @ value
|
111 |
-
# # print('key.shape: ', key.shape)
|
112 |
-
# # print('value.shape: ', value.shape)
|
113 |
-
# # print('receptance.shape: ', receptance.shape)
|
114 |
-
# # print('a.shape: ', a.shape)
|
115 |
-
# # print('time_first.shape: ', time_first.shape)
|
116 |
-
# # print('(time_first * a).shape: ', (time_first * a).shape)
|
117 |
-
# # print('time_decay.shape: ', time_decay.shape)
|
118 |
-
# # print('state.shape: ', state.shape)
|
119 |
-
# out = receptance @ (time_first * a + state)
|
120 |
-
# # print('out.shape: ', out.shape)
|
121 |
-
# state = a + time_decay * state
|
122 |
-
# # print('state.shape: ', state.shape)
|
123 |
-
# out = out.reshape(B, H*S)
|
124 |
-
# out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, 1, H*S)
|
125 |
-
# out = out.to(dtype=hidden.dtype) * gate
|
126 |
-
# out = out @ ow
|
127 |
-
|
128 |
|
129 |
return out, state
|
130 |
|
|
|
92 |
time_first = time_first.float().reshape(-1,1,1).reshape(n_head, -1, 1)
|
93 |
lxw = lxw.float()
|
94 |
lxb = lxb.float()
|
|
|
95 |
out = torch.empty((B, T, H, S), dtype=receptance.dtype, device=receptance.device)
|
96 |
for t in range(T):
|
97 |
rt = receptance[:,:,t:t+1,:]
|
|
|
105 |
out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, T, H*S)
|
106 |
out = out.to(dtype=hidden.dtype) * gate
|
107 |
out = out @ ow
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
return out, state
|
110 |
|