Spaces:
Running
Running
Create mingru_lm.py
Browse files- mingru_lm.py +161 -0
mingru_lm.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.optim as optim
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.nn import Linear, Identity, Module
|
6 |
+
|
7 |
+
|
8 |
+
def default(v, d):
|
9 |
+
return v if exists(v) else d
|
10 |
+
|
11 |
+
def exists(v):
|
12 |
+
return v is not None
|
13 |
+
|
14 |
+
def heinsen_associative_scan_log(log_coeffs, log_values):
|
15 |
+
a_star = log_coeffs.cumsum(dim=1)
|
16 |
+
log_h0_plus_b_star = (log_values - a_star).logcumsumexp(dim=1)
|
17 |
+
log_h = a_star + log_h0_plus_b_star
|
18 |
+
return log_h.exp()
|
19 |
+
|
20 |
+
def log_g(x):
|
21 |
+
return torch.where(x >= 0, (F.relu(x) + 0.5).log(), -F.softplus(-x))
|
22 |
+
|
23 |
+
class MinGRU(Module):
|
24 |
+
def __init__(self, dim, expansion_factor=1.):
|
25 |
+
super().__init__()
|
26 |
+
dim_inner = int(dim * expansion_factor)
|
27 |
+
# Combined transformation for hidden state and gate
|
28 |
+
self.to_hidden = Linear(dim, dim_inner, bias=False)
|
29 |
+
self.to_gate = Linear(dim,dim_inner,bias=False)
|
30 |
+
# Output projection (Identity if no expansion)
|
31 |
+
self.to_out = Linear(dim_inner, dim, bias=False) if expansion_factor != 1. else Identity()
|
32 |
+
|
33 |
+
def forward(self, x, prev_hidden=None, return_next_prev_hidden=False):
|
34 |
+
# Split combined transformation into hidden and gate components
|
35 |
+
hidden= self.to_hidden(x)
|
36 |
+
gate = self.to_gate(x)
|
37 |
+
# Convert to log space for numerical stability
|
38 |
+
log_coeffs = -F.softplus(gate) # log(1 - σ(gate))
|
39 |
+
log_z = -F.softplus(-gate) # log(σ(gate))
|
40 |
+
log_tilde_h = log_g(hidden) # log(g(hidden))
|
41 |
+
log_values = log_z + log_tilde_h # log(z * h_tilde)
|
42 |
+
|
43 |
+
# Handle previous hidden state if it exists
|
44 |
+
if exists(prev_hidden):
|
45 |
+
log_values = torch.cat((log_g(prev_hidden), log_values), dim=1)
|
46 |
+
log_coeffs = F.pad(log_coeffs, (0, 0, 1, 0))
|
47 |
+
|
48 |
+
# Apply parallel scan in log space
|
49 |
+
out = heinsen_associative_scan_log(log_coeffs, log_values)
|
50 |
+
out = out[:, -x.shape[1]:] # Keep only the relevant sequence length
|
51 |
+
|
52 |
+
# Store last hidden state for potential return
|
53 |
+
next_prev_hidden = out[:, -1:]
|
54 |
+
|
55 |
+
# Apply output projection
|
56 |
+
out = self.to_out(out)
|
57 |
+
|
58 |
+
if not return_next_prev_hidden:
|
59 |
+
return out
|
60 |
+
return out, next_prev_hidden
|
61 |
+
|
62 |
+
if __name__ == "__main__":
|
63 |
+
x = torch.rand(2,256,512)
|
64 |
+
model = MinGRU(dim=512)
|
65 |
+
out , next_prev_hidden = model(x,return_next_prev_hidden=True)
|
66 |
+
|
67 |
+
|
68 |
+
print("out",out[0,0,:3])
|
69 |
+
print("next_prev_hidden",next_prev_hidden[0,0,:3])
|
70 |
+
print("out shape",out.shape)
|
71 |
+
print("X shape",x.shape)
|
72 |
+
assert x.shape == out.shape
|
73 |
+
|
74 |
+
|
75 |
+
class FeedForward(nn.Module):
|
76 |
+
def __init__(self, dim, mult=4):
|
77 |
+
super().__init__()
|
78 |
+
self.dim_inner = int(dim * mult)
|
79 |
+
self.net = nn.Sequential(
|
80 |
+
nn.Linear(dim, self.dim_inner),
|
81 |
+
nn.GELU(),
|
82 |
+
nn.Linear(self.dim_inner, dim)
|
83 |
+
)
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
return self.net(x)
|
87 |
+
|
88 |
+
class RMSNorm(nn.Module):
|
89 |
+
def __init__(self, dim):
|
90 |
+
super().__init__()
|
91 |
+
self.scale = dim ** 0.5
|
92 |
+
self.gamma = nn.Parameter(torch.zeros(dim))
|
93 |
+
|
94 |
+
def forward(self, x):
|
95 |
+
return F.normalize(x, dim=-1) * self.scale * (self.gamma + 1)
|
96 |
+
|
97 |
+
class MinGRU_Layers(nn.Module):
|
98 |
+
def __init__(self, dim, num_tokens):
|
99 |
+
super().__init__()
|
100 |
+
self.emb = nn.Embedding(num_tokens, dim)
|
101 |
+
self.rms_norm = RMSNorm(dim)
|
102 |
+
self.gru = MinGRU(dim)
|
103 |
+
self.ff = FeedForward(dim)
|
104 |
+
|
105 |
+
self.norm = RMSNorm(dim)
|
106 |
+
self.to_logits = nn.Linear(dim, num_tokens, bias=False)
|
107 |
+
|
108 |
+
def forward(self, inputs, labels=None, is_first_layer=True, prev_hiddens=None):
|
109 |
+
if is_first_layer:
|
110 |
+
x = self.emb(inputs)
|
111 |
+
else:
|
112 |
+
x = self.emb(inputs.argmax(dim=-1))
|
113 |
+
|
114 |
+
if exists(prev_hiddens):
|
115 |
+
x = x[:, -1:]
|
116 |
+
|
117 |
+
next_prev_hiddens = []
|
118 |
+
prev_hiddens = iter(default(prev_hiddens, []))
|
119 |
+
|
120 |
+
x = self.rms_norm(x)
|
121 |
+
prev_hidden = next(prev_hiddens, None)
|
122 |
+
|
123 |
+
min_gru_out, next_hidden = self.gru(x, prev_hidden, return_next_prev_hidden=True)
|
124 |
+
|
125 |
+
x = min_gru_out + x
|
126 |
+
next_prev_hiddens.append(next_hidden)
|
127 |
+
x = self.ff(x) + x
|
128 |
+
logits = self.to_logits(self.norm(x))
|
129 |
+
|
130 |
+
if labels is not None:
|
131 |
+
loss = F.cross_entropy(logits.transpose(1, 2), labels)
|
132 |
+
else:
|
133 |
+
loss = None
|
134 |
+
|
135 |
+
return loss, logits, next_prev_hiddens
|
136 |
+
|
137 |
+
class MinGRU_LM(nn.Module):
|
138 |
+
def __init__(self, dim, num_tokens, num_layers):
|
139 |
+
super().__init__()
|
140 |
+
self.layers = nn.ModuleList([MinGRU_Layers(dim, num_tokens) for _ in range(num_layers)])
|
141 |
+
|
142 |
+
def forward(self, inputs, labels):
|
143 |
+
total_loss = 0
|
144 |
+
hidden_states = [None] * len(self.layers)
|
145 |
+
current_input = inputs
|
146 |
+
|
147 |
+
for i, layer in enumerate(self.layers):
|
148 |
+
loss, logits, next_hiddens = layer(
|
149 |
+
inputs=current_input,
|
150 |
+
labels=labels,
|
151 |
+
is_first_layer=(i == 0),
|
152 |
+
prev_hiddens=hidden_states[i]
|
153 |
+
)
|
154 |
+
|
155 |
+
if loss is not None:
|
156 |
+
total_loss += loss
|
157 |
+
|
158 |
+
current_input = logits # Use the logits as input for the next layer
|
159 |
+
hidden_states[i] = next_hiddens
|
160 |
+
|
161 |
+
return total_loss / len(self.layers), logits
|